In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import OneCycleLR
from torchsummary import summary
from tqdm import tqdm
import numpy as np

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(f"Using device: {device}")

class Cutout(object):
    def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = np.ones((h, w), np.float32)
        for n in range(self.n_holes):
            y, x = np.random.randint(h), np.random.randint(w)
            y1, y2 = np.clip(y - self.length // 2, 0, h), np.clip(y + self.length // 2, 0, h)
            x1, x2 = np.clip(x - self.length // 2, 0, w), np.clip(x + self.length // 2, 0, w)
            mask[y1: y2, x1: x2] = 0.
        mask = torch.from_numpy(mask).expand_as(img)
        return img * mask

torch.manual_seed(1)
batch_size = 128
train_transforms = transforms.Compose([
    transforms.RandomRotation((-7.0, 7.0), fill=(0.1307,)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
    Cutout(n_holes=1, length=16)
])
test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True, transform=train_transforms), batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, transform=test_transforms), batch_size=batch_size, shuffle=True, **kwargs)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.convblock1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(16), nn.ReLU(), nn.Dropout(0.1),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(32), nn.ReLU(), nn.Dropout(0.1)
        ) # output_size = 28

        self.transblock1 = nn.Sequential(
            nn.MaxPool2d(2, 2), # output_size = 14
            nn.Conv2d(in_channels=32, out_channels=16, kernel_size=1, bias=False),
            nn.BatchNorm2d(16), nn.ReLU()
        ) # output_size = 14


        self.convblock2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(32), nn.ReLU(), nn.Dropout(0.1),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(32), nn.ReLU(), nn.Dropout(0.1)
        ) # output_size = 14


        self.gap = nn.AdaptiveAvgPool2d(1)
        self.output_conv = nn.Conv2d(in_channels=32, out_channels=10, kernel_size=1, bias=False)

    def forward(self, x):
        x = self.convblock1(x)
        x = self.transblock1(x)
        x = self.convblock2(x)
        x = self.gap(x)
        x = self.output_conv(x)
        x = x.view(-1, 10)
        return F.log_softmax(x, dim=-1)

model = Net().to(device)
summary(model, input_size=(1, 28, 28))

def train(model, device, train_loader, optimizer, scheduler):
    model.train()
    pbar = tqdm(train_loader)
    correct, processed = 0, 0
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        scheduler.step()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        processed += len(data)
        pbar.set_description(desc=f'Loss={loss.item():.4f} LR={scheduler.get_last_lr()[0]:.6f} Acc={100*correct/processed:0.2f}')

def test(model, device, test_loader):
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Avg loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100.*correct/len(test_loader.dataset):.2f}%)\n')

EPOCHS = 20
model = Net().to(device)
optimizer = optim.AdamW(model.parameters(), lr=0.01, weight_decay=0.01)
scheduler = OneCycleLR(optimizer, max_lr=0.015, steps_per_epoch=len(train_loader), epochs=EPOCHS, anneal_strategy='linear')

for epoch in range(1, EPOCHS + 1):
    print(f"Epoch {epoch}")
    train(model, device, train_loader, optimizer, scheduler)
    test(model, device, test_loader)

Using device: cuda
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 28, 28]             144
       BatchNorm2d-2           [-1, 16, 28, 28]              32
              ReLU-3           [-1, 16, 28, 28]               0
           Dropout-4           [-1, 16, 28, 28]               0
            Conv2d-5           [-1, 32, 28, 28]           4,608
       BatchNorm2d-6           [-1, 32, 28, 28]              64
              ReLU-7           [-1, 32, 28, 28]               0
           Dropout-8           [-1, 32, 28, 28]               0
         MaxPool2d-9           [-1, 32, 14, 14]               0
           Conv2d-10           [-1, 16, 14, 14]             512
      BatchNorm2d-11           [-1, 16, 14, 14]              32
             ReLU-12           [-1, 16, 14, 14]               0
           Conv2d-13           [-1, 32, 14, 14]           4,608
      BatchNorm2d-14

Loss=0.7161 LR=0.003001 Acc=60.54: 100%|██████████| 469/469 [00:25<00:00, 18.73it/s]



Test set: Avg loss: 0.3449, Accuracy: 9033/10000 (90.33%)

Epoch 2


Loss=0.4342 LR=0.005402 Acc=85.08: 100%|██████████| 469/469 [00:25<00:00, 18.14it/s]



Test set: Avg loss: 0.1416, Accuracy: 9608/10000 (96.08%)

