In [2]:
from __future__ import print_function

import torch
from torchsummary import summary
import torch.optim as optim

import sys
sys.path.append('/content/drive/My Drive/tsai/S8/session_7')
from models.resnet18 import ResNet18
from data_loader.data_loader_cifar import get_train_loader
from data_loader.data_loader_cifar import get_test_loader
from scoring.scoring import test
from training.training import train
from tqdm import tqdm


# Set seed for all the environments
SEED = 1
torch.manual_seed(SEED)
CUDA = torch.cuda.is_available()
print("CUDA is available:",CUDA)
# If CUDA is available the set SEED for it
if CUDA:
    torch.cuda.manual_seed(SEED)
device = torch.device("cuda" if CUDA else "cpu")
print(device)

# Load the model
model = ResNet18().to(device)
summary(model, input_size=(3, 32, 32))

train_losses = []
test_losses = []
train_acc = []
test_acc = []

train_loader = get_train_loader()
test_loader = get_test_loader()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
EPOCHS = 30
for epoch in range(EPOCHS):
    print("EPOCH:", epoch)
    train(model, device, train_loader, optimizer, epoch,train_losses,train_acc)
    test(model, device, test_loader,test_losses,test_acc)

CUDA is available: True
cuda
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,728
       BatchNorm2d-2           [-1, 64, 32, 32]             128
            Conv2d-3           [-1, 64, 32, 32]          36,864
       BatchNorm2d-4           [-1, 64, 32, 32]             128
            Conv2d-5           [-1, 64, 32, 32]          36,864
       BatchNorm2d-6           [-1, 64, 32, 32]             128
        BasicBlock-7           [-1, 64, 32, 32]               0
            Conv2d-8           [-1, 64, 32, 32]          36,864
       BatchNorm2d-9           [-1, 64, 32, 32]             128
           Conv2d-10           [-1, 64, 32, 32]          36,864
      BatchNorm2d-11           [-1, 64, 32, 32]             128
       BasicBlock-12           [-1, 64, 32, 32]               0
           Conv2d-13          [-1, 128, 16, 16]          73,728
      Batc

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


  0%|          | 0/782 [00:00<?, ?it/s]

EPOCH: 0


Loss=0.7651723623275757 Batch_id=781 Accuracy=52.77: 100%|██████████| 782/782 [01:07<00:00, 11.54it/s]
Loss=0.6828662157058716 Batch_id=1 Accuracy=70.31:   0%|          | 2/782 [00:00<01:08, 11.42it/s]


Test set: Average loss: -4.1966, Accuracy: 6540/10000 (65.40%)

EPOCH: 1


Loss=0.4453192949295044 Batch_id=781 Accuracy=73.51: 100%|██████████| 782/782 [01:09<00:00, 12.58it/s]
Loss=0.43890419602394104 Batch_id=1 Accuracy=81.25:   0%|          | 1/782 [00:00<01:18,  9.95it/s]


Test set: Average loss: -5.8825, Accuracy: 7458/10000 (74.58%)

EPOCH: 2


Loss=0.3932861089706421 Batch_id=781 Accuracy=81.17: 100%|██████████| 782/782 [01:09<00:00, 11.17it/s]
Loss=0.36898332834243774 Batch_id=1 Accuracy=82.03:   0%|          | 2/782 [00:00<01:09, 11.27it/s]


Test set: Average loss: -7.0529, Accuracy: 7971/10000 (79.71%)

EPOCH: 3


Loss=0.6923843622207642 Batch_id=781 Accuracy=86.11: 100%|██████████| 782/782 [01:10<00:00, 12.32it/s]
Loss=0.24753251671791077 Batch_id=1 Accuracy=89.84:   0%|          | 2/782 [00:00<01:12, 10.75it/s]


Test set: Average loss: -8.7118, Accuracy: 8121/10000 (81.21%)

EPOCH: 4


Loss=0.3302212655544281 Batch_id=781 Accuracy=89.82: 100%|██████████| 782/782 [01:11<00:00, 12.15it/s]
Loss=0.32324323058128357 Batch_id=1 Accuracy=92.19:   0%|          | 2/782 [00:00<01:13, 10.64it/s]


Test set: Average loss: -9.6295, Accuracy: 7963/10000 (79.63%)

EPOCH: 5


Loss=0.4221039116382599 Batch_id=781 Accuracy=92.85: 100%|██████████| 782/782 [01:11<00:00, 12.17it/s]
Loss=0.20550081133842468 Batch_id=1 Accuracy=95.31:   0%|          | 2/782 [00:00<01:11, 10.98it/s]


Test set: Average loss: -10.5382, Accuracy: 7888/10000 (78.88%)

EPOCH: 6


Loss=0.15630605816841125 Batch_id=781 Accuracy=94.81: 100%|██████████| 782/782 [01:11<00:00, 12.13it/s]
Loss=0.08284130692481995 Batch_id=1 Accuracy=96.88:   0%|          | 2/782 [00:00<01:10, 11.00it/s]


