In [7]:
import jax
from jax import jit
from typing import Any, Callable, Sequence
from jax import lax, random, numpy as jnp
from flax.core import freeze, unfreeze
from flax import linen as nn

# 1. Linear Regression with Flax

In [8]:
#create a dense layer instance
model = nn.Dense(features=5)

### Model parameters & initialization

In [11]:
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10, )) # this x is used to trigger shape inference
params = model.init(key2, x)
jax.tree_map(lambda x: x.shape, params)

FrozenDict({
    params: {
        bias: (5,),
        kernel: (10, 5),
    },
})

Parameters are stored in a `FrozenDict` instance which deal with the function nature of JAX by preventing any mutation of the underlying dict and making the user aware of it.

In [12]:
print(params)

FrozenDict({
    params: {
        kernel: Array([[ 2.35571593e-01, -1.71652630e-01, -4.45728898e-02,
                -4.68043625e-01,  4.54595298e-01],
               [-6.87736511e-01,  3.67835432e-01, -1.79262117e-01,
                 1.29276216e-01, -2.42580175e-01],
               [ 2.02303097e-01, -2.49465629e-01,  2.74955630e-01,
                 4.73488301e-01, -1.98002532e-01],
               [ 2.74478376e-01, -1.21369645e-01, -2.25361690e-01,
                -4.78193611e-01, -9.63979959e-02],
               [-6.19886220e-02, -1.72743499e-01,  2.96947401e-04,
                -7.17593431e-01,  2.00894251e-01],
               [-5.60321212e-01,  3.27208459e-01,  1.06281511e-01,
                 1.28758654e-01,  1.16973273e-01],
               [ 1.82219014e-01,  1.11444041e-01, -1.62924170e-01,
                 3.24953273e-02, -1.67053357e-01],
               [ 4.31294084e-01,  2.08004534e-01,  1.47714198e-01,
                -8.51502791e-02, -1.26487076e-01],
               [ 3.29

In [14]:
print(model)

Dense(
    # attributes
    features = 5
    use_bias = True
    dtype = None
    param_dtype = float32
    precision = None
    kernel_init = init
    bias_init = zeros
    dot_general = dot_general
)


### Gradient Descent

In [13]:
# Set problem dimensions
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))

# Store the parameters in a pytree
true_params = freeze({'params':{'bias': b, 'kernel': W}})

# Generate samples with additional noise
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise, (n_samples, y_dim))
print('x shape: ', x_samples.shape, '; y shape', y_samples.shape)

x shape:  (20, 10) ; y shape (20, 5)
