<a href="https://colab.research.google.com/github/venkateshgali91/SRIP/blob/main/notebooks/colab-github-demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import jax.numpy as jnp
from jax.scipy.special import logsumexp
import jax
from jax import jit, vmap, pmap, grad, value_and_grad

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

In [20]:
seed = 0
mnist_img_size = (28, 28)

def init_MLP(layer_widths, parent_key, scale=0.01):

    p = []
    keys = jax.random.split(parent_key, num=len(layer_widths)-1)

    for in_width, out_width, key in zip(layer_widths[:-1], layer_widths[1:], keys):
        weight_key, bias_key = jax.random.split(key)
        p.append([
                       scale*jax.random.normal(weight_key, shape=(out_width, in_width)),
                       scale*jax.random.normal(bias_key, shape=(out_width,))
                       ]
        )

    return p
key = jax.random.PRNGKey(seed)
MLP = init_MLP([784, 256, 512, 16], key)
print(jax.tree_map(lambda x: x.shape, MLP)) 

[[(256, 784), (256,)], [(512, 256), (512,)], [(16, 512), (16,)]]


In [27]:
def MLP_predict(p, x):
    hidden_layer= p[:-1]

    activation_function = x
    for w, b in hidden_layer:
        activation_function = jax.nn.relu(jnp.dot(w, activation_function) + b)

    w_last, b_last = p[-1]
    log = jnp.dot(w_last, activation_function) + b_last

    return log - logsumexp(log)

# tests
 #example

d_img = np.random.randn(np.prod(mnist_img_size))
print(d_img.shape)

prediction = MLP_predict(MLP, d_img)
print(prediction.shape)

# test batched function
batched_MLP_predict = vmap(MLP_predict, in_axes=(None, 0))

d_img = np.random.randn(16, np.prod(mnist_img_size))
print(d_img.shape)
prediction = batched_MLP_predict(MLP, d_img)


(784,)
(16,)
(16, 784)


In [10]:
def custom_transform(x):
    return np.ravel(np.array(x, dtype=np.float32))

def custom_collate(batch):
    trans_data = list(zip(*batch))

    labels = np.array(trans_data[1])
    imgs = np.stack(trans_data[0])

    return imgs, labels
  # alternate method
''' we can also load the data in pandas library by using the command 
df=pd.read_csv("mnist_dataset.csv") and by using keras we train model by using function callled x_train,x_test,y_train,y_test'''
# using dataloader 
batch_size = 64
train_dataset = MNIST(root='train_mnist', train=True, download=True, transform=custom_transform)
test_dataset = MNIST(root='test_mnist', train=False, download=True, transform=custom_transform)

train_loader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=custom_collate, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size, shuffle=False, collate_fn=custom_collate, drop_last=True)

# test
batch_data = next(iter(train_loader))
imgs = batch_data[0]
lbls = batch_data[1]
print(imgs.shape, imgs[0].dtype, lbls.shape, lbls[0].dtype)

# optimization - loading the whole dataset into memory
train_images = jnp.array(train_dataset.data).reshape(len(train_dataset), -1)
train_lbls = jnp.array(train_dataset.targets)

test_images = jnp.array(test_dataset.data).reshape(len(test_dataset), -1)
test_lbls = jnp.array(test_dataset.targets)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to train_mnist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting train_mnist/MNIST/raw/train-images-idx3-ubyte.gz to train_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to train_mnist/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting train_mnist/MNIST/raw/train-labels-idx1-ubyte.gz to train_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to train_mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting train_mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to train_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to train_mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting train_mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to train_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to test_mnist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting test_mnist/MNIST/raw/train-images-idx3-ubyte.gz to test_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to test_mnist/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting test_mnist/MNIST/raw/train-labels-idx1-ubyte.gz to test_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to test_mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting test_mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to test_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to test_mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting test_mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to test_mnist/MNIST/raw

(64, 784) float32 (64,) int64


In [11]:
num_epochs = 10
# here we use jax to obtain optimization of the model 
def loss_function(pam , imgs, gt_lbls):
    predictions = batched_MLP_predict(pam, imgs)

    return -jnp.mean(predictions * gt_lbls)

def accuracy(p, dataset_img, dataset_lbls):
    pred_classes = jnp.argmax(batched_MLP_predict(p, dataset_img), axis=1)
    return jnp.mean(dataset_lbls == pred_classes)

@jit
def update(p, imgs, gt_lbls, lr=0.01):
    loss, grads = value_and_grad(loss_function)(p, imgs, gt_lbls)

    return loss, jax.tree_multimap(lambda p, g: p - lr*g, p, grads)
# now we are creating multilayer perceptron layer in our neural networks
MLP_params = init_MLP([np.prod(mnist_img_size), 512, 256, len(MNIST.classes)], key)

for epoch in range(num_epochs):

    for cnt, (imgs, lbls) in enumerate(train_loader):

        gt_labels = jax.nn.one_hot(lbls, len(MNIST.classes))
        
        loss, MLP_params = update(MLP_params, imgs, gt_labels)
        
        if cnt % 50 == 0:
            print(loss)

    print(f'Epoch {epoch}, train acc = {accuracy(MLP_params, train_images, train_lbls)} test acc = {accuracy(MLP_params, test_images, test_lbls)}')

0.24110737
0.089520425
0.07011603
0.07218732
0.028679645
0.0663233
0.019861856
0.0130532
0.014857905
0.032602765
0.028439471
0.028668124
0.015437047
0.029392604
0.027350953
0.0117994165
0.045196466
0.02232051
0.027745938
Epoch 0, train acc = 0.9310500025749207 test acc = 0.9307000041007996
0.01673922
0.025583241
0.01991645
0.022684613
0.008694423
0.0319991
0.027315242
0.025488183
0.01830566
0.024629407
0.040597536
0.01910992
0.019701323
0.016781632
0.009218305
0.010125055
0.018095199
0.02024169
0.020485487
Epoch 1, train acc = 0.9521333575248718 test acc = 0.9488999843597412
0.016769517
0.009163578
0.010169112
0.021908669
0.009106166
0.011004353
0.009706794
0.016292302
0.019835763
0.010217128
0.015932344
0.008749564
0.011998409
0.009054077
0.01893959
0.014921847
0.0098811155
0.017204558
0.028199507
Epoch 2, train acc = 0.9623000025749207 test acc = 0.9578999876976013
0.020758877
0.0066460744
0.019411061
0.011036934
0.015032447
0.027433751
0.018825796
0.01600807
0.009995681
0.012628595
