In [1]:
# from chatGPT
import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
import optax
from flax.training import train_state

def create_model(num_classes):
    model = nn.Sequential([
        nn.Dense(128),
        nn.relu,
        nn.Dense(128),
        nn.relu,
        nn.Dense(num_classes),
        nn.log_softmax,
    ])
    return model

def create_loss_fn(model):
    def loss_fn(params, batch):
        inputs, targets = batch
        logits = model.apply({'params': params}, inputs)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=targets))
        return loss
    return loss_fn

def create_metrics():
    metrics = {
        'accuracy': flax.metrics.accuracy,
    }
    return metrics

def train_step(state, batch):
    model = state.model
    optimizer = state.optimizer
    loss_fn = state.loss_fn
    grad_fn = jax.value_and_grad(loss_fn)

    inputs, targets = batch
    logits, grads = grad_fn(model.params, (inputs, targets))
    model = model.apply_gradients(grads=grads)

    metrics = state.metrics
    metrics = flax.metrics.update_metrics(metrics, logits, targets)

    new_state = state.replace(model=model, metrics=metrics)
    return new_state

def eval_step(state, batch):
    model = state.model
    loss_fn = state.loss_fn

    inputs, targets = batch
    logits = model.apply({'params': model.params}, inputs)
    loss = loss_fn(model.params, (inputs, targets))
    metrics = state.metrics
    metrics = flax.metrics.update_metrics(metrics, logits, targets)

    return metrics

def main():
    # Training configuration
    num_classes = 10
    batch_size = 128
    learning_rate = 0.001
    num_epochs = 10

    # Prepare the dataset
    train_dataset, test_dataset = load_mnist_dataset()  # Your MNIST dataset loading function

    # Create the model
    model = create_model(num_classes)
    params = model.init(jax.random.PRNGKey(0), jnp.ones([1, 784]))['params']

    # Create the optimizer
    optimizer = optax.adam(learning_rate).create(params)

    # Create the loss function and metrics
    loss_fn = create_loss_fn(model)
    metrics = create_metrics()

    # Create the initial training state
    initial_state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer,
        loss_fn=loss_fn,
        metrics=metrics,
    )

    # Training loop
    for epoch in range(num_epochs):
        # Training
        train_metrics = initial_state.metrics
        for batch in train_loader:
            initial_state = train_step(initial_state, batch)
            train_metrics = flax.metrics.merge_metrics(train_metrics, initial_state.metrics)

        # Evaluation
        eval_metrics = initial_state.metrics
        for batch in test_loader:
            eval_metrics = eval_step(initial_state, batch)

        # Print epoch summary
        train_summary = flax.metrics.to_scalar(train_metrics)
        eval_summary = flax.metrics.to_scalar(eval_metrics)
        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Train Loss: {train_summary['loss']:.4f} | "
              f"Train Accuracy: {train_summary['accuracy']:.4f} | "
              f"Eval Loss: {eval_summary['loss']:.4f} | "
              f"Eval Accuracy: {eval_summary['accuracy']:.4f}")

In [5]:
import gzip
import os
import numpy as np
import urllib

def load_mnist_dataset():
    # Define the paths to the MNIST dataset files
    base_url = 'http://yann.lecun.com/exdb/mnist/'
    files = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz',
             't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz']
    paths = []
    for file in files:
        path = os.path.join('mnist', file)
        if not os.path.exists(path):
            os.makedirs('mnist', exist_ok=True)
            url = base_url + file
            print(f"Downloading {url}...")
            urllib.request.urlretrieve(url, path)
        paths.append(path)

    # Load the training data
    with gzip.open(paths[0], 'rb') as f:
        x_train = np.frombuffer(f.read(), dtype=np.uint8, offset=16).reshape(-1, 28*28)
    with gzip.open(paths[1], 'rb') as f:
        y_train = np.frombuffer(f.read(), dtype=np.uint8, offset=8)

    # Load the test data
    with gzip.open(paths[2], 'rb') as f:
        x_test = np.frombuffer(f.read(), dtype=np.uint8, offset=16).reshape(-1, 28*28)
    with gzip.open(paths[3], 'rb') as f:
        y_test = np.frombuffer(f.read(), dtype=np.uint8, offset=8)

    # Normalize pixel values
    x_train = x_train / 255.0
    x_test = x_test / 255.0

    # Convert labels to one-hot encoding
    y_train = np.eye(10)[y_train]
    y_test = np.eye(10)[y_test]

    # Create datasets
    train_dataset = list(zip(x_train, y_train))
    test_dataset = list(zip(x_test, y_test))

    return train_dataset, test_dataset

In [6]:
main()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz...
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz...
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz...
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz...


AttributeError: module 'flax' has no attribute 'data'