# Install and Import Required Libraries
Install JAX and other required libraries (e.g., numpy, matplotlib). Import necessary modules from JAX and other libraries.

In [6]:
# Install Required Libraries
# Uncomment the following lines if running in an environment where JAX is not installed
# !pip install jax jaxlib numpy matplotlib

# Import Required Libraries
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

# Load and Preprocess the MNIST Dataset
Download the MNIST dataset, normalize the pixel values, and split the data into training and testing sets.

In [7]:
# Load and Preprocess the MNIST Dataset
from tensorflow.keras.datasets import mnist

# Load MNIST dataset
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize pixel values to the range [0, 1]
train_images = train_images / 255.0
test_images = test_images / 255.0

# Flatten the images for input into the neural network
train_images = train_images.reshape(-1, 28 * 28)
test_images = test_images.reshape(-1, 28 * 28)

# Convert labels to one-hot encoding
train_labels = jax.nn.one_hot(train_labels, 10)
test_labels = jax.nn.one_hot(test_labels, 10)

ModuleNotFoundError: No module named 'tensorflow'

# Define the Neural Network Model
Use JAX to define a simple feedforward neural network for MNIST classification.

In [8]:
# Define the Neural Network Model
def neural_network(params, x):
    # Hidden layer with ReLU activation
    hidden = jax.nn.relu(jnp.dot(x, params['W1']) + params['b1'])
    # Output layer with softmax activation
    logits = jnp.dot(hidden, params['W2']) + params['b2']
    return logits

# Initialize Model Parameters
Randomly initialize the weights and biases of the neural network using JAX's random module.

In [9]:
# Initialize Model Parameters
def initialize_parameters(key, input_dim, hidden_dim, output_dim):
    keys = jax.random.split(key, 2)
    params = {
        'W1': jax.random.normal(keys[0], (input_dim, hidden_dim)),
        'b1': jnp.zeros(hidden_dim),
        'W2': jax.random.normal(keys[1], (hidden_dim, output_dim)),
        'b2': jnp.zeros(output_dim)
    }
    return params

# Initialize parameters
key = jax.random.PRNGKey(0)
params = initialize_parameters(key, input_dim=28*28, hidden_dim=128, output_dim=10)

# Define the Loss Function
Implement the cross-entropy loss function to measure the model's performance.

In [10]:
# Define the Loss Function
def cross_entropy_loss(params, x, y):
    logits = neural_network(params, x)
    return -jnp.mean(jnp.sum(y * jax.nn.log_softmax(logits), axis=1))

# Implement the Training Loop
Write a training loop to optimize the model parameters using gradient descent or a similar optimization algorithm.

In [11]:
# Implement the Training Loop
from jax import grad

# Define the optimizer step
def update_parameters(params, grads, learning_rate):
    return {key: params[key] - learning_rate * grads[key] for key in params}

# Training loop
def train_model(params, train_images, train_labels, epochs, learning_rate):
    for epoch in range(epochs):
        # Compute gradients
        grads = grad(cross_entropy_loss)(params, train_images, train_labels)
        # Update parameters
        params = update_parameters(params, grads, learning_rate)
        # Compute loss
        loss = cross_entropy_loss(params, train_images, train_labels)
        print(f"Epoch {epoch + 1}, Loss: {loss:.4f}")
    return params

# Train the model
params = train_model(params, train_images, train_labels, epochs=10, learning_rate=0.01)

NameError: name 'train_images' is not defined

# Evaluate the Model on Test Data
Evaluate the trained model on the test dataset and calculate the accuracy.

In [12]:
# Evaluate the Model on Test Data
def evaluate_model(params, test_images, test_labels):
    logits = neural_network(params, test_images)
    predictions = jnp.argmax(logits, axis=1)
    accuracy = jnp.mean(predictions == jnp.argmax(test_labels, axis=1))
    return accuracy

# Calculate accuracy
accuracy = evaluate_model(params, test_images, test_labels)
print(f"Test Accuracy: {accuracy:.4f}")

NameError: name 'test_images' is not defined

# Install and Import Required Libraries
Install JAX and import necessary libraries such as jax, jax.numpy, and optax.

In [13]:
# Install JAX
# Note: Uncomment the following line if running in an environment where JAX is not installed.
# !pip install jax jaxlib optax tensorflow-datasets

In [14]:
# Import Required Libraries
import jax
import jax.numpy as jnp
import optax
import tensorflow_datasets as tfds

# Load and Preprocess the MNIST Dataset
Use TensorFlow Datasets or another library to load the MNIST dataset and preprocess it by normalizing pixel values and splitting into training and test sets.

In [15]:
# Load and Preprocess the MNIST Dataset
def preprocess_data(data):
    """Normalize pixel values and convert labels to one-hot encoding."""
    images = data['image'] / 255.0
    labels = jax.nn.one_hot(data['label'], num_classes=10)
    return images, labels

# Load dataset
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))

# Preprocess dataset
train_images, train_labels = preprocess_data(train_ds)
test_images, test_labels = preprocess_data(test_ds)

