In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from prettytable import PrettyTable

import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision.datasets import CIFAR10
import numpy as np
from torch.utils.data import DataLoader

In [27]:

# 1. Load the CIFAR-10 training dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True)
data = train_dataset.data

data = data / 255.0

mean = np.mean(data, axis=(0, 1, 2))

std = np.std(data, axis=(0, 1, 2))

# Print the results
print(f"Calculated Mean: {mean}")
print(f"Calculated Std Dev: {std}")

Calculated Mean: [0.49139968 0.48215841 0.44653091]
Calculated Std Dev: [0.24703223 0.24348513 0.26158784]


In [28]:
mean = [0.4914, 0.4822, 0.4465]
std = [0.2470, 0.2435, 0.2616]

train_transforms = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.1),
    A.CoarseDropout(max_holes=1, max_height=16, max_width=16, min_holes=1, min_height=16, min_width=16, fill_value=mean, mask_fill_value=None, p=0.1),
    A.Normalize(mean=mean, std=std),
    ToTensorV2()
])

test_transforms = A.Compose([
    A.Normalize(mean=mean, std=std),
    ToTensorV2(),
])



class AlbumentationDataset(CIFAR10):
    def __init__(self, root="./data", train=True, download=True, transform=None):
        super().__init__(root=root, train=train, download=download, transform=transform)

    def __getitem__(self, index):
        image, label = self.data[index], self.targets[index]

        if self.transform is not None:
            transformed = self.transform(image=image)
            image = transformed["image"]

        return image, label

train_dataset = AlbumentationDataset(root='./data', train=True, download=True, transform=train_transforms)
test_dataset = AlbumentationDataset(root='./data', train=False, download=True, transform=test_transforms)



  original_init(self, **validated_kwargs)
  A.CoarseDropout(max_holes=1, max_height=16, max_width=16, min_holes=1, min_height=16, min_width=16, fill_value=mean, mask_fill_value=None, p=0.1),


In [29]:
SEED = 2

# CUDA?
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
cuda = torch.cuda.is_available()
print("GPU Available?", device)

# For reproducibility
torch.manual_seed(SEED)

if device == "cuda":
    torch.cuda.manual_seed(SEED)

# dataloader arguments - something you'll fetch these from cmdprmt
dataloader_args = dict(shuffle=True, batch_size=128, num_workers=4, pin_memory=True) if cuda else dict(shuffle=True, batch_size=128)

# train dataloader
train_loader = torch.utils.data.DataLoader(train_dataset, **dataloader_args)

# test dataloader
test_loader = torch.utils.data.DataLoader(test_dataset, **dataloader_args)

# Pretty table for collecting all the accuracy and loss parameters in a table
log_table = PrettyTable()


GPU Available? cuda




In [30]:
import torch.nn as nn

DROPOUT_PROB = 0.1

class CIFAR10Net(nn.Module):
    def __init__(self):
        super(CIFAR10Net, self).__init__()

        # BLOCK 1
        # Input: 32x32x3 | Output: 32x32x32 | RF: 5
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1, bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Dropout(DROPOUT_PROB),

            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1, bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Dropout(DROPOUT_PROB)
        )

        # BLOCK 2
        # Input: 32x32x32 | Output: 16x16x32 | RF: 7
        self.conv_block2 = nn.Sequential(
             nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, bias=False),
             nn.ReLU(),
             nn.BatchNorm2d(32)
        )

        # BLOCK 3
        # Input: 16x16x32 | Output: 16x16x64 | RF: 31
        self.conv_block3 = nn.Sequential(
            # Depthwise Separable Convolution
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, groups=32, bias=False), # Depthwise
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=1, bias=False), # Pointwise
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Dropout(DROPOUT_PROB),

            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=2, dilation=2, bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Dropout(DROPOUT_PROB),
        )

        # BLOCK 4
        # Input: 16x16x64 | Output: 8x8x128 | RF: 39
        self.conv_block4 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1, bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(128)
        )

        # OUTPUT BLOCK
        # Input: 8x8x128 | Output: 1x1x10 | RF: 55
        self.output_block = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), # GAP layer
            nn.Conv2d(in_channels=128, out_channels=10, kernel_size=1, bias=False)
        )

    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = self.conv_block4(x)
        x = self.output_block(x)
        x = x.view(-1, 10)
        return x



In [31]:
from torchsummary import summary
use_cuda = torch.cuda.is_available()
cuda = torch.device("cuda" if use_cuda else "cpu")
print(cuda)
model = CIFAR10Net().to(cuda)
summary(model, input_size=(3, 32, 32))


