In [None]:
import torch
import torchvision
from torchvision import datasets
from torchvision.transforms import transforms
import numpy as np
import matplotlib.pyplot as plt
import time
import os

batch_size = 100

transforms = transforms.Compose([
    transforms.Resize((120,120)),
    transforms.CenterCrop((110,110)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
                                 ])

trainset = datasets.CIFAR10(root='data', train=True, transform=transforms, download=True)
testset = datasets.CIFAR10(root='data', train=False, transform=transforms, download=True)

dataset_len = len(trainset)
indices = list(range(dataset_len))
len_trainset = int(dataset_len * 0.8)

train_idx = torch.utils.data.SubsetRandomSampler(indices[:len_trainset])
validation_idx = torch.utils.data.SubsetRandomSampler(indices[len_trainset:])

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=train_idx)
validationloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=validation_idx)
testloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
           'ship', 'truck']

for images, labels in trainloader:
  print('Image batch dimensions:', images.shape)
  print('Label batch dimensions:', labels.shape)
  print('Class labels of 10 examples:', labels[:10])
  break

In [None]:
from torch.utils.data import non_deterministic
import torch.nn as nn

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
  return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 
                         padding=dilation, groups=groups, bias=False, dilation=dilation)
  
def conv1x1(in_planes, out_planes, stride=1):
  return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class BasicBlock(nn.Module):
  expansion: int = 1

  def __init__(self, inplanes, planes, stride=1, downsample=None,
               groups=1, base_width=64, dilation=1, norm_layer=None):
    
    super().__init__()
    if norm_layer is None:
      norm_layer = nn.BatchNorm2d
    if groups != 1 or base_width != 64:
      raise ValueError('BasicBlock only supports groups=1 and base_width=64')
    if dilation > 1:
      raise NotImplementedError('Dilation > 1 not supported by BasicBlock')

    self.conv1 = conv3x3(inplanes, planes, stride)
    self.bn1 = norm_layer(planes)
    self.relu = nn.ReLU(inplace=True) 
    self.conv2 = conv3x3(planes, planes)
    self.bn2 = norm_layer(planes)
    self.downsample = downsample
    self.stride = stride

  def forward(self, x):
    identity = x

    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.conv2(out)
    out = self.bn2(out)

    if self.downsample is not None:
      identity = self.downsample(x)

    out += identity
    out = self.relu(out)

    return out
  
class Bottleneck(torch.nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):

      super.__init__()
      if norm_layer is None:
        norm_layer = nn.BatchNorm2d

      width = int(planes * (base_width / 64.)) * groups 
      self.conv1 = conv1x1(inplanes, width)
      self.bn1 = norm_layer(width)
      self.conv2 = conv3x3(width, width, stride, groups, dilation)
      self.bn2 = norm_layer(width)
      self.conv3 = conv1x1(width, planes * self.expansion)
      self.bn3 = norm_layer(planes * self.expansion)
      self.relu = nn.ReLU(inplace=True)
      self.downsample = downsample
      self.stride = stride

    def forward(self, x):
      identity = x

      out = self.conv1(x)
      out = self.bn1(out)
      out = self.relu(out)

      out = self.conv2(out)
      out = self.bn2(out)
      out = self.relu(out)

      out = self.conv3(out)
      out = self.bn3(out)

      if self.downsample is not None:
        identity = self.downsample(x)

      out += identity
      out = self.relu(out)

      return out


