<a href="https://colab.research.google.com/github/sambhav-antriksh/AIML-Project/blob/main/Equinox.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Creating a Convolutional Neural Network on MNIST
Equinox is one such library which provides neural network capability to JAX.

In [3]:
!pip install equinox

Collecting equinox
  Downloading equinox-0.13.1-py3-none-any.whl.metadata (19 kB)
Collecting jaxtyping>=0.2.20 (from equinox)
  Downloading jaxtyping-0.3.3-py3-none-any.whl.metadata (7.8 kB)
Collecting wadler-lindig>=0.1.0 (from equinox)
  Downloading wadler_lindig-0.1.7-py3-none-any.whl.metadata (17 kB)
Downloading equinox-0.13.1-py3-none-any.whl (179 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.3/179.3 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jaxtyping-0.3.3-py3-none-any.whl (55 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.9/55.9 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading wadler_lindig-0.1.7-py3-none-any.whl (20 kB)
Installing collected packages: wadler-lindig, jaxtyping, equinox
Successfully installed equinox-0.13.1 jaxtyping-0.3.3 wadler-lindig-0.1.7


After running the above cell to install `equinox`, you should be able to run the original cell without the `ModuleNotFoundError`. The error occurred because Python couldn't find the `equinox` library when trying to import it. Installing the library makes it available for use.

In [4]:
import equinox as eqx
import jax
import jax.numpy as jnp
import optax
import torch
import torchvision
from jaxtyping import Array,Float,Int,PyTree
import numpy as np

In [5]:
# Hyperparameters
BATCH_SIZE=64
LEARNING_RATE=3E-4
STEPS=300
PRINT_EVERY=30
SEED=5678
key=jax.random.PRNGKey(SEED)

# The DATSET

In [6]:
normalise_data=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.5,),(0.5,))])
train_dataset=torchvision.datasets.MNIST(root='./data',train=True,download=True,transform=normalise_data)
test_dataset=torchvision.datasets.MNIST(root='./data',train=False,download=True,transform=normalise_data)
trainloader= torch.utils.data.DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)  #iterator (images,label)
testloader= torch.utils.data.DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=True)

100%|██████████| 9.91M/9.91M [00:00<00:00, 35.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.41MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 10.4MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 10.4MB/s]


In [7]:
# checking our data
# trainloader gives me the batches of the data
dummy_x,dummy_y=next(iter(trainloader))
dummy_x=dummy_x.numpy()  #converts pytorch tensor to numpy array #np would not work because no such method
dummy_y=dummy_y.numpy()
print(dummy_x.shape)
print(dummy_y.shape)
print(dummy_y)


(64, 1, 28, 28)
(64,)
[7 6 0 9 4 3 8 8 2 6 3 3 9 8 7 0 1 7 2 9 9 9 8 5 4 9 8 9 8 8 5 3 9 4 9 5 7
 9 3 9 8 5 1 9 8 4 9 8 0 4 2 1 6 8 8 1 8 1 0 8 2 8 8 1]


In [8]:
for images,labels in trainloader:
  print(images.shape)
  print(labels.shape)
  break

torch.Size([64, 1, 28, 28])
torch.Size([64])


# The Model

In [24]:
class CNN(eqx.Module):
    layers: list

    def __init__(self, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        # Standard CNN setup: convolutional layer, followed by flattening,
        # with a small MLP on top.
        self.layers = [
            eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
            eqx.nn.MaxPool2d(kernel_size=2),
            jax.nn.relu,
            jnp.ravel,
            eqx.nn.Linear(1728, 512, key=key2),
            jax.nn.sigmoid,
            eqx.nn.Linear(512, 64, key=key3),
            jax.nn.relu,
            eqx.nn.Linear(64, 10, key=key4),
            jax.nn.log_softmax,
        ]

    def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
        for layer in self.layers:
            x = layer(x)
        return x


key, subkey = jax.random.split(key, 2)
model = CNN(subkey)

In [31]:
def loss(model: CNN,x:Float[Array,"batch 1 28 28"],y: Int[Array,'batch']) -> Float[Array, ""]:
  pred_y=jax.vmap(model)(x)            # Arguments passed are vectorized in the leading axis now it passes over batch
  return cross_entropy(y,pred_y)

def cross_entropy(y: Int[Array,'batch'],pred_y:Float[Array,'batch 10'])-> Float[Array,""]:
  pred_y= jnp.take_along_axis(pred_y,jnp.expand_dims(y,1),axis=1)
  return -jnp.mean(pred_y)

# Example loss
loss_value=loss(model,dummy_x,dummy_y)
print(loss_value)
# Example inference
output=jax.vmap(model)(dummy_x)
print(output.shape)  #batch of predictions


2.3046346
(64, 10)


# Evaluation

In [32]:
loss= eqx.filter_jit(loss)
@eqx.filter_jit

def compute_accuracy(model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, "batch"]) -> Float[Array, ""]:
  '''Compute the average accuracy on a batch '''
  pred_y=jax.vmap(model)(x)
  pred_y=jnp.argmax(pred_y,axis=1)
  return jnp.mean(y==pred_y)