cuda
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             432
              ReLU-2           [-1, 16, 32, 32]               0
       BatchNorm2d-3           [-1, 16, 32, 32]              32
           Dropout-4           [-1, 16, 32, 32]               0
            Conv2d-5           [-1, 32, 32, 32]           4,608
              ReLU-6           [-1, 32, 32, 32]               0
       BatchNorm2d-7           [-1, 32, 32, 32]              64
           Dropout-8           [-1, 32, 32, 32]               0
            Conv2d-9           [-1, 32, 16, 16]           9,216
             ReLU-10           [-1, 32, 16, 16]               0
      BatchNorm2d-11           [-1, 32, 16, 16]              64
           Conv2d-12           [-1, 32, 16, 16]             288
           Conv2d-13           [-1, 64, 16, 16]           2,048
             ReLU-14           [-1

In [32]:
from tqdm import tqdm

train_losses = []
test_losses = []
train_acc = []
test_acc = []

def train(model, device, train_loader, optimizer, epoch):
  model.train()
  pbar = tqdm(train_loader)
  correct = 0
  processed = 0
  for batch_idx, (data, target) in enumerate(pbar):
    # get samples
    data, target = data.to(device), target.to(device)

    # Init
    optimizer.zero_grad()
    # In PyTorch, we need to set the gradients to zero before starting to do backpropragation because PyTorch accumulates the gradients on subsequent backward passes.
    # Because of this, when you start your training loop, ideally you should zero out the gradients so that you do the parameter update correctly.

    # Predict
    y_pred = model(data)

    # Calculate loss
    criterion = nn.CrossEntropyLoss()
    loss = criterion(y_pred, target)
    #loss = F.nll_loss(y_pred, target)
    train_losses.append(loss)

    # Backpropagation
    loss.backward()
    optimizer.step()

    # Update pbar-tqdm

    pred = y_pred.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
    correct += pred.eq(target.view_as(pred)).sum().item()
    processed += len(data)

    pbar.set_description(desc= f'Loss={loss.item()} Batch_id={batch_idx} Accuracy={100*correct/processed:0.2f}')
    train_acc.append(100*correct/processed)

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    criterion = nn.CrossEntropyLoss(reduction='sum')

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            #test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

    test_acc.append(100. * correct / len(test_loader.dataset))


In [33]:
from torch.optim.lr_scheduler import StepLR
from prettytable import PrettyTable, MSWORD_FRIENDLY, MARKDOWN

print("model running on: ", device)
log_table = PrettyTable()
log_table.field_names = ["Epoch", "Training Accuracy", "Test Accuracy", "Diff", "Training Loss", "Test Loss"]

model =  CIFAR10Net().to(device)
#optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam(model.parameters(), lr=0.003)
scheduler = StepLR(optimizer, step_size=6, gamma=0.1)

EPOCHS = 50
for epoch in range(EPOCHS):
    print("EPOCH:", epoch+1)
    train(model, device, train_loader, optimizer, epoch)
    #scheduler.step()
    test(model, device, test_loader)
    log_table.add_row([epoch+1, f"{train_acc[-1]:.2f}%", f"{test_acc[-1]:.2f}%", f"{float(train_acc[-1]) - float(test_acc[-1]):.2f}" ,f"{train_losses[-1]:.4f}", f"{test_losses[-1]:.4f}"])
log_table.set_style(MARKDOWN)
print(log_table)


  from prettytable import PrettyTable, MSWORD_FRIENDLY, MARKDOWN
  from prettytable import PrettyTable, MSWORD_FRIENDLY, MARKDOWN


model running on:  cuda
EPOCH: 1


Loss=1.223803162574768 Batch_id=390 Accuracy=49.58: 100%|██████████| 391/391 [00:10<00:00, 35.99it/s]



Test set: Average loss: 1.0870, Accuracy: 6057/10000 (60.57%)

EPOCH: 2


Loss=0.9151042699813843 Batch_id=390 Accuracy=64.34: 100%|██████████| 391/391 [00:10<00:00, 38.26it/s]



Test set: Average loss: 0.9155, Accuracy: 6800/10000 (68.00%)

EPOCH: 3


Loss=0.9748958349227905 Batch_id=390 Accuracy=70.18: 100%|██████████| 391/391 [00:09<00:00, 39.12it/s]



Test set: Average loss: 0.7707, Accuracy: 7281/10000 (72.81%)

EPOCH: 4


Loss=0.7730000019073486 Batch_id=390 Accuracy=73.57: 100%|██████████| 391/391 [00:09<00:00, 39.82it/s]



Test set: Average loss: 0.7294, Accuracy: 7477/10000 (74.77%)

EPOCH: 5


Loss=0.5815104246139526 Batch_id=390 Accuracy=76.14: 100%|██████████| 391/391 [00:10<00:00, 38.56it/s]



Test set: Average loss: 0.6831, Accuracy: 7625/10000 (76.25%)

EPOCH: 6


Loss=0.8225975036621094 Batch_id=390 Accuracy=77.98: 100%|██████████| 391/391 [00:10<00:00, 36.17it/s]



Test set: Average loss: 0.6047, Accuracy: 7890/10000 (78.90%)

EPOCH: 7


Loss=0.7528146505355835 Batch_id=390 Accuracy=79.31: 100%|██████████| 391/391 [00:10<00:00, 36.41it/s]



Test set: Average loss: 0.5960, Accuracy: 7940/10000 (79.40%)

EPOCH: 8


Loss=0.6128249168395996 Batch_id=390 Accuracy=80.24: 100%|██████████| 391/391 [00:10<00:00, 36.55it/s]



Test set: Average loss: 0.5838, Accuracy: 7988/10000 (79.88%)

EPOCH: 9


Loss=0.6028963923454285 Batch_id=390 Accuracy=81.17: 100%|██████████| 391/391 [00:10<00:00, 37.53it/s]



Test set: Average loss: 0.5375, Accuracy: 8129/10000 (81.29%)

EPOCH: 10


Loss=0.39518359303474426 Batch_id=390 Accuracy=81.67: 100%|██████████| 391/391 [00:10<00:00, 37.54it/s]



Test set: Average loss: 0.5178, Accuracy: 8205/10000 (82.05%)

EPOCH: 11


Loss=0.5978726744651794 Batch_id=390 Accuracy=82.56: 100%|██████████| 391/391 [00:10<00:00, 36.65it/s]



Test set: Average loss: 0.5259, Accuracy: 8210/10000 (82.10%)

EPOCH: 12


Loss=0.35919463634490967 Batch_id=390 Accuracy=82.93: 100%|██████████| 391/391 [00:10<00:00, 38.06it/s]



Test set: Average loss: 0.5172, Accuracy: 8242/10000 (82.42%)

EPOCH: 13


Loss=0.5130105018615723 Batch_id=390 Accuracy=83.33: 100%|██████████| 391/391 [00:09<00:00, 41.52it/s]



Test set: Average loss: 0.5061, Accuracy: 8281/10000 (82.81%)

EPOCH: 14


Loss=0.6213909983634949 Batch_id=390 Accuracy=83.77: 100%|██████████| 391/391 [00:09<00:00, 39.94it/s]



Test set: Average loss: 0.4939, Accuracy: 8323/10000 (83.23%)

EPOCH: 15


Loss=0.5182145833969116 Batch_id=390 Accuracy=84.30: 100%|██████████| 391/391 [00:11<00:00, 34.71it/s]



Test set: Average loss: 0.4951, Accuracy: 8326/10000 (83.26%)

EPOCH: 16


Loss=0.39655908942222595 Batch_id=390 Accuracy=84.60: 100%|██████████| 391/391 [00:10<00:00, 37.67it/s]



Test set: Average loss: 0.5275, Accuracy: 8251/10000 (82.51%)

EPOCH: 17


Loss=0.38757190108299255 Batch_id=390 Accuracy=84.96: 100%|██████████| 391/391 [00:10<00:00, 37.13it/s]



Test set: Average loss: 0.4907, Accuracy: 8337/10000 (83.37%)

EPOCH: 18


Loss=0.4423726201057434 Batch_id=390 Accuracy=85.35: 100%|██████████| 391/391 [00:10<00:00, 37.64it/s]



Test set: Average loss: 0.5149, Accuracy: 8290/10000 (82.90%)

EPOCH: 19


Loss=0.41987913846969604 Batch_id=390 Accuracy=85.49: 100%|██████████| 391/391 [00:10<00:00, 36.65it/s]



Test set: Average loss: 0.5059, Accuracy: 8301/10000 (83.01%)

EPOCH: 20


Loss=0.39618009328842163 Batch_id=390 Accuracy=85.84: 100%|██████████| 391/391 [00:10<00:00, 37.83it/s]



Test set: Average loss: 0.4948, Accuracy: 8393/10000 (83.93%)

EPOCH: 21


Loss=0.3927505910396576 Batch_id=390 Accuracy=86.06: 100%|██████████| 391/391 [00:10<00:00, 38.15it/s]



Test set: Average loss: 0.4611, Accuracy: 8432/10000 (84.32%)

EPOCH: 22


Loss=0.5805143117904663 Batch_id=390 Accuracy=86.19: 100%|██████████| 391/391 [00:09<00:00, 40.28it/s]



Test set: Average loss: 0.4875, Accuracy: 8366/10000 (83.66%)

EPOCH: 23


Loss=0.4517667293548584 Batch_id=390 Accuracy=86.54: 100%|██████████| 391/391 [00:09<00:00, 41.39it/s]



Test set: Average loss: 0.4880, Accuracy: 8353/10000 (83.53%)

EPOCH: 24


Loss=0.25974878668785095 Batch_id=390 Accuracy=86.57: 100%|██████████| 391/391 [00:10<00:00, 38.62it/s]



Test set: Average loss: 0.4722, Accuracy: 8426/10000 (84.26%)

EPOCH: 25


Loss=0.35850661993026733 Batch_id=390 Accuracy=87.00: 100%|██████████| 391/391 [00:10<00:00, 37.31it/s]



Test set: Average loss: 0.4892, Accuracy: 8410/10000 (84.10%)

EPOCH: 26


Loss=0.21230335533618927 Batch_id=390 Accuracy=87.38: 100%|██████████| 391/391 [00:10<00:00, 37.76it/s]



Test set: Average loss: 0.4682, Accuracy: 8443/10000 (84.43%)

EPOCH: 27


Loss=0.6947134137153625 Batch_id=390 Accuracy=87.15: 100%|██████████| 391/391 [00:10<00:00, 36.80it/s]



Test set: Average loss: 0.4675, Accuracy: 8446/10000 (84.46%)

EPOCH: 28


Loss=0.4665605425834656 Batch_id=390 Accuracy=87.43: 100%|██████████| 391/391 [00:10<00:00, 37.57it/s]



Test set: Average loss: 0.4910, Accuracy: 8390/10000 (83.90%)

EPOCH: 29


Loss=0.29158759117126465 Batch_id=390 Accuracy=87.71: 100%|██████████| 391/391 [00:10<00:00, 37.67it/s]



Test set: Average loss: 0.4846, Accuracy: 8465/10000 (84.65%)

EPOCH: 30


Loss=0.41242653131484985 Batch_id=390 Accuracy=87.79: 100%|██████████| 391/391 [00:09<00:00, 40.69it/s]



Test set: Average loss: 0.4545, Accuracy: 8496/10000 (84.96%)

EPOCH: 31


Loss=0.680749773979187 Batch_id=390 Accuracy=88.04: 100%|██████████| 391/391 [00:09<00:00, 41.04it/s]



Test set: Average loss: 0.4524, Accuracy: 8498/10000 (84.98%)

EPOCH: 32


Loss=0.31119805574417114 Batch_id=390 Accuracy=88.19: 100%|██████████| 391/391 [00:10<00:00, 37.90it/s]



Test set: Average loss: 0.4653, Accuracy: 8451/10000 (84.51%)

EPOCH: 33


Loss=0.41903629899024963 Batch_id=390 Accuracy=88.15: 100%|██████████| 391/391 [00:10<00:00, 37.81it/s]



Test set: Average loss: 0.4791, Accuracy: 8454/10000 (84.54%)

EPOCH: 34


Loss=0.2843922972679138 Batch_id=390 Accuracy=88.12: 100%|██████████| 391/391 [00:10<00:00, 37.73it/s]



Test set: Average loss: 0.4733, Accuracy: 8434/10000 (84.34%)

EPOCH: 35


Loss=0.5898496508598328 Batch_id=390 Accuracy=88.57: 100%|██████████| 391/391 [00:10<00:00, 37.27it/s]



Test set: Average loss: 0.4586, Accuracy: 8519/10000 (85.19%)

EPOCH: 36


Loss=0.37316054105758667 Batch_id=390 Accuracy=88.56: 100%|██████████| 391/391 [00:10<00:00, 37.54it/s]



Test set: Average loss: 0.4617, Accuracy: 8454/10000 (84.54%)

EPOCH: 37


Loss=0.4005200266838074 Batch_id=390 Accuracy=88.73: 100%|██████████| 391/391 [00:09<00:00, 39.19it/s]



Test set: Average loss: 0.4684, Accuracy: 8465/10000 (84.65%)

EPOCH: 38


Loss=0.4067765772342682 Batch_id=390 Accuracy=88.92: 100%|██████████| 391/391 [00:09<00:00, 41.53it/s]



Test set: Average loss: 0.4705, Accuracy: 8460/10000 (84.60%)

EPOCH: 39


Loss=0.3579431474208832 Batch_id=390 Accuracy=88.95: 100%|██████████| 391/391 [00:09<00:00, 40.08it/s]



Test set: Average loss: 0.4563, Accuracy: 8525/10000 (85.25%)

EPOCH: 40


Loss=0.3399641513824463 Batch_id=390 Accuracy=88.97: 100%|██████████| 391/391 [00:10<00:00, 37.37it/s]



Test set: Average loss: 0.4537, Accuracy: 8530/10000 (85.30%)

EPOCH: 41


Loss=0.23141101002693176 Batch_id=390 Accuracy=89.03: 100%|██████████| 391/391 [00:11<00:00, 34.96it/s]



Test set: Average loss: 0.4453, Accuracy: 8561/10000 (85.61%)

EPOCH: 42


Loss=0.4962076246738434 Batch_id=390 Accuracy=89.54: 100%|██████████| 391/391 [00:10<00:00, 37.89it/s]



Test set: Average loss: 0.4596, Accuracy: 8525/10000 (85.25%)

EPOCH: 43


Loss=0.3248372972011566 Batch_id=390 Accuracy=89.08: 100%|██████████| 391/391 [00:10<00:00, 36.79it/s]



Test set: Average loss: 0.4859, Accuracy: 8424/10000 (84.24%)

EPOCH: 44


Loss=0.3755224347114563 Batch_id=390 Accuracy=89.52: 100%|██████████| 391/391 [00:10<00:00, 37.72it/s]



Test set: Average loss: 0.4852, Accuracy: 8482/10000 (84.82%)

EPOCH: 45


Loss=0.3703649342060089 Batch_id=390 Accuracy=89.61: 100%|██████████| 391/391 [00:09<00:00, 39.19it/s]



Test set: Average loss: 0.4560, Accuracy: 8540/10000 (85.40%)

EPOCH: 46


Loss=0.25547662377357483 Batch_id=390 Accuracy=89.44: 100%|██████████| 391/391 [00:09<00:00, 41.62it/s]



Test set: Average loss: 0.4606, Accuracy: 8487/10000 (84.87%)

EPOCH: 47


Loss=0.2557060122489929 Batch_id=390 Accuracy=89.75: 100%|██████████| 391/391 [00:09<00:00, 39.69it/s]



Test set: Average loss: 0.4732, Accuracy: 8469/10000 (84.69%)

EPOCH: 48


Loss=0.3020472824573517 Batch_id=390 Accuracy=89.71: 100%|██████████| 391/391 [00:10<00:00, 36.86it/s]



Test set: Average loss: 0.4761, Accuracy: 8478/10000 (84.78%)

EPOCH: 49


Loss=0.3766241669654846 Batch_id=390 Accuracy=89.77: 100%|██████████| 391/391 [00:10<00:00, 37.62it/s]



Test set: Average loss: 0.4630, Accuracy: 8540/10000 (85.40%)

EPOCH: 50


Loss=0.41594618558883667 Batch_id=390 Accuracy=90.05: 100%|██████████| 391/391 [00:10<00:00, 35.83it/s]



Test set: Average loss: 0.4559, Accuracy: 8538/10000 (85.38%)

| Epoch | Training Accuracy | Test Accuracy |  Diff  | Training Loss | Test Loss |
| :---: | :---------------: | :-----------: | :----: | :-----------: | :-------: |
|   1   |       49.58%      |     60.57%    | -10.99 |     1.2238    |   1.0870  |
|   2   |       64.34%      |     68.00%    | -3.66  |     0.9151    |   0.9155  |
|   3   |       70.18%      |     72.81%    | -2.63  |     0.9749    |   0.7707  |
|   4   |       73.57%      |     74.77%    | -1.20  |     0.7730    |   0.7294  |
|   5   |       76.14%      |     76.25%    | -0.11  |     0.5815    |   0.6831  |
|   6   |       77.98%      |     78.90%    | -0.92  |     0.8226    |   0.6047  |
|   7   |       79.31%      |     79.40%    | -0.09  |     0.7528    |   0.5960  |
|   8   |       80.24%      |     79.88%    |  0.36  |     0.6128    |   0.5838  |
|   9   |       81.17%      |     81.29%    | -0.12  |     0.6029    |   0.5375  |
|   10  |       81.67% 