In [250]:
# import JAX to use
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, random
from sklearn.datasets import fetch_openml

In [251]:
# Implement a dense feedforward neural network from scratch. Make your implementation flexible with
# respect to the input and output dimensions, the number of hidden layers and neurons per hidden
# layer, and the activation functions used. Use this implementation for the following two questions.
# Choose a suitable initialization for the network parameters.

In [None]:
# We start with a function that initializes the network parameters.
def init_net_params(layer_widths, key):
    """
    Initialize the network parameters.
    layer_widths: list of integers, the number of neurons in each layer
    key: jax.random.PRNGKey, the random key
    returns: list of jax.numpy.ndarray, the network parameters
    """
    params = []
    keys = random.split(key, len(layer_widths) - 1)

    for i, (n_in, n_out) in enumerate(zip(layer_widths[:-1], layer_widths[1:])):
        w_key = keys[i]
        w = random.normal(w_key, shape=(n_in, n_out))
        b = jnp.zeros((n_out,))
        params.append({'w': w, 'b': b})

    return params


In [253]:
# Next, we define a forward pass function that computes the output of the network for a given input.
def forward(params, x, activation):
    """
    Forward pass of the network.
    """

    # dictionary of activation functions
    activations = {
        'relu': jax.nn.relu,
        'sigmoid': jax.nn.sigmoid,
        'tanh': jax.nn.tanh,
        'softmax': jax.nn.softmax
    }
    activation = activations[activation]

    for layer in params[:-1]:
        x = x @ layer['w'] + layer['b']
        x = activation(x)

    # output layer
    final_layer = params[-1]
    return jnp.dot(x, final_layer['w']) + final_layer['b']

def get_batches(x, y, batch_size=256):
    """
    Returns a list of tuples (x_batch, y_batch), each of size batch_size
    (last batch may be smaller).
    """
    n = x.shape[0]
    
    # Use provided key or create a new one
    key = random.PRNGKey(0)
    perm = jax.random.permutation(key, n)
    x_shuffled = x[perm]
    y_shuffled = y[perm]

    batches = []
    for i in range(0, n, batch_size):
        x_batch = x_shuffled[i:i+batch_size]
        y_batch = y_shuffled[i:i+batch_size]
        batches.append((x_batch, y_batch))

    return batches

def get_splits(x, y, train=0.8, valid=0.1):
    """
    This return a jax array of the training, validation, and test splits
    """
    n = x.shape[0]
    x = jnp.array(x)
    y = jnp.array(y)
    # Calculate split indices (as integers)
    train_end = int(train * n)
    valid_end = train_end + int(valid * n)
    
    # Split the data
    x_train = x[:train_end]
    y_train = y[:train_end]
    
    x_valid = x[train_end:valid_end]
    y_valid = y[train_end:valid_end]
    
    x_test = x[valid_end:]
    y_test = y[valid_end:]
    
    return x_train, y_train, x_valid, y_valid, x_test, y_test

In [254]:
# Next, we define the MSE loss function, we can have other loss functions
def mse_loss(params, x, y, activation):
    """
    MSE loss function for the network.
    """
    batched_forward = vmap(forward, in_axes=(None, 0, None))
    preds = batched_forward(params, x, activation)
    return jnp.mean((preds - y) ** 2)

def class_loss(params, x, y, activation):
    """
    Classification cross-entroy loss function
    """
    batched_forward = vmap(forward, in_axes=(None, 0, None))
    logits = batched_forward(params, x, activation)
    log_probs = jax.nn.softmax(logits, axis=1)
    
    nll = -log_probs[jnp.arange(y.shape[0]), y]
    loss = jnp.mean(nll)

    return loss

def evaluate_model(params, x, y, activation, classification):
    """
    Evaluate model on a dataset and return accuracy and loss.
    """
    batched_forward = vmap(forward, in_axes=(None, 0, None))
    logits = batched_forward(params, x, activation)

    if classification:
        preds = jnp.argmax(logits, axis=1)
        
        accuracy = jnp.mean(preds == y)
        loss = class_loss(params, x, y, activation)
        
        return accuracy, loss
    else:
        return jnp.mean((logits - y) ** 2)

In [255]:
# We now define an update function that updates the network parameters.
def update(params, x, y, activation, lr, classification):
    """
    Update function for the network parameters (basic gradient descent).
    """
    loss_fn = class_loss if classification else mse_loss
    grads = grad(loss_fn)(params, x, y, activation)
    new_params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
    return new_params

In [256]:
# Consider a standard benchmark dataset for classification: train a neural network to classify handwrit-
# ten digits into the ten classes 0, 1, . . . , 9. As input for your model, use flattened vector representations
# of the MNIST images
mnist = fetch_openml('mnist_784')
print(f"We have {mnist.data.shape[0]} images")
print(f"Each image has {mnist.data.shape[1]} pixels (features)")

We have 70000 images
Each image has 784 pixels (features)


In [None]:
# Training the neural network
key = random.PRNGKey(42)
layer_widths = [784, 256, 256, 10]
epochs = 500
params = init_net_params(layer_widths, key)
x = mnist.data.to_numpy()
y = mnist.target.astype(int).to_numpy()
x_train, y_train, x_valid, y_valid, x_test, y_test = get_splits(x, y)
activation = 'relu'
batch_size = 128
lr = 0.01

print("Starting training...")
for epoch in range(epochs):
    # Get batches for this epoch (shuffles data each epoch)
    batches = get_batches(x_train, y_train, batch_size=batch_size)
    
    for x_batch, y_batch in batches:
        params = update(params, x_batch, y_batch, activation, lr, classification=True) 
    # Print progress every 100 epochs
    if epoch % 100 == 0:
        train_loss = class_loss(params, x_train, y_train, activation)
        val_loss = class_loss(params, x_valid, y_valid, activation)
        print(f"Epoch {epoch}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

print("Training complete!")

Starting training...
Epoch 0, Train Loss: -0.1059, Val Loss: -0.0936
Epoch 100, Train Loss: -0.1059, Val Loss: -0.0936


In [None]:
# We now test the model
# Test on test set
test_acc, test_loss = evaluate_model(params, x_test, y_test, activation, classification=True)
print(f"Test Accuracy: {test_acc * 100:.2f}%")
print(f"Test Loss: {test_loss:.2f}")

# Also check validation set
val_acc, val_loss = evaluate_model(params, x_valid, y_valid, activation, classification=True)
print(f"Validation Accuracy: {val_acc * 100:.2f}%")
print(f"Validation Loss: {val_loss:.2f}")

Test Accuracy: 33.03%
Test Loss: -0.33
Validation Accuracy: 31.93%
Validation Loss: -0.32