class ResNet(nn.Module):

  def __init__(self, block, layers, num_classes, zero_init_residual=False, groups=1,
               width_per_group=64, replace_stride_with_dilation=None, norm_layer=None):
    
    super().__init__()
    if norm_layer is None:
      norm_layer = nn.BatchNorm2d
    self.norm_layer = norm_layer

    self.inplanes = 64
    self.dilation = 1
    if replace_stride_with_dilation is None:
      replace_stride_with_dilation = [False, False, False]
    if len(replace_stride_with_dilation) != 3:
      raise ValueError('replace_stride_with_value should be None or a 3-element tuple')

    self.groups = groups
    self.base_width = width_per_group
    self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                           bias=False)
    self.bn1 = norm_layer(self.inplanes)
    self.relu = nn.ReLU(inplace=True)
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    self.layer1 = self.make_layer_(block, 64, layers[0])
    self.layer2 = self.make_layer_(block, 128, layers[1], stride=2,
                                   dilate=replace_stride_with_dilation[0])
    self.layer3 = self.make_layer_(block, 256, layers[2], stride=2, 
                                   dilate=replace_stride_with_dilation[1])
    self.layer4 = self.make_layer_(block, 512, layers[3], stride=2,
                                   dilate=replace_stride_with_dilation[2])
    self.avgpool = nn.AdaptiveAvgPool2d((1,1))
    self.fc = nn.Linear(512 * block.expansion, num_classes) 

    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity='relu')
      elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

    if zero_init_residual:
      for m in self.modules():
        if isinstance(m, Bottleneck):
          nn.init.constant_(m.bn3.weight, 0) 
        elif isinstance(m, BasicBlock):
          nn.init.constant_(m.bn2.weight, 0) 

  def make_layer_(self, block, planes, blocks, stride=1, dilate=False):
    norm_layer = self.norm_layer
    downsample = None
    previous_dilation = self.dilation
    if dilate:
      self.dilation *= stride
      stride = 1
    if stride != 1 or self.inplanes != planes * block.expansion:
      downsample = nn.Sequential(
          conv1x1(self.inplanes, planes * block.expansion, stride),
          norm_layer(planes * block.expansion)
      )

    layers = []
    layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                        self.base_width, previous_dilation, norm_layer))
    self.inplanes = planes * block.expansion
    for _ in range(1, blocks):
      layers.append(block(self.inplanes, planes, groups=self.groups,
                          base_width=self.base_width, dilation=self.dilation,
                          norm_layer=norm_layer))
    
    return nn.Sequential(*layers)


  def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.avgpool(x)
    x = torch.flatten(x, 1)
    x = self.fc(x)

    return x

In [None]:
model = ResNet(BasicBlock, layers=[2,2,2,2], num_classes=10)
print(model)

In [None]:
def compute_accuracy(model, data_loader):
  with torch.no_grad():

    correct_pred, num_examples = 0, 0

    for i, (features, targets) in enumerate(data_loader):

      features = features
      targets = targets.float()

      logits = model(features)
      _, predicted_labels = torch.max(logits, 1)

      num_examples += targets.size(0)
      correct_pred += (predicted_labels == targets).sum()
  
  return correct_pred.float()/num_examples * 100

In [None]:
def train_model(model, num_epochs, train_loader, valid_loader, test_loader,
                optimizer, logging_interval=50, scheduler=None, 
                scheduler_on='valid_acc'):
  
  start_time = time.time()
  minibatch_loss_list, train_acc_list, valid_acc_list = [], [], []

  for epoch in range(num_epochs):

    model.train()
    for batch_idx, (features, targets) in enumerate(train_loader):

        logits = model(features)
        loss = nn.functional.cross_entropy(logits, targets)
        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        minibatch_loss_list.append(loss.item())
        if not batch_idx % logging_interval:
          print(f'Epoch: {epoch+1:03d}/{num_epochs:03d} '
                      f'| Batch {batch_idx:04d}/{len(train_loader):04d} '
                      f'| Loss: {loss:.4f}')
        
    model.eval()
    with torch.no_grad():
      train_acc = compute_accuracy(model, train_loader)
      valid_acc = compute_accuracy(model, valid_loader)
      print(f'Epoch: {epoch+1:03d}/{num_epochs:03d} '
                  f'| Train: {train_acc :.2f}% '
                  f'| Validation: {valid_acc :.2f}%')
      train_acc_list.append(train_acc.item())
      valid_acc_list.append(valid_acc.item())

    elapsed = (time.time() - start_time)/60
    print(f'Time elapsed: {elapsed:.2f} min')

    if scheduler is not None:
       if scheduler_on == 'valid_acc':
         scheduler.step(valid_acc_list[-1])
       elif scheduler_on == 'minibatch_loss':
        scheduler.step(minibatch_loss_list[-1])
       else:
        raise ValueError(f'Invalid `scheduler_on choice`')
    
    elapsed = (time.time() - start_time)/60
    print(f'Total Training Time: {elapsed:.2f} min')

    test_acc = compute_accuracy(model, test_loader)
    print(f'Test accuracy {test_acc:.2f}%')

    return minibatch_loss_list, train_acc_list, valid_acc_list    

