In [2]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt

%matplotlob inline

### Device configuration

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

### Hyper parameters

In [4]:
num_epochs = 80
learning_rate = 0.1

### Image preprocessing modules

In [5]:
transform = transforms.Compose([transforms.Pad(4), 
                                transforms.RandomHorizontalFlip(), 
                                transforms.RandomCrop(32),
                                transforms.ToTensor()])

### Cifar-10 dataset and dataloader

In [6]:
train_dataset = torchvision.datasets.CIFAR10(download=True, root='./cifar10', train=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(download=True, root='./cifar10', train=False, transform=transforms.ToTensor())

train_dataloader = torch.utils.data.DataLoader(batch_size=128, dataset=train_dataset, num_workers=4, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(batch_size=128, dataset=test_dataset, num_workers=4, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


### 3*3 convolution

In [7]:
def conv_3x3(in_channels, out_channels, stride=1):
    return nn.Conv2d(kernel_size=3, 
                     in_channels=in_channels, 
                     out_channels=out_channels, 
                     stride=stride, 
                     padding=1)

### Residual block

In [8]:
class ResidualBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        
        self.conv1 = conv_3x3(in_channels, out_channels, stride)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        
        self.conv2 = conv_3x3(out_channels, out_channels)
        self.downsample = downsample
        
    def forward(self, x):

        residual = x
        
        x = self.conv1(x)
        x = self.bn(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.bn(x)
        
        if self.downsample:
            residual = self.downsample(residual)
            
        return self.relu(x + residual)

### ResNet

In [9]:
class ResNet(nn.Module):
    
    def __init__(self, block_class, n, num_classes=10):
        super(ResNet, self).__init__()
        
        self.n = n
        self.block_class = block_class
        
        self.conv = conv_3x3(3, 16)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        
        self.res_block_16channels = self.make_res_block(16, 16)
        self.res_block_32channels = self.make_res_block(16, 32, 2)
        self.res_block_64channels = self.make_res_block(32, 64, 2)
        
        self.avg_pool = nn.AvgPool2d(kernel_size=8)
        self.fc = nn.Linear(64, num_classes)
        
    def make_res_block(self, in_channels, out_channels, stride=1):
        
        layers = []
        
        # generate downsample block
        downsample = None
        if stride != 1:
            downsample = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, stride), 
                                       nn.BatchNorm2d(out_channels))
        layers.append(self.block_class(in_channels, out_channels, stride, downsample))
    
        # generate other blocks
        for _ in range(self.n - 1):
            layers.append(self.block_class(in_channels, out_channels))
            
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        
        x = self.res_block_16channels(x)
        x = self.res_block_32channels(x)
        x = self.res_block_64channels(x)
        
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        
        return x

### Define models etc.

In [10]:
model = ResNet(ResidualBlock, 1).to(device)

optimizer = torch.optim.SGD(lr=learning_rate, momentum=0.9, weight_decay=0.0001, params=model.parameters())
criterion = nn.CrossEntropyLoss()

### Update learning rate

In [11]:
def update_lr(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

### Train the model

In [14]:
losses = []
curr_lr = learning_rate
for epoch in range(num_epochs):
    
    total = 0
    correct = 0
    
    for i, (images, labels) in enumerate(train_dataloader):
        
        images = images.to(device)
        labels = labels.to(device)
        
        # forward
        outputs = model(images)
    
        # loss
        loss = criterion(outputs, labels)
        losses.append(loss)
        
        # backward
        model.zero_grad()
        loss.backward()
        optimizer.step()
        
        # loss
        if (i + 1) % 100 == 0:
            print('epoch {} round {} loss {}'.format(epoch, i+1, loss))
            
    # update learning rate
    if (epoch + 1) % 20 == 0:
        curr_lr /= 3
        update_lr(optimizer, curr_lr)

epoch 0 round 100 loss 1.7892673015594482
epoch 0 round 200 loss 1.5608744621276855
epoch 0 round 300 loss 1.5897588729858398
epoch 1 round 100 loss 1.4723455905914307


Process Process-16:
Process Process-14:
Process Process-15:
Process Process-13:
  File "/Users/dushuchen/anaconda/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/dushuchen/anaconda/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/Users/dushuchen/anaconda/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/Users/dushuchen/anaconda/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/Users/dushuchen/anaconda/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/Users/dushuchen/anaconda/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/dushuchen/anaconda/lib/python3.6

KeyboardInterrupt: 

### Plot training curve

In [None]:
plt.figure(figsize=(20, 10))
plt.plot(losses)

plt.title('training loss')

plt.xlabel('iteration')
plt.ylabel('training loss')

plt.grid(True)

### Test the model

In [None]:
model.eval()

with torch.no_grad():
    
    total = 0
    correct = 0
    
    for images, labels in test_dataloader:
        outputs = model(images)
        
        _, prediction = torch.max(outputs.data, 1)
        
        total += labels.size(0)
        correct += (prediction == labels).sum().item()
    
    print('test accuracy: {}'.format(correct / total))