In [1]:
from __future__ import print_function
import torch
import torch.optim as optim

!pip install torchsummary
from torchsummary import summary

from eva.model import Net
from eva.train import train
from eva.test import test
from eva.dataloader import getMnistDataLoader
from eva.eval import fit




In [2]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [3]:
torch.manual_seed(1)
batch_size = 128
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

train_loader, test_loader = getMnistDataLoader(batch_size,**kwargs)

In [4]:
kwargs = {}

vanilla_model = Net(**kwargs).to(device)
summary(vanilla_model, input_size=(1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 10, 26, 26]              90
              ReLU-2           [-1, 10, 26, 26]               0
            Conv2d-3           [-1, 10, 24, 24]             900
              ReLU-4           [-1, 10, 24, 24]               0
            Conv2d-5           [-1, 10, 22, 22]             900
              ReLU-6           [-1, 10, 22, 22]               0
         AvgPool2d-7           [-1, 10, 11, 11]               0
            Conv2d-8             [-1, 10, 9, 9]             900
              ReLU-9             [-1, 10, 9, 9]               0
           Conv2d-10             [-1, 10, 7, 7]             900
             ReLU-11             [-1, 10, 7, 7]               0
        AvgPool2d-12             [-1, 10, 3, 3]               0
           Linear-13                   [-1, 10]             910
Total params: 4,600
Trainable params: 4

In [5]:
kwargs = {"normalization" : "BN"}

bn_model = Net(**kwargs).to(device)
summary(bn_model, input_size=(1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 10, 26, 26]              90
              ReLU-2           [-1, 10, 26, 26]               0
       BatchNorm2d-3           [-1, 10, 26, 26]              20
            Conv2d-4           [-1, 10, 24, 24]             900
              ReLU-5           [-1, 10, 24, 24]               0
       BatchNorm2d-6           [-1, 10, 24, 24]              20
            Conv2d-7           [-1, 10, 22, 22]             900
              ReLU-8           [-1, 10, 22, 22]               0
       BatchNorm2d-9           [-1, 10, 22, 22]              20
        AvgPool2d-10           [-1, 10, 11, 11]               0
           Conv2d-11             [-1, 10, 9, 9]             900
             ReLU-12             [-1, 10, 9, 9]               0
      BatchNorm2d-13             [-1, 10, 9, 9]              20
           Conv2d-14             [-1, 1

In [6]:
kwargs = {"normalization" : "LN"}

ln_model = Net(**kwargs).to(device)
summary(ln_model, input_size=(1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 10, 26, 26]              90
              ReLU-2           [-1, 10, 26, 26]               0
         GroupNorm-3           [-1, 10, 26, 26]              20
            Conv2d-4           [-1, 10, 24, 24]             900
              ReLU-5           [-1, 10, 24, 24]               0
         GroupNorm-6           [-1, 10, 24, 24]              20
            Conv2d-7           [-1, 10, 22, 22]             900
              ReLU-8           [-1, 10, 22, 22]               0
         GroupNorm-9           [-1, 10, 22, 22]              20
        AvgPool2d-10           [-1, 10, 11, 11]               0
           Conv2d-11             [-1, 10, 9, 9]             900
             ReLU-12             [-1, 10, 9, 9]               0
        GroupNorm-13             [-1, 10, 9, 9]              20
           Conv2d-14             [-1, 1

In [7]:
kwargs = {"normalization" : "GN"}

gn_model = Net(**kwargs).to(device)
summary(gn_model, input_size=(1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 10, 26, 26]              90
              ReLU-2           [-1, 10, 26, 26]               0
         GroupNorm-3           [-1, 10, 26, 26]              20
            Conv2d-4           [-1, 10, 24, 24]             900
              ReLU-5           [-1, 10, 24, 24]               0
         GroupNorm-6           [-1, 10, 24, 24]              20
            Conv2d-7           [-1, 10, 22, 22]             900
              ReLU-8           [-1, 10, 22, 22]               0
         GroupNorm-9           [-1, 10, 22, 22]              20
        AvgPool2d-10           [-1, 10, 11, 11]               0
           Conv2d-11             [-1, 10, 9, 9]             900
             ReLU-12             [-1, 10, 9, 9]               0
        GroupNorm-13             [-1, 10, 9, 9]              20
           Conv2d-14             [-1, 1

In [8]:
epochs = 2
lr = 0.01
momentum=0.9
optimizer=optim.SGD
weight_decay = 0
lambda_l1 = None

In [9]:
kwargs = {
    'device':device,
    'epochs':epochs,
    'train':train,
    'test':test,
    'train_loader':train_loader,
    'test_loader':test_loader,
    'optimizer':optimizer,
    'learning_rate':lr,
    'momentum':momentum,
    'weight_decay': weight_decay,
    'lambda_l1': lambda_l1
            }

print("Vanilla Model")
vanilla_output = fit(vanilla_model, **kwargs)

torch.manual_seed(1)

print("BatchNorm Model")
bn_output = fit(bn_model,  **kwargs)

torch.manual_seed(1)

print("LayerNorm Model")
ln_output = fit(ln_model,  **kwargs)

torch.manual_seed(1)

print("GroupNorm Model")
gn_output = fit(gn_model,  **kwargs)

torch.manual_seed(1)

print("BatchNorm Model + L1 + L2")
kwargs['weight_decay'] = 1e-5
kwargs['lambda_l1'] = 0.001
bn_l1_l2_output = fit(bn_model,  **kwargs)

torch.manual_seed(1)

print("LayerNorm Model + L2")
kwargs['weight_decay'] = 1e-5
kwargs['lambda_l1'] = None
ln_l2_output = fit(ln_model,  **kwargs)

torch.manual_seed(1)

print("GroupNorm Model + L1")
kwargs['weight_decay'] = 0
kwargs['lambda_l1'] = 0.001
gn_l1_output = fit(gn_model,  **kwargs)


  0%|          | 0/469 [00:00<?, ?it/s]

Vanilla Model
Epoch 1


loss=0.12847991287708282 batch_id=468: 100%|██████████| 469/469 [00:10<00:00, 44.69it/s]


Train set: Average loss: 0.0089, Accuracy: 34947/60000 (58.24%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.2095, Accuracy: 9334/10000 (93.34%)

Epoch 2


loss=0.18545116484165192 batch_id=468: 100%|██████████| 469/469 [00:10<00:00, 43.91it/s]


Train set: Average loss: 0.0013, Accuracy: 56915/60000 (94.86%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.1105, Accuracy: 9668/10000 (96.68%)

BatchNorm Model
Epoch 1


loss=0.10892511159181595 batch_id=468: 100%|██████████| 469/469 [00:11<00:00, 42.54it/s]


Train set: Average loss: 0.0018, Accuracy: 56234/60000 (93.72%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0612, Accuracy: 9818/10000 (98.18%)

Epoch 2


loss=0.015855086967349052 batch_id=468: 100%|██████████| 469/469 [00:11<00:00, 42.33it/s]


Train set: Average loss: 0.0005, Accuracy: 58958/60000 (98.26%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0452, Accuracy: 9869/10000 (98.69%)

LayerNorm Model
Epoch 1


loss=0.12319368124008179 batch_id=468: 100%|██████████| 469/469 [00:11<00:00, 41.65it/s]


Train set: Average loss: 0.0019, Accuracy: 55778/60000 (92.96%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0690, Accuracy: 9791/10000 (97.91%)

Epoch 2


loss=0.014586624689400196 batch_id=468: 100%|██████████| 469/469 [00:10<00:00, 43.18it/s]


Train set: Average loss: 0.0005, Accuracy: 58799/60000 (98.00%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0447, Accuracy: 9864/10000 (98.64%)

GroupNorm Model
Epoch 1


loss=0.16030772030353546 batch_id=468: 100%|██████████| 469/469 [00:10<00:00, 43.29it/s]


Train set: Average loss: 0.0018, Accuracy: 56245/60000 (93.74%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0796, Accuracy: 9751/10000 (97.51%)

Epoch 2


loss=0.008139369077980518 batch_id=468: 100%|██████████| 469/469 [00:11<00:00, 42.07it/s]


Train set: Average loss: 0.0005, Accuracy: 58848/60000 (98.08%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0489, Accuracy: 9851/10000 (98.51%)

BatchNorm Model + L1 + L2
Epoch 1


loss=0.3526088297367096 batch_id=468: 100%|██████████| 469/469 [00:11<00:00, 40.52it/s]


Train set: Average loss: 0.0030, Accuracy: 59094/60000 (98.49%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0655, Accuracy: 9806/10000 (98.06%)

Epoch 2


loss=0.23470309376716614 batch_id=468: 100%|██████████| 469/469 [00:11<00:00, 39.82it/s]


Train set: Average loss: 0.0023, Accuracy: 58992/60000 (98.32%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0631, Accuracy: 9806/10000 (98.06%)

LayerNorm Model + L2
Epoch 1


loss=0.06715451925992966 batch_id=468: 100%|██████████| 469/469 [00:11<00:00, 41.43it/s]


Train set: Average loss: 0.0004, Accuracy: 59085/60000 (98.47%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0426, Accuracy: 9869/10000 (98.69%)

Epoch 2


loss=0.007055711466819048 batch_id=468: 100%|██████████| 469/469 [00:11<00:00, 42.22it/s]


Train set: Average loss: 0.0003, Accuracy: 59219/60000 (98.70%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0346, Accuracy: 9887/10000 (98.87%)

GroupNorm Model + L1
Epoch 1


loss=0.3797954022884369 batch_id=468: 100%|██████████| 469/469 [00:11<00:00, 39.37it/s]


Train set: Average loss: 0.0031, Accuracy: 58990/60000 (98.32%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0621, Accuracy: 9815/10000 (98.15%)

Epoch 2


loss=0.24373072385787964 batch_id=468: 100%|██████████| 469/469 [00:11<00:00, 39.92it/s]


Train set: Average loss: 0.0023, Accuracy: 58885/60000 (98.14%)







Test set: Average loss: 0.0748, Accuracy: 9775/10000 (97.75%)



In [10]:
#TO do print 4 metrics


In [11]:
#To print misclassified images/wrong predictions(20)