In [46]:
import torch
import torchvision
import torchvision.transforms as tf
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data import random_split, DataLoader
from torchvision.datasets.utils import download_url
import torch.nn as nn
import torch.nn.functional as F
import tarfile
%matplotlib inline

We use the architecture from https://myrtle.ai/learn/how-to-train-your-resnet-4-architecture/

In [47]:
import torchvision.datasets as datasets

cifar_trainset = datasets.CIFAR10(root='./data', train=True, download=True,transform=tf.ToTensor()  )
data = cifar_trainset.data / 255 

Files already downloaded and verified


In [48]:
means = data.mean(axis = (0,1,2)) 
stds = data.std(axis = (0,1,2))

In [49]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [50]:
train_transforms=tf.Compose([tf.RandomCrop(32,padding=4,padding_mode='reflect'),
                            tf.RandomHorizontalFlip(),
                            tf.ToTensor(),
                            tf.Normalize(means,stds,inplace=True) ])

In [51]:
valid_transforms = tf.Compose([tf.ToTensor(), tf.Normalize(means,stds,inplace=True)])

In [52]:
train_ds = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transforms)
valid_ds = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=valid_transforms)

Files already downloaded and verified
Files already downloaded and verified


In [53]:
batch_size=400
train_dl=DataLoader(train_ds,batch_size,shuffle=True,num_workers=2,pin_memory=True)
valid_dl=DataLoader(valid_ds,batch_size*2,num_workers=2,pin_memory=True)



We are using denormalization to show images from dataloader and as the data loader is 4d so for multiplying and adding std and mean we need to reshape them to 4d


In [54]:
def denormalize(img,means,stds):
  means=torch.tensor(means).reshape(1,3,1,1)
  stds=torch.tensor(stds).reshape(1,3,1,1)
  return img*stds + means

In [55]:
def to_device(data,device):
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device,non_blocking=True)

class DeviceLoading:
    def __init__(self,dl,device):
        self.dl=dl
        self.device=device
    def __iter__(self):
        for x in self.dl:
            yield to_device(x,self.device)
    def __len__(self):
        return len(self.dl)

In [56]:
train_dl=DeviceLoading(train_dl,device)
valid_dl=DeviceLoading(valid_dl,device)

In [57]:
device

device(type='cuda', index=0)

In [58]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

class Loss_Acc(nn.Module):
    def training_step(self, batch):
        images, labels = batch 
        out = self.forward(images)                  
        loss = F.cross_entropy(out, labels) 
        return loss
    
    def validation_step(self, batch):
        images, labels = batch 
        out = self.forward(images)                    
        acc = accuracy(out, labels)           
        return acc

In [71]:
def Add_layers(in_channels, out_channels, pool=False):
    layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 
              nn.BatchNorm2d(out_channels), 
              nn.ReLU(inplace=True)]
    if pool: layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)

class ResNet9(Loss_Acc):
    def __init__(self, in_channels, num_classes): 
        super(ResNet9,self).__init__()
#         400x3x32x32
        self.conv1=Add_layers(in_channels, 64) # 400x64x32x32
        self.conv2=Add_layers(64, 128, pool=True) # 400x128x16x16
        self.res1=nn.Sequential(Add_layers(128, 128), Add_layers(128, 128))
        
        self.conv3=Add_layers(128, 256, pool=True)  # 400x256x8x8
        self.conv4=Add_layers(256, 256, pool=True)
        self.res2=nn.Sequential(Add_layers(256, 256), Add_layers(256, 256))
        
        self.classifier=nn.Sequential(nn.MaxPool2d(4), # 10x512x1x1
                                        nn.Flatten(), # 512
                                        nn.Dropout(0.2),
                                        nn.Linear(256, num_classes))
        
    def forward(self, xb):
        out=self.conv1(xb)
        out=self.conv2(out)
        out=self.res1(out) + out
        out=self.conv3(out)
        out=self.conv4(out)
        out=self.res2(out) + out
        out=self.classifier(out)
        return out

