In [1]:
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms


datadir = './'
batch_size = 1024

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
)

trainset = torchvision.datasets.CIFAR10(root=datadir, train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root=datadir, train=False,
                                       download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

In [2]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=(3, 3), stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=(3, 3))
        self.norm = nn.BatchNorm2d(32)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(32, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        x = F.relu(self.norm(x))
        x = torch.flatten(self.pool(x), 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [3]:
from tqdm.notebook import tqdm
import composer.functional as cf

num_epochs = 5

def train_and_eval(model, train_loader, test_loader):
    torch.manual_seed(42)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    opt = torch.optim.Adam(model.parameters())
    for epoch in range(num_epochs):
        print(f"---- Beginning epoch {epoch} ----")
        model.train()
        progress_bar = tqdm(train_loader)
        for X, y in progress_bar:
            X = X.to(device)
            y = y.to(device)
            y_hat = model(X)
            loss = F.cross_entropy(y_hat, y)
            progress_bar.set_postfix_str(f"train loss: {loss.item():.4f}")
            loss.backward()
            opt.step()
            opt.zero_grad()
        model.eval()
        num_right = 0
        eval_size = 0
        for X, y in test_loader:
            X = X.to(device)
            y = y.to(device)
            y_hat = model(X)
            num_right += (y_hat.argmax(dim=1) == y).sum().item()
            eval_size += len(y)
        acc_percent = 100 * num_right / eval_size
        print(f"Epoch {epoch} validation accuracy: {acc_percent:.2f}%")

In [4]:
model = Net()
train_and_eval(model, trainloader, testloader)

---- Beginning epoch 0 ----


  0%|          | 0/49 [00:00<?, ?it/s]

Epoch 0 validation accuracy: 26.36%
---- Beginning epoch 1 ----


  0%|          | 0/49 [00:00<?, ?it/s]

Epoch 1 validation accuracy: 34.80%
---- Beginning epoch 2 ----


  0%|          | 0/49 [00:00<?, ?it/s]

Epoch 2 validation accuracy: 38.04%
---- Beginning epoch 3 ----


  0%|          | 0/49 [00:00<?, ?it/s]

Epoch 3 validation accuracy: 41.64%
---- Beginning epoch 4 ----


  0%|          | 0/49 [00:00<?, ?it/s]

Epoch 4 validation accuracy: 41.74%


In [5]:
# create dataloaders for the train and test sets
shared_transforms = [
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]

train_transforms = shared_transforms[:] + [cf.colout_batch]

test_transform = transforms.Compose(shared_transforms)
train_transform = transforms.Compose(train_transforms)

trainset = torchvision.datasets.CIFAR10(root=datadir, train=True,
                                        download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                        shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root=datadir, train=False,
                                        download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                          shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
model = Net()
# only use one data augmentation since our small model runs quickly
# and allows the dataloader little time to do anything fancy
train_and_eval(model, trainloader, testloader)

---- Beginning epoch 0 ----


  0%|          | 0/49 [00:00<?, ?it/s]

Epoch 0 validation accuracy: 25.61%
---- Beginning epoch 1 ----


  0%|          | 0/49 [00:00<?, ?it/s]

Epoch 1 validation accuracy: 30.13%
---- Beginning epoch 2 ----


  0%|          | 0/49 [00:00<?, ?it/s]

Epoch 2 validation accuracy: 35.77%
---- Beginning epoch 3 ----


  0%|          | 0/49 [00:00<?, ?it/s]

Epoch 3 validation accuracy: 40.30%
---- Beginning epoch 4 ----


  0%|          | 0/49 [00:00<?, ?it/s]

Epoch 4 validation accuracy: 42.62%


In [7]:
# squeeze-excite can add a lot of overhead for small
# conv2d operations, so only add it after convs with a
# minimum number of channels
cf.apply_squeeze_excite(model, latent_channels=64, min_channels=16)

Net(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2))
  (conv2): SqueezeExciteConv2d(
    (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
    (se): SqueezeExcite2d(
      (pool_and_mlp): Sequential(
        (0): AdaptiveAvgPool2d(output_size=1)
        (1): Flatten(start_dim=1, end_dim=-1)
        (2): Linear(in_features=32, out_features=64, bias=False)
        (3): ReLU()
        (4): Linear(in_features=64, out_features=32, bias=False)
        (5): Sigmoid()
      )
    )
  )
  (norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): AdaptiveAvgPool2d(output_size=1)
  (fc1): Linear(in_features=32, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=10, bias=True)
)

In [8]:
train_and_eval(model, trainloader, testloader)

---- Beginning epoch 0 ----


  0%|          | 0/49 [00:00<?, ?it/s]

Epoch 0 validation accuracy: 42.76%
---- Beginning epoch 1 ----


  0%|          | 0/49 [00:00<?, ?it/s]

Epoch 1 validation accuracy: 43.39%
---- Beginning epoch 2 ----


  0%|          | 0/49 [00:00<?, ?it/s]

Epoch 2 validation accuracy: 45.60%
---- Beginning epoch 3 ----


  0%|          | 0/49 [00:00<?, ?it/s]

Epoch 3 validation accuracy: 48.12%
---- Beginning epoch 4 ----


  0%|          | 0/49 [00:00<?, ?it/s]

Epoch 4 validation accuracy: 48.12%
