# JAX MLP with Flax NNX on Colab TPU

This notebook demonstrates how to define, train, and evaluate a Multi-Layer Perceptron (MLP) using JAX, Flax NNX, and Optax. It's designed to run on a Colab TPU environment.

**Key Steps:**
1. Install necessary libraries.
2. Import dependencies.
3. Define the MLP model using `flax.nnx`.
4. Generate synthetic classification data.
5. Define training components: loss function, optimizer, and the training step function.
6. Initialize the model and optimizer.
7. Run the training loop.
8. Evaluate the trained model.

## 1. Install Dependencies

First, we need to install the required Python packages: `jax` (and `jaxlib` compatible with the Colab TPU), `flax`, `optax`, and `scikit-learn`.

In [None]:
!pip install -q jax jaxlib flax optax scikit-learn

## 2. Import Libraries

Now, let's import all the necessary modules.

In [None]:
import jax
import jax.numpy as jnp
import flax.nnx as nnx  # Using the stable nnx module
import optax
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import functools

print(f"JAX version: {jax.__version__}")
print(f"Flax version: {nnx.__version__}")
print(f"Optax version: {optax.__version__}")
try:
    print(f"Detected TPUs: {jax.devices('tpu')}")
except RuntimeError:
    print("No TPU detected. Ensure TPU runtime is selected in Colab: Runtime > Change runtime type > Hardware accelerator > TPU")

## 3. MLP Definition using Flax NNX

We define our Multi-Layer Perceptron (MLP) model using `flax.nnx.Module`. NNX provides a more stateful and Pythonic way to define models compared to the traditional Flax functional core. 

- The `__init__` method constructs the layers (a series of hidden linear layers and an output linear layer). It takes the input size, a list of hidden layer sizes, and the output size as arguments. `nnx.Rngs` is used for initializing layer parameters.
- The `__call__` method defines the forward pass of the model. Data flows through hidden layers with ReLU activations, and finally through the output layer.

In [None]:
class MLP(nnx.Module):
    def __init__(self, input_size: int, hidden_sizes: list[int], output_size: int, *, rngs: nnx.Rngs):
        current_in_features = input_size
        self.hidden_layers = []
        for h_size in hidden_sizes:
            layer = nnx.Linear(in_features=current_in_features, out_features=h_size, rngs=rngs)
            self.hidden_layers.append(layer)
            current_in_features = h_size
        
        self.output_layer = nnx.Linear(in_features=current_in_features, out_features=output_size, rngs=rngs)

    def __call__(self, x: jax.Array):
        for layer in self.hidden_layers:
            x = layer(x)
            x = jax.nn.relu(x)
        x = self.output_layer(x)
        return x

## 4. Synthetic Data Generation

This function generates a simple synthetic dataset for a classification task using `scikit-learn`'s `make_classification`. The labels are one-hot encoded for use with cross-entropy loss. The data is converted to JAX arrays.

In [None]:
def generate_synthetic_data(n_samples=200, n_features=2, n_classes=2, random_state=42):
    """Generates simple synthetic data for classification."""
    X, y = make_classification(
        n_samples=n_samples,
        n_features=n_features,
        n_informative=n_features,
        n_redundant=0,
        n_repeated=0,
        n_classes=n_classes,
        n_clusters_per_class=1,
        flip_y=0.01,
        class_sep=1.0,
        random_state=random_state
    )
    y_one_hot = jax.nn.one_hot(y, num_classes=n_classes)
    return jnp.array(X), jnp.array(y_one_hot), jnp.array(y)

## 5. Training Loop Components

These functions define the core components needed for the training loop:

### Loss Function (Cross-Entropy)
Computes the cross-entropy loss, suitable for classification tasks. It uses `jax.nn.log_softmax` for numerical stability.

### Optimizer (Adam)
Returns an Adam optimizer instance from Optax with a specified learning rate.

### Training Step Function (`train_step`)
This function performs a single training step. It's JIT-compiled with `jax.jit` for performance.
- **NNX State Handling**: Crucially, for Flax NNX, the model is split into trainable `params` and `static` parts. Inside the `loss_for_grad` function (which is differentiated), the model is temporarily reconstructed using `nnx.merge(current_params, static)` to perform the forward pass. This is because JAX transformations like `jax.grad` operate on functions of JAX arrays (the `params`).
- **Gradient Calculation**: `jax.value_and_grad` computes both the loss value and the gradients of the loss with respect to the trainable parameters.
- **Optimizer Update**: The Optax optimizer calculates parameter updates based on the gradients and its internal state.
- **Parameter Application**: `optax.apply_updates` applies these updates to the model parameters.

### Prediction Function (`predict`)
Makes predictions using the trained model. Similar to `train_step`, it merges `params` and `static` parts to reconstruct the model before the forward pass. It's also JIT-compiled.

In [None]:
# Loss Function (Cross-Entropy for classification)
def cross_entropy_loss(logits, labels_one_hot):
    """Computes cross-entropy loss."""
    return -jnp.sum(labels_one_hot * jax.nn.log_softmax(logits), axis=-1).mean()

# Optimizer
def get_optimizer(learning_rate=1e-3):
    return optax.adam(learning_rate)

# Training Step Function
@functools.partial(jax.jit, static_argnames=('loss_fn', 'optimizer_update_fn'))
def train_step(params, static, opt_state, loss_fn, optimizer_update_fn, X_batch, y_batch):
    """Performs a single training step with Flax NNX using split state."""
    def loss_for_grad(current_params):
        # Reconstruct the model for the forward pass
        model_for_forward_pass = nnx.merge(current_params, static)
        logits = model_for_forward_pass(X_batch) # Call the model directly
        return loss_fn(logits, y_batch)

    loss_val, grads = jax.value_and_grad(loss_for_grad)(params) # Grads w.r.t. params
    updates, new_opt_state = optimizer_update_fn(grads, opt_state, params) # Pass params for optimizer
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss_val

