In [1]:
import argparse
import json
import logging
import os
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
import torchvision

from torchvision import datasets, transforms

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.set_float32_matmul_precision('high') #Uses TF32 when available

In [2]:
def _get_model():
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = nn.Conv2d(6, 16, 5)
            self.fc1 = nn.Linear(16 * 5 * 5, 120)
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, 10)

        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = x.view(-1, 16 * 5 * 5)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    return Net()

In [3]:
# Define data augmentation
def _get_transforms():
        transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
        return transform

In [4]:
def _get_dataloaders(batch_size):
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=_get_transforms())
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=_get_transforms())
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False)
    return trainloader, testloader

In [5]:
def test(model, test_loader, device):
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='mean').item()  # sum up batch loss
            pred = output.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f"Test set: Average loss: {test_loss}, Accuracy: {correct / len(test_loader.dataset)}\n")

In [6]:
import time
def train(model, batch_size, epochs):
    torch.manual_seed(0)
    lr = 0.01
    momentum=0.9
    train_loader, test_loader = _get_dataloaders(batch_size)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

    for epoch in range(1, epochs + 1):
        start_time = time.time()
        for batch_idx, (data, target) in enumerate(train_loader, 1):
            data, target = data.to(device), target.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

        print(f"Train Epoch: {epoch} Epoch time: {time.time()-start_time:0.4f} Loss: {loss.item()}")
        test(model, test_loader, device)

In [7]:
%%time
print("Train in eager mode on CIFAR-10")
model = _get_model().to(device)

train(model, batch_size=16, epochs=50)

Train in eager mode on CIFAR-10
Files already downloaded and verified
Files already downloaded and verified
Train Epoch: 1 Epoch time: 18.4492 Loss: 1.5229963064193726
Test set: Average loss: -0.11479381905794143, Accuracy: 0.3574

Train Epoch: 2 Epoch time: 17.1039 Loss: 1.255014419555664
Test set: Average loss: -0.12258006700873375, Accuracy: 0.4207

Train Epoch: 3 Epoch time: 17.0540 Loss: 1.2103943824768066
Test set: Average loss: -0.14219269587993622, Accuracy: 0.4306

Train Epoch: 4 Epoch time: 17.1195 Loss: 1.6141940355300903
Test set: Average loss: -0.1223116991698742, Accuracy: 0.4078

Train Epoch: 5 Epoch time: 17.1593 Loss: 1.7167662382125854
Test set: Average loss: -0.1371232023358345, Accuracy: 0.427

Train Epoch: 6 Epoch time: 17.0450 Loss: 1.6442832946777344
Test set: Average loss: -0.11765576581060887, Accuracy: 0.4195

Train Epoch: 7 Epoch time: 17.2778 Loss: 1.1586203575134277
Test set: Average loss: -0.12454334303736686, Accuracy: 0.4175

Train Epoch: 8 Epoch time: 1

In [11]:
%%time
import torch._inductor.config
model = _get_model().to(device)
model = torch.compile(model, backend="inductor",
                      mode="max-autotune")

randinput = torch.randn(16,3,32,32).to(device)
randoutput = torch.randn(16,10).to(device)

print('Generating forward and backward graphs')
out = model(randinput)
nn.CrossEntropyLoss()(out, randoutput).backward()

Generating forward and backward graphs
CPU times: user 1.01 s, sys: 12 ms, total: 1.02 s
Wall time: 1.02 s


In [12]:
%%time
print("Train in compiled mode on CIFAR-10")
model = _get_model().to(device)
train(model, batch_size=16, epochs=50)

Train in compiled mode on CIFAR-10
Files already downloaded and verified
Files already downloaded and verified
Train Epoch: 1 Epoch time: 16.8979 Loss: 1.1281245946884155
Test set: Average loss: -0.10141951135396958, Accuracy: 0.3753

Train Epoch: 2 Epoch time: 16.8175 Loss: 1.5165073871612549
Test set: Average loss: -0.11379602966904641, Accuracy: 0.4128

Train Epoch: 3 Epoch time: 16.8435 Loss: 1.3197925090789795
Test set: Average loss: -0.14007682649493217, Accuracy: 0.4195

Train Epoch: 4 Epoch time: 17.1099 Loss: 1.7537651062011719
Test set: Average loss: -0.12210166681408882, Accuracy: 0.444

Train Epoch: 5 Epoch time: 16.9783 Loss: 1.5357666015625
Test set: Average loss: -0.15047866759300232, Accuracy: 0.4326

Train Epoch: 6 Epoch time: 16.7821 Loss: 1.6223543882369995
Test set: Average loss: -0.143238749986887, Accuracy: 0.4263

Train Epoch: 7 Epoch time: 16.8694 Loss: 1.7723138332366943
Test set: Average loss: -0.12494028667807579, Accuracy: 0.4065

Train Epoch: 8 Epoch time: 