[1mDownloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /home/codespace/tensorflow_datasets/mnist/3.0.1...[0m


Dl Completed...: 0 url [00:00, ? url/s]

Extraction completed...: 0 file [00:00, ? file/s]
Dl Completed...:   0%|          | 0/1 [00:00<?, ? url/s]
Dl Completed...:   0%|          | 0/1 [00:00<?, ? url/s]0 url/s][A
Dl Completed...: 100%|██████████| 1/1 [00:00<00:00, 84.60 url/s]
Dl Completed...: 100%|██████████| 1/1 [00:00<00:00, 63.32 url/s]
Dl Completed...: 100%|██████████| 1/1 [00:00<00:00, 49.81 url/s]
[A
Dl Completed...: 100%|██████████| 1/1 [00:00<00:00, 49.81 url/s]
Dl Completed...:  50%|█████     | 1/2 [00:00<00:00, 36.44 url/s]
Dl Completed...: 100%|██████████| 2/2 [00:00<00:00, 61.57 url/s]

Dl Completed...:  50%|█████     | 1/2 [00:00<00:00, 36.44 url/s]
Dl Completed...: 100%|██████████| 2/2 [00:00<00:00, 61.57 url/s]
Dl Completed...: 100%|██████████| 2/2 [00:00<00:00, 53.13 url/s][A
Dl Completed...: 100%|██████████| 2/2 [00:00<00:00, 47.85 url/s]
Dl Completed...:  67%|██████▋   | 2/3 [00:00<00:00, 42.13 url/s]
Dl Completed...: 100%|██████████| 2/2 [00:00<00:00, 47.85 url

ExtractError: Error while extracting /home/codespace/tensorflow_datasets/downloads/mnist/cvdf-datasets_mnist_t10k-images-idx3-ubytejUIsewocHHkkWlvPB_6G4z7q_ueSuEWErsJ29aLbxOY.gz to /home/codespace/tensorflow_datasets/downloads/extracted/GZIP.cvdf-datasets_mnist_t10k-images-idx3-ubytejUIsewocHHkkWlvPB_6G4z7q_ueSuEWErsJ29aLbxOY.gz: No module named 'tensorflow'

# Define the Neural Network Model
Define a simple feedforward neural network using JAX functions.

In [None]:
# Define the Neural Network Model
def neural_network(params, inputs):
    """Feedforward neural network."""
    hidden = jax.nn.relu(jnp.dot(inputs, params['w1']) + params['b1'])
    logits = jnp.dot(hidden, params['w2']) + params['b2']
    return logits

# Initialize Parameters
Randomly initialize the weights and biases of the neural network using JAX's random module.

In [None]:
# Initialize Parameters
def initialize_parameters(key, input_dim, hidden_dim, output_dim):
    """Initialize weights and biases for the neural network."""
    keys = jax.random.split(key, 3)
    params = {
        'w1': jax.random.normal(keys[0], (input_dim, hidden_dim)),
        'b1': jnp.zeros(hidden_dim),
        'w2': jax.random.normal(keys[1], (hidden_dim, output_dim)),
        'b2': jnp.zeros(output_dim)
    }
    return params

# Set dimensions
input_dim = 28 * 28
hidden_dim = 128
output_dim = 10

# Initialize parameters
key = jax.random.PRNGKey(0)
params = initialize_parameters(key, input_dim, hidden_dim, output_dim)

# Define the Loss Function and Accuracy Metric
Implement a cross-entropy loss function and an accuracy metric to evaluate the model's performance.

In [None]:
# Define the Loss Function and Accuracy Metric
def cross_entropy_loss(params, inputs, targets):
    """Compute the cross-entropy loss."""
    logits = neural_network(params, inputs)
    return -jnp.mean(jnp.sum(targets * jax.nn.log_softmax(logits), axis=1))

def compute_accuracy(params, inputs, targets):
    """Compute the accuracy of the model."""
    logits = neural_network(params, inputs)
    predictions = jnp.argmax(logits, axis=1)
    targets = jnp.argmax(targets, axis=1)
    return jnp.mean(predictions == targets)

# Implement the Training Loop
Write a training loop that performs forward and backward passes, updates parameters using an optimizer, and logs training progress.

In [None]:
# Implement the Training Loop
def train_model(params, train_images, train_labels, test_images, test_labels, num_epochs, learning_rate):
    """Train the neural network."""
    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(params)

    for epoch in range(num_epochs):
        # Compute gradients
        loss, grads = jax.value_and_grad(cross_entropy_loss)(params, train_images, train_labels)
        
        # Update parameters
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        
        # Compute accuracy
        train_acc = compute_accuracy(params, train_images, train_labels)
        test_acc = compute_accuracy(params, test_images, test_labels)
        
        # Log progress
        print(f"Epoch {epoch + 1}, Loss: {loss:.4f}, Train Accuracy: {train_acc:.4f}, Test Accuracy: {test_acc:.4f}")
    
    return params

# Train the model
num_epochs = 10
learning_rate = 0.001
trained_params = train_model(params, train_images, train_labels, test_images, test_labels, num_epochs, learning_rate)

# Evaluate the Model on Test Data
Evaluate the trained model on the test dataset and report the final accuracy.

In [None]:
# Evaluate the Model on Test Data
final_test_accuracy = compute_accuracy(trained_params, test_images, test_labels)
print(f"Final Test Accuracy: {final_test_accuracy:.4f}")

# Install and Import Required Libraries
Install JAX and other required libraries (e.g., numpy, matplotlib, tensorflow_datasets). Import all necessary modules.

In [None]:
# Install Required Libraries
# Uncomment the following lines if running in an environment where JAX and other libraries are not installed.
# !pip install jax jaxlib numpy matplotlib tensorflow-datasets

In [None]:
# Import Required Libraries
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

# Load and Preprocess the MNIST Dataset
Use tensorflow_datasets to load the MNIST dataset. Normalize the images and split the data into training and testing sets.

In [None]:
# Load MNIST Dataset
ds_builder = tfds.builder("mnist")
ds_builder.download_and_prepare()
datasets = ds_builder.as_dataset(as_supervised=True)

# Normalize and split the dataset
def preprocess(image, label):
    image = jnp.array(image) / 255.0  # Normalize to [0, 1]
    label = jnp.array(label)
    return image, label

train_dataset = datasets["train"].map(preprocess)
test_dataset = datasets["test"].map(preprocess)

# Define the Neural Network Model
Use JAX to define a simple feedforward neural network with layers for MNIST classification.

In [None]:
# Define Neural Network Model
def relu(x):
    return jnp.maximum(0, x)

def softmax(x):
    exp_x = jnp.exp(x - jnp.max(x))
    return exp_x / exp_x.sum(axis=-1, keepdims=True)

def neural_network(params, x):
    hidden = relu(jnp.dot(x, params["W1"]) + params["b1"])
    logits = jnp.dot(hidden, params["W2"]) + params["b2"]
    return softmax(logits)

# Initialize Parameters
Randomly initialize the weights and biases of the neural network using JAX's random module.

In [None]:
# Initialize Parameters
def initialize_params(key):
    input_size = 28 * 28  # MNIST images are 28x28
    hidden_size = 128
    output_size = 10  # 10 classes for digits 0-9

    keys = jax.random.split(key, 3)
    params = {
        "W1": jax.random.normal(keys[0], (input_size, hidden_size)),
        "b1": jnp.zeros(hidden_size),
        "W2": jax.random.normal(keys[1], (hidden_size, output_size)),
        "b2": jnp.zeros(output_size),
    }
    return params

key = jax.random.PRNGKey(0)
params = initialize_params(key)

# Define the Loss Function
Implement a cross-entropy loss function to compute the error between predictions and true labels.

In [None]:
# Define Loss Function
def cross_entropy_loss(params, x, y):
    predictions = neural_network(params, x)
    one_hot_y = jax.nn.one_hot(y, num_classes=10)
    return -jnp.mean(jnp.sum(one_hot_y * jnp.log(predictions), axis=-1))

# Define the Training Loop
Write a training loop that uses JAX's grad function to compute gradients and update the model parameters using gradient descent.

In [None]:
# Define Training Loop
@jax.jit
def update_params(params, x, y, learning_rate):
    grads = jax.grad(cross_entropy_loss)(params, x, y)
    return jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, params, grads)

# Train the Model
Run the training loop for a specified number of epochs and monitor the training loss.

In [None]:
# Train the Model
epochs = 10
learning_rate = 0.01

for epoch in range(epochs):
    for batch in train_dataset.batch(32):
        images, labels = batch
        images = images.reshape(-1, 28 * 28)  # Flatten images
        params = update_params(params, images, labels, learning_rate)
    print(f"Epoch {epoch + 1} completed.")

# Evaluate the Model
Evaluate the trained model on the test dataset and compute the accuracy.

In [None]:
# Evaluate the Model
def compute_accuracy(params, dataset):
    correct = 0
    total = 0
    for batch in dataset.batch(32):
        images, labels = batch
        images = images.reshape(-1, 28 * 28)  # Flatten images
        predictions = neural_network(params, images)
        predicted_labels = jnp.argmax(predictions, axis=-1)
        correct += jnp.sum(predicted_labels == labels)
        total += len(labels)
    return correct / total

accuracy = compute_accuracy(params, test_dataset)
print(f"Test Accuracy: {accuracy * 100:.2f}%")

# Install and Import Required Libraries
Install JAX and other required libraries (e.g., numpy, matplotlib, tensorflow_datasets). Import all necessary modules.

In [None]:
# Install Required Libraries
# Uncomment the following lines if running in an environment where JAX and other libraries are not installed.
# !pip install jax jaxlib numpy matplotlib tensorflow-datasets

In [None]:
# Import Required Libraries
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

# Load and Preprocess the MNIST Dataset
Use tensorflow_datasets to load the MNIST dataset. Normalize the images and split the data into training and testing sets.

In [None]:
# Load the MNIST Dataset
ds_builder = tfds.builder("mnist")
ds_builder.download_and_prepare()
datasets = ds_builder.as_dataset(as_supervised=True)

# Normalize and split the data
def preprocess(image, label):
    image = jnp.array(image) / 255.0  # Normalize to [0, 1]
    label = jnp.array(label)
    return image, label

train_dataset = datasets["train"].map(preprocess)
test_dataset = datasets["test"].map(preprocess)

# Define the Neural Network Model
Use JAX to define a simple feedforward neural network with layers for MNIST classification.

In [None]:
# Define the Neural Network Model
def relu(x):
    return jnp.maximum(0, x)

def softmax(x):
    exp_x = jnp.exp(x - jnp.max(x))
    return exp_x / jnp.sum(exp_x)

def neural_network(params, x):
    hidden = relu(jnp.dot(x, params["W1"]) + params["b1"])
    logits = jnp.dot(hidden, params["W2"]) + params["b2"]
    return softmax(logits)

# Initialize Parameters
Randomly initialize the weights and biases of the neural network using JAX's random module.

In [None]:
# Initialize Parameters
def initialize_parameters(rng, input_size, hidden_size, output_size):
    params = {
        "W1": jax.random.normal(rng, (input_size, hidden_size)),
        "b1": jnp.zeros(hidden_size),
        "W2": jax.random.normal(rng, (hidden_size, output_size)),
        "b2": jnp.zeros(output_size),
    }
    return params

rng = jax.random.PRNGKey(0)
params = initialize_parameters(rng, input_size=784, hidden_size=128, output_size=10)

# Define the Loss Function and Accuracy Metric
Implement the cross-entropy loss function and accuracy metric using JAX functions.

In [None]:
# Define Loss Function and Accuracy Metric
def cross_entropy_loss(params, x, y):
    predictions = neural_network(params, x)
    return -jnp.mean(jnp.sum(y * jnp.log(predictions), axis=1))

def accuracy(params, x, y):
    predictions = neural_network(params, x)
    predicted_labels = jnp.argmax(predictions, axis=1)
    true_labels = jnp.argmax(y, axis=1)
    return jnp.mean(predicted_labels == true_labels)

# Implement the Training Loop
Write a training loop to update the model parameters using gradient descent and JAX's grad function.

In [None]:
# Training Loop
@jax.jit
def update_parameters(params, x, y, learning_rate):
    grads = jax.grad(cross_entropy_loss)(params, x, y)
    updated_params = {key: params[key] - learning_rate * grads[key] for key in params}
    return updated_params

# Example Training Loop
learning_rate = 0.01
epochs = 10

for epoch in range(epochs):
    for batch in train_dataset.batch(32):
        images, labels = batch
        params = update_parameters(params, images, labels, learning_rate)
    print(f"Epoch {epoch + 1} completed.")

# Evaluate the Model on Test Data
Evaluate the trained model on the test dataset and calculate the accuracy.

In [None]:
# Evaluate the Model
test_images, test_labels = next(iter(test_dataset.batch(10000)))
test_acc = accuracy(params, test_images, test_labels)
print(f"Test Accuracy: {test_acc * 100:.2f}%")

# Install and Import Required Libraries
Install JAX and import necessary libraries such as jax, jax.numpy, and optax.

In [None]:
# Install JAX
# Note: Uncomment the following line to install JAX in your environment
# !pip install jax jaxlib optax

# Import Required Libraries
import jax
import jax.numpy as jnp
from jax import random
import optax

# Load and Preprocess the MNIST Dataset
Load the MNIST dataset, normalize the pixel values, and split the data into training and testing sets.

In [None]:
# Load and Preprocess the MNIST Dataset
from tensorflow.keras.datasets import mnist

# Load dataset
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize pixel values to the range [0, 1]
train_images = train_images / 255.0
test_images = test_images / 255.0

# Flatten images
train_images = train_images.reshape(-1, 28 * 28)
test_images = test_images.reshape(-1, 28 * 28)

# Convert labels to one-hot encoding
train_labels = jax.nn.one_hot(train_labels, num_classes=10)
test_labels = jax.nn.one_hot(test_labels, num_classes=10)

# Define the Neural Network Model
Use JAX to define a simple feedforward neural network for MNIST classification.

In [None]:
# Define the Neural Network Model
def neural_network(params, x):
    for w, b in params[:-1]:
        x = jnp.dot(x, w) + b
        x = jax.nn.relu(x)
    final_w, final_b = params[-1]
    x = jnp.dot(x, final_w) + final_b
    return x

# Initialize Parameters
Initialize the model parameters using JAX's random number generation utilities.

In [None]:
# Initialize Parameters
def initialize_parameters(layer_sizes, key):
    params = []
    for i in range(len(layer_sizes) - 1):
        key, subkey = random.split(key)
        w = random.normal(subkey, (layer_sizes[i], layer_sizes[i + 1])) * jnp.sqrt(2.0 / layer_sizes[i])
        b = jnp.zeros((layer_sizes[i + 1],))
        params.append((w, b))
    return params

# Define layer sizes
layer_sizes = [28 * 28, 128, 64, 10]
key = random.PRNGKey(0)
params = initialize_parameters(layer_sizes, key)

# Define the Loss Function
Implement a cross-entropy loss function to compute the difference between predictions and true labels.

In [None]:
# Define the Loss Function
def cross_entropy_loss(params, x, y):
    logits = neural_network(params, x)
    return -jnp.mean(jnp.sum(y * jax.nn.log_softmax(logits), axis=1))

# Define the Training Loop
Write a training loop that updates model parameters using gradient descent or an optimizer from Optax.

In [None]:
# Define the Training Loop
@jax.jit
def update(params, x, y, opt_state, optimizer):
    grads = jax.grad(cross_entropy_loss)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state

# Train the Model
Run the training loop on the training dataset and monitor the loss and accuracy.

In [None]:
# Train the Model
batch_size = 128
num_epochs = 10
learning_rate = 0.001

# Initialize optimizer
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

# Training loop
for epoch in range(num_epochs):
    for i in range(0, len(train_images), batch_size):
        x_batch = train_images[i:i + batch_size]
        y_batch = train_labels[i:i + batch_size]
        params, opt_state = update(params, x_batch, y_batch, opt_state, optimizer)
    # Compute loss for the epoch
    epoch_loss = cross_entropy_loss(params, train_images, train_labels)
    print(f"Epoch {epoch + 1}, Loss: {epoch_loss:.4f}")

# Evaluate the Model
Evaluate the trained model on the test dataset and compute the final accuracy.

In [None]:
# Evaluate the Model
def accuracy(params, x, y):
    predictions = jnp.argmax(neural_network(params, x), axis=1)
    targets = jnp.argmax(y, axis=1)
    return jnp.mean(predictions == targets)

test_accuracy = accuracy(params, test_images, test_labels)
print(f"Test Accuracy: {test_accuracy:.4f}")

# Install and Import Required Libraries
Install JAX and other required libraries (e.g., numpy, matplotlib, tensorflow_datasets). Import necessary modules.

In [None]:
# Install Required Libraries
!pip install jax jaxlib numpy matplotlib tensorflow-datasets

In [None]:
# Import Required Libraries
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

# Load and Preprocess the MNIST Dataset
Use tensorflow_datasets to load the MNIST dataset. Normalize the images and split the data into training and testing sets.

In [None]:
# Load MNIST Dataset
def load_and_preprocess_data():
    ds_train, ds_test = tfds.load('mnist', split=['train', 'test'], as_supervised=True)
    
    def normalize_img(image, label):
        image = jnp.array(image) / 255.0  # Normalize to [0, 1]
        label = jnp.array(label)
        return image, label
    
    ds_train = ds_train.map(normalize_img)
    ds_test = ds_test.map(normalize_img)
    
    return ds_train, ds_test

ds_train, ds_test = load_and_preprocess_data()

# Define the Neural Network Model
Use JAX to define a simple feedforward neural network with layers for MNIST classification.

In [None]:
# Define Neural Network Model
def relu(x):
    return jnp.maximum(0, x)

def softmax(x):
    exp_x = jnp.exp(x - jnp.max(x))
    return exp_x / exp_x.sum(axis=-1, keepdims=True)

def neural_net(params, x):
    hidden = relu(jnp.dot(x, params['W1']) + params['b1'])
    logits = jnp.dot(hidden, params['W2']) + params['b2']
    return softmax(logits)

# Initialize Parameters
Randomly initialize the weights and biases of the neural network using JAX's random module.

In [None]:
# Initialize Parameters
def initialize_parameters(rng, input_dim, hidden_dim, output_dim):
    params = {
        'W1': jax.random.normal(rng, (input_dim, hidden_dim)),
        'b1': jnp.zeros(hidden_dim),
        'W2': jax.random.normal(rng, (hidden_dim, output_dim)),
        'b2': jnp.zeros(output_dim)
    }
    return params

rng = jax.random.PRNGKey(0)
params = initialize_parameters(rng, input_dim=784, hidden_dim=128, output_dim=10)

# Define the Loss Function
Implement a cross-entropy loss function to compute the error between predictions and true labels.

In [None]:
# Define Loss Function
def cross_entropy_loss(params, x, y):
    predictions = neural_net(params, x)
    one_hot_y = jax.nn.one_hot(y, num_classes=10)
    return -jnp.mean(jnp.sum(one_hot_y * jnp.log(predictions), axis=-1))

# Define the Training Loop
Write a training loop that uses JAX's grad function to compute gradients and update parameters using gradient descent.

In [None]:
# Define Training Loop
@jax.jit
def update_parameters(params, x, y, learning_rate):
    grads = jax.grad(cross_entropy_loss)(params, x, y)
    updated_params = jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, params, grads)
    return updated_params

# Train the Model
Run the training loop for a specified number of epochs and track the training loss.

In [None]:
# Train the Model
def train_model(params, ds_train, epochs, learning_rate):
    for epoch in range(epochs):
        for batch in ds_train.batch(128):
            images, labels = batch
            images = images.reshape(-1, 784)  # Flatten images
            params = update_parameters(params, images, labels, learning_rate)
        print(f"Epoch {epoch + 1} completed.")
    return params

params = train_model(params, ds_train, epochs=10, learning_rate=0.01)

# Evaluate the Model
Evaluate the trained model on the test dataset and compute the accuracy.

In [None]:
# Evaluate the Model
def evaluate_model(params, ds_test):
    correct = 0
    total = 0
    for batch in ds_test.batch(128):
        images, labels = batch
        images = images.reshape(-1, 784)  # Flatten images
        predictions = neural_net(params, images)
        predicted_labels = jnp.argmax(predictions, axis=-1)
        correct += jnp.sum(predicted_labels == labels)
        total += len(labels)
    accuracy = correct / total
    print(f"Test Accuracy: {accuracy * 100:.2f}%")

evaluate_model(params, ds_test)

# Install and Import Required Libraries
Install JAX and import necessary libraries such as jax, jax.numpy, and optax.

In [None]:
# Install JAX
# Note: Uncomment the following line to install JAX if not already installed.
# !pip install jax jaxlib optax tensorflow-datasets

# Import Required Libraries
import jax
import jax.numpy as jnp
import optax
import tensorflow_datasets as tfds

# Load and Preprocess the MNIST Dataset
Use TensorFlow Datasets to load the MNIST dataset and preprocess it by normalizing pixel values and converting labels to one-hot encoding.

In [None]:
# Load MNIST Dataset
def load_and_preprocess_data():
    ds_train, ds_test = tfds.load('mnist', split=['train', 'test'], as_supervised=True)
    
    def preprocess(image, label):
        image = jnp.array(image) / 255.0  # Normalize pixel values
        label = jax.nn.one_hot(label, 10)  # Convert labels to one-hot encoding
        return image, label
    
    ds_train = ds_train.map(preprocess)
    ds_test = ds_test.map(preprocess)
    return ds_train, ds_test

train_data, test_data = load_and_preprocess_data()

# Define the Neural Network Model
Define a simple feedforward neural network using JAX functions.

In [None]:
# Define Neural Network Model
def neural_network(params, x):
    for w, b in params[:-1]:
        x = jnp.dot(x, w) + b
        x = jax.nn.relu(x)
    w, b = params[-1]
    x = jnp.dot(x, w) + b
    return x

# Initialize Parameters
Initialize the weights and biases of the neural network using JAX's random number generator.

In [None]:
# Initialize Parameters
def initialize_parameters(layer_sizes, key):
    params = []
    for i in range(len(layer_sizes) - 1):
        key, subkey = jax.random.split(key)
        w = jax.random.normal(subkey, (layer_sizes[i], layer_sizes[i + 1]))
        b = jnp.zeros((layer_sizes[i + 1],))
        params.append((w, b))
    return params

key = jax.random.PRNGKey(0)
layer_sizes = [784, 128, 64, 10]  # Input layer, two hidden layers, output layer
params = initialize_parameters(layer_sizes, key)

# Define the Loss Function
Implement the cross-entropy loss function to compute the difference between predictions and true labels.

In [None]:
# Define Loss Function
def cross_entropy_loss(params, x, y):
    logits = neural_network(params, x)
    return -jnp.mean(jnp.sum(y * jax.nn.log_softmax(logits), axis=-1))

# Implement the Training Loop
Write a training loop that updates the model parameters using gradient descent and evaluates the loss on the training data.

In [None]:
# Training Loop
def train_model(params, train_data, learning_rate, epochs):
    optimizer = optax.sgd(learning_rate)
    opt_state = optimizer.init(params)
    
    for epoch in range(epochs):
        for x, y in train_data:
            loss, grads = jax.value_and_grad(cross_entropy_loss)(params, x, y)
            updates, opt_state = optimizer.update(grads, opt_state)
            params = optax.apply_updates(params, updates)
        print(f"Epoch {epoch + 1}, Loss: {loss}")
    return params

learning_rate = 0.01
epochs = 5
trained_params = train_model(params, train_data, learning_rate, epochs)

# Evaluate the Model on Test Data
Evaluate the trained model on the test dataset and compute the accuracy.

In [None]:
# Evaluate Model
def evaluate_model(params, test_data):
    correct = 0
    total = 0
    for x, y in test_data:
        logits = neural_network(params, x)
        predictions = jnp.argmax(logits, axis=-1)
        labels = jnp.argmax(y, axis=-1)
        correct += jnp.sum(predictions == labels)
        total += len(labels)
    accuracy = correct / total
    print(f"Test Accuracy: {accuracy * 100:.2f}%")

evaluate_model(trained_params, test_data)

# Install and Import Required Libraries
Install JAX and other required libraries (e.g., numpy, matplotlib, tensorflow_datasets). Import the necessary modules.

In [None]:
# Install Required Libraries
# Uncomment the following lines if running in an environment where JAX and other libraries are not installed.
# !pip install jax jaxlib numpy matplotlib tensorflow-datasets

In [None]:
# Import Required Libraries
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

# Load and Preprocess the MNIST Dataset
Use tensorflow_datasets to load the MNIST dataset. Normalize the images and split the data into training and testing sets.

In [None]:
# Load MNIST Dataset
ds_builder = tfds.builder("mnist")
ds_builder.download_and_prepare()
datasets = ds_builder.as_dataset(as_supervised=True)

# Normalize and split the data
def preprocess(image, label):
    image = jnp.array(image) / 255.0  # Normalize to [0, 1]
    label = jnp.array(label)
    return image, label

train_ds = datasets["train"].map(preprocess)
test_ds = datasets["test"].map(preprocess)

# Define the Neural Network Model
Use JAX to define a simple feedforward neural network with layers for MNIST classification.

In [None]:
# Define Neural Network Model
def relu(x):
    return jnp.maximum(0, x)

def softmax(x):
    exp_x = jnp.exp(x - jnp.max(x))
    return exp_x / jnp.sum(exp_x)

def neural_net(params, x):
    hidden = relu(jnp.dot(x, params["W1"]) + params["b1"])
    logits = jnp.dot(hidden, params["W2"]) + params["b2"]
    return softmax(logits)

# Initialize Model Parameters
Randomly initialize the weights and biases of the neural network using JAX's random module.

In [None]:
# Initialize Parameters
def initialize_params(key, input_dim, hidden_dim, output_dim):
    key1, key2 = jax.random.split(key)
    params = {
        "W1": jax.random.normal(key1, (input_dim, hidden_dim)),
        "b1": jnp.zeros(hidden_dim),
        "W2": jax.random.normal(key2, (hidden_dim, output_dim)),
        "b2": jnp.zeros(output_dim),
    }
    return params

key = jax.random.PRNGKey(0)
params = initialize_params(key, input_dim=28*28, hidden_dim=128, output_dim=10)

# Define the Loss Function
Implement a cross-entropy loss function to compute the model's loss during training.

In [None]:
# Define Loss Function
def cross_entropy_loss(params, x, y):
    predictions = neural_net(params, x)
    one_hot_y = jax.nn.one_hot(y, num_classes=10)
    return -jnp.mean(jnp.sum(one_hot_y * jnp.log(predictions), axis=1))

# Implement the Training Loop
Write a training loop that uses JAX's grad function to compute gradients and update the model parameters using gradient descent.

In [None]:
# Training Loop
@jax.jit
def update_params(params, x, y, learning_rate):
    grads = jax.grad(cross_entropy_loss)(params, x, y)
    return jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, params, grads)

# Example training loop (pseudo-code, adjust for batching)
# for epoch in range(num_epochs):
#     for batch_x, batch_y in train_ds:
#         params = update_params(params, batch_x, batch_y, learning_rate=0.01)

# Evaluate the Model on Test Data
Evaluate the trained model on the test dataset and compute the accuracy. Visualize some predictions.

In [None]:
# Evaluate Model
def evaluate_model(params, test_ds):
    correct = 0
    total = 0
    for x, y in test_ds:
        predictions = neural_net(params, x)
        predicted_labels = jnp.argmax(predictions, axis=1)
        correct += jnp.sum(predicted_labels == y)
        total += len(y)
    return correct / total

# Example evaluation (pseudo-code, adjust for batching)
# accuracy = evaluate_model(params, test_ds)
# print(f"Test Accuracy: {accuracy * 100:.2f}%")

In [None]:
# Visualize Predictions
def visualize_predictions(params, test_ds, num_images=5):
    for i, (x, y) in enumerate(test_ds.take(num_images)):
        predictions = neural_net(params, x)
        predicted_label = jnp.argmax(predictions)
        plt.imshow(x.reshape(28, 28), cmap="gray")
        plt.title(f"True: {y}, Predicted: {predicted_label}")
        plt.show()

# Example visualization (pseudo-code)
# visualize_predictions(params, test_ds)

# Install and Import Required Libraries
Install JAX and import necessary libraries such as jax, jax.numpy, and optax.

In [None]:
# Install JAX
# Note: Uncomment the following line to install JAX if not already installed
# !pip install jax jaxlib optax tensorflow-datasets

# Import Required Libraries
import jax
import jax.numpy as jnp
import optax
import tensorflow_datasets as tfds

# Load and Preprocess the MNIST Dataset
Load the MNIST dataset using TensorFlow Datasets or another library, normalize the images, and split the data into training and testing sets.

In [None]:
# Load MNIST Dataset
def load_and_preprocess_data():
    dataset = tfds.load('mnist', as_supervised=True)
    train_data, test_data = dataset['train'], dataset['test']

    def normalize_img(image, label):
        image = jnp.array(image) / 255.0  # Normalize to [0, 1]
        return image, label

    train_data = train_data.map(normalize_img)
    test_data = test_data.map(normalize_img)
    return train_data, test_data

train_data, test_data = load_and_preprocess_data()

# Define the Neural Network Model
Use JAX to define a simple feedforward neural network for MNIST classification.

In [None]:
# Define Neural Network Model
def neural_network(params, x):
    for w, b in params[:-1]:
        x = jnp.dot(x, w) + b
        x = jax.nn.relu(x)
    w, b = params[-1]
    x = jnp.dot(x, w) + b
    return x

# Initialize Parameters
Initialize the model parameters using JAX's random number generation utilities.

In [None]:
# Initialize Parameters
def initialize_parameters(layer_sizes, key):
    params = []
    for i in range(len(layer_sizes) - 1):
        key, subkey = jax.random.split(key)
        w = jax.random.normal(subkey, (layer_sizes[i], layer_sizes[i + 1]))
        b = jnp.zeros(layer_sizes[i + 1])
        params.append((w, b))
    return params

key = jax.random.PRNGKey(0)
layer_sizes = [784, 128, 64, 10]  # Input layer, two hidden layers, output layer
params = initialize_parameters(layer_sizes, key)

# Define the Loss Function
Implement the cross-entropy loss function for classification tasks.

In [None]:
# Define Loss Function
def cross_entropy_loss(params, x, y):
    logits = neural_network(params, x)
    one_hot_y = jax.nn.one_hot(y, num_classes=10)
    loss = -jnp.sum(one_hot_y * jax.nn.log_softmax(logits))
    return loss / x.shape[0]

# Define the Training Loop
Write a training loop that updates the model parameters using gradient descent or an optimizer from Optax.

In [None]:
# Define Training Loop
def train_model(params, train_data, optimizer, num_epochs, batch_size):
    opt_state = optimizer.init(params)

    @jax.jit
    def update(params, opt_state, x, y):
        grads = jax.grad(cross_entropy_loss)(params, x, y)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        return params, opt_state

    for epoch in range(num_epochs):
        for batch in tfds.as_numpy(train_data.batch(batch_size)):
            x_batch, y_batch = batch
            params, opt_state = update(params, opt_state, x_batch, y_batch)
        print(f"Epoch {epoch + 1} completed.")
    return params

optimizer = optax.adam(learning_rate=0.001)
params = train_model(params, train_data, optimizer, num_epochs=5, batch_size=32)

# Evaluate the Model
Evaluate the trained model on the test dataset and calculate accuracy.

In [None]:
# Evaluate Model
def evaluate_model(params, test_data):
    correct = 0
    total = 0
    for batch in tfds.as_numpy(test_data.batch(32)):
        x_batch, y_batch = batch
        logits = neural_network(params, x_batch)
        predictions = jnp.argmax(logits, axis=1)
        correct += jnp.sum(predictions == y_batch)
        total += y_batch.shape[0]
    accuracy = correct / total
    print(f"Test Accuracy: {accuracy * 100:.2f}%")

evaluate_model(params, test_data)

# Install and Import Required Libraries
Install JAX and other required libraries (e.g., numpy, matplotlib, tensorflow_datasets). Import necessary modules.

In [None]:
# Install Required Libraries
# Uncomment the following lines if running in an environment where JAX and other libraries are not installed.
# !pip install jax jaxlib numpy matplotlib tensorflow-datasets

In [None]:
# Import Required Libraries
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

# Load and Preprocess the MNIST Dataset
Use tensorflow_datasets to load the MNIST dataset. Normalize the images and split the data into training and testing sets.

In [None]:
# Load MNIST Dataset
ds_train, ds_test = tfds.load('mnist', split=['train', 'test'], as_supervised=True)

# Normalize and preprocess the dataset
def preprocess(image, label):
    image = jnp.array(image) / 255.0  # Normalize to [0, 1]
    label = jnp.array(label)
    return image, label

ds_train = ds_train.map(preprocess)
ds_test = ds_test.map(preprocess)

# Define the Neural Network Model
Use JAX to define a simple feedforward neural network with layers for MNIST classification.

In [None]:
# Define a Feedforward Neural Network
def relu(x):
    return jnp.maximum(0, x)

def softmax(x):
    exp_x = jnp.exp(x - jnp.max(x))
    return exp_x / exp_x.sum(axis=-1, keepdims=True)

def neural_network(params, x):
    hidden = relu(jnp.dot(x, params['W1']) + params['b1'])
    logits = jnp.dot(hidden, params['W2']) + params['b2']
    return softmax(logits)

# Initialize Parameters
Randomly initialize the weights and biases of the neural network using JAX's random module.

In [None]:
# Initialize Parameters
def initialize_params(key, input_dim, hidden_dim, output_dim):
    key1, key2 = jax.random.split(key)
    params = {
        'W1': jax.random.normal(key1, (input_dim, hidden_dim)),
        'b1': jnp.zeros(hidden_dim),
        'W2': jax.random.normal(key2, (hidden_dim, output_dim)),
        'b2': jnp.zeros(output_dim)
    }
    return params

key = jax.random.PRNGKey(0)
params = initialize_params(key, input_dim=784, hidden_dim=128, output_dim=10)

# Define the Loss Function
Implement a cross-entropy loss function to measure the model's performance.

In [None]:
# Cross-Entropy Loss Function
def cross_entropy_loss(params, x, y):
    predictions = neural_network(params, x)
    one_hot_y = jax.nn.one_hot(y, num_classes=10)
    return -jnp.mean(jnp.sum(one_hot_y * jnp.log(predictions), axis=-1))

# Define the Training Loop
Write a training loop that uses JAX's grad function to compute gradients and update parameters using gradient descent.

In [None]:
# Training Loop
@jax.jit
def update_params(params, x, y, learning_rate):
    grads = jax.grad(cross_entropy_loss)(params, x, y)
    return jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, params, grads)

# Train the Model
Run the training loop for a specified number of epochs and track the training loss.

In [None]:
# Training the Model
epochs = 10
learning_rate = 0.01

for epoch in range(epochs):
    for batch in ds_train.batch(32):
        images, labels = batch
        images = images.reshape(-1, 784)  # Flatten images
        params = update_params(params, images, labels, learning_rate)
    print(f"Epoch {epoch + 1} completed.")

# Import Required Libraries
Import JAX, NumPy, and other necessary libraries such as matplotlib and tensorflow_datasets.

In [None]:
# Import Required Libraries
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

# Load and Preprocess the MNIST Dataset
Load the MNIST dataset using tensorflow_datasets, normalize the images, and split the data into training and testing sets.

In [None]:
# Load and Preprocess the MNIST Dataset
def preprocess_data(data):
    image = data['image'] / 255.0  # Normalize images to [0, 1]
    label = data['label']
    return image, label

# Load dataset
ds_train, ds_test = tfds.load('mnist', split=['train', 'test'], as_supervised=False)

# Preprocess dataset
train_data = [(preprocess_data(sample)) for sample in tfds.as_numpy(ds_train)]
test_data = [(preprocess_data(sample)) for sample in tfds.as_numpy(ds_test)]

# Define the Neural Network Model
Use JAX to define a simple feedforward neural network with one or more hidden layers.

In [None]:
# Define the Neural Network Model
def relu(x):
    return jnp.maximum(0, x)

def neural_network(params, x):
    hidden = relu(jnp.dot(x, params['W1']) + params['b1'])
    logits = jnp.dot(hidden, params['W2']) + params['b2']
    return logits

# Initialize Parameters
Randomly initialize the weights and biases for the neural network using JAX's random module.

In [None]:
# Initialize Parameters
def initialize_parameters(key, input_dim, hidden_dim, output_dim):
    keys = jax.random.split(key, 2)
    params = {
        'W1': jax.random.normal(keys[0], (input_dim, hidden_dim)),
        'b1': jnp.zeros(hidden_dim),
        'W2': jax.random.normal(keys[1], (hidden_dim, output_dim)),
        'b2': jnp.zeros(output_dim)
    }
    return params

key = jax.random.PRNGKey(0)
params = initialize_parameters(key, input_dim=784, hidden_dim=128, output_dim=10)

# Define the Loss Function
Implement the cross-entropy loss function to measure the model's performance.

In [None]:
# Define the Loss Function
def cross_entropy_loss(params, x, y):
    logits = neural_network(params, x)
    one_hot_y = jax.nn.one_hot(y, num_classes=10)
    loss = -jnp.sum(one_hot_y * jax.nn.log_softmax(logits))
    return loss

# Define the Training Step
Use JAX's grad function to compute gradients and update the model parameters using gradient descent.

In [None]:
# Define the Training Step
@jax.jit
def update_parameters(params, x, y, learning_rate):
    grads = jax.grad(cross_entropy_loss)(params, x, y)
    updated_params = jax.tree_util.tree_map(
        lambda p, g: p - learning_rate * g, params, grads
    )
    return updated_params

# Train the Model
Iteratively train the model over multiple epochs, updating the parameters and tracking the loss.

In [None]:
# Train the Model
epochs = 10
learning_rate = 0.01

for epoch in range(epochs):
    for x, y in train_data:
        x = x.reshape(-1)  # Flatten the image
        params = update_parameters(params, x, y, learning_rate)
    print(f"Epoch {epoch + 1} completed.")

# Evaluate the Model
Evaluate the trained model on the test dataset and calculate the accuracy.

In [None]:
# Evaluate the Model
def evaluate_model(params, test_data):
    correct_predictions = 0
    total_predictions = 0

    for x, y in test_data:
        x = x.reshape(-1)  # Flatten the image
        logits = neural_network(params, x)
        predicted_label = jnp.argmax(logits)
        if predicted_label == y:
            correct_predictions += 1
        total_predictions += 1

    accuracy = correct_predictions / total_predictions
    return accuracy

accuracy = evaluate_model(params, test_data)
print(f"Test Accuracy: {accuracy * 100:.2f}%")

# Install and Import Required Libraries
Install JAX and other required libraries (e.g., numpy, matplotlib, tensorflow_datasets). Import all necessary modules.

In [None]:
# Install Required Libraries
# Uncomment the following lines if running in an environment where JAX and other libraries are not installed.
# !pip install jax jaxlib numpy matplotlib tensorflow-datasets

# Import Required Libraries
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

# Load and Preprocess the MNIST Dataset
Use tensorflow_datasets to load the MNIST dataset. Normalize the images and split the data into training and testing sets.

In [None]:
# Load MNIST Dataset
ds_builder = tfds.builder("mnist")
ds_builder.download_and_prepare()
datasets = ds_builder.as_dataset(as_supervised=True)

# Normalize and Split Data
def preprocess(image, label):
    image = jnp.array(image) / 255.0  # Normalize to [0, 1]
    label = jnp.array(label)
    return image, label

train_dataset = datasets["train"].map(preprocess)
test_dataset = datasets["test"].map(preprocess)

# Define the Neural Network Model
Define a simple feedforward neural network using JAX's functional programming paradigm.

In [None]:
# Define Neural Network Model
def neural_network(params, x):
    for w, b in params[:-1]:
        x = jnp.dot(x, w) + b
        x = jax.nn.relu(x)
    w, b = params[-1]
    x = jnp.dot(x, w) + b
    return x

# Initialize Parameters
Randomly initialize the weights and biases of the neural network using JAX's random module.

In [None]:
# Initialize Parameters
def initialize_params(layer_sizes, key):
    params = []
    for in_size, out_size in zip(layer_sizes[:-1], layer_sizes[1:]):
        key, subkey = jax.random.split(key)
        w = jax.random.normal(subkey, (in_size, out_size)) * jnp.sqrt(2.0 / in_size)
        b = jnp.zeros(out_size)
        params.append((w, b))
    return params

key = jax.random.PRNGKey(0)
layer_sizes = [784, 128, 64, 10]  # Input layer, two hidden layers, output layer
params = initialize_params(layer_sizes, key)

# Define the Loss Function
Implement the cross-entropy loss function to measure the model's performance.

In [None]:
# Define Loss Function
def cross_entropy_loss(params, x, y):
    logits = neural_network(params, x)
    one_hot_y = jax.nn.one_hot(y, num_classes=10)
    return -jnp.mean(jnp.sum(one_hot_y * jax.nn.log_softmax(logits), axis=-1))

# Define the Training Loop
Write a training loop that computes gradients using JAX's grad function and updates the model parameters using gradient descent.

In [None]:
# Define Training Loop
@jax.jit
def update_params(params, x, y, lr):
    grads = jax.grad(cross_entropy_loss)(params, x, y)
    return [(w - lr * dw, b - lr * db) for (w, b), (dw, db) in zip(params, grads)]

# Train the Model
Run the training loop for a specified number of epochs and monitor the training loss.

In [None]:
# Train Model
epochs = 10
learning_rate = 0.01

for epoch in range(epochs):
    for batch in train_dataset.batch(32):
        images, labels = batch
        images = images.reshape(-1, 784)  # Flatten images
        params = update_params(params, images, labels, learning_rate)
    print(f"Epoch {epoch + 1} completed.")

# Evaluate the Model
Evaluate the trained model on the test dataset and compute the accuracy.

In [None]:
# Evaluate Model
def compute_accuracy(params, dataset):
    correct = 0
    total = 0
    for batch in dataset.batch(32):
        images, labels = batch
        images = images.reshape(-1, 784)  # Flatten images
        logits = neural_network(params, images)
        predictions = jnp.argmax(logits, axis=-1)
        correct += jnp.sum(predictions == labels)
        total += len(labels)
    return correct / total

accuracy = compute_accuracy(params, test_dataset)
print(f"Test Accuracy: {accuracy * 100:.2f}%")

# Install and Import Required Libraries
Install JAX and other required libraries (e.g., numpy, matplotlib, tensorflow_datasets). Import all necessary modules.

In [3]:
# Install Required Libraries
!pip install jax jaxlib tensorflow-datasets matplotlib

Collecting tensorflow-datasets
  Downloading tensorflow_datasets-4.9.8-py3-none-any.whl.metadata (11 kB)
Collecting tensorflow-datasets
  Downloading tensorflow_datasets-4.9.8-py3-none-any.whl.metadata (11 kB)
Collecting absl-py (from tensorflow-datasets)
  Downloading absl_py-2.2.2-py3-none-any.whl.metadata (2.6 kB)
Collecting absl-py (from tensorflow-datasets)
  Downloading absl_py-2.2.2-py3-none-any.whl.metadata (2.6 kB)
Collecting array_record>=0.5.0 (from tensorflow-datasets)
  Downloading array_record-0.7.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (877 bytes)
Collecting array_record>=0.5.0 (from tensorflow-datasets)
  Downloading array_record-0.7.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (877 bytes)
Collecting dm-tree (from tensorflow-datasets)
  Downloading dm_tree-0.1.9-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.4 kB)
