In [8]:
## Standard libraries
import os
import math
import numpy as np
import time

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgba
import seaborn as sns
sns.set()

## Progress bar
from tqdm.auto import tqdm

import jax
import jax.numpy as jnp
print("Using jax", jax.__version__)

import flax

from flax import linen as nn

Using jax 0.4.27


  set_matplotlib_formats('svg', 'pdf') # For export


This is a very simple perceptron with 1 hidden layer

In [4]:
class SimpleClassifier(nn.Module):
    num_hidden : int   # Number of hidden neurons
    num_outputs : int  # Number of output neurons

    def setup(self):
        # Create the modules we need to build the network
        # nn.Dense is a linear layer
        self.linear1 = nn.Dense(features=self.num_hidden)
        self.linear2 = nn.Dense(features=self.num_outputs)

    def __call__(self, x):
        # Perform the calculation of the model to determine the prediction
        x = self.linear1(x)
        x = nn.tanh(x)
        x = self.linear2(x)
        return x

You can avoid double calling the setup with the compact annotation

In [5]:
class SimpleClassifierCompact(nn.Module):
    num_hidden : int   # Number of hidden neurons
    num_outputs : int  # Number of output neurons

    @nn.compact  # Tells Flax to look for defined submodules
    def __call__(self, x):
        # Perform the calculation of the model to determine the prediction
        # while defining necessary layers
        x = nn.Dense(features=self.num_hidden)(x)
        x = nn.tanh(x)
        x = nn.Dense(features=self.num_outputs)(x)
        return x

In [6]:
model = SimpleClassifier(num_hidden=8, num_outputs=1)
# Printing the model shows its attributes
print(model)

SimpleClassifier(
    # attributes
    num_hidden = 8
    num_outputs = 1
)


In [10]:
rng = jax.random.PRNGKey(42)
rng, inp_rng, init_rng = jax.random.split(rng, 3)
inp = jax.random.normal(inp_rng, (8, 2))  # Batch size 8, input size 2
# Initialize the model
params = model.init(init_rng, inp)
print(params)

{'params': {'linear1': {'kernel': Array([[ 0.5564613 ,  0.9367376 ,  0.2285179 , -0.23255277, -0.25101846,
        -0.48948383,  0.11607227,  0.40487856],
       [-0.3619682 ,  0.9271343 ,  0.6478837 ,  0.26224074,  0.34578732,
         1.1132734 ,  0.06098709,  0.49297702]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}, 'linear2': {'kernel': Array([[ 0.4818003 ],
       [-0.35573798],
       [-0.62196773],
       [ 0.28606406],
       [-0.79486924],
       [ 0.5573447 ],
       [-0.1400483 ],
       [ 0.41512278]], dtype=float32), 'bias': Array([0.], dtype=float32)}}}


In [11]:
model.apply(params, inp)

Array([[ 0.13819844],
       [ 0.6173139 ],
       [-0.19211891],
       [ 0.00855249],
       [ 0.12030913],
       [-0.34759673],
       [ 0.07192342],
       [ 0.11894515]], dtype=float32)