# Install and Import Required Libraries
Install JAX and import necessary libraries such as jax, jax.numpy, and optax.

In [3]:
# Install JAX
# Note: Uncomment the following line if running in an environment where JAX is not installed.
#!pip install jax jaxlib optax

# Import Required Libraries
import jax
import jax.numpy as jnp
from jax import random, grad, jit
import optax

# Load and Preprocess the MNIST Dataset
Load the MNIST dataset, normalize the images, and split the data into training and testing sets.

In [4]:
# Load MNIST Dataset
from tensorflow.keras.datasets import mnist

# Load data
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize images
train_images = train_images / 255.0
test_images = test_images / 255.0

# Flatten images
train_images = train_images.reshape(-1, 28 * 28)
test_images = test_images.reshape(-1, 28 * 28)

# Convert labels to one-hot encoding
train_labels = jax.nn.one_hot(train_labels, num_classes=10)
test_labels = jax.nn.one_hot(test_labels, num_classes=10)

ModuleNotFoundError: No module named 'tensorflow'

# Define the Neural Network Model
Define a simple feedforward neural network using JAX functions.

In [None]:
# Define Neural Network Model
def neural_network(params, x):
    for w, b in params[:-1]:
        x = jnp.dot(x, w) + b
        x = jax.nn.relu(x)
    final_w, final_b = params[-1]
    x = jnp.dot(x, final_w) + final_b
    return x

# Initialize Parameters
Initialize the model parameters using JAX's random number generator.

In [None]:
# Initialize Parameters
def initialize_parameters(layer_sizes, key):
    params = []
    for i in range(len(layer_sizes) - 1):
        key, subkey = random.split(key)
        w = random.normal(subkey, (layer_sizes[i], layer_sizes[i + 1])) * jnp.sqrt(2.0 / layer_sizes[i])
        b = jnp.zeros((layer_sizes[i + 1],))
        params.append((w, b))
    return params

# Define layer sizes
layer_sizes = [28 * 28, 128, 64, 10]
key = random.PRNGKey(0)
params = initialize_parameters(layer_sizes, key)

# Define the Loss Function
Implement the cross-entropy loss function for classification tasks.

In [None]:
# Define Loss Function
def cross_entropy_loss(params, x, y):
    logits = neural_network(params, x)
    return -jnp.mean(jnp.sum(y * jax.nn.log_softmax(logits), axis=1))

# Define the Training Loop
Write a training loop to update model parameters using gradient descent or an optimizer from Optax.

In [None]:
# Define Training Loop
@jit
def update(params, x, y, opt_state, optimizer):
    grads = grad(cross_entropy_loss)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state

# Initialize optimizer
optimizer = optax.adam(learning_rate=0.001)
opt_state = optimizer.init(params)

# Training loop
num_epochs = 10
batch_size = 128

for epoch in range(num_epochs):
    for i in range(0, len(train_images), batch_size):
        batch_x = train_images[i:i + batch_size]
        batch_y = train_labels[i:i + batch_size]
        params, opt_state = update(params, batch_x, batch_y, opt_state, optimizer)
    print(f"Epoch {epoch + 1} completed.")

# Evaluate the Model on Test Data
Evaluate the trained model on the test dataset and calculate accuracy.

In [None]:
# Evaluate Model
def accuracy(params, x, y):
    predictions = jnp.argmax(neural_network(params, x), axis=1)
    targets = jnp.argmax(y, axis=1)
    return jnp.mean(predictions == targets)

test_accuracy = accuracy(params, test_images, test_labels)
print(f"Test Accuracy: {test_accuracy * 100:.2f}%")

# Install and Import Required Libraries
Install JAX and other required libraries (if not already installed). Import libraries such as jax, jax.numpy, and tensorflow_datasets.

In [None]:
# Install Required Libraries
# Uncomment the following lines if the libraries are not already installed
# !pip install jax jaxlib tensorflow-datasets

In [None]:
# Import Required Libraries
import jax
import jax.numpy as jnp
import tensorflow_datasets as tfds

# Load and Preprocess the MNIST Dataset
Use tensorflow_datasets to load the MNIST dataset. Normalize the images and split the data into training and testing sets.

In [None]:
# Load the MNIST Dataset
ds_builder = tfds.builder("mnist")
ds_builder.download_and_prepare()
datasets = ds_builder.as_dataset(as_supervised=True)

# Normalize the images and split into training and testing sets
def preprocess(image, label):
    image = jnp.array(image) / 255.0  # Normalize to [0, 1]
    label = jnp.array(label)
    return image, label

train_ds = datasets["train"].map(preprocess)
test_ds = datasets["test"].map(preprocess)

# Define the Neural Network Model
Define a simple feedforward neural network using JAX functions.

In [None]:
# Define the Neural Network Model
def neural_net(params, x):
    for w, b in params[:-1]:
        x = jnp.dot(x, w) + b
        x = jax.nn.relu(x)
    final_w, final_b = params[-1]
    x = jnp.dot(x, final_w) + final_b
    return x

# Initialize Parameters
Initialize the weights and biases of the neural network using JAX's random number generator.

In [None]:
# Initialize Parameters
def initialize_params(layer_sizes, key):
    params = []
    for n_in, n_out in zip(layer_sizes[:-1], layer_sizes[1:]):
        key, subkey = jax.random.split(key)
        w = jax.random.normal(subkey, (n_in, n_out)) * jnp.sqrt(2.0 / n_in)
        b = jnp.zeros(n_out)
        params.append((w, b))
    return params

key = jax.random.PRNGKey(0)
layer_sizes = [784, 128, 64, 10]  # Input layer, two hidden layers, output layer
params = initialize_params(layer_sizes, key)

# Define the Loss Function
Implement a cross-entropy loss function to compute the difference between predictions and true labels.

In [None]:
# Define the Loss Function
def cross_entropy_loss(params, x, y):
    logits = neural_net(params, x)
    one_hot_y = jax.nn.one_hot(y, num_classes=10)
    return -jnp.mean(jnp.sum(one_hot_y * jax.nn.log_softmax(logits), axis=-1))

# Define the Training Loop
Write a training loop that updates the model parameters using gradient descent and JAX's grad function.

In [None]:
# Define the Training Loop
@jax.jit
def update(params, x, y, lr):
    grads = jax.grad(cross_entropy_loss)(params, x, y)
    return [(w - lr * dw, b - lr * db) for (w, b), (dw, db) in zip(params, grads)]

# Train the Model
Run the training loop for a specified number of epochs and monitor the training loss.

In [None]:
# Train the Model
epochs = 10
learning_rate = 0.01

for epoch in range(epochs):
    for batch in train_ds.batch(32):
        images, labels = batch
        params = update(params, images, labels, learning_rate)
    print(f"Epoch {epoch + 1} completed.")

# Evaluate the Model
Evaluate the trained model on the test dataset and compute the accuracy.

In [None]:
# Evaluate the Model
def accuracy(params, dataset):
    correct = 0
    total = 0
    for batch in dataset.batch(32):
        images, labels = batch
        predictions = jnp.argmax(neural_net(params, images), axis=-1)
        correct += jnp.sum(predictions == labels)
        total += len(labels)
    return correct / total

test_accuracy = accuracy(params, test_ds)
print(f"Test Accuracy: {test_accuracy * 100:.2f}%")