In [None]:
!pip install -q -U torch 'jax[tpu]' ipykernel jupyter optax flax 'numpy<2.0' 'datasets[audio]' transformers orbax matplotlib seaborn tqdm

In [None]:
# Standard libraries
import math
import os
import time

# Imports for plotting
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline
from IPython.display import set_matplotlib_formats

set_matplotlib_formats("svg", "pdf")  # For export
import seaborn as sns
from matplotlib.colors import to_rgba

sns.set()

# Progress bar
from tqdm.auto import tqdm

In [None]:
import flax.linen as nn
import flax.nnx as nnx
import jax
import jax.numpy as jnp

print("Using jax", jax.__version__)

In [None]:
class MyModule(nnx.Module):
    def __init__(self, dim_in: int, dim_out: int, rngs):
        self.fc1 = nnx.Linear(dim_in, dim_out, rngs=rngs)
        self.fc2 = nnx.Linear(dim_out, 1, rngs=rngs)

    def __call__(self, x):
        return self.fc2(nnx.relu(self.fc1(x)))

In [None]:
rng = jax.random.PRNGKey(0)
rng, inp_rng = jax.random.split(rng)
inp = jax.random.normal(inp_rng, (8, 2))  # Batch size 8, input size 2

In [None]:
import torch.utils.data as data


class XORDataset(data.Dataset):

    def __init__(self, size, seed, std=0.1):
        """
        Inputs:
            size - Number of data points we want to generate
            seed - The seed to use to create the PRNG state with which we want to generate the data points
            std - Standard deviation of the noise (see generate_continuous_xor function)
        """
        super().__init__()
        self.size = size
        self.np_rng = np.random.RandomState(seed=seed)
        self.std = std
        self.generate_continuous_xor()

    def generate_continuous_xor(self):
        # Each data point in the XOR dataset has two variables, x and y, that can be either 0 or 1
        # The label is their XOR combination, i.e. 1 if only x or only y is 1 while the other is 0.
        # If x=y, the label is 0.
        data = self.np_rng.randint(low=0, high=2, size=(self.size, 2)).astype(np.float32)
        label = (data.sum(axis=1) == 1).astype(np.int32)
        # To make it slightly more challenging, we add a bit of gaussian noise to the data points.
        data += self.np_rng.normal(loc=0.0, scale=self.std, size=data.shape)

        self.data = data
        self.label = label

    def __len__(self):
        # Number of data point we have. Alternatively self.data.shape[0], or self.label.shape[0]
        return self.size

    def __getitem__(self, idx):
        # Return the idx-th data point of the dataset
        # If we have multiple things to return (data point and label), we can return them as tuple
        data_point = self.data[idx]
        data_label = self.label[idx]
        return data_point, data_label

In [None]:
dataset = XORDataset(size=200, seed=42)
print("Size of dataset:", len(dataset))
print("Data point 0:", dataset[0])

In [None]:
def visualize_samples(data, label):
    data_0 = data[label == 0]
    data_1 = data[label == 1]

    plt.figure(figsize=(4, 4))
    plt.scatter(data_0[:, 0], data_0[:, 1], edgecolor="#333", label="Class 0")
    plt.scatter(data_1[:, 0], data_1[:, 1], edgecolor="#333", label="Class 1")
    plt.title("Dataset samples")
    plt.ylabel(r"$x_2$")
    plt.xlabel(r"$x_1$")
    plt.legend()

In [None]:
visualize_samples(dataset.data, dataset.label)
plt.show()

In [None]:
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple, list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)


data_loader = data.DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=numpy_collate, num_workers=2)

In [None]:
# next(iter(...)) catches the first batch of the data loader
# If shuffle is True, this will return a different batch every time we run this cell
# For iterating over the whole dataset, we can simple use "for batch in data_loader: ..."
data_inputs, data_labels = next(iter(data_loader))

# The shape of the outputs are [batch_size, d_1,...,d_N] where d_1,...,d_N are the
# dimensions of the data point returned from the dataset class
print("Data inputs", data_inputs.shape, "\n", data_inputs)
print("Data labels", data_labels.shape, "\n", data_labels)

In [None]:
import optax

model = MyModule(2, 10, nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate=0.1))

In [None]:
def loss_fn(model, batch):
    data_inputs, labels = batch
    # Obtain the logits and predictions of the model for the input data
    logits = model(data_inputs).squeeze(-1)
    pred_labels = (logits > 0).astype(jnp.float32)
    # Calculate the loss and accuracy
    loss = optax.sigmoid_binary_cross_entropy(logits, labels).mean()
    acc = (pred_labels == labels).mean()
    return loss, acc


batch = next(iter(data_loader))
loss_fn(model, batch)

In [None]:
@nnx.jit
def train_step(model, optimizer, batch):
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, acc), grads = grad_fn(model, batch)
    optimizer.update(grads)
    return loss, acc

In [None]:
@nnx.jit  # Jit the function for efficiency
def eval_step(model, batch):
    # Determine the accuracy
    _, acc = loss_fn(model, batch)
    return acc

In [None]:
train_dataset = XORDataset(size=2500, seed=42)
train_data_loader = data.DataLoader(
    train_dataset, batch_size=128, shuffle=True, collate_fn=numpy_collate, num_workers=2
)

In [None]:
num_epochs = 20
for epoch in tqdm(range(num_epochs)):
    for batch_idx, batch in enumerate(data_loader):
        loss, acc = train_step(model, optimizer, batch)
        if batch_idx % 50 == 0:
            print(loss, acc)