In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

In [2]:
simple_transform=transforms.Compose([transforms.ToTensor()])
exp_dataset=datasets.CIFAR10('./data',train=True,download=True,transform=simple_transform)

Files already downloaded and verified


In [3]:
train_loader = torch.utils.data.DataLoader(exp_dataset, batch_size=len(exp_dataset), shuffle=False)
data = next(iter(train_loader))

# Calculate mean and std per channel
mean = data[0].mean(dim=(0, 2, 3))  # Calculate mean across all images and height, width dimensions
std = data[0].std(dim=(0, 2, 3))    # Calculate std across all images and height, width dimensions

print("Mean per channel:", mean)
print("Std per channel:", std)

Mean per channel: tensor([0.4914, 0.4822, 0.4465])
Std per channel: tensor([0.2470, 0.2435, 0.2616])


In [4]:
import cv2
import torchvision

cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)


class Cifar10SearchDataset(torchvision.datasets.CIFAR10):
    def __init__(self, root="~/data/cifar10", train=True, download=True, transform=None):
        super().__init__(root=root, train=train, download=download, transform=transform)

    def __getitem__(self, index):
        image, label = self.data[index], self.targets[index]

        if self.transform is not None:
            transformed = self.transform(image=image)
            image = transformed["image"]

        return image, label
train_transform=A.Compose([
                            A.Normalize((0.4914, 0.4822, 0.4465),(0.2470, 0.2435, 0.2616)),
                           A.HorizontalFlip(),
                           A.ShiftScaleRotate (shift_limit=0.0625, scale_limit=0.1, rotate_limit=45, interpolation=1, border_mode=4, value=None, mask_value=None, shift_limit_x=None, shift_limit_y=None, rotate_method='largest_box', always_apply=False, p=0.5),
                           A.CoarseDropout(max_holes = 1, max_height=16, max_width=16, min_holes = 1, min_height=16, min_width=16, fill_value=(0.4914, 0.4822, 0.4465), mask_fill_value=None),
                            ToTensorV2()
                           
                                   ])
test_transform=A.Compose([A.Normalize((0.4914, 0.4822, 0.4465),(0.2470, 0.2435, 0.2616)),
                          ToTensorV2()
                                   ])
train_data=Cifar10SearchDataset('./data',train=True,download=True,transform=train_transform)
test_data=Cifar10SearchDataset('./data',train=False,download=True,transform=test_transform)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
SEED = 1

# CUDA?
cuda = torch.cuda.is_available()
print("CUDA Available?", cuda)

# For reproducibility
torch.manual_seed(SEED)

if cuda:
    torch.cuda.manual_seed(SEED)

# dataloader arguments - something you'll fetch these from cmdprmt
dataloader_args = dict(shuffle=True, batch_size=128, num_workers=4, pin_memory=True) if cuda else dict(shuffle=True, batch_size=64)

# train dataloader
train_loader = torch.utils.data.DataLoader(train_data, **dataloader_args)

# test dataloader
test_loader = torch.utils.data.DataLoader(test_data, **dataloader_args)

CUDA Available? False


In [6]:
from model import Net
!pip install torchsummary
from torchsummary import summary
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)
model = Net().to(device)
summary(model, input_size=(3, 32, 32))

