# Haliax CNN Example: Training a Convolutional Neural Network on MNIST

This notebook demonstrates how to build, train, and evaluate a simple Convolutional Neural Network (CNN) using Haliax. We'll use the MNIST dataset of handwritten digits.

**Key Haliax Concepts Demonstrated:**
- Defining named axes (`hax.Axis`).
- Building neural network layers (`hax.nn.Conv`, `hax.nn.Linear`, `hax.nn.relu`, `hax.nn.max_pool`).
- Manipulating named tensors (reshaping, dot products implicitly handled by named axes).
- Using Equinox for module structure (`eqx.Module`).
- Basic training loop with JAX (`jax.grad`, `jax.jit`, optimizers).

## 1. Setup and Imports

In [None]:
import equinox as eqx
import haliax as hax
import haliax.nn as hnn
import jax
import jax.numpy as jnp
import jax.random as jrandom
import optax  # For optimizers
from datasets import load_dataset  # To load MNIST
import numpy as np  # For data manipulation

# Axis definitions
Batch = hax.Axis("batch", 32)  # We'll use a batch size of 32 for the notebook
Height = hax.Axis("height", 28)
Width = hax.Axis("width", 28)
Channels = hax.Axis("channels", 1) # MNIST is grayscale
Classes = hax.Axis("classes", 10) # 0-9 digits

# For intermediate layers - these match the notebook's original intent
Conv1Out = hax.Axis("conv1_out", 16) # As originally planned for the notebook
Conv2Out = hax.Axis("conv2_out", 32) # As originally planned for the notebook

# Derived Axis definitions based on convolution effects (3x3 kernels, no 'SAME' padding)
# These are defined globally for use in the SimpleCNN model definition.
# Conv1 output: 28x28 -> 26x26
Conv1Height = Height.resize(Height.size - 2)
Conv1Width = Width.resize(Width.size - 2)

# Conv2 output: 26x26 -> 24x24
PostConvHeight = Conv1Height.resize(Conv1Height.size - 2)
PostConvWidth = Conv1Width.resize(Conv1Width.size - 2)

FlattenedFeatures = hax.Axis("flattened_features", PostConvHeight.size * PostConvWidth.size * Conv2Out.size)

# Note: the dev_test_cnn.py script used smaller Conv1Out/Conv2Out and Batch sizes for faster testing.
# The notebook uses these larger, more standard sizes for the example.

## 2. Load and Preprocess MNIST Dataset

In [None]:
# Load MNIST dataset from Hugging Face
mnist = load_dataset("mnist")

train_data = mnist["train"]
test_data = mnist["test"]

def preprocess_images(examples):
    # Convert images to numpy arrays, normalize to [0, 1], and add channel dimension
    images = np.array([np.array(img) for img in examples["image"]], dtype=np.float32) / 255.0
    images = np.expand_dims(images, axis=-1) # Add channel dim: (N, H, W, C)
    return {"image": images, "label": examples["label"]}

train_data.set_transform(preprocess_images)
test_data.set_transform(preprocess_images)

# Create a simple data loader (iterating over numpy arrays)
# For more advanced data loading, consider libraries like PyTorch DataLoader or tf.data.
def numpy_collate(batch):
    if isinstance(batch[0], dict):
        return {key: np.array([d[key] for d in batch]) for key in batch[0]}
    return np.array(batch)

def dataloader(dataset, batch_size, shuffle=True, key=None):
    indices = np.arange(len(dataset))
    if shuffle:
        if key is None:
            raise ValueError("A JAX random key must be provided for shuffling.")
        indices = jax.random.permutation(key, indices)
        indices = np.asarray(indices) # convert back to numpy for indexing

    for i in range(0, len(indices), batch_size):
        batch_indices = indices[i:i+batch_size]
        if len(batch_indices) < batch_size and i > 0: # drop last partial batch for simplicity
            continue
        
        batch_samples = [dataset[int(j)] for j in batch_indices]
        collated_batch = numpy_collate(batch_samples)
        
        # Convert to Haliax NamedArrays
        images = hax.named(collated_batch["image"], (Batch, Height, Width, Channels))
        labels = hax.named(collated_batch["label"], (Batch,))
        yield {"image": images, "label": labels}

# Example of fetching one batch (we'll need a key for shuffling in training)
dummy_key = jrandom.PRNGKey(0)
train_loader_example = dataloader(train_data, Batch.size, shuffle=True, key=dummy_key)
for batch_ex in train_loader_example:
    print("Image batch shape:", batch_ex["image"].shape)
    print("Label batch shape:", batch_ex["label"].shape)
    break

## 3. Define the CNN Model

