In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F


debug = True

# TODO stride=2 is used in the paper but this is causing shape issues right now :(
def conv_3x3(in_channels, out_channels, pad=1, stride=1):
    return nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=pad)


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        """
        A ResidualBlock implements the basic residual block discussed in the paper, for the network used on CIFAR 10.
        It consists of a pair of 3x3 convolutional layers, all with the same output feature map size (either 32, 16, or 8)
        We apply conv-BN-relu for the first layer, then conv-BN, then add the input (residual connection) and do a final RELU.
        We zero-pad the input for dimension mismatches, so that no new parameters are introduced in the residual connections.
        """
        self.in_channels, self.out_channels = in_channels, out_channels
        super(ResidualBlock, self).__init__()
        self.conv1 = conv_3x3(in_channels, in_channels)
        self.bn1 = nn.BatchNorm2d(num_features=in_channels)
        self.conv2 = conv_3x3(in_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(num_features=out_channels)

        
    def forward(self, x):
        identity = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        # The paper said that identity shortcuts are used in all cases, 
        # so in cases of shape mismatch I pad to align dimensions, which introduces no new parameters.
        if x.shape != identity.shape:
            shape_diff = abs((sum(identity.shape) - sum(x.shape)))
            identity = F.pad(identity, pad=(0,0,0,0,shape_diff//2,shape_diff//2))
        return F.relu(identity + x)

    def __repr__(self):
        return f'Residual block with in_channels {self.in_channels} and out channels {self.out_channels}'


class Resnet(nn.Module):
    def __init__(self, n=1, dbg=False):
        super(Resnet, self).__init__()
        debug = dbg
        self.residual_blocks = []
        # create number of residual blocks needed
        cur_feature_map_size = 16
        changed = False
        for i in range(3*n):
            if i != 0 and i % n == 0:
                cur_feature_map_size = cur_feature_map_size*2
                changed = True
            block = ResidualBlock(cur_feature_map_size if not changed else cur_feature_map_size//2, cur_feature_map_size)
            changed = False
            self.residual_blocks.append(block)

        self.linear = nn.Linear(16384, 10)
        self.conv1 = conv_3x3(3, 16)
        self.pool = nn.AvgPool2d(2, stride=2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        for block in self.residual_blocks:
            x = block(x)
        # x = self.first_block(x)
        # flatten the multidimensional input to a single matix for input into the FC layer
        x = self.pool(x) # only difference is this pool
        x = x.view(x.shape[0], -1)
        x = self.linear(x)
        return x

In [26]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
import numpy as np

def get_dataset():
    transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=1)
    testloader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False, num_workers=1)
    return trainloader, testloader 

def get_test_accuracy(net, testloader):
  accs = []
  net.eval()
  with torch.no_grad():
    for i, data in enumerate(testloader, 0):
      x, y = data
      x, y = x.cuda(), y.cuda()
      preds = resnet(x)
      _, predicted = torch.max(preds, 1)
      accuracy = accuracy_score(predicted.cpu(), y.cpu())
      accs.append(accuracy)
      break
  net.train()
  return np.mean(accs)

def update_learning_rate(current_lr, optim):
  new_lr = current_lr/10
  for g in optim.param_groups:
    g['lr'] = new_lr
  return new_lr

    


if __name__ == '__main__':
    num_epochs = 25
    # get cifar 10 data
    trainloader, testloader = get_dataset()
    benchmark, debug = False, True
    resnet = Resnet(dbg=debug)
    resnet.train()
    resnet = resnet.cuda()
    for block in resnet.residual_blocks:
      block.cuda()
    current_lr = 0.1e-4
#     optimizer = optim.Adam(resnet.parameters(), lr=1e-4, weight_decay=0.0001)
    optimizer = optim.SGD(resnet.parameters(), lr=current_lr, weight_decay=0.0001, momentum=0.9)
    train_accs, test_accs = [], []
    def train_model():
      stopping_threshold, current_count = 3, 0
      n_iters = 0
      for e in range(num_epochs):
        # modify learning rate at 
          for i, data in enumerate(trainloader, 0):
              x, y = data
              x, y = x.cuda(), y.cuda()
              # zero the grad
              optimizer.zero_grad()
              preds = resnet(x)
              loss = F.cross_entropy(preds, y)
              loss.backward()
              optimizer.step()
              if i % 10 == 0:
                  _, predicted = torch.max(preds, 1)
                  accuracy = accuracy_score(predicted.cpu(), y.cpu())
                  train_accs.append(accuracy)
                  print('loss: {}, accuracy: {}'.format(loss, accuracy))
              if i % 50 == 0:
                # get test accuracy
                test_acc = get_test_accuracy(resnet, testloader)
                print('test accuracy: {}'.format(test_acc))
                test_accs.append(test_acc)
              n_iters+=1
              if n_iters == 32000 or n_iters == 48000:
                current_lr = update_learning_rate(current_lr, optim)
      print('iterated {} times'.format(n_iters))
      return resnet
    
    trained_resnet_model = train_model()
    
                



Files already downloaded and verified
Files already downloaded and verified
loss: 2.367396831512451, accuracy: 0.0859375
test accuracy: 0.091
loss: 2.3379154205322266, accuracy: 0.1171875
loss: 2.319140911102295, accuracy: 0.09375
loss: 2.3387510776519775, accuracy: 0.09375
loss: 2.321025848388672, accuracy: 0.1015625
loss: 2.2964954376220703, accuracy: 0.1015625
test accuracy: 0.1
loss: 2.332428455352783, accuracy: 0.109375
loss: 2.252169370651245, accuracy: 0.125
loss: 2.3197128772735596, accuracy: 0.109375
loss: 2.2791786193847656, accuracy: 0.09375
loss: 2.28896164894104, accuracy: 0.09375
test accuracy: 0.118
loss: 2.285951614379883, accuracy: 0.1328125
loss: 2.2776944637298584, accuracy: 0.15625
loss: 2.2602126598358154, accuracy: 0.171875
loss: 2.282613754272461, accuracy: 0.109375
loss: 2.246964693069458, accuracy: 0.203125
test accuracy: 0.133
loss: 2.2511887550354004, accuracy: 0.15625
loss: 2.2664756774902344, accuracy: 0.125
loss: 2.2104454040527344, accuracy: 0.2265625
los

In [27]:
n_iters

NameError: ignored