cpu
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 30, 30]             864
       BatchNorm2d-2           [-1, 32, 30, 30]              64
              ReLU-3           [-1, 32, 30, 30]               0
            Conv2d-4           [-1, 64, 15, 15]          18,432
       BatchNorm2d-5           [-1, 64, 15, 15]             128
              ReLU-6           [-1, 64, 15, 15]               0
            Conv2d-7           [-1, 32, 15, 15]           2,048
            Conv2d-8           [-1, 32, 15, 15]           9,216
       BatchNorm2d-9           [-1, 32, 15, 15]              64
             ReLU-10           [-1, 32, 15, 15]               0
           Conv2d-11           [-1, 32, 15, 15]           9,216
      BatchNorm2d-12           [-1, 32, 15, 15]              64
             ReLU-13           [-1, 32, 15, 15]               0
           Conv2d-14             [-

In [9]:
from train import train,test
device = torch.device("mps")
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
train_losses = []
test_losses = []
train_acc = []
test_acc = []
for epoch in range(1, 41):
    print(epoch)
    train_loss,train_accuracy=train(model, device, train_loader, optimizer, epoch)
    test_loss,test_accuracy=test(model, device, test_loader)
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    train_acc.append(train_accuracy)
    test_acc.append(test_accuracy)

1


loss=1.594364047050476 batch_id=781 Accuracy=35.93: 100%|█| 782/782 [00:32<00:00



Test set: Average loss: 1.4632, Accuracy: 4604/10000 (46%)

2


loss=1.4038478136062622 batch_id=781 Accuracy=47.22: 100%|█| 782/782 [00:36<00:0



Test set: Average loss: 1.3943, Accuracy: 5193/10000 (52%)

3


loss=1.3733062744140625 batch_id=781 Accuracy=52.42: 100%|█| 782/782 [00:44<00:0



Test set: Average loss: 1.0730, Accuracy: 6129/10000 (61%)

4


loss=1.3032957315444946 batch_id=781 Accuracy=56.54: 100%|█| 782/782 [00:46<00:0



Test set: Average loss: 0.9879, Accuracy: 6568/10000 (66%)

5


loss=0.794598400592804 batch_id=781 Accuracy=59.41: 100%|█| 782/782 [00:48<00:00



Test set: Average loss: 0.9310, Accuracy: 6747/10000 (67%)

6


loss=1.077425479888916 batch_id=781 Accuracy=61.45: 100%|█| 782/782 [00:46<00:00



Test set: Average loss: 0.9116, Accuracy: 6827/10000 (68%)

7


loss=0.7175000905990601 batch_id=781 Accuracy=63.42: 100%|█| 782/782 [00:48<00:0



Test set: Average loss: 0.8275, Accuracy: 7104/10000 (71%)

8


loss=0.8193728923797607 batch_id=781 Accuracy=64.61: 100%|█| 782/782 [00:49<00:0



Test set: Average loss: 0.7542, Accuracy: 7383/10000 (74%)

9


loss=1.1781092882156372 batch_id=781 Accuracy=65.32: 100%|█| 782/782 [00:51<00:0



Test set: Average loss: 0.7367, Accuracy: 7433/10000 (74%)

10


loss=1.3658347129821777 batch_id=781 Accuracy=66.29: 100%|█| 782/782 [00:56<00:0



Test set: Average loss: 0.7697, Accuracy: 7433/10000 (74%)

11


loss=1.4331800937652588 batch_id=781 Accuracy=67.45: 100%|█| 782/782 [00:56<00:0



Test set: Average loss: 0.7020, Accuracy: 7582/10000 (76%)

12


loss=1.134873628616333 batch_id=781 Accuracy=67.97: 100%|█| 782/782 [00:58<00:00



Test set: Average loss: 0.7183, Accuracy: 7581/10000 (76%)

13


loss=0.9734516143798828 batch_id=781 Accuracy=68.55: 100%|█| 782/782 [00:54<00:0



Test set: Average loss: 0.6914, Accuracy: 7557/10000 (76%)

14


loss=0.944484531879425 batch_id=781 Accuracy=69.37: 100%|█| 782/782 [00:45<00:00



Test set: Average loss: 0.7075, Accuracy: 7593/10000 (76%)

15


loss=0.8846458196640015 batch_id=781 Accuracy=69.90: 100%|█| 782/782 [00:39<00:0



Test set: Average loss: 0.6693, Accuracy: 7701/10000 (77%)

16


loss=0.5492952466011047 batch_id=781 Accuracy=70.11: 100%|█| 782/782 [00:38<00:0



Test set: Average loss: 0.6331, Accuracy: 7832/10000 (78%)

17


loss=1.6262816190719604 batch_id=781 Accuracy=70.67: 100%|█| 782/782 [00:41<00:0



Test set: Average loss: 0.6361, Accuracy: 7826/10000 (78%)

18


loss=0.5116596817970276 batch_id=781 Accuracy=71.06: 100%|█| 782/782 [00:40<00:0



Test set: Average loss: 0.6253, Accuracy: 7846/10000 (78%)

19


loss=1.4397834539413452 batch_id=781 Accuracy=71.54: 100%|█| 782/782 [00:38<00:0



Test set: Average loss: 0.5997, Accuracy: 7964/10000 (80%)

20


loss=0.6980091333389282 batch_id=781 Accuracy=72.06: 100%|█| 782/782 [00:37<00:0



Test set: Average loss: 0.6140, Accuracy: 7869/10000 (79%)

21


loss=0.9816020727157593 batch_id=781 Accuracy=72.35: 100%|█| 782/782 [00:37<00:0



Test set: Average loss: 0.5771, Accuracy: 8053/10000 (81%)

22


loss=0.8938164710998535 batch_id=781 Accuracy=72.63: 100%|█| 782/782 [00:37<00:0



Test set: Average loss: 0.5703, Accuracy: 8052/10000 (81%)

23


loss=1.3695602416992188 batch_id=781 Accuracy=72.74: 100%|█| 782/782 [00:44<00:0



Test set: Average loss: 0.5817, Accuracy: 8020/10000 (80%)

24


loss=0.5839904546737671 batch_id=781 Accuracy=72.78: 100%|█| 782/782 [00:29<00:0



Test set: Average loss: 0.6044, Accuracy: 7907/10000 (79%)

25


loss=0.8143808841705322 batch_id=781 Accuracy=73.47: 100%|█| 782/782 [00:28<00:0



Test set: Average loss: 0.5925, Accuracy: 7994/10000 (80%)

26


loss=0.9397359490394592 batch_id=781 Accuracy=73.41: 100%|█| 782/782 [00:26<00:0



Test set: Average loss: 0.5640, Accuracy: 8054/10000 (81%)

27


loss=1.2687121629714966 batch_id=781 Accuracy=73.86: 100%|█| 782/782 [00:27<00:0



Test set: Average loss: 0.5663, Accuracy: 8073/10000 (81%)

28


loss=0.9542284607887268 batch_id=781 Accuracy=74.10: 100%|█| 782/782 [00:25<00:0



Test set: Average loss: 0.6098, Accuracy: 7933/10000 (79%)

29


loss=0.7712061405181885 batch_id=781 Accuracy=74.28: 100%|█| 782/782 [00:25<00:0



Test set: Average loss: 0.5359, Accuracy: 8175/10000 (82%)

30


loss=0.83130943775177 batch_id=781 Accuracy=74.59: 100%|█| 782/782 [00:28<00:00,



Test set: Average loss: 0.5327, Accuracy: 8199/10000 (82%)

31


loss=0.715196967124939 batch_id=781 Accuracy=74.84: 100%|█| 782/782 [00:27<00:00



Test set: Average loss: 0.5645, Accuracy: 8152/10000 (82%)

32


loss=1.1297940015792847 batch_id=781 Accuracy=74.88: 100%|█| 782/782 [00:24<00:0



Test set: Average loss: 0.5450, Accuracy: 8157/10000 (82%)

33


loss=1.029753565788269 batch_id=781 Accuracy=75.18: 100%|█| 782/782 [00:27<00:00



Test set: Average loss: 0.4998, Accuracy: 8302/10000 (83%)

34


loss=0.4920450448989868 batch_id=781 Accuracy=75.30: 100%|█| 782/782 [00:24<00:0



Test set: Average loss: 0.5134, Accuracy: 8270/10000 (83%)

35


loss=0.9288721084594727 batch_id=781 Accuracy=75.36: 100%|█| 782/782 [00:25<00:0



Test set: Average loss: 0.5122, Accuracy: 8278/10000 (83%)

36


loss=0.4819714426994324 batch_id=781 Accuracy=75.29: 100%|█| 782/782 [00:25<00:0



Test set: Average loss: 0.5012, Accuracy: 8267/10000 (83%)

37


loss=0.656170129776001 batch_id=781 Accuracy=75.57: 100%|█| 782/782 [00:24<00:00



Test set: Average loss: 0.5180, Accuracy: 8234/10000 (82%)

38


loss=0.9347673654556274 batch_id=781 Accuracy=76.08: 100%|█| 782/782 [00:25<00:0



Test set: Average loss: 0.5226, Accuracy: 8214/10000 (82%)

39


loss=0.8540305495262146 batch_id=781 Accuracy=76.21: 100%|█| 782/782 [00:26<00:0



Test set: Average loss: 0.5145, Accuracy: 8258/10000 (83%)

40


loss=0.6189147233963013 batch_id=781 Accuracy=76.03: 100%|█| 782/782 [00:25<00:0



Test set: Average loss: 0.5001, Accuracy: 8317/10000 (83%)

