In [None]:
from google.colab import drive

drive.mount('/content/gdrive')

In [None]:

"""
We will do the following steps in order:
1. Load and normalizing the CIFAR10 training and test datasets using 'torchvision'
2. Define a Convolutional Neural Network
3. Define a loss function
4. Train the network on the training data
5. Test the network on the test data
"""
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.init as init
import copy

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

############################################
# 1. Loading and normalizing CIFAR-10
############################################

# The output of torchvision datasets are PILImage images of range [0, 1].
# We transform them to Tensors of normalized range [-1, 1].

transform_train = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.RandomCrop(32, padding=4),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

transform_test = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=8, pin_memory=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8, pin_memory=True)


############################################
# 2. Define a Convolutional Neural Network
############################################

#### kaiming initialization ####
def He_init(layer):
  if isinstance(layer, nn.Linear):
      init.kaiming_normal_(layer.weight)
  elif isinstance(layer, nn.Conv2d):
      init.kaiming_normal_(layer.weight)
  #elif isinstance(layer, nn.BatchNorm2d):
  #    layer.weight.data.fill_(1)
  #    layer.bias.data.zero_()

#### Identitiy Mapping for matching dimension ####      
class IdentityMapping(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
      super(IdentityMapping, self).__init__()
		
      self.pooling = nn.MaxPool2d(1, stride=stride)
      self.dist_channels = out_channels - in_channels
    
    def forward(self, x):
      x = F.pad(x, (0, 0, 0, 0, 0, self.dist_channels))
      x = self.pooling(x)
      return x
	
#### each ResNet Block ####
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, sub_sampling=False):
      super(ResNetBlock, self).__init__()
      
      self.block_bn1 = nn.BatchNorm2d(in_channels)
      self.block_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                                  stride=stride, padding=1, bias=False)

      self.block_bn2 = nn.BatchNorm2d(out_channels)
      self.block_conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                                  stride=1, padding=1, bias=False)

      
      if sub_sampling:
        self.sub_sampling = IdentityMapping(in_channels, out_channels, stride)
      else:
        self.sub_sampling = None
      
    def forward(self, x):
      # make shortcut
      shortcut = x
      
      # first layer of block
      x = self.block_bn1(x)
      x = F.relu(x)
      x = self.block_conv1(x)
      
      # second layer of block
      x = self.block_bn2(x)
      x = F.relu(x)
      x = self.block_conv2(x)

      # Identitiy Mapping
      if self.sub_sampling is not None:
        shortcut = self.sub_sampling(shortcut)

      x += shortcut
      #x = F.relu(x)
      
      return x

#### Net Structure ####
class Net(nn.Module):
    # num_blocks : the number of each block.
    # block : type of block (ResNetBlock)
    def __init__(self, num_blocks, block):
        super(Net, self).__init__()
        
        # input_size = 32 X 32 X 3 (W, H, C)
        self.conv_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        # output_size = 32 X 32 X 16 (W, H, C)
        self.bn_1 = nn.BatchNorm2d(16)

        # image_size = 32 X 32 X 16 (W, H, C)
        self.block_1 = self.get_block(block, 16, 16, num_blocks, stride=1)
        # image_size = 32 X 32 X 16 (W, H, C)
        self.block_2 = self.get_block(block, 16, 32, num_blocks, stride=2)
        # image_size = 16 X 16 X 32 (W, H, C)
        self.block_3 = self.get_block(block, 32, 64, num_blocks, stride=2)
        # image_size = 8 X 8 X 64 (W, H, C)
        
        self.bn = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        
        # input_size = 1 X 1 X 64
        self.linear = nn.Linear(64, 10)
        #end classification (matching 10 classes)
        
        self.apply(He_init)
        
    def get_block(self, block, in_channels, out_channels, num_blocks, stride):
        layers = []
      
        if stride == 1:
           sub_sampling = False
        elif stride == 2:
           sub_sampling = True
          
        layers.append(block(in_channels, out_channels, stride, sub_sampling))
        
        for i in range(num_blocks - 1):
          layers.append(block(out_channels, out_channels, stride=1))
          
        return nn.Sequential(*layers)
    
    def forward(self, x):
        # first Convolution and Batch Normalization
        x = self.conv_1(x)
        x = self.bn_1(x)
        x = F.relu(x)
        
        # block step
        x = self.block_1(x)
        x = self.block_2(x)
        x = self.block_3(x)
        
        x = self.bn(x)
        x = self.relu(x)
        
        # global_average_pooling and Fully-connected layer
        # (8 X 8 X 64) image
        x = F.avg_pool2d(x, 64)
        # (1 X 1 X 64) Node
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        
        return x

# 6n+2 layer creates using Net(n, block)
net = Net(3, ResNetBlock).to(device)

############################################
# 3. Define a Loss function and optimizer
############################################
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001)
#each 81, 123, 163 change the learning rate
step_lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[150, 250], gamma=0.1)

############################################
# 4. Train the network
############################################
def train():
    best_acc = 0
    
    for epoch in range(500):  # loop over the dataset multiple times

        running_loss = 0.0
        step_lr_scheduler.step()
        net.train()
        
        for i, data in enumerate(trainloader):
            # get the inputs
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 100 == 99:  # print every 100 mini-batches
                print('[{}, {}] loss: {:.4f}'.format(epoch + 1, i + 1, running_loss / 100))
                running_loss = 0.0

        print('epoch', epoch + 1, " : " , end=" ")

        # save the best model
        test_acc = test()
        if test_acc > best_acc:
            best_acc = test_acc
            best_model = copy.deepcopy(net)
            
            torch.save(best_model.state_dict(), 'gdrive/My Drive/best_model.pt')    
    print('Finished Training')


############################################
# 5. Test the network on the test data
############################################
def test():
    correct = 0
    total = 0
    accuracy = 0
    net.eval()
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            accuracy = 100 * correct / total
    print('Accuracy on test images: {:.2f}%'.format(accuracy))
    return accuracy


if __name__ == '__main__':
    train()