## Install

In [0]:
!pip3 install torch torchvision numpy



## Imports

In [0]:
from matplotlib import pyplot as plt
import numpy as np

import torch as th
from torch import nn
import torchvision
from torchvision import transforms

## Config

In [0]:
device = th.device('cuda' if th.cuda.is_available() else 'cpu')
print(f'Using {device}')

num_epochs = 80
batch_size = 100
learning_rate = 0.001

Using cuda


## CIFAR-10 Dataset

In [0]:
transform = transforms.Compose([
    transforms.Pad(4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor(),
])

train_dataset = torchvision.datasets.CIFAR10(
    root='~/code/data/cifar10/',
    train=True,
    transform=transform,
    download=True)
test_dataset = torchvision.datasets.CIFAR10(
    root='~/code/data/cifar10/',
    train=False,
    transform=transforms.ToTensor(),
    download=True)

# Data loader (input pipeline)
train_loader = th.utils.data.DataLoader(dataset=train_dataset,
                                        batch_size=batch_size,
                                        shuffle=True)
test_loader = th.utils.data.DataLoader(dataset=test_dataset,
                                       batch_size=batch_size,
                                       shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


## Model

In [0]:
# 3x3 convolution.
def conv3x3(in_channels, out_channels, stride=1):
  return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
                   padding=1, bias=False)

# Residual block.
class ResidualBlock(nn.Module):
  def __init__(self, in_channels, out_channels, stride=1, downsample=None):
    super().__init__()
    self.conv1 = conv3x3(in_channels, out_channels, stride)
    self.bn1 = nn.BatchNorm2d(out_channels)
    self.relu = nn.ReLU(inplace=True)
    self.conv2 = conv3x3(out_channels, out_channels)
    self.bn2 = nn.BatchNorm2d(out_channels)
    self.downsample = downsample
    
  def forward(self, x):
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)
    out = self.conv2(out)
    out = self.bn2(out)
    residual = x if not self.downsample else self.downsample(x)
    out += residual
    out = self.relu(out)
    return out

# Residual layer with several blocks.
def residual_layer(block, in_channels, out_channels, num_blocks, stride=1):
  downsample = None
  if (stride != 1) or (in_channels != out_channels):
    downsample = nn.Sequential(
        conv3x3(in_channels, out_channels, stride=stride),
        nn.BatchNorm2d(out_channels),
    )
  layers = []
  layers.append(block(in_channels, out_channels, stride, downsample))
  for i in range(1, num_blocks):
    layers.append(block(out_channels, out_channels))
  return nn.Sequential(*layers)

# ResNet
class ResNet(nn.Module):
  def __init__(self, block, block_counts, num_classes=10):
    super().__init__()
    self.conv = conv3x3(3, 16)
    self.bn = nn.BatchNorm2d(16)
    self.relu = nn.ReLU(inplace=True)
    self.layer1 = residual_layer(block, 16, 16, block_counts[0])
    self.layer2 = residual_layer(block, 16, 32, block_counts[1], stride=2)
    self.layer3 = residual_layer(block, 32, 64, block_counts[2], stride=2)
    self.avg_pool = nn.AvgPool2d(8)
    self.fc = nn.Linear(64, num_classes)
    
  def forward(self, x):
    out = self.conv(x)
    out = self.bn(out)
    out = self.relu(out)
    out = self.layer1(out)
    out = self.layer2(out)
    out = self.layer3(out)
    out = self.avg_pool(out)
    out = out.view(out.size(0), -1)
    out = self.fc(out)
    return out

model = ResNet(ResidualBlock, [2, 2, 2]).to(device)


## Train

In [0]:
# Loss and optimizer.
loss_fn = nn.CrossEntropyLoss()
optimizer = th.optim.Adam(model.parameters(), lr=learning_rate)

# Learning rate update.
def update_lr(optimizer, lr):
  for param_group in optimizer.param_groups:
    param_group['lr'] = lr
    

num_steps = len(train_loader)
curr_lr = learning_rate
for epoch in range(num_epochs):
  for step, (images, labels) in enumerate(train_loader):
    images = images.to(device)
    labels = labels.to(device)

    # Forward
    outputs = model(images)
    loss = loss_fn(outputs, labels)
    
    # Backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (step + 1) % 100 == 0:
      print(f'Epoch [{epoch+1}/{num_epochs}], Step [{step+1}/{num_steps}], '
            f'Loss: {loss.item():.4}')
  
  # Decay learning rate.
  if (epoch + 1) % 20 == 0:
    curr_lr /= 3
    update_lr(optimizer, curr_lr)