# Prediction function
@jax.jit
def predict(params, static, X):
    """Makes predictions using the model with split state."""
    model_for_prediction = nnx.merge(params, static)
    return model_for_prediction(X)

## 6. Main Execution: Configuration, Initialization, Training, and Evaluation

This section brings everything together:
1.  **Configuration**: Define hyperparameters like dataset size, features, classes, network architecture, learning rate, epochs, and batch size.
2.  **PRNG Keys**: JAX requires explicit management of pseudo-random number generator (PRNG) keys. We create and split keys for data generation, model initialization, and shuffling within the training loop.
3.  **Data Preparation**: Generate the synthetic dataset and split it into training and testing sets.
4.  **Model and Optimizer Initialization**:
    - Create an instance of the `MLP` model. `nnx.Rngs(params=key_init)` provides the necessary PRNG key for parameter initialization within the NNX model.
    - Perform a "dry run" by passing a dummy input through the model. This ensures all layers are built and parameters are initialized before splitting the model state.
    - Split the NNX model into trainable `params` (e.g., weights, biases) and `static` parts (e.g., layer structure, non-trainable attributes) using `nnx.split(model)`.
    - Initialize the Optax optimizer with the trainable `params`.
5.  **Training Loop**:
    - Iterate for the specified number of epochs.
    - In each epoch, shuffle the training data.
    - Iterate over batches of the shuffled data.
    - Call the `train_step` function to update model parameters and optimizer state for each batch.
    - Accumulate and print the average loss periodically.
6.  **Evaluation**:
    - After training, use the `predict` function to get logits for the test set.
    - Calculate the classification accuracy by comparing the predicted class labels (derived from `argmax` of logits) with the true labels.

In [None]:
# Configuration
N_SAMPLES = 500
N_FEATURES = 4
N_CLASSES = 3
HIDDEN_SIZES = [64, 32]
LEARNING_RATE = 0.005
EPOCHS = 100
BATCH_SIZE = 32
PRINT_EVERY_EPOCHS = 10
SEED = 42

key = jax.random.PRNGKey(SEED)
key_data, key_init, key_shuffle = jax.random.split(key, 3)

# Generate Data
X_data, y_data_one_hot, y_data_orig = generate_synthetic_data(
    n_samples=N_SAMPLES, n_features=N_FEATURES, n_classes=N_CLASSES, random_state=int(jax.random.key_data_bits(key_data)[0]) # Use key for sklearn too
)
X_train, X_test, y_train_one_hot, y_test_one_hot, y_train_orig, y_test_orig = train_test_split(
    X_data, y_data_one_hot, y_data_orig, test_size=0.2, random_state=SEED # sklearn's random_state
)

print(f"X_train shape: {X_train.shape}, y_train_one_hot shape: {y_train_one_hot.shape}")
print(f"X_test shape: {X_test.shape}, y_test_one_hot shape: {y_test_one_hot.shape}")
print(f"Number of classes: {N_CLASSES}")

# Initialize Model and Optimizer with Flax NNX
# Create a dummy input to infer shapes
dummy_x = X_train[:1]

# NNX model initialization
model = MLP(input_size=N_FEATURES, hidden_sizes=HIDDEN_SIZES, output_size=N_CLASSES, rngs=nnx.Rngs(params=key_init))

# Initialize parameters by a "dry run"
_ = model(dummy_x) 

# Split the model into trainable parameters and static parts
params, static = nnx.split(model)

optimizer = get_optimizer(learning_rate=LEARNING_RATE)
opt_state = optimizer.init(params) # Optimizer initializes with trainable parameters

# Training Loop
num_train_samples = X_train.shape[0]
num_batches = num_train_samples // BATCH_SIZE

print(f"\nStarting training for {EPOCHS} epochs...")
for epoch in range(EPOCHS):
    key_shuffle, key_loop = jax.random.split(key_shuffle) # Get a new key for this epoch's shuffle
    permutation = jax.random.permutation(key_loop, num_train_samples)
    shuffled_X_train = X_train[permutation]
    shuffled_y_train_one_hot = y_train_one_hot[permutation]

    epoch_loss = 0.0
    for i in range(num_batches):
        start_idx = i * BATCH_SIZE
        end_idx = start_idx + BATCH_SIZE
        X_batch = shuffled_X_train[start_idx:end_idx]
        y_batch = shuffled_y_train_one_hot[start_idx:end_idx]

        params, opt_state, loss_val = train_step(
            params,         
            static,         
            opt_state,
            cross_entropy_loss, 
            optimizer.update,   
            X_batch,
            y_batch
        )
        epoch_loss += loss_val

    avg_epoch_loss = epoch_loss / num_batches
    if (epoch + 1) % PRINT_EVERY_EPOCHS == 0 or epoch == 0:
        print(f"Epoch {epoch + 1}/{EPOCHS}, Avg Loss: {avg_epoch_loss:.4f}")

print("Training finished.")

# Evaluation (simple accuracy)
test_logits = predict(params, static, X_test)
predicted_classes = jnp.argmax(test_logits, axis=1)
accuracy = jnp.mean(predicted_classes == y_test_orig)
print(f"\nTest Accuracy: {accuracy:.4f}")

## 7. Conclusion

This notebook demonstrated the process of building, training, and evaluating an MLP using JAX with the Flax NNX API and Optax for optimization. The use of `nnx.split` and `nnx.merge` is key to integrating the stateful NNX model style with JAX's functional transformations and Optax's parameter-based optimization.