Collecting etils>=1.9.1 (from etils[edc,enp,epath,epy,etree]>=1.9.1; python_ver

In [4]:
# Import Required Libraries
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

# Load and Preprocess the MNIST Dataset
Use tensorflow_datasets to load the MNIST dataset. Normalize the images and split the data into training and testing sets.

In [None]:
# Load MNIST Dataset
ds_builder = tfds.builder("mnist")
ds_builder.download_and_prepare()
datasets = ds_builder.as_dataset(as_supervised=True)

# Normalize and split the dataset
def preprocess(image, label):
    image = jnp.array(image) / 255.0  # Normalize to [0, 1]
    label = jnp.array(label)
    return image, label

train_ds = datasets["train"].map(preprocess)
test_ds = datasets["test"].map(preprocess)

  from .autonotebook import tqdm as notebook_tqdm


[1mDownloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /home/codespace/tensorflow_datasets/mnist/3.0.1...[0m


Dl Completed...: 0 url [00:00, ? url/s]
Dl Completed...:   0%|          | 0/1 [00:00<?, ? url/s]

Dl Completed...:   0%|          | 0/1 [00:00<?, ? url/s]
Dl Completed...:   0%|          | 0/2 [00:00<?, ? url/s]

Dl Completed...:   0%|          | 0/2 [00:00<?, ? url/s]
Dl Completed...:   0%|          | 0/3 [00:00<?, ? url/s][A
Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]

Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  4.80 url/s]

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  4.80 url/s]
Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  8.94 url/s]

