In [1]:
import argparse
import json
import logging
import os
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
import torchvision

from torchvision import datasets, transforms

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.set_float32_matmul_precision('high') #Uses TF32 when available

In [2]:
def _get_model(model_type='custom'):
    if model_type == 'resnet':
        model = torchvision.models.resnet18(weights='IMAGENET1K_V1')
        model.fc = nn.Linear(512, 10)
        return model
    else:
        class Net(nn.Module):
            def __init__(self):
                super(Net, self).__init__()
                self.conv1 = nn.Conv2d(3, 6, 5)
                self.pool = nn.MaxPool2d(2, 2)
                self.conv2 = nn.Conv2d(6, 16, 5)
                self.fc1 = nn.Linear(16 * 5 * 5, 120)
                self.fc2 = nn.Linear(120, 84)
                self.fc3 = nn.Linear(84, 10)

            def forward(self, x):
                x = self.pool(F.relu(self.conv1(x)))
                x = self.pool(F.relu(self.conv2(x)))
                x = x.view(-1, 16 * 5 * 5)
                x = F.relu(self.fc1(x))
                x = F.relu(self.fc2(x))
                x = self.fc3(x)
                return x
    return Net()

In [3]:
# Define data augmentation
def _get_transforms():
        transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
        return transform

In [4]:
def _get_dataloaders(batch_size):
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=_get_transforms())
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=_get_transforms())
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False)
    return trainloader, testloader

In [5]:
def test(model, test_loader, device):
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f"Test set: Average loss: {test_loss}, Accuracy: {correct / len(test_loader.dataset)}\n")

In [6]:
import time
def train(model, batch_size, epochs):
    torch.manual_seed(0)
    lr = 0.01
    momentum=0.9
    train_loader, test_loader = _get_dataloaders(batch_size)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

    for epoch in range(1, epochs + 1):
        start_time = time.time()
        for batch_idx, (data, target) in enumerate(train_loader, 1):
            data, target = data.to(device), target.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

        print(f"Train Epoch: {epoch} Epoch time: {time.time()-start_time:0.4f} Loss: {loss.item()}")
        test(model, test_loader, device)

In [7]:
%%time
print("Train resnet on CIFAR-10 without compilation")
model = _get_model('resnet').to(device)
train(model, batch_size=16, epochs=2)

Train resnet on CIFAR-10 without compilation
Files already downloaded and verified
Files already downloaded and verified
Train Epoch: 1 Epoch time: 35.7816 Loss: 1.7419991493225098
Test set: Average loss: -1.5503838454246521, Accuracy: 0.3308

Train Epoch: 2 Epoch time: 33.3076 Loss: 1.5548328161239624
Test set: Average loss: -1.5215434109210968, Accuracy: 0.3991

CPU times: user 1min 18s, sys: 1.23 s, total: 1min 19s
Wall time: 1min 21s


In [8]:
# %%time
# import torch._inductor.config
# model = _get_model('resnet50').to(device)
# model = torch.compile(model, backend="inductor")

# # model = torch.compile(model, backend="inductor",
# #                       options={'trace.graph_diagram':False,
# #                                      'trace.enabled':False}

# randinput = torch.randn(256,3,32,32).to(device)
# randoutput = torch.randn(256,10).to(device)

# print('Generating forward and backward graphs')
# out = model(randinput)
# nn.CrossEntropyLoss()(out, randoutput).backward()

In [9]:
print("Train compiled resnet on CIFAR-10")
model = _get_model('resnet').to(device)

model = torch.compile(model, backend="inductor", 
                      options={'trace.graph_diagram':False,
                               'trace.enabled':False})

train(model, batch_size=16, epochs=2)

Train compiled resnet on CIFAR-10
Files already downloaded and verified
Files already downloaded and verified
Train Epoch: 1 Epoch time: 46.4201 Loss: 1.404991865158081
Test set: Average loss: -1.601543181371689, Accuracy: 0.3479

Train Epoch: 2 Epoch time: 37.6773 Loss: 1.2884515523910522
Test set: Average loss: -1.9376735474586486, Accuracy: 0.4442

