In [87]:
import tensorflow as tf
import jax
import jax.numpy as jnp
from jax import random, jit, value_and_grad
import optax
import tensorflow_datasets as tfds
import flax
from flax import linen as nn
import matplotlib.pyplot as plt
from flax.training import train_state
from typing import *

In [38]:
BATCH_SIZE = 128
BUFFER_SIZE = 20_000

In [157]:
def load_data(batch_size, buffer_size):
    train_ds = tfds.load('mnist', split='train', shuffle_files=True)
    test_ds = tfds.load('mnist', split='test', shuffle_files = True)

    def preprocess(batch):
        image, label = batch["image"], batch["label"]
        
        image = tf.cast(image, dtype = tf.float32) / 255.0

        return {"image": image, "label" : label}

    train_ds = train_ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    train_ds = train_ds.shuffle(buffer_size=buffer_size)
    train_ds = train_ds.batch(batch_size = batch_size)
    train_ds = train_ds.prefetch(buffer_size = tf.data.AUTOTUNE)
    
    test_ds = test_ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    test_ds = test_ds.batch(batch_size = batch_size)
    test_ds = test_ds.prefetch(buffer_size=tf.data.AUTOTUNE)
    
    return train_ds, test_ds

In [173]:
train_data, test_data = load_data(batch_size = BATCH_SIZE, buffer_size = BUFFER_SIZE)

In [174]:
exmp = next(iter(train_data))

2024-07-16 20:19:41.256568: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


## Model Architecture

In [185]:
class CNNClasifier(nn.Module):
    in_channels: int
    num_classes: int

    def setup(self):
        self.conv1 = nn.Conv(features = self.in_channels, kernel_size = (3,3), strides = (1,1), padding = "SAME")
        self.norm1 = nn.LayerNorm()

        self.conv2 = nn.Conv(features = self.in_channels * 2, kernel_size = (3,3), strides = (2,2), padding = "SAME")
        self.norm2 = nn.LayerNorm()

        self.out = nn.Dense(features = self.num_classes)

    @nn.compact
    def __call__(self, x, train: bool):
        batch_size, _, _, _ = x.shape
        x = self.conv1(x)
        x = self.norm1(x)
        x = nn.relu(x)
        x = self.conv2(x)
        x = self.norm2(x)
        x = nn.relu(x)
        x = jnp.mean(x, axis = (1, 2))
        x = x.reshape(batch_size, -1)
        x = self.out(x)

        return x


In [188]:
model = CNNClasifier(in_channels = 32, num_classes = 10)
rng_key = jax.random.key(10)
variables = model.init(rngs = rng_key, x = exmp["image"], train = True)

In [189]:
print(model.tabulate(rngs = rng_key, x = exmp["image"], train = True, compute_flops = True))


