In [49]:
# import JAX to use for differentiation
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, random
from sklearn.datasets import fetch_openml

In [50]:
# Implement a dense feedforward neural network from scratch. Make your implementation flexible with
# respect to the input and output dimensions, the number of hidden layers and neurons per hidden
# layer, and the activation functions used. Use this implementation for the following two questions.
# Choose a suitable initialization for the network parameters.

In [51]:
# We start with a function that initializes the network parameters.
def init_net_params(layer_widths, key):
    """
    Initialize the network parameters.
    layer_widths: list of integers, the number of neurons in each layer
    key: jax.random.PRNGKey, the random key
    returns: list of jax.numpy.ndarray, the network parameters
    """
    params = []
    keys = random.split(key, len(layer_widths) - 1)

    for i, (n_in, n_out) in enumerate(zip(layer_widths[:-1], layer_widths[1:])):
        w_key = keys[i]
        w = random.normal(w_key, shape=(n_in, n_out))
        b = jnp.zeros((n_out,))
        params.append({'w': w, 'b': b})

    return params


In [52]:
# Next, we define a forward pass function that computes the output of the network for a given input.
def forward(params, x, activation):
    """
    Forward pass of the network.
    """

    # dictionary of activation functions
    activations = {
        'relu': jax.nn.relu,
        'sigmoid': jax.nn.sigmoid,
        'tanh': jax.nn.tanh,
        'softmax': jax.nn.softmax
    }
    activation = activations[activation]

    for layer in params[:-1]:
        x = x @ layer['w'] + layer['b']
        x = activation(x)

    # output layer
    final_layer = params[-1]
    return jnp.dot(x, final_layer['w']) + final_layer['b']


In [53]:
# Next, we define the loss function.
def loss(params, x, y, activation):
    """
    Loss function for the network.
    """
    batched_forward = vmap(forward, in_axes=(None, 0, None))
    preds = batched_forward(params, x, activation)
    return jnp.mean((preds - y) ** 2)


In [54]:
# We now define an update function that updates the network parameters.
def update(params, x, y, activation, lr):
    """
    Update function for the network parameters.
    """
    grads = grad(loss)(params, x, y, activation)
    new_params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
    return new_params

In [55]:
# we test it now
key = random.PRNGKey(42)
layer_widths = [2, 3, 1]
epochs = 1000
params = init_net_params(layer_widths, key)
x = jnp.array([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])
y = jnp.array([0.0, 1.0, 1.0, 0.0])
activation = 'relu'

print("Starting training...")
for epoch in range(epochs):
    params = update(params, x, y, activation, lr=0.1)
    
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {loss(params, x, y, activation)}")

print("Training complete!")


Starting training...
Epoch 0, Loss: 0.5238784551620483
Epoch 100, Loss: 0.250282347202301
Epoch 200, Loss: 0.25010988116264343
Epoch 300, Loss: 0.2500782012939453
Epoch 400, Loss: 0.25005677342414856
Epoch 500, Loss: 0.2500416338443756
Epoch 600, Loss: 0.2500308156013489
Epoch 700, Loss: 0.2500229477882385
Epoch 800, Loss: 0.2500171661376953
Epoch 900, Loss: 0.25001290440559387
Training complete!


In [None]:
# Consider a standard benchmark dataset for classification: train a neural network to classify handwrit-
# ten digits into the ten classes 0, 1, . . . , 9. As input for your model, use flattened vector representations
# of the MNIST images
mnist = fetch_openml('mnist_784')
