In [2]:
import jax
import jax.numpy as jnp
import numpy as np
from abc import ABC, abstractmethod

In [3]:
class ParamInitialiser(ABC):
    def __init__(self):
        self.prng_key = jax.random.PRNGKey(0)
        self.is_leaf = lambda x : (
            isinstance(x, tuple) 
            and not isinstance(x[0], (tuple, dict, list, set))
        )

    def __call__(self, pytree):
        # unflatten pytree
        leaves, struct = jax.tree.flatten(pytree, is_leaf = self.is_leaf)

        # create a new key for each leaf
        keys = jax.random.split(self.prng_key, len(leaves))

        inits = [
            self.param_generator(shape, key)
            for (shape, key) in zip (leaves, keys)
        ]

        # map over leaves generating normals of desired shape & unflatten
        inits = jax.tree_util.tree_unflatten(struct, inits)

        # update key taking first of new keys
        self.prng_key = keys[0]
        
        return inits

    @abstractmethod
    def param_generator(shape, key):
        pass

## Zero initialiser

In [4]:
class Zeros(ParamInitialiser):
    def __init__(self):
        super().__init__()
    
    def param_generator(self, shape, key):
        return jnp.zeros(shape, dtype="float32")

In [5]:
init = Zeros()



Metal device set to: Apple M3

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB



I0000 00:00:1739969737.500557 8723108 service.cc:145] XLA service 0x11dc39620 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1739969737.500568 8723108 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1739969737.501592 8723108 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1739969737.501600 8723108 mps_client.cc:384] XLA backend will use up to 11452776448 bytes on device 0 for SimpleAllocator.


In [6]:
init({"w_yx": (256, 10), "b_y": (10,)})

{'b_y': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'w_yx': Array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)}

## Gaussian initialiser

In [7]:
class Gaussian(ParamInitialiser):
    def __init__(self, mean=0, std=1):
        super().__init__()
        self.mean = mean
        self.std = std
    
    def param_generator(self, shape, key):
        return (
            self.mean 
            + jax.random.normal(key=key, shape=shape) * self.std
        )

## He Initialiser

In [8]:
class He(ParamInitialiser):
    def __init__(self):
        super().__init__()
    
    def param_generator(self, shape, key):
        fan_in = shape[0] if len(shape) > 1 else 1  # Handling scalar case
        std = jnp.sqrt(2.0 / fan_in)
        return jax.random.normal(key, shape, dtype=jnp.float32) * std

## Xavier Initialisation

In [9]:
class Xavier(ParamInitialiser):
    def __init__(self):
        super().__init__()
    
    def param_generator(self, shape, key):
        fan_in = shape[0] if len(shape) > 1 else 1  # Handling scalar case
        fan_out = shape[1] if len(shape) > 1 else 1
        std = jnp.sqrt(2.0 / (fan_in + fan_out))
        return jax.random.normal(key, shape, dtype=jnp.float32) * std

## Orthonormal initialiser

In [None]:
class Orthonormal(ParamInitialiser):
    def __init__(self):
        super().__init__()

    def param_generator(self, shape, key):
        if len(shape) != 2:
            raise NotImplementedError(
                f"Orthonormal initialisation not implemented for arrays"
                f"that are not two dimensional; passed shape implies"
                f"{len(shape)} dimension(s).")

        n, m = shape
        # NOTE: jax bug with qr on metal means we have to use numpy

        # if array was originally on metal first
        # QR decomp: A: n x m -> Q: n x min(n, m); R: min(n, m) x m
        # where Q is orthonormal. SO, two cases with n > m & n < m
        if n > m:
            # QR: A: n x m -> Q: n x m; R: m x m
            norm = jax.random.normal(key=key, shape=(n, m))
            norm_cpu = jax.device_put(norm, jax.devices("cpu")[0])
            q, r = jnp.linalg.qr(norm_cpu) # n x m -> q = n x m; r = m x m
            return jax.device_put(q, jax.devices("METAL")[0])
        
        else:
            # QR: A: n x m -> Q: n x n; R: n x m SO do QR on an m x n
            # matrix instead (Q: m x n now) and just transpose
            norm = jax.random.normal(key=key, shape=(m, n))
            norm_cpu = jax.device_put(norm, jax.devices("cpu")[0])
            q, r = jnp.linalg.qr(norm_cpu) # m x n -> q = m x n;  r = n x n
            return jax.device_put(q, jax.devices("METAL")[0]).T


In [14]:
init = Orthonormal()

In [15]:
p0 = init({'w_yx': (64, 50)})
p0["w_yx"].shape

(64, 50)

In [288]:
jnp.linalg.svd(p0["w_yx"]).S # spectral radius is one now

Array([1.0000002, 1.0000002, 1.0000002, 1.0000001, 1.0000001, 1.0000001,
       1.0000001, 1.0000001, 1.0000001, 1.0000001, 1.0000001, 1.0000001,
       1.0000001, 1.0000001, 1.0000001, 1.0000001, 1.       , 1.       ,
       1.       , 1.       , 1.       , 1.       , 1.       , 1.       ,
       1.       , 1.       , 1.       , 1.       , 1.       , 1.       ,
       1.       , 1.       , 0.9999999, 0.9999999, 0.9999999, 0.9999999,
       0.9999999, 0.9999999, 0.9999999, 0.9999999, 0.9999999, 0.9999999,
       0.9999999, 0.9999999, 0.9999999, 0.9999999, 0.9999999, 0.9999999,
       0.9999999, 0.9999999], dtype=float32)