In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import time
import argparse

In [None]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    
    def forward(self, x):
        return self.fn(x) + x

def ConvMixer(dim, depth, kernel_size=5, patch_size=2, n_classes=10):
    return nn.Sequential(
        nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
        nn.GELU(),
        nn.BatchNorm2d(dim),
        *[nn.Sequential(
            Residual(nn.Sequential(
                nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
                nn.GELU(),
                nn.BatchNorm2d(dim)
            )),
            nn.Conv2d(dim, dim, kernel_size=1),
            nn.GELU(),
            nn.BatchNorm2d(dim)
        ) for i in range(depth)],
        nn.AdaptiveAvgPool2d(1,1),
        nn.Flatten(),
        nn.Linear(dim, n_classes)
    )

In [None]:
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32, scale=(0.75, 1.0), ratio = (1.0, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5), 
    transforms.RandAugment(num_ops=1, magnitude=8),
    transforms.ColorJitter(0.1, 0.1, 0.1),
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std),
    transforms.RandomErasing(p=0.25)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std)
])

epochs = 25
batch_size = 512

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=trains_transforms)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(root='.data', train=False,
                                        download=True, transform=test_transform)

testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                        shuffle=False, num_workers=4)

In [None]:
torch.cuda.get_device_name(0)

In [None]:
lr_scheduler = lambda t: np.interp([t], [0, epochs*2//5, epochs*4//5, epochs],
                                    [0,0.01,0.01/20.0,0])[0]

depth = 10
hdim = 256
psize = 2
conv_ks = 5
clip_norm = True

model = ConvMixer(hdim, depth, patch_size =psize, kernel_size=conv_ks, n_classes=10)
model = nn.DataParallel(model, device_ids=[0]).cuda()

opt = optim.AdamW(model.parameters(), lr=0.01, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler()

for epoch in range(epochs):
    start = time.time()
    train_loss, train_acc, n = 0, 0, 0
    for i, (X,y) in enumerate(trainloader):
        model.train()
        X, y = X.cuda(), y.cuda()

        lr = lr_scheduler()