Dl Completed...:



***************************************************************
Failed to import TensorFlow. Please note that TensorFlow is not installed by default when you install TFDS. This allows you to choose to install either `tf-nightly` or `tensorflow`. Please install the most recent version of TensorFlow, by following instructions at https://tensorflow.org/install.
***************************************************************




***************************************************************
Failed to import TensorFlow. Please note that TensorFlow is not installed by default when you install TFDS. This allows you to choose to install either `tf-nightly` or `tensorflow`. Please install the most recent version of TensorFlow, by following instructions at https://tensorflow.org/install.
***************************************************************




***************************************************************
Failed to import TensorFlow. Please note that TensorFlow is not installed by

ExtractError: Error while extracting /home/codespace/tensorflow_datasets/downloads/mnist/cvdf-datasets_mnist_t10k-images-idx3-ubytejUIsewocHHkkWlvPB_6G4z7q_ueSuEWErsJ29aLbxOY.gz to /home/codespace/tensorflow_datasets/downloads/extracted/GZIP.cvdf-datasets_mnist_t10k-images-idx3-ubytejUIsewocHHkkWlvPB_6G4z7q_ueSuEWErsJ29aLbxOY.gz: No module named 'tensorflow'



***************************************************************
Failed to import TensorFlow. Please note that TensorFlow is not installed by default when you install TFDS. This allows you to choose to install either `tf-nightly` or `tensorflow`. Please install the most recent version of TensorFlow, by following instructions at https://tensorflow.org/install.
***************************************************************