Epoch [1/80], Step [100/500], Loss: 1.711
Epoch [1/80], Step [200/500], Loss: 1.538
Epoch [1/80], Step [300/500], Loss: 1.442
Epoch [1/80], Step [400/500], Loss: 1.233
Epoch [1/80], Step [500/500], Loss: 1.146
Epoch [2/80], Step [100/500], Loss: 1.155
Epoch [2/80], Step [200/500], Loss: 1.047
Epoch [2/80], Step [300/500], Loss: 1.102
Epoch [2/80], Step [400/500], Loss: 0.9901
Epoch [2/80], Step [500/500], Loss: 0.916
Epoch [3/80], Step [100/500], Loss: 1.159
Epoch [3/80], Step [200/500], Loss: 0.9383
Epoch [3/80], Step [300/500], Loss: 0.8163
Epoch [3/80], Step [400/500], Loss: 0.8732
Epoch [3/80], Step [500/500], Loss: 0.7863
Epoch [4/80], Step [100/500], Loss: 0.6293
Epoch [4/80], Step [200/500], Loss: 0.948
Epoch [4/80], Step [300/500], Loss: 0.7323
Epoch [4/80], Step [400/500], Loss: 0.5984
Epoch [4/80], Step [500/500], Loss: 0.7478
Epoch [5/80], Step [100/500], Loss: 0.8586
Epoch [5/80], Step [200/500], Loss: 0.6831
Epoch [5/80], Step [300/500], Loss: 0.6886
Epoch [5/80], Step [40

Epoch [14/80], Step [300/500], Loss: 0.4608
Epoch [14/80], Step [400/500], Loss: 0.3135
Epoch [14/80], Step [500/500], Loss: 0.4462
Epoch [15/80], Step [100/500], Loss: 0.5551
Epoch [15/80], Step [200/500], Loss: 0.6233
Epoch [15/80], Step [300/500], Loss: 0.5657
Epoch [15/80], Step [400/500], Loss: 0.3665
Epoch [15/80], Step [500/500], Loss: 0.4609
Epoch [16/80], Step [100/500], Loss: 0.3753
Epoch [16/80], Step [200/500], Loss: 0.5481
Epoch [16/80], Step [300/500], Loss: 0.6467
Epoch [16/80], Step [400/500], Loss: 0.4033
Epoch [16/80], Step [500/500], Loss: 0.4289
Epoch [17/80], Step [100/500], Loss: 0.4285
Epoch [17/80], Step [200/500], Loss: 0.284
Epoch [17/80], Step [300/500], Loss: 0.3859
Epoch [17/80], Step [400/500], Loss: 0.3915
Epoch [17/80], Step [500/500], Loss: 0.2972
Epoch [18/80], Step [100/500], Loss: 0.3581
Epoch [18/80], Step [200/500], Loss: 0.3169
Epoch [18/80], Step [300/500], Loss: 0.2848
Epoch [18/80], Step [400/500], Loss: 0.324
Epoch [18/80], Step [500/500], Los

Epoch [27/80], Step [400/500], Loss: 0.2027
Epoch [27/80], Step [500/500], Loss: 0.258
Epoch [28/80], Step [100/500], Loss: 0.2994
Epoch [28/80], Step [200/500], Loss: 0.367
Epoch [28/80], Step [300/500], Loss: 0.2483
Epoch [28/80], Step [400/500], Loss: 0.2916
Epoch [28/80], Step [500/500], Loss: 0.1852
Epoch [29/80], Step [100/500], Loss: 0.2242
Epoch [29/80], Step [200/500], Loss: 0.3158
Epoch [29/80], Step [300/500], Loss: 0.3571
Epoch [29/80], Step [400/500], Loss: 0.3707
Epoch [29/80], Step [500/500], Loss: 0.2543
Epoch [30/80], Step [100/500], Loss: 0.2209
Epoch [30/80], Step [200/500], Loss: 0.2618
Epoch [30/80], Step [300/500], Loss: 0.2028
Epoch [30/80], Step [400/500], Loss: 0.2935
Epoch [30/80], Step [500/500], Loss: 0.2072
Epoch [31/80], Step [100/500], Loss: 0.2393
Epoch [31/80], Step [200/500], Loss: 0.2948
Epoch [31/80], Step [300/500], Loss: 0.3133
Epoch [31/80], Step [400/500], Loss: 0.2679
Epoch [31/80], Step [500/500], Loss: 0.3198
Epoch [32/80], Step [100/500], Los

Epoch [40/80], Step [500/500], Loss: 0.1724
Epoch [41/80], Step [100/500], Loss: 0.219
Epoch [41/80], Step [200/500], Loss: 0.07018
Epoch [41/80], Step [300/500], Loss: 0.2831
Epoch [41/80], Step [400/500], Loss: 0.2509
Epoch [41/80], Step [500/500], Loss: 0.3095
Epoch [42/80], Step [100/500], Loss: 0.1266
Epoch [42/80], Step [200/500], Loss: 0.2
Epoch [42/80], Step [300/500], Loss: 0.1179
Epoch [42/80], Step [400/500], Loss: 0.1778
Epoch [42/80], Step [500/500], Loss: 0.1977
Epoch [43/80], Step [100/500], Loss: 0.1754
Epoch [43/80], Step [200/500], Loss: 0.1131
Epoch [43/80], Step [300/500], Loss: 0.1942
Epoch [43/80], Step [400/500], Loss: 0.2143
Epoch [43/80], Step [500/500], Loss: 0.2518
Epoch [44/80], Step [100/500], Loss: 0.1977
Epoch [44/80], Step [200/500], Loss: 0.1878
Epoch [44/80], Step [300/500], Loss: 0.1396
Epoch [44/80], Step [400/500], Loss: 0.07896
Epoch [44/80], Step [500/500], Loss: 0.2089
Epoch [45/80], Step [100/500], Loss: 0.2353
Epoch [45/80], Step [200/500], Los

Epoch [54/80], Step [100/500], Loss: 0.2042
Epoch [54/80], Step [200/500], Loss: 0.3252
Epoch [54/80], Step [300/500], Loss: 0.1563
Epoch [54/80], Step [400/500], Loss: 0.2023
Epoch [54/80], Step [500/500], Loss: 0.1492
Epoch [55/80], Step [100/500], Loss: 0.2223
Epoch [55/80], Step [200/500], Loss: 0.2036
Epoch [55/80], Step [300/500], Loss: 0.2828
Epoch [55/80], Step [400/500], Loss: 0.1725
Epoch [55/80], Step [500/500], Loss: 0.128
Epoch [56/80], Step [100/500], Loss: 0.1771
Epoch [56/80], Step [200/500], Loss: 0.07497
Epoch [56/80], Step [300/500], Loss: 0.2644
Epoch [56/80], Step [400/500], Loss: 0.1623
Epoch [56/80], Step [500/500], Loss: 0.2105
Epoch [57/80], Step [100/500], Loss: 0.1762
Epoch [57/80], Step [200/500], Loss: 0.1549
Epoch [57/80], Step [300/500], Loss: 0.2236
Epoch [57/80], Step [400/500], Loss: 0.286
Epoch [57/80], Step [500/500], Loss: 0.2549
Epoch [58/80], Step [100/500], Loss: 0.1637
Epoch [58/80], Step [200/500], Loss: 0.2434
Epoch [58/80], Step [300/500], Lo

Epoch [67/80], Step [200/500], Loss: 0.1054
Epoch [67/80], Step [300/500], Loss: 0.168
Epoch [67/80], Step [400/500], Loss: 0.09866
Epoch [67/80], Step [500/500], Loss: 0.09712
Epoch [68/80], Step [100/500], Loss: 0.25
Epoch [68/80], Step [200/500], Loss: 0.08042
Epoch [68/80], Step [300/500], Loss: 0.1209
Epoch [68/80], Step [400/500], Loss: 0.1039
Epoch [68/80], Step [500/500], Loss: 0.1045
Epoch [69/80], Step [100/500], Loss: 0.2255
Epoch [69/80], Step [200/500], Loss: 0.1349
Epoch [69/80], Step [300/500], Loss: 0.1225
Epoch [69/80], Step [400/500], Loss: 0.2016
Epoch [69/80], Step [500/500], Loss: 0.1131
Epoch [70/80], Step [100/500], Loss: 0.172
Epoch [70/80], Step [200/500], Loss: 0.1455
Epoch [70/80], Step [300/500], Loss: 0.09429
Epoch [70/80], Step [400/500], Loss: 0.08697
Epoch [70/80], Step [500/500], Loss: 0.1469
Epoch [71/80], Step [100/500], Loss: 0.1391
Epoch [71/80], Step [200/500], Loss: 0.1861
Epoch [71/80], Step [300/500], Loss: 0.2539
Epoch [71/80], Step [400/500], 

Epoch [80/80], Step [300/500], Loss: 0.1519
Epoch [80/80], Step [400/500], Loss: 0.2194
Epoch [80/80], Step [500/500], Loss: 0.1939


## Test

In [0]:
with th.no_grad():
  correct, total = 0, 0
  for images, labels in test_loader:
    images = images.to(device)
    labels = labels.to(device)
    outputs = model(images)
    _, predicted = th.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
  accuracy = correct / total
  print(f'Accuracy of model on 10000 test images: {100 * accuracy:0.2f}%')

Accuracy of model on 10000 test images: 85.33%


## Save model

In [0]:
th.save(model.state_dict(), '/tmp/resnet.ckpt')