Epoch 3


Loss=0.3046 LR=0.007803 Acc=87.61: 100%|██████████| 469/469 [00:24<00:00, 18.92it/s]



Test set: Avg loss: 0.1299, Accuracy: 9607/10000 (96.07%)

Epoch 4


Loss=0.4072 LR=0.010203 Acc=88.67: 100%|██████████| 469/469 [00:25<00:00, 18.33it/s]



Test set: Avg loss: 0.1166, Accuracy: 9639/10000 (96.39%)

Epoch 5


Loss=0.5040 LR=0.012604 Acc=89.62: 100%|██████████| 469/469 [00:26<00:00, 17.87it/s]



Test set: Avg loss: 0.1505, Accuracy: 9542/10000 (95.42%)

Epoch 6


Loss=0.3854 LR=0.014998 Acc=89.94: 100%|██████████| 469/469 [00:25<00:00, 18.66it/s]



Test set: Avg loss: 0.1116, Accuracy: 9641/10000 (96.41%)

Epoch 7


Loss=0.2758 LR=0.013926 Acc=90.42: 100%|██████████| 469/469 [00:24<00:00, 18.99it/s]



Test set: Avg loss: 0.0471, Accuracy: 9851/10000 (98.51%)

Epoch 8


Loss=0.3483 LR=0.012855 Acc=91.09: 100%|██████████| 469/469 [00:24<00:00, 18.93it/s]



Test set: Avg loss: 0.0986, Accuracy: 9677/10000 (96.77%)

Epoch 9


Loss=0.1626 LR=0.011783 Acc=91.30: 100%|██████████| 469/469 [00:24<00:00, 19.21it/s]



Test set: Avg loss: 0.0918, Accuracy: 9738/10000 (97.38%)

Epoch 10


Loss=0.2863 LR=0.010712 Acc=91.71: 100%|██████████| 469/469 [00:24<00:00, 18.97it/s]



Test set: Avg loss: 0.0604, Accuracy: 9808/10000 (98.08%)

Epoch 11


Loss=0.2377 LR=0.009641 Acc=91.98: 100%|██████████| 469/469 [00:24<00:00, 18.79it/s]



Test set: Avg loss: 0.0666, Accuracy: 9795/10000 (97.95%)

Epoch 12


Loss=0.2666 LR=0.008569 Acc=92.18: 100%|██████████| 469/469 [00:24<00:00, 19.42it/s]



Test set: Avg loss: 0.0552, Accuracy: 9835/10000 (98.35%)

Epoch 13


Loss=0.2124 LR=0.007498 Acc=92.33: 100%|██████████| 469/469 [00:24<00:00, 19.52it/s]



Test set: Avg loss: 0.0652, Accuracy: 9795/10000 (97.95%)

Epoch 14


Loss=0.1837 LR=0.006426 Acc=92.52: 100%|██████████| 469/469 [00:24<00:00, 19.50it/s]



Test set: Avg loss: 0.0427, Accuracy: 9870/10000 (98.70%)

Epoch 15


Loss=0.2166 LR=0.005355 Acc=92.71: 100%|██████████| 469/469 [00:23<00:00, 19.58it/s]



Test set: Avg loss: 0.0487, Accuracy: 9850/10000 (98.50%)

Epoch 16


Loss=0.3447 LR=0.004283 Acc=92.94: 100%|██████████| 469/469 [00:23<00:00, 19.69it/s]



Test set: Avg loss: 0.0475, Accuracy: 9852/10000 (98.52%)

Epoch 17


Loss=0.1100 LR=0.003212 Acc=93.27: 100%|██████████| 469/469 [00:25<00:00, 18.17it/s]



Test set: Avg loss: 0.0422, Accuracy: 9870/10000 (98.70%)

Epoch 18


Loss=0.2463 LR=0.002141 Acc=93.51: 100%|██████████| 469/469 [00:25<00:00, 18.62it/s]



Test set: Avg loss: 0.0346, Accuracy: 9886/10000 (98.86%)

Epoch 19


Loss=0.3769 LR=0.001069 Acc=93.77: 100%|██████████| 469/469 [00:25<00:00, 18.74it/s]



Test set: Avg loss: 0.0319, Accuracy: 9894/10000 (98.94%)

Epoch 20


Loss=0.1407 LR=-0.000002 Acc=93.94: 100%|██████████| 469/469 [00:24<00:00, 18.83it/s]



Test set: Avg loss: 0.0320, Accuracy: 9896/10000 (98.96%)