In [None]:
# Model Definition (Core logic validated with dev_test_cnn.py)
# This version accounts for spatial dimension reduction by convolutions
# and temporarily omits max pooling layers.
class SimpleCNN(eqx.Module):
    conv1: hnn.Conv
    conv2: hnn.Conv
    linear_out: hnn.Linear

    # Axes like Conv1Height, PostConvHeight, FlattenedFeatures are defined globally in the first code cell.

    def __init__(self, *, key: jrandom.PRNGKey):
        k_conv1, k_conv2, k_linear = jrandom.split(key, 3)
        
        # Uses module-level globals for Channels, Conv1Out, Conv2Out, Classes,
        # Height, Width, Conv1Height, Conv1Width, FlattenedFeatures.
        # Spatial argument indicates the expected spatial axes of the input to the conv layer.
        self.conv1 = hnn.Conv.init(In=Channels, Out=Conv1Out, kernel_size=(3,3), Spatial=(Height, Width), key=k_conv1)
        self.conv2 = hnn.Conv.init(In=Conv1Out, Out=Conv2Out, kernel_size=(3,3), Spatial=(Conv1Height, Conv1Width), key=k_conv2)
        self.linear_out = hnn.Linear.init(In=FlattenedFeatures, Out=Classes, key=k_linear)

    def __call__(self, x: hax.NamedArray) -> hax.NamedArray:
        # Input x: [Batch, Height, Width, Channels]
        
        x = self.conv1(x) # -> [Batch, Conv1Height, Conv1Width, Conv1Out]
        x = hnn.relu(x)
        
        # TODO: Add Max Pooling Layer 1 here if API is clarified.
        # Example (hypothetical API): 
        # Pool1H = Conv1Height.resize(Conv1Height.size // 2)
        # Pool1W = Conv1Width.resize(Conv1Width.size // 2)
        # x = hnn.max_pool(x, window_shape=(2,2), strides=(2,2), new_axes=(Pool1H, Pool1W)) 
        # Remember to adjust subsequent layer's Spatial input axes and FlattenedFeatures if pooling is added.

        x = self.conv2(x) # Input: [Batch, Conv1Height, Conv1Width, Conv1Out]
                          # Output: [Batch, PostConvHeight, PostConvWidth, Conv2Out]
        x = hnn.relu(x)

        # TODO: Add Max Pooling Layer 2 here if API is clarified.
        # Example (hypothetical API):
        # Pool2H = PostConvHeight.resize(PostConvHeight.size // 2)
        # Pool2W = PostConvWidth.resize(PostConvWidth.size // 2)
        # x = hnn.max_pool(x, window_shape=(2,2), strides=(2,2), new_axes=(Pool2H, Pool2W))
        # Remember to adjust FlattenedFeatures if pooling is added.
        
        # Flattening uses PostConvHeight, PostConvWidth (actual output dims after convs without pooling)
        x = x.flatten_axes((PostConvHeight, PostConvWidth, Conv2Out), FlattenedFeatures)
        # x: [Batch, FlattenedFeatures]
        
        x = self.linear_out(x) # -> [Batch, Classes]
        return x

# Initialize model (example for display, not used in training directly here)
model_key_example = jrandom.PRNGKey(1)
model_example = SimpleCNN(key=model_key_example)

# Test with a dummy batch
dummy_key_for_data = jrandom.PRNGKey(2)
# Ensure dummy data uses the notebook's global Batch axis
dummy_images = hax.random.uniform(dummy_key_for_data, (Batch, Height, Width, Channels)) 
output_logits_example = model_example(dummy_images)
print(f"Example model output logits shape: {output_logits_example.shape}")
print(f"Using axes: Batch={Batch.size}, Height={Height.size}, Width={Width.size}, Channels={Channels.size}")
print(f"Conv1Out={Conv1Out.size}, Conv2Out={Conv2Out.size}")
print(f"Conv1Height={Conv1Height.size}, Conv1Width={Conv1Width.size}")
print(f"PostConvHeight={PostConvHeight.size}, PostConvWidth={PostConvWidth.size}")
print(f"FlattenedFeatures={FlattenedFeatures.size}, Classes={Classes.size}")

### Note on Max Pooling
The `SimpleCNN` model above currently does not include max pooling layers. This is a temporary simplification
due to difficulties in ascertaining the exact API for `haliax.nn.max_pool` compatible with the
JAX version (0.4.26) and Haliax version (1.3) used during development of this example.

Max pooling layers are standard in CNNs and would typically be added after each ReLU activation
to reduce spatial dimensions, e.g., using a 2x2 window and stride.

```python
# Hypothetical max pooling after first conv + relu:
# Pool1H = Conv1Height.resize(Conv1Height.size // 2)
# Pool1W = Conv1Width.resize(Conv1Width.size // 2)
# x = hnn.max_pool(x, /* correct_args */, new_axes=(Pool1H, Pool1W))
```

If you intend to use max pooling, please consult the latest Haliax documentation for the correct
function signature of `haliax.nn.max_pool` and adjust the model definition accordingly. This would involve:
1. Defining new pooled axes (e.g., `Pool1H`, `Pool1W`, `Pool2H`, `Pool2W`).
2. Updating the `Spatial` argument for `conv2` if it follows a pooling layer.
3. Recalculating `FlattenedFeatures` based on the final pooled spatial dimensions.

---
## 4. Loss Function, Optimizer, and Training Step

