In [None]:
'''
Change your code in such a way that all of these are in their respective files:
model
training code
testing code
regularization techniques (dropout, L1, L2, etc)
dataloader/transformations/image-augmentations
misc items like finding misclassified images
'''

In [3]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from torchsummary import summary

# Let's visualize some of the images
#%matplotlib inline
import matplotlib.pyplot as plt
import argparse

from torch.optim.lr_scheduler import StepLR,OneCycleLR

from train import *
from test import *
from model import *
from plotter import *
from data import *
# from model_group_norm import *
# from model_layer_norm import *

#from parser_args import norm, epochs

In [4]:
norm='bn'
epochs=20

In [5]:
SEED = 1

# CUDA?
cuda = torch.cuda.is_available()
print("CUDA Available?", cuda)

# For reproducibility
torch.manual_seed(SEED)

if cuda:
    torch.cuda.manual_seed(SEED)



# train dataloader
train_loader = load_train()

# test dataloader
test_loader = load_test()

CUDA Available? True


In [7]:
# Printing the summary of the model
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print("Using: ",device)

if norm == 'bn':
	print("Loading Batchnorm Model")
	model = Net().to(device)
elif norm == 'group':
	print("Loading Group Model")
	model = Net_group_norm().to(device)

elif norm == 'layer':
	print("Loading layer Model")
	model = Net_layer_norm().to(device)

#model = Net().to(device)
#model = Net_group_norm().to(device)
#model = Net_layer_norm().to(device)

model.apply(weights_init)
summary(model, input_size=(1, 28, 28))


Using:  cuda
Loading Batchnorm Model
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 26, 26]              72
              ReLU-2            [-1, 8, 26, 26]               0
           Dropout-3            [-1, 8, 26, 26]               0
            Conv2d-4           [-1, 16, 24, 24]           1,152
              ReLU-5           [-1, 16, 24, 24]               0
       BatchNorm2d-6           [-1, 16, 24, 24]              32
           Dropout-7           [-1, 16, 24, 24]               0
         MaxPool2d-8           [-1, 16, 12, 12]               0
            Conv2d-9            [-1, 8, 12, 12]             128
           Conv2d-10           [-1, 10, 10, 10]             720
             ReLU-11           [-1, 10, 10, 10]               0
      BatchNorm2d-12           [-1, 10, 10, 10]              20
          Dropout-13           [-1, 10, 10, 10]               0
  

In [8]:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

scheduler = OneCycleLR(optimizer, max_lr=0.020,epochs=20,steps_per_epoch=len(train_loader))


for epoch in range(epochs):
    print("EPOCH:", epoch+1)
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

  0%|          | 0/469 [00:00<?, ?it/s]

EPOCH: 1


Loss=0.6363967061042786 Batch_id=468 Accuracy=52.06: 100%|██████████| 469/469 [00:23<00:00, 20.16it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.2832, Accuracy: 9334/10000 (93.34%)

EPOCH: 2


Loss=0.33075740933418274 Batch_id=468 Accuracy=84.72: 100%|██████████| 469/469 [00:21<00:00, 21.73it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.1331, Accuracy: 9624/10000 (96.24%)

EPOCH: 3


Loss=0.30546125769615173 Batch_id=468 Accuracy=89.56: 100%|██████████| 469/469 [00:21<00:00, 21.59it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0935, Accuracy: 9727/10000 (97.27%)

EPOCH: 4


Loss=0.4090825617313385 Batch_id=468 Accuracy=92.08: 100%|██████████| 469/469 [00:21<00:00, 21.87it/s] 
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0941, Accuracy: 9715/10000 (97.15%)

EPOCH: 5


Loss=0.24273762106895447 Batch_id=468 Accuracy=93.02: 100%|██████████| 469/469 [00:21<00:00, 21.87it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0862, Accuracy: 9725/10000 (97.25%)

EPOCH: 6


Loss=0.16012538969516754 Batch_id=468 Accuracy=93.85: 100%|██████████| 469/469 [00:22<00:00, 20.39it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0645, Accuracy: 9801/10000 (98.01%)

EPOCH: 7


Loss=0.07229413092136383 Batch_id=468 Accuracy=94.44: 100%|██████████| 469/469 [00:23<00:00, 19.65it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0609, Accuracy: 9812/10000 (98.12%)

EPOCH: 8


Loss=0.2702573537826538 Batch_id=468 Accuracy=94.84: 100%|██████████| 469/469 [00:25<00:00, 18.66it/s] 
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0529, Accuracy: 9830/10000 (98.30%)

EPOCH: 9


Loss=0.13641127943992615 Batch_id=468 Accuracy=95.24: 100%|██████████| 469/469 [00:23<00:00, 19.68it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0464, Accuracy: 9845/10000 (98.45%)

EPOCH: 10


Loss=0.07752441614866257 Batch_id=468 Accuracy=95.47: 100%|██████████| 469/469 [00:23<00:00, 20.25it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0437, Accuracy: 9845/10000 (98.45%)

EPOCH: 11


Loss=0.2021065056324005 Batch_id=468 Accuracy=95.78: 100%|██████████| 469/469 [00:21<00:00, 21.91it/s]  
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0430, Accuracy: 9861/10000 (98.61%)

EPOCH: 12


Loss=0.03920157998800278 Batch_id=468 Accuracy=96.00: 100%|██████████| 469/469 [00:21<00:00, 21.79it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0392, Accuracy: 9868/10000 (98.68%)

EPOCH: 13


Loss=0.07932155579328537 Batch_id=468 Accuracy=96.16: 100%|██████████| 469/469 [00:22<00:00, 20.53it/s] 
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0380, Accuracy: 9875/10000 (98.75%)

EPOCH: 14


Loss=0.08386734873056412 Batch_id=468 Accuracy=96.35: 100%|██████████| 469/469 [00:22<00:00, 20.98it/s] 
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0363, Accuracy: 9872/10000 (98.72%)

EPOCH: 15


Loss=0.17029717564582825 Batch_id=468 Accuracy=96.46: 100%|██████████| 469/469 [00:22<00:00, 20.58it/s] 
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0336, Accuracy: 9881/10000 (98.81%)

EPOCH: 16


Loss=0.131439208984375 Batch_id=468 Accuracy=96.44: 100%|██████████| 469/469 [00:23<00:00, 19.55it/s]   
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0349, Accuracy: 9883/10000 (98.83%)

EPOCH: 17


Loss=0.1454789638519287 Batch_id=468 Accuracy=96.55: 100%|██████████| 469/469 [00:22<00:00, 21.08it/s]  
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0335, Accuracy: 9883/10000 (98.83%)

EPOCH: 18


Loss=0.1349489986896515 Batch_id=468 Accuracy=96.66: 100%|██████████| 469/469 [00:22<00:00, 20.78it/s]  
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0311, Accuracy: 9883/10000 (98.83%)

EPOCH: 19


Loss=0.18580052256584167 Batch_id=468 Accuracy=96.78: 100%|██████████| 469/469 [00:21<00:00, 22.17it/s] 
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0290, Accuracy: 9898/10000 (98.98%)

EPOCH: 20


Loss=0.045082949101924896 Batch_id=468 Accuracy=96.87: 100%|██████████| 469/469 [00:22<00:00, 21.22it/s]



Test set: Average loss: 0.0296, Accuracy: 9893/10000 (98.93%)

