In [1]:
import torch
import numpy as np
from torch.utils.data import DataLoader, RandomSampler

from data import load_cifar10, ImageClassificationDataset
from models import MyrtleNet

In [None]:
# TODO: Logging.
# TODO: Learning rate schedule.

In [2]:
epochs = 20
batch_size = 512
lr = 3e-4
device = "cuda:0"

In [3]:
x_train, y_train, x_test, y_test = load_cifar10()

Files already downloaded and verified
Files already downloaded and verified


In [4]:
# Create dataloaders.
train_dataset = ImageClassificationDataset(x_train, y_train)
train_dataloader = DataLoader(
    train_dataset, sampler=RandomSampler(train_dataset), batch_size=batch_size
)
print("{:>5,} training samples.".format(len(train_dataset)))
test_dataset = ImageClassificationDataset(x_test, y_test)
test_dataloader = DataLoader(
    test_dataset, sampler=RandomSampler(test_dataset), batch_size=batch_size
)
print("{:>5,} test samples.".format(len(test_dataset)))

50,000 training samples.
10,000 test samples.


In [5]:
model = MyrtleNet().float().to(device)

In [6]:
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

In [9]:
for epoch in range(epochs):
    for i, (x_batch, y_batch) in enumerate(train_dataloader):
        if i % (len(train_dataloader) // 10) == 0:
            print(f"Processing batch {i+1:02d}/{len(train_dataloader)}.")
        model.zero_grad()
        logits = model(x_batch.to(device))
        loss = loss_func(logits, y_batch.to(device))
        loss.backward()
        optimizer.step()

    total = 0
    with torch.no_grad():
        for (x_batch, y_batch) in test_dataloader:
            logits = model(x_batch.to(device))
            y_pred = torch.argmax(logits, dim=1).cpu()
            total += torch.sum(y_pred == y_batch)
    print(f"Epoch {epoch:02d} val. accuracy: {total / len(test_dataset):0.4f}")
            

Processing batch 01/98.
Processing batch 10/98.
Processing batch 19/98.
Processing batch 28/98.
Processing batch 37/98.
Processing batch 46/98.
Processing batch 55/98.
Processing batch 64/98.
Processing batch 73/98.
Processing batch 82/98.
Processing batch 91/98.
Epoch 00 val. accuracy: 0.7602
Processing batch 01/98.
Processing batch 10/98.
Processing batch 19/98.
Processing batch 28/98.
Processing batch 37/98.
Processing batch 46/98.
Processing batch 55/98.
Processing batch 64/98.
Processing batch 73/98.
Processing batch 82/98.
Processing batch 91/98.
Epoch 01 val. accuracy: 0.7953
Processing batch 01/98.
Processing batch 10/98.
Processing batch 19/98.
Processing batch 28/98.
Processing batch 37/98.
Processing batch 46/98.
Processing batch 55/98.


KeyboardInterrupt: 

In [7]:
y.shape

torch.Size([5, 10])