[3m                              CNNClasifier Summary                              [0m
┏━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath [0m[1m [0m┃[1m [0m[1mmodule      [0m[1m [0m┃[1m [0m[1minputs       [0m[1m [0m┃[1m [0m[1moutputs      [0m[1m [0m┃[1m [0m[1mflops[0m[1m [0m┃[1m [0m[1mparams       [0m[1m [0m┃
┡━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━━━━━┩
│       │ CNNClasifier │ train: True   │ [2mfloat32[0m[128,… │ 0     │               │
│       │              │ x:            │               │       │               │
│       │              │ "<tf.Tensor:  │               │       │               │
│       │              │ shape=(128,   │               │       │               │
│       │              │ 28, 28, 1),   │               │       │               │
│       │              │ dtype=float3… │               │       │               │
│       │              │ nump

In [190]:
out = model.apply(variables, exmp["image"], train = True)
out.shape

(128, 10)

Array(2.3528588, dtype=float32)

In [256]:

def cross_entropy_loss(logits, labels):
    """
    logits : unnormalized probability distribution
    labels : one hot encoding of labels 
    """
    labels = nn.one_hot(labels, num_classes = 10)
    loss = optax.losses.softmax_cross_entropy(logits = logits, labels = labels)
    return jnp.mean(loss)


def compute_metrics(logits, labels):
    """
    logits : unnormalized probability distribution
    labels : one hot encoding of labels 
    """

    loss = cross_entropy_loss(logits, labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    
    return {'loss': loss, 'accuracy': accuracy}

In [257]:
cross_entropy_loss(out, exmp["label"].numpy())

Array(2.3528588, dtype=float32)

In [263]:
# Now is the time to create train state for the model training

class TrainState(train_state.TrainState):
    batch_stats: Dict

LEARNING_RATE = 1e-3
tx = optax.lamb(learning_rate = LEARNING_RATE)

state = TrainState.create(
    apply_fn=model.apply,
    params=variables['params'],
    tx=tx,
    batch_stats=variables.get('batch_stats', {})
)

In [268]:
# Writing train step

@jit
def train_step(state, batch):
    def loss_fn(params):
        logits = model.apply({"params": params}, batch["image"], train = True)
        loss = cross_entropy_loss(logits, batch["label"])
        
        return loss, logits
        
    grad_fn = value_and_grad(loss_fn, has_aux = True)
    (loss, logits), grads = grad_fn(state.params)
    
    state = state.apply_gradients(grads = grads)
    
    metrics = compute_metrics(logits, batch["label"])

    return state, loss, metrics

def train_model(state, train_ds, test_ds, num_epochs):
    for epoch in range(num_epochs):
        tr_loss = 0
        tr_acc = 0
        for i, batch in enumerate(train_ds):
            image = batch["image"].numpy()
            label = batch["label"].numpy()
            
            batch = {
                "image" : jnp.array(image),
                "label" : jnp.array(label)
                 
            }
            state, loss, metrics = train_step(state, batch)
            tr_loss += loss
            tr_acc += metrics["accuracy"]
            
            if i % 10 == 0:
                print(f"Epoch {epoch + 1}, Batch {i}, Loss: {loss:.4f}, Accuracy: {metrics['accuracy']:.4f}")
        tr_loss = tr_loss / len(train_ds)
        tr_acc = tr_acc / len(train_ds)
        
        print(f"Epoch {epoch + 1}, Train Accuracy: {tr_acc:.4f}, Test Loss: {tr_loss:.04f}")

        # Validation loop
        vl_loss = 0
        vl_acc = 0
        for batch in test_ds:
            image = batch["image"].numpy()
            label = batch["label"].numpy()
            
            batch = {
                "image" : jnp.array(image),
                "label" : jnp.array(label)
                 
            }
            val_logits = model.apply({'params': state.params}, batch['image'], True)
            val_metrics = compute_metrics(val_logits, batch['label'])
            
            vl_loss += val_metrics["loss"]
            vl_acc += val_metrics["accuracy"]
            
        vl_loss = vl_loss / len(test_ds)
        vl_acc = vl_acc / len(test_ds)
        
        print(f"Epoch {epoch + 1}, Validation Accuracy: {vl_acc:.4f}, Validation Loss: {vl_loss:.04f}")

    return state

In [269]:
train_model(state, train_data, test_data, num_epochs = 100)

Epoch 1, Batch 0, Loss: 2.3794, Accuracy: 0.0938
Epoch 1, Batch 10, Loss: 2.3296, Accuracy: 0.0781
Epoch 1, Batch 20, Loss: 2.2975, Accuracy: 0.1484
Epoch 1, Batch 30, Loss: 2.2882, Accuracy: 0.1953
Epoch 1, Batch 40, Loss: 2.2655, Accuracy: 0.2109
Epoch 1, Batch 50, Loss: 2.2719, Accuracy: 0.1250
Epoch 1, Batch 60, Loss: 2.2521, Accuracy: 0.2031
Epoch 1, Batch 70, Loss: 2.2256, Accuracy: 0.2266
Epoch 1, Batch 80, Loss: 2.2304, Accuracy: 0.1719
Epoch 1, Batch 90, Loss: 2.2160, Accuracy: 0.1875
Epoch 1, Batch 100, Loss: 2.2029, Accuracy: 0.2422
Epoch 1, Batch 110, Loss: 2.2235, Accuracy: 0.2031
Epoch 1, Batch 120, Loss: 2.2384, Accuracy: 0.1172
Epoch 1, Batch 130, Loss: 2.1823, Accuracy: 0.2188
Epoch 1, Batch 140, Loss: 2.1589, Accuracy: 0.2031
Epoch 1, Batch 150, Loss: 2.1431, Accuracy: 0.2031
Epoch 1, Batch 160, Loss: 2.1471, Accuracy: 0.2656
Epoch 1, Batch 170, Loss: 2.1018, Accuracy: 0.2344
Epoch 1, Batch 180, Loss: 2.1376, Accuracy: 0.1953
Epoch 1, Batch 190, Loss: 2.1133, Accuracy

KeyboardInterrupt: 