***************************************************************
Failed to import TensorFlow. Please note that TensorFlow is not installed by default when you install TFDS. This allows you to choose to install either `tf-nightly` or `tensorflow`. Please install the most recent version of TensorFlow, by following instructions at https://tensorflow.org/install.
***************************************************************




***************************************************************
Failed to import TensorFlow. Please note that TensorFlow is not installed by

# Define the Neural Network Model
Use JAX to define a simple feedforward neural network model for MNIST classification.

In [None]:
# Define Neural Network Model
def relu(x):
    return jnp.maximum(0, x)

def softmax(x):
    exp_x = jnp.exp(x - jnp.max(x))
    return exp_x / jnp.sum(exp_x)

def neural_net(params, x):
    hidden = relu(jnp.dot(x, params["W1"]) + params["b1"])
    logits = jnp.dot(hidden, params["W2"]) + params["b2"]
    return softmax(logits)

# Initialize Parameters
Initialize the weights and biases of the neural network using JAX's random number generation utilities.

In [None]:
# Initialize Parameters
def initialize_params(key, input_dim, hidden_dim, output_dim):
    key1, key2 = jax.random.split(key)
    params = {
        "W1": jax.random.normal(key1, (input_dim, hidden_dim)),
        "b1": jnp.zeros(hidden_dim),
        "W2": jax.random.normal(key2, (hidden_dim, output_dim)),
        "b2": jnp.zeros(output_dim),
    }
    return params