Test set: Average loss: -12.4249, Accuracy: 8005/10000 (80.05%)

EPOCH: 7


Loss=0.004976332187652588 Batch_id=781 Accuracy=96.52: 100%|██████████| 782/782 [01:11<00:00, 12.17it/s]
Loss=0.07951920479536057 Batch_id=1 Accuracy=97.66:   0%|          | 2/782 [00:00<01:12, 10.81it/s]


Test set: Average loss: -13.0317, Accuracy: 8130/10000 (81.30%)

EPOCH: 8


Loss=0.0019849538803100586 Batch_id=781 Accuracy=97.28: 100%|██████████| 782/782 [01:11<00:00, 12.12it/s]
Loss=0.05859377235174179 Batch_id=1 Accuracy=98.44:   0%|          | 2/782 [00:00<01:13, 10.60it/s]


Test set: Average loss: -13.6039, Accuracy: 8096/10000 (80.96%)

EPOCH: 9


Loss=0.028850644826889038 Batch_id=781 Accuracy=98.29: 100%|██████████| 782/782 [01:11<00:00, 12.14it/s]
Loss=0.03656589239835739 Batch_id=1 Accuracy=99.22:   0%|          | 2/782 [00:00<01:13, 10.55it/s]


Test set: Average loss: -14.3536, Accuracy: 8165/10000 (81.65%)

EPOCH: 10


Loss=0.03949004411697388 Batch_id=781 Accuracy=98.72: 100%|██████████| 782/782 [01:12<00:00, 12.15it/s]
Loss=0.03258052468299866 Batch_id=1 Accuracy=96.09:   0%|          | 2/782 [00:00<01:11, 10.86it/s]


Test set: Average loss: -15.5460, Accuracy: 8257/10000 (82.57%)

EPOCH: 11


Loss=0.036685049533843994 Batch_id=781 Accuracy=99.00: 100%|██████████| 782/782 [01:12<00:00, 12.06it/s]
Loss=0.005033180117607117 Batch_id=1 Accuracy=100.00:   0%|          | 2/782 [00:00<01:14, 10.50it/s]


Test set: Average loss: -16.1035, Accuracy: 8269/10000 (82.69%)

EPOCH: 12


Loss=0.01838773488998413 Batch_id=781 Accuracy=99.24: 100%|██████████| 782/782 [01:11<00:00, 12.08it/s]
Loss=0.004500873386859894 Batch_id=1 Accuracy=99.22:   0%|          | 2/782 [00:00<01:11, 10.87it/s]


Test set: Average loss: -16.9845, Accuracy: 8324/10000 (83.24%)

EPOCH: 13


Loss=0.08563518524169922 Batch_id=781 Accuracy=98.92: 100%|██████████| 782/782 [01:11<00:00, 12.19it/s]
Loss=0.03137633949518204 Batch_id=1 Accuracy=99.22:   0%|          | 2/782 [00:00<01:13, 10.57it/s]


Test set: Average loss: -16.6247, Accuracy: 8135/10000 (81.35%)

EPOCH: 14


Loss=0.03539919853210449 Batch_id=781 Accuracy=99.16: 100%|██████████| 782/782 [01:11<00:00, 12.08it/s]
Loss=0.06481960415840149 Batch_id=1 Accuracy=99.22:   0%|          | 2/782 [00:00<01:11, 10.90it/s]


Test set: Average loss: -16.9296, Accuracy: 8296/10000 (82.96%)

EPOCH: 15


Loss=0.001752018928527832 Batch_id=781 Accuracy=99.40: 100%|██████████| 782/782 [01:12<00:00, 12.11it/s]
Loss=0.01994071900844574 Batch_id=1 Accuracy=99.22:   0%|          | 2/782 [00:00<01:12, 10.73it/s]


Test set: Average loss: -17.9115, Accuracy: 8397/10000 (83.97%)

EPOCH: 16


Loss=0.0916474461555481 Batch_id=781 Accuracy=99.56: 100%|██████████| 782/782 [01:11<00:00, 12.17it/s]
Loss=0.05051082372665405 Batch_id=1 Accuracy=99.22:   0%|          | 2/782 [00:00<01:11, 10.93it/s]


Test set: Average loss: -17.5583, Accuracy: 8373/10000 (83.73%)

EPOCH: 17


Loss=0.005097866058349609 Batch_id=781 Accuracy=99.54: 100%|██████████| 782/782 [01:11<00:00, 12.22it/s]
Loss=0.0006350874900817871 Batch_id=1 Accuracy=100.00:   0%|          | 2/782 [00:00<01:13, 10.62it/s]


Test set: Average loss: -18.2667, Accuracy: 8398/10000 (83.98%)

EPOCH: 18


