In [83]:
import numpy as np
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F

In [43]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.49139969, 0.48215842, 0.44653093], [0.20220212, 0.19931542, 0.20086347])
])
training_data = CIFAR10('cifar/train', train=True, download=True, transform=transform)
val_data = CIFAR10('cifar/train', train=True, download=True, transform=transform)
test_data = CIFAR10('cifar/test', train=False, download=True, transform=transform)
N = len(training_data)

loader_train = DataLoader(training_data, batch_size=128, shuffle=False, sampler=SubsetRandomSampler(range(0,N-5000)))
loader_val = DataLoader(val_data, batch_size=128, shuffle=False, sampler=SubsetRandomSampler(range(N-5000,N)))
loader_test = DataLoader(test_data, 128)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [198]:
class Resnet(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu1 = nn.ReLU()
        self.block1 = self._create_block(n, 16, 16, 1)
        self.block2 = self._create_block(n, 16, 32, 2)
        self.block3 = self._create_block(n, 32, 64, 2)
        self.pool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64, 10)
    
    def _create_block(self, n, num_input, num_output, stride):
        sub_blocks = [Resblock(num_input, num_output, stride)]
        for l in range(n-1):
            sub_blocks.append(Resblock(num_output, num_output, 1))
        return nn.Sequential(*sub_blocks)
        
    def forward(self,x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.pool(out)
        out = out.view(out.shape[0], -1)
        out = self.fc(out)
        return out

In [199]:
class Resblock(nn.Module):
    def __init__(self, num_input, num_output, stride):
        super().__init__()
        self.conv1 = nn.Conv2d(num_input, num_output, 3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(num_output)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(num_output, num_output, 3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(num_output)
        self.relu2 = nn.ReLU()
        self.identity = IdentityMapping(stride, num_input, num_output)
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)
        out += self.identity(x)
        return out

In [200]:
#Option A from paper, allows to change size of input image, if necessary (size/2 when stride=2)
#This helps to match sizes of input resblock and output of resblock
class IdentityMapping(nn.Module):
    def __init__(self, stride, num_input, num_output):
        super().__init__()
        self.maxpool = nn.MaxPool2d(kernel_size=1, stride=stride)
        self.extra_filters = num_output - num_input
        
    def forward(self, x):
        out = F.pad(x, (0, 0, 0, 0, 0, self.extra_filters))
        out = self.maxpool(out)
        return out

In [204]:
model = Resnet(9)
for idx, (data, labels) in enumerate(loader_train):
    scores = model(data)
    print(scores.shape)
    break

torch.Size([128, 10])