key = jax.random.PRNGKey(0)
params = initialize_params(key, input_dim=784, hidden_dim=128, output_dim=10)

# Define the Loss Function
Implement the cross-entropy loss function to measure the model's performance.

In [None]:
# Define Loss Function
def cross_entropy_loss(params, x, y):
    predictions = neural_net(params, x)
    return -jnp.sum(jnp.log(predictions) * y)

# Define the Training Loop
Write a training loop that uses JAX's grad function to compute gradients and update the model parameters using gradient descent.

In [None]:
# Define Training Loop
@jax.jit
def update_params(params, x, y, learning_rate):
    grads = jax.grad(cross_entropy_loss)(params, x, y)
    updated_params = {k: v - learning_rate * grads[k] for k, v in params.items()}
    return updated_params

# Training Loop
def train(params, train_ds, epochs, learning_rate):
    for epoch in range(epochs):
        for image, label in train_ds:
            x = image.flatten()
            y = jax.nn.one_hot(label, 10)
            params = update_params(params, x, y, learning_rate)
        print(f"Epoch {epoch + 1} completed.")
    return params

# Evaluate the Model
Evaluate the trained model on the test dataset and calculate the accuracy.

In [None]:
# Evaluate Model
def evaluate(params, test_ds):
    correct = 0
    total = 0
    for image, label in test_ds:
        x = image.flatten()
        y_pred = jnp.argmax(neural_net(params, x))
        if y_pred == label:
            correct += 1
        total += 1
    accuracy = correct / total
    print(f"Test Accuracy: {accuracy * 100:.2f}%")
    return accuracy

