<a href="https://colab.research.google.com/github/sambhav-antriksh/AIML-Project/blob/main/MNIST_equinox.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Creating a Convolutional Neural Network on MNIST
Equinox is one such library which provides neural network capability to JAX.

ModuleNotFoundError: No module named 'equinox'

In [5]:
!pip install equinox



After running the above cell to install `equinox`, you should be able to run the original cell without the `ModuleNotFoundError`. The error occurred because Python couldn't find the `equinox` library when trying to import it. Installing the library makes it available for use.

In [13]:
import equinox as eqx
import jax
import jax.numpy as jnp
import optax
import torch
import torchvision
from jaxtyping import Array,Float,Int,PyTree
import numpy as np

In [9]:
# Hyperparameters
BATCH_SIZE=64
LEARNING_RATE=3E-4
STEPS=300
PRINT_EVERY=30
SEED=5678
key=jax.random.PRNGKey(SEED)

# The DATSET

In [10]:
normalise_data=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.5,),(0.5,))])
train_dataset=torchvision.datasets.MNIST(root='./data',train=True,download=True,transform=normalise_data)
test_dataset=torchvision.datasets.MNIST(root='./data',train=False,download=True,transform=normalise_data)
trainloader= torch.utils.data.DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)  #iterator (images,label)
testloader= torch.utils.data.DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=True)

100%|██████████| 9.91M/9.91M [00:00<00:00, 18.0MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 477kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.45MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.32MB/s]


In [33]:
# checking our data
# trainloader gives me the batches of the data
dummy_x,dummy_y=next(iter(trainloader))
dummy_x=dummy_x.numpy()  #converts pytorch tensor to numpy array #np would not work because no such method
dummy_y=dummy_y.numpy()
print(dummy_x.shape)
print(dummy_y.shape)
print(dummy_y)


(64, 1, 28, 28)
(64,)
[2 5 1 1 6 9 5 0 6 6 4 0 0 5 8 9 5 9 4 3 0 8 9 4 1 9 2 4 0 3 3 5 2 5 3 1 1
 7 1 5 7 3 2 8 7 5 9 3 9 5 7 7 4 8 4 0 4 1 3 6 5 4 0 9]


In [31]:
for images,labels in trainloader:
  print(images.shape)
  print(labels.shape)
  break

torch.Size([64, 1, 28, 28])
torch.Size([64])


# The Model

In [34]:
class CNN(eqx.Module):
  layers:list

  def __init__(self,key):
    key1,key2,key3,key4= jax.random.split(key,4)

    # standard CNN setup : convolutional layer, followed by flattening.
    # with a small MLP on top.
    self.layers =[
        eqx.nn.Conv2d(1,3,kernel_size=4,key=key1),  # the size of the output is output_size = input_size - kernel_size + 1
        eqx.nn.MaxPool2d(kernel_size=2),
        jax.nn.relu,
        jnp.ravel,
        eqx.nn.Linear(1728,512,key=key2),
        jax.nn.sigmoid,
        eqx.nn.Linear(512,64,key=key3),
        jax.nn.relu,
        eqx.nn.Linear(64,10,key=key4),
        jax.nn.log_softmax,

    ]

    def __call__(self,x:Float[Array,"1 28 28"])-> Float[Array, "10"]: # call function you can call the object like a function
      for layer in self.layers:
        x=layer(x)
      return x
key,subkey=jax.random.split(key,2)
model=CNN(subkey)

In [36]:
print(model)

CNN(
  layers=[
    Conv2d(
      num_spatial_dims=2,
      weight=f32[3,1,4,4],
      bias=f32[3,1,1],
      in_channels=1,
      out_channels=3,
      kernel_size=(4, 4),
      stride=(1, 1),
      padding=((0, 0), (0, 0)),
      dilation=(1, 1),
      groups=1,
      use_bias=True,
      padding_mode='ZEROS'
    ),
    MaxPool2d(
      init=-inf,
      operation=<function max>,
      num_spatial_dims=2,
      kernel_size=(2, 2),
      stride=(1, 1),
      padding=((0, 0), (0, 0)),
      use_ceil=False
    ),
    <PjitFunction of <function relu at 0x7a8f0fc7a0c0>>,
    <PjitFunction of <function ravel at 0x7a8f10269ee0>>,
    Linear(
      weight=f32[512,1728],
      bias=f32[512],
      in_features=1728,
      out_features=512,
      use_bias=True
    ),
    <PjitFunction of <function sigmoid at 0x7a8f0fd62e80>>,
    Linear(
      weight=f32[64,512],
      bias=f32[64],
      in_features=512,
      out_features=64,
      use_bias=True
    ),
    <PjitFunction of <function relu at 0x