In [74]:
torch.cuda.empty_cache()
model=to_device(ResNet9(3, 10), device)

In [61]:
@torch.no_grad()
def evaluation(model,val_loader,t_loss,lrs):
    model.eval()
    outputs = [model.validation_step(batch) for batch in val_loader]
    epoch_acc = torch.stack([x for x in outputs]).mean().item()
    return epoch_acc

In [62]:
def new_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [63]:
def fit(epochs, max_lr, model, train_loader, val_loader, 
                  weight_decay=0, grad_clip=None, opt=torch.optim.Adam):
    torch.cuda.empty_cache()
    optimizer=opt(model.parameters(),max_lr,weight_decay=weight_decay)
    #every epoch has different lr in OneCycleLR
    lr_sched=torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epochs, 
                                                steps_per_epoch=len(train_loader))

    records=[]
    for epoch in range(epochs):
        model.train()
        lrs=[]
        train_losses=[]
        for batch in train_loader:
            loss = model.training_step(batch)
            train_losses.append(loss)
            loss.backward()
            
            if grad_clip: 
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)
                
            optimizer.step()
            optimizer.zero_grad()
            
            lrs.append(new_lr(optimizer))
            lr_sched.step()
        
        t_loss = torch.stack(train_losses).mean().item()
        lrs = lrs
        result = evaluation(model, val_loader,t_loss,lrs)
        records.append([result,t_loss,lrs[-1]])
        print(f"result:{result},t_loss:{t_loss},lrs:{lrs[-1]}")
    
    return records

In [77]:
rec=fit(20,0.01,model,train_dl,valid_dl,1e-4,0.1)

result:0.543749988079071,t_loss:1.342811942100525,lrs:0.0010347147065464194
result:0.6287500858306885,t_loss:0.9583999514579773,lrs:0.0027883855836428892
result:0.7487499713897705,t_loss:0.799932062625885,lrs:0.005189933488862935
result:0.7665385007858276,t_loss:0.7103345990180969,lrs:0.007594185749005899
result:0.7064422965049744,t_loss:0.6261388063430786,lrs:0.00935524316949251
result:0.8009615540504456,t_loss:0.5548087358474731,lrs:0.01
result:0.7893269062042236,t_loss:0.49511200189590454,lrs:0.009874640062350875
result:0.7846154570579529,t_loss:0.46220862865448,lrs:0.009504846320134737
result:0.8079807758331299,t_loss:0.434877872467041,lrs:0.0089091617757105
result:0.8335577249526978,t_loss:0.4105146825313568,lrs:0.008117456539497631
result:0.8371153473854065,t_loss:0.38311952352523804,lrs:0.007169430017913008
result:0.820192277431488,t_loss:0.3550102114677429,lrs:0.0061126202193628925
result:0.8544229865074158,t_loss:0.32791367173194885,lrs:0.00500002
result:0.8710577487945557,t_l

In [65]:
!pip install torchinfo

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [76]:
from torchinfo import summary
summary(model)

Layer (type:depth-idx)                   Param #
ResNet9                                  --
├─Sequential: 1-1                        --
│    └─Conv2d: 2-1                       1,792
│    └─BatchNorm2d: 2-2                  128
│    └─ReLU: 2-3                         --
├─Sequential: 1-2                        --
│    └─Conv2d: 2-4                       73,856
│    └─BatchNorm2d: 2-5                  256
│    └─ReLU: 2-6                         --
│    └─MaxPool2d: 2-7                    --
├─Sequential: 1-3                        --
│    └─Sequential: 2-8                   --
│    │    └─Conv2d: 3-1                  147,584
│    │    └─BatchNorm2d: 3-2             256
│    │    └─ReLU: 3-3                    --
│    └─Sequential: 2-9                   --
│    │    └─Conv2d: 3-4                  147,584
│    │    └─BatchNorm2d: 3-5             256
│    │    └─ReLU: 3-6                    --
├─Sequential: 1-4                        --
│    └─Conv2d: 2-10                      295,168
│