In [1]:
import numpy as np 
import pandas as pd
import torch
from torch import nn 
from torch.utils.data import DataLoader,random_split,Dataset
from torchvision import transforms 
from torchvision import datasets

In [31]:
class Block(nn.Module): 
    def __init__(self, in_channel,out_channel,down_sample=None,stride=1): 
        super().__init__(); 
        self.net = nn.Sequential(nn.Conv2d(in_channel,out_channel,kernel_size=1,padding=0),
                                 nn.BatchNorm2d(out_channel),nn.ReLU(),
                                 nn.Conv2d(out_channel,out_channel,kernel_size=3,padding=1,
                                          stride=stride),
                                 nn.BatchNorm2d(out_channel),nn.ReLU(),
                                 nn.Conv2d(out_channel,out_channel*4,kernel_size=1,padding=0),
                                 nn.BatchNorm2d(out_channel*4),nn.ReLU()); 
        self.down_sample = down_sample; 
        self.relu = nn.ReLU();
    def forward(self, x): 
        identity = x; 
        x = self.net(x); 
        if (self.down_sample is not None): identity = self.down_sample(identity); 
        x = x.add(identity)
        x = self.relu(x); 
        return x 

In [32]:
class ResNet(nn.Module): 
    def __init__(self,layers,in_channles,num_class): 
        super().__init__(); 
        self.conv1 = nn.Conv2d(in_channles,64,kernel_size=7,padding=3,stride=2); 
        self.bn1 = nn.BatchNorm2d(64); 
        self.relu = nn.ReLU(); 
        self.maxpool = nn.MaxPool2d(kernel_size=3,padding=1,stride=2); 
        self.in_channels = 64;
        self.layer1 = self._make_layer(layers[0],64,stride = 1); 
        self.layer2 = self._make_layer(layers[1],128,stride = 2);
        self.layer3 = self._make_layer(layers[2],256,stride = 2);
        self.layer4 = self._make_layer(layers[3],512,stride = 2);
        self.avgpool = nn.AdaptiveAvgPool2d((1,1));
        self.fc = nn.Linear(512*4,num_class); 
    def forward(self, x): 
        x = self.conv1(x); 
        x = self.bn1(x); 
        x = self.relu(x); 
        x = self.maxpool(x); 
        x = self.layer1(x); 
        x = self.layer2(x); 
        x = self.layer3(x); 
        x = self.layer4(x); 
        x = self.avgpool(x) ;
        x = x.reshape(x.shape[0],-1)
        x = self.fc(x); 
        return x; 
    def _make_layer(self,num_layers,out_channels,stride): 
        down_sample = None;
        if (stride!=1 or self.in_channels!=out_channels*4): 
            down_sample = nn.Sequential(nn.Conv2d(self.in_channels,out_channels*4,kernel_size=1,
                                                  stride=stride)); 
        layers = [] 
        layers.append(Block(self.in_channels,out_channels,stride=stride,down_sample=down_sample))
        self.in_channels = out_channels * 4; 
        for i in range(num_layers-1): 
            layers.append(Block(self.in_channels,out_channels)); 
        return nn.Sequential(*layers); 

In [8]:
transform = transforms.Compose([transforms.RandomHorizontalFlip(p=0.5),
                                transforms.RandomVerticalFlip(p=0.5), 
                                transforms.RandomAutocontrast(p=0.5), 
                                transforms.ToTensor()]); 
test_transform = transforms.ToTensor();
train_dataset = datasets.CIFAR10(root='torchvision_datasets',train=True,download=True,transform=transform);
test_dataset = datasets.CIFAR10(root='torchvision_datasets',train=False,download=True,transform=test_transform);


100%|██████████████████████████████████████| 170M/170M [45:51<00:00, 62.0kB/s]


In [43]:
train_loader = DataLoader(train_dataset,batch_size=64,shuffle=True); 
test_loader = DataLoader(test_dataset,batch_size=64,shuffle=False); 

In [44]:
model = ResNet(in_channles=3,num_class = 10,layers = [2,2,2,2]); 


In [45]:
optimizer = torch.optim.Adam(lr = 5e-3,params = model.parameters()); 
loss_function = nn.CrossEntropyLoss(); 

In [46]:
model.train();
b = 0; 
for x, y in train_loader : 
    optimizer.zero_grad(); 
    out = model(x); 
    loss = loss_function(out,y); 
    loss.backward(); 
    optimizer.step(); 
    b+=1;
    print(round(loss.item(),4))

2.2969
22.6337
12.0756
6.1198
4.4569
3.1194
3.3161
3.7425
5.0826
6.0931
5.9051
4.7354
7.174
6.4609
3.8481
2.349
2.7132
2.6722
2.5559
5.1292
3.2615
4.9757
2.46
2.7303
2.9017
2.7393
3.852
3.2703
3.5351
2.7802
3.4843
3.8105
3.5447
3.0366
2.6091
2.7476
2.7916
3.2117
2.4037
2.4342
2.9224
2.8669
3.0854
2.4702
2.5968
3.0339
2.3108
2.7963
2.6759
2.9463
2.4476
3.0055
2.3487
2.6877
2.7269
2.4196
2.4093
2.3793
2.4036
2.8644
3.5368
2.3798
3.4934
2.4525
3.3094
2.8138
2.6966
3.2439
2.5255
3.4303
2.8934
3.1222
2.5088
2.5757
3.0219
2.5496
2.4881
2.7805
2.2499
3.2172
2.3858
2.6199
2.3916
2.9201
2.8668
3.513
2.9319
3.4524
2.9456
3.5821
2.3858
2.3121
3.2952
2.3451
2.4732
2.8227
2.9837
2.2479
2.7014
2.4235
2.5619
2.3223
2.4832
2.4968
2.2098
2.6177
2.6166
2.5352
2.2422
2.297
2.2601
2.4597
2.7128
2.4589
2.5733
2.2828
2.72
2.3249
2.2471
2.9196
2.4517
2.289
2.4843
3.6014
2.5218
3.2533
2.545
2.3468
2.3639
2.3949
2.3002
2.5067
2.5707
2.3952
3.6532
2.7054
2.5519
2.6792
2.3939
2.4092
2.2969
2.5552
2.4147
2.4144
2

KeyboardInterrupt: 

In [25]:
32768//4

8192