In [1]:
!pip install -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install -q flax optax transformers datasets


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m129.7/129.7 MB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.2/69.2 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m47.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.8/194.8 kB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━

1️⃣ Import Required Libraries

In [5]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax  # Optimizers


2️⃣ Define a Simple Neural Network using Flax

In [6]:
class MLP(nn.Module):
    hidden_dim: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)  # Output layer
        return x




3️⃣ Initialize Model Parameters

In [7]:
model = MLP(hidden_dim=32)
key = jax.random.PRNGKey(0)
x_dummy = jnp.ones((1, 10))  # Example input
params = model.init(key, x_dummy)  # Initialize model
print(params)



{'params': {'Dense_0': {'kernel': Array([[-0.5223456 ,  0.23860869, -0.43568492, -0.26943007,  0.29074085,
         0.40591595,  0.3301583 , -0.01470341,  0.1336994 ,  0.24582365,
        -0.57535076,  0.2701447 , -0.2233426 ,  0.16106607,  0.38631102,
        -0.08694729, -0.4690676 , -0.41746357, -0.22493699, -0.01054574,
        -0.10714039, -0.19746391, -0.6667231 , -0.17042542, -0.36339986,
        -0.19222786,  0.17231458, -0.23188609,  0.40331432,  0.38735247,
         0.5382161 , -0.46305057],
       [ 0.19625618,  0.5367912 ,  0.27752233, -0.07770313, -0.00775636,
        -0.2685969 ,  0.35266262,  0.15011781,  0.18526915, -0.08636144,
        -0.03711447,  0.21603984,  0.15343831,  0.00115734,  0.12622738,
        -0.33444613,  0.12262806, -0.44731817, -0.2748209 ,  0.2535692 ,
         0.33146128,  0.30883512,  0.6163292 , -0.01071304, -0.13781956,
         0.33463055,  0.15033457, -0.12188746, -0.39209256,  0.2970293 ,
        -0.16921943, -0.20661804],
       [ 0.48238912,

What is flax.linen.Module?
In Flax, flax.linen.Module is the base class for defining neural networks.
It follows an object-oriented approach, where each model is a Python class that defines:
Layers (e.g., Dense, Conv)
Activation functions (e.g., ReLU, Sigmoid)
The forward pass (__call__ method)


🔹 Explanation of Components in flax.linen.Module
Component- Explanation
hidden_dim: int-	Defines the number of hidden neurons (passed as an argument when creating the model).
@nn.compact	-A decorator that allows layers to be defined inside __call__ instead of setup().
__call__(self, x)-	The forward function where data flows through layers.
nn.Dense(self.hidden_dim)(x)-	A fully connected layer with hidden_dim neurons.
nn.relu(x)	Activation function (ReLU introduces non-linearity).
nn.Dense(1)(x)-	Final output layer with 1 neuron (useful for regression tasks).
model.init(key, x_dummy)-	Initializes weights using a random key.



🔹 Alternative Way: Using setup() Instead of @nn.compact
Flax models can also be defined using setup(), where layers are declared as instance attributes.

In [8]:
class MLP(nn.Module):
    hidden_dim: int

    def setup(self):
        self.dense1 = nn.Dense(self.hidden_dim)
        self.dense2 = nn.Dense(1)

    def __call__(self, x):
        x = self.dense1(x)
        x = nn.relu(x)
        x = self.dense2(x)
        return x


✅ Difference?

setup() stores layers as attributes (self.dense1), making them reusable in different parts of the model.
@nn.compact is simpler and more concise but does not allow layer reuse.