# Example Usage
trained_params = train(params, train_ds, epochs=5, learning_rate=0.01)
evaluate(trained_params, test_ds)

# Install and Import Required Libraries
Install JAX and other required libraries (e.g., numpy, matplotlib, tensorflow_datasets). Import necessary modules.

In [None]:
# Install Required Libraries
# Uncomment the following lines if running in an environment where JAX and other libraries are not installed.
# !pip install jax jaxlib numpy matplotlib tensorflow-datasets

In [None]:
# Import Required Libraries
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

ModuleNotFoundError: No module named 'tensorflow_datasets'

# Load and Preprocess the MNIST Dataset
Use tensorflow_datasets to load the MNIST dataset. Normalize the images and split the data into training and testing sets.

In [None]:
# Load MNIST Dataset
ds_builder = tfds.builder("mnist")
ds_builder.download_and_prepare()
datasets = ds_builder.as_dataset(as_supervised=True)

# Normalize and split the dataset
def preprocess(image, label):
    image = jnp.array(image) / 255.0  # Normalize to [0, 1]
    label = jnp.array(label)
    return image, label

train_dataset = datasets["train"].map(preprocess)
test_dataset = datasets["test"].map(preprocess)

# Define the Neural Network Model
Define a simple feedforward neural network using JAX's functional programming paradigm.