In [None]:
# Loss Function (validated with dev_test_cnn.py)
# Uses global Batch and Classes axes.
def cross_entropy_loss(logits: hax.NamedArray, labels: hax.NamedArray) -> jax.Array:
    one_hot_labels = hnn.one_hot(labels, Classes)
    log_probs = hnn.log_softmax(logits, axis=Classes)
    loss_terms = -hax.sum(one_hot_labels * log_probs, axis=Classes)
    mean_loss = hax.mean(loss_terms, axis=Batch) 
    return mean_loss.array # Return raw JAX array

# Optimizer
learning_rate = 1e-3
optimizer_instance = optax.adam(learning_rate) # Renamed to avoid conflict if optimizer was global

# Training State (validated with dev_test_cnn.py)
class TrainingState(eqx.Module):
    model: SimpleCNN
    opt_state: optax.OptState
    optimizer: optax.GradientTransformation = eqx.static_field()

# Initialize training state for actual training
actual_model_key = jrandom.PRNGKey(42)
actual_model = SimpleCNN(key=actual_model_key) # Uses notebook's global axes
initial_opt_state = optimizer_instance.init(eqx.filter(actual_model, eqx.is_array))
state = TrainingState(model=actual_model, opt_state=initial_opt_state, optimizer=optimizer_instance)

# Training Step (validated with dev_test_cnn.py)
@eqx.filter_jit
def train_step(current_state: TrainingState, batch_data: dict):
    images = batch_data["image"] 
    labels = batch_data["label"] 

    def compute_loss_for_grad(model_to_train):
        logits = model_to_train(images)
        loss = cross_entropy_loss(logits, labels)
        return loss
    
    loss_val, grads = eqx.filter_value_and_grad(compute_loss_for_grad)(current_state.model)
    updates, new_opt_state = current_state.optimizer.update(grads, current_state.opt_state, current_state.model)
    new_model = eqx.apply_updates(current_state.model, updates)
    
    return loss_val, TrainingState(model=new_model, opt_state=new_opt_state, optimizer=current_state.optimizer)

# Evaluation step
@eqx.filter_jit
def eval_step(model: SimpleCNN, batch: dict):
    images = batch["image"]
    labels = batch["label"]
    logits = model(images) # Model is SimpleCNN without pooling here
    
    predicted_class = hax.argmax(logits, axis=Classes)
    correct_predictions = hax.sum(predicted_class == labels)
    # Ensure Batch axis from images is used, which should match global Batch for this notebook's dataloader
    accuracy = correct_predictions / images.axis_size(Batch) 
    return accuracy

## 5. Training Loop

In [None]:
num_epochs = 3  # Small number for a quick example
train_loader_key = jrandom.PRNGKey(123)

losses = []

for epoch in range(num_epochs):
    epoch_key, train_loader_key = jrandom.split(train_loader_key)
    train_dl = dataloader(train_data, Batch.size, shuffle=True, key=epoch_key)
    
    epoch_loss = 0
    num_batches = 0
    for batch in train_dl:
        loss_val, state = train_step(state, batch)
        epoch_loss += loss_val.item()
        num_batches += 1
        losses.append(loss_val.item())
    
    avg_epoch_loss = epoch_loss / num_batches if num_batches > 0 else 0
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss:.4f}")

print("Training finished.")

## 6. Plot Loss Curve

In [None]:
try:
    import matplotlib.pyplot as plt
    plt.figure(figsize=(10, 5))
    plt.plot(losses)
    plt.title("Training Loss Curve")
    plt.xlabel("Training Step")
    plt.ylabel("Cross-Entropy Loss")
    plt.grid(True)
    plt.show()
except ImportError:
    print("Matplotlib not available. Skipping loss plot.")

## 7. Evaluation on Test Set

In [None]:
test_loader_key = jrandom.PRNGKey(456)
test_dl = dataloader(test_data, Batch.size, shuffle=False) # No need to shuffle or use key for test

total_accuracy = 0
num_test_batches = 0

for batch in test_dl:
    # Our simple dataloader drops the last batch if it's not full to ensure all batches have Batch.size.
    # If handling partial batches were required, one would typically pad the batch or 
    # adjust the Batch axis dynamically, and potentially mask the loss/accuracy contributions
    # from the padded examples. Haliax's named axes make dynamic reshaping straightforward.
    # For this example, we rely on the dataloader providing full batches.
    
    # Example check (though our dataloader currently ensures this by dropping partials):
    # current_batch_size = batch['image'].resolve_axis(Batch)
    # if current_batch_size != Batch.size:
    #     # This part would require careful handling of axis resizing for the model or batch.
    #     # For instance, one might create a new Batch axis for this specific batch:
    #     # TempBatch = hax.Axis("batch", current_batch_size)
    #     # temp_images = batch['image'].rename({Batch: TempBatch})
    #     # ... and then ensure eval_step or model can handle TempBatch.
    #     print(f"Skipping batch of size {current_batch_size} for simplicity in example.")
    #     continue

    acc = eval_step(state.model, batch)
    total_accuracy += acc.item()
    num_test_batches += 1

avg_test_accuracy = total_accuracy / num_test_batches if num_test_batches > 0 else 0
print(f"Test Accuracy: {avg_test_accuracy:.4f}")