<a href="https://colab.research.google.com/github/zhenyiqi/rawLLM/blob/main/jax_flax_basics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax.numpy as jnp

## Matrix operations

### Element-wise operations (broad-castable)

In [None]:
x = jnp.array([[0, -1, 1], [-2, 0, 1]])

In [None]:
jnp.maximum(0, x)

Array([[0, 0, 1],
       [0, 0, 1]], dtype=int32)

### Matrix-specific operations

binary (or multi-nary) operators

In [2]:
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])

In [3]:
jnp.stack((a, b))

Array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)

The `axis` parameter means the axis index of the **new axis**. The dimension of hte new axis is the same as the number of matrixes in the tuple (first parameter).

In [4]:
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
jnp.stack((a, b), axis=0)

Array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)

In [5]:
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
jnp.stack((a, b), axis=1)

Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)

# Linen

In [20]:
import flax.linen as nn
import jax.random as random
import jax

In [7]:
layer = nn.Dense(features=4)

In [28]:
params = layer.init(random.key(0), jnp.ones((200, 5)))

In [29]:
params

{'params': {'kernel': Array([[ 0.16268499,  0.7846524 , -0.08340393,  0.62642825],
         [-0.1928301 , -0.05645721,  0.59652334, -0.30912825],
         [ 0.47221094, -0.7885982 , -0.04464982,  0.6568747 ],
         [-0.07706029,  0.16432254, -0.2558527 ,  0.00360778],
         [-0.16997725,  0.22439648,  0.70202154,  0.5716631 ]],      dtype=float32),
  'bias': Array([0., 0., 0., 0.], dtype=float32)}}

In [21]:
jax.tree_map(jnp.shape, params)

{'params': {'bias': (4,), 'kernel': (3, 4)}}

In [11]:
layer = nn.Dense(features=4, kernel_init=nn.initializers.xavier_uniform(),  # Weights with Xavier uniform init
                               bias_init=nn.initializers.zeros)