In [None]:
# Define the Neural Network Model
def neural_net(params, x):
    for w, b in params[:-1]:
        x = jnp.dot(x, w) + b
        x = jax.nn.relu(x)
    final_w, final_b = params[-1]
    x = jnp.dot(x, final_w) + final_b
    return x

# Initialize Model Parameters
Initialize the weights and biases of the neural network using JAX's random number generation utilities.

In [None]:
# Initialize Model Parameters
def initialize_params(layer_sizes, key):
    params = []
    for in_size, out_size in zip(layer_sizes[:-1], layer_sizes[1:]):
        key, subkey = jax.random.split(key)
        w = jax.random.normal(subkey, (in_size, out_size)) * jnp.sqrt(2.0 / in_size)
        b = jnp.zeros(out_size)
        params.append((w, b))
    return params

key = jax.random.PRNGKey(0)
layer_sizes = [784, 128, 64, 10]  # Input layer, two hidden layers, output layer
params = initialize_params(layer_sizes, key)

# Define the Loss Function
Implement a cross-entropy loss function to evaluate the model's performance.

In [None]:
# Define the Loss Function
def cross_entropy_loss(params, x, y):
    logits = neural_net(params, x)
    one_hot_y = jax.nn.one_hot(y, num_classes=10)
    return -jnp.mean(jnp.sum(one_hot_y * jax.nn.log_softmax(logits), axis=-1))

# Define the Training Loop
Write a training loop that performs forward and backward passes, computes gradients, and updates model parameters using an optimizer.

In [None]:
# Define the Training Loop
@jax.jit
def update(params, x, y, learning_rate):
    grads = jax.grad(cross_entropy_loss)(params, x, y)
    updated_params = [(w - learning_rate * dw, b - learning_rate * db) 
                      for (w, b), (dw, db) in zip(params, grads)]
    return updated_params

# Train the Model
Run the training loop for a specified number of epochs and monitor the training loss.

In [None]:
# Train the Model
learning_rate = 0.01
num_epochs = 10

for epoch in range(num_epochs):
    for batch in train_dataset.batch(32):
        images, labels = batch
        images = images.reshape(-1, 784)  # Flatten images
        params = update(params, images, labels, learning_rate)
    print(f"Epoch {epoch + 1} completed.")

# Evaluate the Model
Evaluate the trained model on the test dataset and compute the accuracy.

In [None]:
# Evaluate the Model
def accuracy(params, dataset):
    correct = 0
    total = 0
    for batch in dataset.batch(32):
        images, labels = batch
        images = images.reshape(-1, 784)  # Flatten images
        logits = neural_net(params, images)
        predictions = jnp.argmax(logits, axis=-1)
        correct += jnp.sum(predictions == labels)
        total += len(labels)
    return correct / total

test_accuracy = accuracy(params, test_dataset)
print(f"Test Accuracy: {test_accuracy * 100:.2f}%")