Loss=9.781122207641602e-05 Batch_id=781 Accuracy=99.89: 100%|██████████| 782/782 [01:11<00:00, 12.14it/s]
Loss=0.000662490725517273 Batch_id=1 Accuracy=100.00:   0%|          | 2/782 [00:00<01:11, 10.85it/s]


Test set: Average loss: -18.6492, Accuracy: 8504/10000 (85.04%)

EPOCH: 19


Loss=1.245737075805664e-05 Batch_id=781 Accuracy=99.92: 100%|██████████| 782/782 [01:11<00:00, 12.04it/s]
Loss=0.004414603114128113 Batch_id=1 Accuracy=100.00:   0%|          | 2/782 [00:00<01:14, 10.48it/s]


Test set: Average loss: -19.1195, Accuracy: 8541/10000 (85.41%)

EPOCH: 20


Loss=0.000814974308013916 Batch_id=781 Accuracy=99.96: 100%|██████████| 782/782 [01:11<00:00, 12.23it/s]
Loss=0.00012071430683135986 Batch_id=1 Accuracy=100.00:   0%|          | 2/782 [00:00<01:11, 10.93it/s]


Test set: Average loss: -18.9986, Accuracy: 8499/10000 (84.99%)

EPOCH: 21


Loss=0.0065310001373291016 Batch_id=781 Accuracy=99.96: 100%|██████████| 782/782 [01:11<00:00, 11.97it/s]
Loss=0.0001156926155090332 Batch_id=1 Accuracy=100.00:   0%|          | 2/782 [00:00<01:12, 10.69it/s]


Test set: Average loss: -19.3497, Accuracy: 8594/10000 (85.94%)

EPOCH: 22


Loss=8.32676887512207e-05 Batch_id=781 Accuracy=99.99: 100%|██████████| 782/782 [01:11<00:00, 12.14it/s]
Loss=9.767711162567139e-05 Batch_id=1 Accuracy=100.00:   0%|          | 2/782 [00:00<01:11, 10.98it/s]


Test set: Average loss: -19.1351, Accuracy: 8595/10000 (85.95%)

EPOCH: 23


Loss=0.0018911957740783691 Batch_id=781 Accuracy=100.00: 100%|██████████| 782/782 [01:12<00:00, 12.08it/s]
Loss=4.889070987701416e-05 Batch_id=1 Accuracy=100.00:   0%|          | 2/782 [00:00<01:11, 10.85it/s]


Test set: Average loss: -19.3743, Accuracy: 8630/10000 (86.30%)

EPOCH: 24


Loss=1.8537044525146484e-05 Batch_id=781 Accuracy=100.00: 100%|██████████| 782/782 [01:11<00:00, 12.08it/s]
Loss=0.0001688152551651001 Batch_id=1 Accuracy=100.00:   0%|          | 2/782 [00:00<01:10, 11.00it/s]


Test set: Average loss: -19.3740, Accuracy: 8612/10000 (86.12%)

EPOCH: 25


Loss=0.004179775714874268 Batch_id=781 Accuracy=100.00: 100%|██████████| 782/782 [01:12<00:00, 12.11it/s]
Loss=0.000188484787940979 Batch_id=1 Accuracy=100.00:   0%|          | 2/782 [00:00<01:10, 11.08it/s]


Test set: Average loss: -19.5762, Accuracy: 8622/10000 (86.22%)

EPOCH: 26


Loss=7.772445678710938e-05 Batch_id=781 Accuracy=100.00: 100%|██████████| 782/782 [01:12<00:00, 12.10it/s]
Loss=0.0002073347568511963 Batch_id=1 Accuracy=100.00:   0%|          | 2/782 [00:00<01:11, 10.98it/s]


Test set: Average loss: -19.6248, Accuracy: 8599/10000 (85.99%)

EPOCH: 27


Loss=0.024475514888763428 Batch_id=781 Accuracy=100.00: 100%|██████████| 782/782 [01:11<00:00, 12.09it/s]
Loss=6.133317947387695e-05 Batch_id=1 Accuracy=100.00:   0%|          | 2/782 [00:00<01:11, 10.98it/s]


Test set: Average loss: -19.8531, Accuracy: 8610/10000 (86.10%)

EPOCH: 28


Loss=0.0015168190002441406 Batch_id=781 Accuracy=99.91: 100%|██████████| 782/782 [01:11<00:00, 12.17it/s]
Loss=0.00040866434574127197 Batch_id=1 Accuracy=100.00:   0%|          | 2/782 [00:00<01:12, 10.80it/s]


Test set: Average loss: -19.3234, Accuracy: 8506/10000 (85.06%)

EPOCH: 29


Loss=0.0006255507469177246 Batch_id=781 Accuracy=99.96: 100%|██████████| 782/782 [01:11<00:00, 12.18it/s]



Test set: Average loss: -19.5128, Accuracy: 8577/10000 (85.77%)

