In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F


debug = False

def conv_3x3(in_channels, out_channels, pad=1):
    return nn.Conv2d(in_channels, out_channels, 3, padding=pad)


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        """
        A ResidualBlock implements the basic residual block discussed in the paper, for the network used on CIFAR 10.
        It consists of a pair of 3x3 convolutional layers, all with the same output feature map size (either 32, 16, or 8)
        We apply conv-BN-relu for the first layer, then conv-BN, then add the input (residual connection) and do a final RELU.
        We zero-pad the input for dimension mismatches, so that no new parameters are introduced in the residual connections.
        """
        self.in_channels, self.out_channels = in_channels, out_channels
        super(ResidualBlock, self).__init__()
        self.conv1 = conv_3x3(in_channels, in_channels)
        self.bn1 = nn.BatchNorm2d(num_features=in_channels)
        self.conv2 = conv_3x3(in_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(num_features=out_channels)

    def forward(self, x):
        identity = x
        if debug: print(x.shape)
        x = self.conv1(x)
        # x = self.bn1(x)
        x = F.relu(x)
        if debug: print('After conv x.shape is {}'.format(x.shape))
        x = self.conv2(x)
        # x = self.bn2(x)
        if x.shape != identity.shape:
            if debug: print(f'Shape mismatch - x: {x.shape}, original identity: {identity.shape}')
            shape_diff = abs((sum(identity.shape) - sum(x.shape)))
            identity = F.pad(identity, pad=(0,0,0,0,shape_diff//2,shape_diff//2)) # TODO - instead of 8, compute difference in dimensionality and pad based on that (half each), this only currently works if the dim is 16. Also fix the docs on https://pytorch.org/docs/master/_modules/torch/nn/functional.html#pad!!
        if debug: print(f': {x.shape}, {identity.shape}')
        return F.relu(identity + x)

    def __repr__(self):
        return f'Residual block with in_channels {self.in_channels} and out channels {self.out_channels}'


class Resnet(nn.Module):
    def __init__(self, n=1, dbg=False):
        super(Resnet, self).__init__()
        debug = dbg
        self.residual_blocks = []
        # create number of residual blocks needed
        cur_feature_map_size = 16
        changed = False
        for i in range(3*n):
            if i != 0 and i % n == 0:
                cur_feature_map_size = cur_feature_map_size*2
                changed = True
            block = ResidualBlock(cur_feature_map_size if not changed else cur_feature_map_size//2, cur_feature_map_size)
            changed = False
            self.residual_blocks.append(block)
        for b in self.residual_blocks:
            if debug: print(b)

        self.linear = nn.Linear(65536, 10)

        self.conv1 = conv_3x3(3, 16)
        # self.conv2 = conv_3x3(32, 32)
        # self.first_block = ResidualBlock(32, 32)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        for block in self.residual_blocks:
            x = block(x)
        # x = self.first_block(x)
        # flatten the multidimensional input to a single matix for input into the FC layer
        x = x.view(x.shape[0], -1)
        if debug: print('Shape before FC layer: {}'.format(x.shape))
        x = self.linear(x)
        return x

In [0]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
import numpy as np

def get_dataset():
    transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=1)
    testloader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False, num_workers=1)
    return trainloader, testloader 

def get_test_accuracy(net, testloader):
  accs = []
  for i, data in enumerate(testloader, 0):
    x, y = data
    print(x.shape, y.shape)
    preds = resnet(x)
    _, predicted = torch.max(preds, 1)
    accuracy = accuracy_score(predicted, y)
    return accuracy # just doing a sample for now
    accs.append(accuracy)
  return np.mean(accs)

    


if __name__ == '__main__':
    num_epochs = 25
    # get cifar 10 data
    trainloader, testloader = get_dataset()
    benchmark, debug = False, False
    if not benchmark:
        print('Using my resnet')
        resnet = Resnet(dbg=debug)
    else:
        print('Using benchmark resnet')
        resnet = example_resnet.ResNet18()
    resnet.train()
    optimizer = optim.Adam(resnet.parameters(), lr=0.003)
    for e in range(num_epochs):
        for i, data in enumerate(trainloader, 0):
#             if i == 5:
#                 break
            print(i)
            x, y = data
            # zero the grad
            optimizer.zero_grad()
            preds = resnet(x)
            loss = F.cross_entropy(preds, y)
            loss.backward()
            optimizer.step()
            if i % 10 == 0:
                _, predicted = torch.max(preds, 1)
                accuracy = accuracy_score(predicted, y)
                print('loss: {}, accuracy: {}'.format(loss, accuracy))
            if i % 50 == 0:
              # get test accuracy
              test_acc = get_test_accuracy(resnet, testloader)
              print('test accuracy: {}'.format(test_acc))



Files already downloaded and verified
Files already downloaded and verified
Using my resnet
0
loss: 2.309892177581787, accuracy: 0.1640625
torch.Size([1000, 3, 32, 32]) torch.Size([1000])
torch.Size([1000, 3, 32, 32]) torch.Size([1000])
torch.Size([1000, 3, 32, 32]) torch.Size([1000])
torch.Size([1000, 3, 32, 32]) torch.Size([1000])
torch.Size([1000, 3, 32, 32]) torch.Size([1000])
torch.Size([1000, 3, 32, 32]) torch.Size([1000])
torch.Size([1000, 3, 32, 32]) torch.Size([1000])
torch.Size([1000, 3, 32, 32]) torch.Size([1000])
torch.Size([1000, 3, 32, 32]) torch.Size([1000])
torch.Size([1000, 3, 32, 32]) torch.Size([1000])
test accuracy: 0.18569999999999998
1
2
3
4
5
6
7
8
9
10
loss: 5.442500591278076, accuracy: 0.140625
11
12
13
14
15
16
17
18
19
20
loss: 4.242424488067627, accuracy: 0.1171875
21
22
23
24
25
26
27
28
29
30
loss: 2.3667941093444824, accuracy: 0.2578125
31
32
33
34
35
36
37
38
39
40
loss: 1.8923252820968628, accuracy: 0.375
41
42
43
44
45
46
47
48
49
50
loss: 2.0040180683