In [1]:
import torch
import torch.nn as nn
import torchvision
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
path = "datasets/cifar-10"
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.25, 0.25, 0.25))])
trainset = torchvision.datasets.CIFAR10(path, train=True, transform=transform)
testset = torchvision.datasets.CIFAR10(path, train=False, transform=transform)

def run(model, loss_fn, optim, num_epochs, batch_size):

    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=batch_size, shuffle=True)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=batch_size, shuffle=False)

    for epoch in range(num_epochs):

        model.train()
        train_loss = 0
        train_top1 = 0
        train_step = 0
        progress = tqdm(trainloader, desc=f"epoch={epoch} train")
        for inputs, targets in progress:
            optim.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)
            loss.backward()
            optim.step()
            top1 = (outputs.max(1).indices  == targets).float().mean()
            train_loss += loss.item()
            train_top1 += top1.item()
            train_step += 1
            progress.set_postfix({
                "train_loss": train_loss / train_step,
                "train_top1": train_top1 / train_step,
            })

        model.eval()
        test_loss = 0
        test_top1 = 0
        test_step = 0
        progress = tqdm(testloader, desc=f"epoch={epoch} test")
        for inputs, targets in progress:
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)
            top1 = (outputs.max(1).indices  == targets).float().mean()
            test_loss += loss.item()
            test_top1 += top1.item()
            test_step += 1
            progress.set_postfix({
                "test_loss": test_loss / test_step,
                "test_top1": test_top1 / test_step,
            })

        metrics = {
            "epoch": epoch,
            "train_loss": train_loss / train_step,
            "train_top1": train_top1 / train_step,
            "test_loss": test_loss / test_step,
            "test_top1": test_top1 / test_step,
        }
        yield metrics

In [7]:
class Dropout(nn.Module):

    def __init__(self, std, zero_mean):
        super().__init__()
        self.std = std
        self.zero_mean = zero_mean

    def forward(self, x):
        if self.training:
            z = torch.randn_like(x)
            if self.zero_mean:
                z = z - z.mean(-1, keepdim=True)
            x = x * (1 + self.std * z)
        return x

num_epochs = 10
batch_size = 100
hidden_dim = 100
learning_rate = 1e-2
std = 0.5
zero_mean = False

model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(3072, hidden_dim), nn.ReLU(), Dropout(std, zero_mean),
    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), Dropout(std, zero_mean),
    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), Dropout(std, zero_mean),
    nn.Linear(hidden_dim, 10))
loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

list(run(model, loss_fn, optim, num_epochs, batch_size))

epoch=0 train: 100%|██████████| 500/500 [00:06<00:00, 74.08it/s, train_loss=1.87, train_top1=0.325]
epoch=0 test: 100%|██████████| 100/100 [00:00<00:00, 107.62it/s, test_loss=1.65, test_top1=0.416]
epoch=1 train: 100%|██████████| 500/500 [00:06<00:00, 79.30it/s, train_loss=1.66, train_top1=0.412]
epoch=1 test: 100%|██████████| 100/100 [00:00<00:00, 116.79it/s, test_loss=1.53, test_top1=0.467]
epoch=2 train: 100%|██████████| 500/500 [00:06<00:00, 77.17it/s, train_loss=1.58, train_top1=0.442]
epoch=2 test: 100%|██████████| 100/100 [00:00<00:00, 115.23it/s, test_loss=1.49, test_top1=0.474]
epoch=3 train: 100%|██████████| 500/500 [00:06<00:00, 78.36it/s, train_loss=1.53, train_top1=0.46] 
epoch=3 test: 100%|██████████| 100/100 [00:00<00:00, 111.01it/s, test_loss=1.43, test_top1=0.491]
epoch=4 train: 100%|██████████| 500/500 [00:06<00:00, 75.82it/s, train_loss=1.49, train_top1=0.475]
epoch=4 test: 100%|██████████| 100/100 [00:00<00:00, 112.38it/s, test_loss=1.41, test_top1=0.492]
epoch=5 tr

[{'epoch': 0,
  'train_loss': 1.8701764814853667,
  'train_top1': 0.3248000002503395,
  'test_loss': 1.6479486954212188,
  'test_top1': 0.4159999969601631},
 {'epoch': 1,
  'train_loss': 1.6603232519626618,
  'train_top1': 0.41151999807357786,
  'test_loss': 1.5291156268119812,
  'test_top1': 0.46729999750852586},
 {'epoch': 2,
  'train_loss': 1.5824878711700439,
  'train_top1': 0.44163999700546264,
  'test_loss': 1.4897980260849,
  'test_top1': 0.4738999956846237},
 {'epoch': 3,
  'train_loss': 1.529931223630905,
  'train_top1': 0.46035999757051466,
  'test_loss': 1.4316531491279603,
  'test_top1': 0.4905999964475632},
 {'epoch': 4,
  'train_loss': 1.4912743566036224,
  'train_top1': 0.4745799962878227,
  'test_loss': 1.414997432231903,
  'test_top1': 0.4920999965071678},
 {'epoch': 5,
  'train_loss': 1.452074227333069,
  'train_top1': 0.48647999674081804,
  'test_loss': 1.394192453622818,
  'test_top1': 0.507399995625019},
 {'epoch': 6,
  'train_loss': 1.428001995563507,
  'train_top