In [41]:
def evaluate(model: CNN, testloader: torch.utils.data.DataLoader):
  avg_loss= 0
  avg_acc= 0
  for x,y in testloader:
    x=x.numpy()
    y=y.numpy()
    avg_loss+=loss(model,x,y)
    avg_acc+=compute_accuracy(model,x,y)
    #print(y)
  return avg_loss/len(testloader),avg_acc/len(testloader)

In [42]:
evaluate(model,testloader)

(Array(2.311523, dtype=float32), Array(0.10360271, dtype=float32))

# Train our Model

In [43]:
optim= optax.adamw(LEARNING_RATE)

In [49]:
def train(model: CNN, trainloader: torch.utils.data.DataLoader,testloader: torch.utils.data.DataLoader,optim: optax.GradientTransformation, steps: int, print_every: int,)->CNN:
  opt_state=optim.init(eqx.filter(model,eqx.is_array))
  @eqx.filter_jit
  def make_step(model: CNN, opt_state: PyTree, x: Float[Array,"batch 1 28 28"],y: Int[Array,"batch"],):
    loss_value,grads= eqx.filter_value_and_grad(loss)(model,x,y)
    updates,opt_state=optim.update(grads,opt_state,eqx.filter(model,eqx.is_array))
    model=eqx.apply_updates(model,updates)
    return model,opt_state,loss_value

  def infinite_trainloader():
    while True:
      yield from trainloader

  for step,(x,y) in zip(range(steps),infinite_trainloader()):
    x=x.numpy()
    y=y.numpy()
    model,opt_state,train_loss= make_step(model,opt_state,x,y)
    if (step%print_every)==0 or ( step == steps-1):
      test_loss,test_accuracy= evaluate(model,testloader)
      print(f'{step=},train_loss={train_loss.item()},'f'test_loss={test_loss.item()},test_accuracy={test_accuracy.item()}')
  return model


In [50]:
model=train(model,trainloader,testloader,optim,STEPS,PRINT_EVERY)

step=0,train_loss=2.2766408920288086,test_loss=2.2788712978363037,test_accuracy=0.24452626705169678
step=30,train_loss=1.9921276569366455,test_loss=1.9998061656951904,test_accuracy=0.4595939517021179
step=60,train_loss=1.670916199684143,test_loss=1.597533941268921,test_accuracy=0.6569466590881348
step=90,train_loss=1.2031606435775757,test_loss=1.2037876844406128,test_accuracy=0.6993431448936462
step=120,train_loss=0.8097004890441895,test_loss=0.901013970375061,test_accuracy=0.7844347357749939
step=150,train_loss=0.8470339775085449,test_loss=0.687089741230011,test_accuracy=0.8458399772644043
step=180,train_loss=0.5941312909126282,test_loss=0.5600081086158752,test_accuracy=0.8653463125228882
step=210,train_loss=0.39332640171051025,test_loss=0.47468823194503784,test_accuracy=0.8818670511245728
step=240,train_loss=0.330039381980896,test_loss=0.41024479269981384,test_accuracy=0.8958996534347534
step=270,train_loss=0.3951440453529358,test_loss=0.3695693016052246,test_accuracy=0.9050557613372