# CIFAR-10 ResNet Reference Implementation in Pytorch.
See the mlax implementation in `resnet.ipynb` notebook.

In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision

### Load the CIFAR-10 dataset.

In [2]:
cifar_train = torchvision.datasets.CIFAR10(
    root="../data",
    train=True,
    download=True,
    transform = torchvision.transforms.Compose([
    torchvision.transforms.AutoAugment(),
    torchvision.transforms.ToTensor()
])
)
cifar_test = torchvision.datasets.CIFAR10(
    root="../data",
    train=False,
    download=True,
    transform=torchvision.transforms.ToTensor()
)
print(cifar_train.data.shape)
print(cifar_test.data.shape)

Files already downloaded and verified
Files already downloaded and verified
(50000, 32, 32, 3)
(10000, 32, 32, 3)


### Batch the MNIST data with Pytorch dataloaders.

In [3]:
batch_size = 128
train_dataloader = DataLoader(cifar_train, batch_size, shuffle=True, num_workers=6)
test_dataloader = DataLoader(cifar_test, batch_size, num_workers=6)
print(len(train_dataloader), len(test_dataloader))

391 79


In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


### Define ResNet using modules.

In [5]:
# Residual block without downsampling (N, C, H, W) -> (N, C, H, W)
class ResBlock1(nn.Module):
    def __init__(self, filters):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(filters, filters, 3, padding=1, bias=False),
            nn.BatchNorm2d(filters),
            nn.ReLU(inplace=True),
            nn.Conv2d(filters, filters, 3, padding=1, bias=False),
            nn.BatchNorm2d(filters),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.block(x) + x

# Residual block with downsampling (N, C, H, W) -> (N, 2*C, H/2, W/2) 
class ResBlock2(nn.Module):
    def __init__(self, filters):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(filters, 2*filters, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(2*filters),
            nn.ReLU(inplace=True),
            nn.Conv2d(2*filters, 2*filters, 3, padding=1, bias=False),
            nn.BatchNorm2d(2*filters),
            nn.ReLU(inplace=True)
        )
        self.downsample = nn.Sequential(
            nn.Conv2d(filters, 2*filters, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(2*filters),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.block(x) + self.downsample(x)

class ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        # (N, 3, 32, 32)
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=1, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
        )
        # (N, 16, 32, 32)
        self.res1 = ResBlock1(16)
        # (N, 16, 32, 32)
        self.res2 = ResBlock2(16)
        # (N, 32, 16, 16)
        self.res3 = ResBlock2(32)
        # (N, 64, 8, 8)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        # (N, 64, 1, 1)
        self.flatten = nn.Flatten()
        # (N, 64)
        self.fc = nn.Linear(64, 10)
        # (N, 10)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.res1(x)
        x = self.res2(x)
        x = self.res3(x)
        x = self.avg_pool(x)
        x = self.flatten(x)
        return self.fc(x)

model = ResNet().to(memory_format=torch.channels_last)
print(model)

ResNet(
  (conv1): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (res1): ResBlock1(
    (block): Sequential(
      (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (res2): ResBlock2(
    (block): Sequential(
      (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 32, kernel_size=(3, 

### Define loss function and optimizer.

In [6]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-2)

### Define training and testing steps.

In [7]:
@torch.compile
def train_step(X, y):
    with torch.enable_grad():
        loss = loss_fn(model(X), y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return loss.item()

In [8]:
@torch.compile
def test_step(X, y):
    with torch.no_grad():
        preds = model(X)
        loss = loss_fn(preds, y)
    accurate = (preds.argmax(1) == y).type(torch.int).sum()
    return loss.item(), accurate.item()

### Define training and testing loops.

In [9]:
def train(dataloader):
    model.train()
    train_loss = 0.0
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        train_loss += train_step(X, y)

    print(f"Train loss: {train_loss / len(dataloader)}")

In [10]:
def test(dataloader):
    model.eval()
    test_loss, accurate = 0.0, 0
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        loss, acc = test_step(X, y)
        test_loss += loss
        accurate += acc
    
    print(f"Test loss: {test_loss / len(dataloader)}, accuracy: {accurate / len(dataloader.dataset)}")

In [11]:
def train_loop(
    train_dataloader,
    test_dataloader,
    epochs,
    test_every
):
    model.to(device)
    for i in range(epochs):
        epoch = (i + 1)
        print(f"Epoch {epoch}\n----------------")
        train(train_dataloader)
        if (epoch % test_every == 0):
            test(test_dataloader)
        print(f"----------------")

### Train ResNet on CIFAR-10 dataset.

In [12]:
train_loop(train_dataloader, test_dataloader, 40, 5)

Epoch 1
----------------


  return super().apply(*args, **kwargs)  # type: ignore[misc]


Train loss: 1.824835014770098
----------------
Epoch 2
----------------
Train loss: 1.3861767706053947
----------------
Epoch 3
----------------
Train loss: 1.1793399161999794
----------------
Epoch 4
----------------
Train loss: 1.0621479351807128
----------------
Epoch 5
----------------
Train loss: 0.9868572009798816


   function: '<graph break in test_step>' (/tmp/ipykernel_58678/1809820324.py:7)
   reasons:  ___stack0 == 0.8893145322799683
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.


Test loss: 0.933613246754755, accuracy: 0.6675
----------------
Epoch 6
----------------
Train loss: 0.9368689090699491
----------------
Epoch 7
----------------
Train loss: 0.8925590068482987
----------------
Epoch 8
----------------
Train loss: 0.8561032951030585
----------------
Epoch 9
----------------
Train loss: 0.8332217353993975
----------------
Epoch 10
----------------
Train loss: 0.8082764580121735
Test loss: 0.6808276025554801, accuracy: 0.7643
----------------
Epoch 11
----------------
Train loss: 0.7779793547242498
----------------
Epoch 12
----------------
Train loss: 0.7635586895906102
----------------
Epoch 13
----------------
Train loss: 0.7482378232814467
----------------
Epoch 14
----------------
Train loss: 0.7382824225041568
----------------
Epoch 15
----------------
Train loss: 0.7150312553128928
Test loss: 0.6059607388098028, accuracy: 0.7926
----------------
Epoch 16
----------------
Train loss: 0.7075823883113959
----------------
Epoch 17
----------------
Trai