In [None]:
import torch as t
import torch.nn as nn
import torchvision
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
train, test = get_data(100)
train_iter = loader_gen(train)

In [None]:
x = next(train_iter)

In [None]:
len(x)

In [None]:
x[0].shape

In [None]:
x[1]

In [None]:
x = next(data)

In [None]:
def get_data(batch_size):

    train_dataset = torchvision.datasets.MNIST(
        root="./_data", train=True, download=True,
        transform=torchvision.transforms.ToTensor()
    )
    train_loader = t.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=1
    )

    test_dataset = torchvision.datasets.MNIST(
        root="./_data", train=False, download=True,
        transform=torchvision.transforms.ToTensor()
    )
    test_loader = t.utils.data.DataLoader(
        test_dataset, batch_size=batch_size, shuffle=True, num_workers=1
    )
    
    return train_loader, test_loader

    
def loader_gen(loader):
    for x in loader: yield x

In [None]:
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(6, 12, 5),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.linear_layers = nn.Sequential(
            nn.Linear(12 * 4 * 4, 10)
        )

    def forward(self, x):
        conv = self.conv_layers(x).flatten(-3)
        return self.linear_layers(conv)


if t.cuda.device_count(): device = 'cuda:0'
else: device = 'cpu'

model = ConvNet()
print('num model params:', sum([p.numel() for p in model.parameters()]))

n_epochs, batch_size = 3, 100
loss_fn = t.nn.CrossEntropyLoss()
optimizer = t.optim.Adam(model.parameters())
train_loader, test_loader = get_data(batch_size)
val_freq = int(len(train_loader) / len(test_loader))
losses, val_losses = [], []

model.train()
with tqdm(total = n_epochs * len(train_loader)) as pbar:
    for epoch in range(n_epochs):
        val_iter = loader_gen(test_loader)
        
        for i, (X, y) in enumerate(train_loader):
            X.to(device), y.to(device)
            
            optimizer.zero_grad()
            y_pred = model(X)
            loss = loss_fn(y_pred, y)
            loss.backward()
            optimizer.step()
            
            losses.append(loss.detach())
            pbar.update()
            pbar.set_postfix({'loss': '%.4f' % float(loss.detach())})
            
            if i % val_freq == 0:
                val_X, val_y = next(val_iter)
                val_X.to(device), val_y.to(device)
                with t.no_grad():
                    val_y_pred = model(val_X)
                    val_loss = loss_fn(val_y_pred, val_y)
                val_losses.append((epoch * len(train_loader) + i, val_loss.detach()))
                
            

In [None]:
plt.plot(losses)
plt.plot(*zip(*val_losses))

In [None]:
# Simple accuracy score 

model.eval()

correct, total = 0, 0
for X, y in tqdm(test_loader):
    y_pred = model(X)
    correct += sum(y_pred.argmax(1) == y)
    total += y.shape[0]
    
print(correct / total)

In [None]:
# Are classes balanced?
from collections import Counter

c = Counter()
[c.update(y.tolist()) for X, y in train_loader]
[c.update(y.tolist()) for X, y in test_loader]
c
#plt.bar(*zip(*sorted(c.items(), key=lambda kv: kv[0])));

**Basin dimensionality code is in cifar_conv**