# CIFAR-10 ResNet Reference Implementation in Pytorch.
See the mlax implementation in `resnet.ipynb` notebook.

In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision

### Load the CIFAR-10 dataset.

In [2]:

cifar_train = torchvision.datasets.CIFAR10(
    root="../data",
    train=True,
    download=True,
    transform = torchvision.transforms.Compose([
    torchvision.transforms.AutoAugment(),
    torchvision.transforms.ToTensor()
])
)
cifar_test = torchvision.datasets.CIFAR10(
    root="../data",
    train=False,
    download=True,
    transform=torchvision.transforms.ToTensor()
)
print(cifar_train.data.shape)
print(cifar_test.data.shape)

Files already downloaded and verified
Files already downloaded and verified
(50000, 32, 32, 3)
(10000, 32, 32, 3)


### Batch the MNIST data with Pytorch dataloaders.

In [3]:
train_dataloader = DataLoader(cifar_train, batch_size=128, shuffle=True, num_workers=8)
test_dataloader = DataLoader(cifar_test, batch_size=128, shuffle=True, num_workers=8)
print(len(train_dataloader), len(test_dataloader))

391 79


In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

### Define ResNet using modules.

In [5]:
# Residual block without downsampling (N, C, H, W) -> (N, C, H, W)
class ResBlock1(nn.Module):
    def __init__(self, filters):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(filters, filters, 3, padding=1, bias=False),
            nn.BatchNorm2d(filters),
            nn.ReLU(inplace=True),
            nn.Conv2d(filters, filters, 3, padding=1, bias=False),
            nn.BatchNorm2d(filters),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.block(x) + x

# Residual block with downsampling (N, C, H, W) -> (N, 2*C, H/2, W/2) 
class ResBlock2(nn.Module):
    def __init__(self, filters):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(filters, 2*filters, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(2*filters),
            nn.ReLU(inplace=True),
            nn.Conv2d(2*filters, 2*filters, 3, padding=1, bias=False),
            nn.BatchNorm2d(2*filters),
            nn.ReLU(inplace=True)
        )
        self.downsample = nn.Sequential(
            nn.Conv2d(filters, 2*filters, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(2*filters),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.block(x) + self.downsample(x)

class ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        # (N, 3, 32, 32)
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=1, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
        )
        # (N, 16, 32, 32)
        self.res1 = ResBlock1(16)
        # (N, 16, 32, 32)
        self.res2 = ResBlock2(16)
        # (N, 32, 16, 16)
        self.res3 = ResBlock2(32)
        # (N, 64, 8, 8)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        # (N, 64, 1, 1)
        self.flatten = nn.Flatten()
        # (N, 64)
        self.fc = nn.Linear(64, 10)
        # (N, 10)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.res1(x)
        x = self.res2(x)
        x = self.res3(x)
        x = self.avg_pool(x)
        x = self.flatten(x)
        return self.fc(x)

model = ResNet().to(memory_format=torch.channels_last)
print(model)
model = torch.jit.script(model)

ResNet(
  (conv1): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (res1): ResBlock1(
    (block): Sequential(
      (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (res2): ResBlock2(
    (block): Sequential(
      (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 32, kernel_size=(3, 

### Define loss function and optimizer.

In [6]:
cross_entropy = torch.jit.script(nn.CrossEntropyLoss())
adam = optim.Adam(model.parameters(), lr=1e-2)

### Define training and testing loops.

In [7]:
def train(dataloader, model, loss_fn, optimizer, device):
    model.to(device)
    model.train()

    train_loss = 0
    for X, y in dataloader:
        X, y = X.to(device, memory_format=torch.channels_last), y.to(device)
        
        loss = loss_fn(model(X), y)
        train_loss += loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Train loss: {train_loss / len(dataloader)}")

In [8]:
def test(dataloader, model, loss_fn, device):
    model.to(device)
    model.eval()

    test_loss, accurate = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device, memory_format=torch.channels_last), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            accurate += (pred.argmax(1) == y).type(torch.float).sum().item()
    print(f"Test loss: {test_loss / len(dataloader)}, accuracy: {accurate / len(dataloader.dataset)}")

In [9]:
def train_loop(
    train_dataloader,
    test_dataloader,
    model, loss_fn, optimizer,
    device,
    epochs, test_every):
    for i in range(epochs):
        epoch = (i + 1)
        print(f"Epoch {epoch}\n----------------")
        train(train_dataloader, model, loss_fn, optimizer, device)
        if (epoch % test_every == 0):
            test(test_dataloader, model, loss_fn, device)
        print(f"----------------")

### Train ResNet on CIFAR-10 dataset.

In [10]:
train_loop(train_dataloader, test_dataloader, model, cross_entropy, adam, device, 50, 5)

Epoch 1
----------------
Train loss: 1.7638086080551147
----------------
Epoch 2
----------------
Train loss: 1.3673878908157349
----------------
Epoch 3
----------------
Train loss: 1.180383563041687
----------------
Epoch 4
----------------
Train loss: 1.068069338798523
----------------
Epoch 5
----------------
Train loss: 0.9908968806266785
Test loss: 0.8300878895988947, accuracy: 0.7051
----------------
Epoch 6
----------------
Train loss: 0.9373352527618408
----------------
Epoch 7
----------------
Train loss: 0.8934167623519897
----------------
Epoch 8
----------------
Train loss: 0.8572283983230591
----------------
Epoch 9
----------------
Train loss: 0.8314029574394226
----------------
Epoch 10
----------------
Train loss: 0.8067886829376221
Test loss: 0.6768896594832216, accuracy: 0.7653
----------------
Epoch 11
----------------
Train loss: 0.7879920601844788
----------------
Epoch 12
----------------
Train loss: 0.7613210678100586
----------------
Epoch 13
----------------
T