In [None]:
def plot_training_loss(minibatch_loss_list, num_epochs, iter_per_epoch,
                       results_dir=None, averaging_iterations=100):

    plt.figure()
    ax1 = plt.subplot(1, 1, 1)
    ax1.plot(range(len(minibatch_loss_list)),
             (minibatch_loss_list), label='Minibatch Loss')

    if len(minibatch_loss_list) > 1000:
        ax1.set_ylim([
            0, np.max(minibatch_loss_list[1000:])*1.5
            ])
    ax1.set_xlabel('Iterations')
    ax1.set_ylabel('Loss')

    ax1.plot(np.convolve(minibatch_loss_list,
                         np.ones(averaging_iterations,)/averaging_iterations,
                         mode='valid'),
             label='Running Average')
    ax1.legend()

    ax2 = ax1.twiny()
    newlabel = list(range(num_epochs+1))

    newpos = [e*iter_per_epoch for e in newlabel]

    ax2.set_xticks(newpos[::10])
    ax2.set_xticklabels(newlabel[::10])

    ax2.xaxis.set_ticks_position('bottom')
    ax2.xaxis.set_label_position('bottom')
    ax2.spines['bottom'].set_position(('outward', 45))
    ax2.set_xlabel('Epochs')
    ax2.set_xlim(ax1.get_xlim())

    plt.tight_layout()

    if results_dir is not None:
        image_path = os.path.join(results_dir, 'plot_training_loss.pdf')
        plt.savefig(image_path)


def plot_accuracy(train_acc_list, valid_acc_list, results_dir):

    num_epochs = len(train_acc_list)

    plt.plot(np.arange(1, num_epochs+1),
             train_acc_list, label='Training')
    plt.plot(np.arange(1, num_epochs+1),
             valid_acc_list, label='Validation')

    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.tight_layout()

    if results_dir is not None:
        image_path = os.path.join(
            results_dir, 'plot_acc_training_validation.pdf')
        plt.savefig(image_path)



def plot_confusion_matrix(conf_mat,
                          hide_spines=False,
                          hide_ticks=False,
                          figsize=None,
                          cmap=None,
                          colorbar=False,
                          show_absolute=True,
                          show_normed=False,
                          class_names=None):

    if not (show_absolute or show_normed):
        raise AssertionError('Both show_absolute and show_normed are False')
    if class_names is not None and len(class_names) != len(conf_mat):
        raise AssertionError('len(class_names) should be equal to number of'
                             'classes in the dataset')

    total_samples = conf_mat.sum(axis=1)[:, np.newaxis]
    normed_conf_mat = conf_mat.astype('float') / total_samples

    fig, ax = plt.subplots(figsize=figsize)
    ax.grid(False)
    if cmap is None:
        cmap = plt.cm.Blues

    if figsize is None:
        figsize = (len(conf_mat)*1.25, len(conf_mat)*1.25)

    if show_normed:
        matshow = ax.matshow(normed_conf_mat, cmap=cmap)
    else:
        matshow = ax.matshow(conf_mat, cmap=cmap)

    if colorbar:
        fig.colorbar(matshow)

    for i in range(conf_mat.shape[0]):
        for j in range(conf_mat.shape[1]):
            cell_text = ""
            if show_absolute:
                cell_text += format(conf_mat[i, j], 'd')
                if show_normed:
                    cell_text += "\n" + '('
                    cell_text += format(normed_conf_mat[i, j], '.2f') + ')'
            else:
                cell_text += format(normed_conf_mat[i, j], '.2f')
            ax.text(x=j,
                    y=i,
                    s=cell_text,
                    va='center',
                    ha='center',
                    color="white" if normed_conf_mat[i, j] > 0.5 else "black")
    
    if class_names is not None:
        tick_marks = np.arange(len(class_names))
        plt.xticks(tick_marks, class_names, rotation=90)
        plt.yticks(tick_marks, class_names)
        
    if hide_spines:
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    if hide_ticks:
        ax.axes.get_yaxis().set_ticks([])
        ax.axes.get_xaxis().set_ticks([])

    plt.xlabel('predicted label')
    plt.ylabel('true label')
    return fig, ax

In [None]:
from torch import optim
optimizer = torch.optim.SGD(model.parameters(), momentum=0.9, lr=0.1)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       factor=0.1,
                                                       mode='max',
                                                       verbose=True)

num_epochs = 10

In [None]:
minibatch_loss_list, train_acc_list, valid_acc_list = train_model(
    model=model,
    num_epochs=num_epochs,
    train_loader=trainloader,
    valid_loader=validationloader,
    test_loader=testloader,
    optimizer=optimizer,
    scheduler=scheduler,
    scheduler_on='valid_acc',
    logging_interval=10)

plot_training_loss(minibatch_loss_list=minibatch_loss_list,
                   num_epochs=num_epochs,
                   iter_per_epoch=len(trainloader),
                   results_dir=None,
                   averaging_iterations=200)
plt.show()

plot_accuracy(train_acc_list=train_acc_list,
              valid_acc_list=valid_acc_list,
              results_dir=None)
plt.ylim([60, 100])
plt.show()