In [1]:
import math
from pathlib import Path
from types import SimpleNamespace

import wandb
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, Subset

from utilities import *


In [2]:
# wandb.login() # uncomment if you want to login to wandb

In [3]:
INPUT_SIZE = 3 * 16 * 16
OUTPUT_SIZE = 5
HIDDEN_SIZE = 256
NUM_WORKERS = 2
CLASSES = ["hero", "non-hero", "food", "spell", "side-facing"]

# Device
device = torch.device("cuda" if torch.cuda.is_available()  else "cpu")

data_dir = Path('./data/')

def get_dataloaders(batch_size, slice_size=None, valid_pct=0.2):
    dataset = CustomDataset.from_np(data_dir/"sprites_1788_16x16.npy", 
                                    data_dir/"sprite_labels_nc_1788_16x16.npy", argmax=True)

    if slice_size:
        dataset = dataset.subset(slice_size)

    train_ds, valid_ds = dataset.split(valid_pct)

    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=1)    
    valid_dl = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, num_workers=1)

    return train_dl, valid_dl

def get_model(dropout):
    return nn.Sequential(
        nn.Flatten(),
        nn.Linear(INPUT_SIZE, HIDDEN_SIZE),
        nn.BatchNorm1d(HIDDEN_SIZE),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(HIDDEN_SIZE, OUTPUT_SIZE)
    ).to(device)


In [4]:
def validate_model(model, valid_dl, loss_func, log_images=False, batch_idx=0):
    "Compute the performance of the model on the validation dataset and log a wandb.Table"
    model.eval()
    val_loss = 0.0
    correct = 0

    with torch.inference_mode():
        for i, (images, labels) in enumerate(valid_dl):
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            val_loss += loss_func(outputs, labels) * labels.size(0)

            # Compute accuracy and accumulate
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()

            # Log one batch of images to the dashboard, always same batch_idx.
            if i == batch_idx and log_images:
                log_image_predictions_table(images, predicted, labels, outputs.softmax(dim=1))

    return val_loss / len(valid_dl.dataset), correct / len(valid_dl.dataset)


def log_image_predictions_table(images, predicted, labels, probs):
    "Create a wandb Table to log images, labels, and predictions"
    table = wandb.Table(columns=["image", "pred", "target"] + [f"score_{i}" for i in range(OUTPUT_SIZE)])
    
    for img, pred, targ, prob in zip(images.cpu(), predicted.cpu(), labels.cpu(), probs.cpu()):
        table.add_data(wandb.Image(img), CLASSES[pred], CLASSES[targ], *prob.numpy())
    
    wandb.log({"predictions_table": table}, commit=False)


In [5]:
def train_model(config):
    "Train a model with a given config"
    wandb.init(
        project="intro",
        config=config,
        anonymous="allow",
    )

    # Get the data
    train_dl, valid_dl = get_dataloaders(config.batch_size, config.slice_size, config.valid_pct)
    n_steps_per_epoch = math.ceil(len(train_dl.dataset) / config.batch_size)

    # A simple MLP model
    model = get_model(config.dropout)

    # Make the loss and optimizer
    loss_func = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=config.lr)

    example_ct = 0

    for epoch in tqdm(range(config.epochs), total=config.epochs):
        model.train()

        for step, (images, labels) in enumerate(train_dl):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            train_loss = loss_func(outputs, labels)
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

            example_ct += len(images)
            metrics = {
                "train/train_loss": train_loss,
                "train/epoch": (step + 1 + (n_steps_per_epoch * epoch)) / n_steps_per_epoch,
                "train/example_ct": example_ct
            }

            if step + 1 < n_steps_per_epoch:
                # Log train metrics to wandb 
                wandb.log(metrics)
                
        val_loss, accuracy = validate_model(model, valid_dl, loss_func, log_images=(epoch == (config.epochs - 1)))

        # Log train and validation metrics to wandb
        val_metrics = {
            "val/val_loss": val_loss,
            "val/val_accuracy": accuracy
        }
        wandb.log({**metrics, **val_metrics})

    # If you had a test set, this is how you could log it as a Summary metric
    wandb.run.summary['test_accuracy'] = 0.8

    wandb.finish()


In [6]:
config = SimpleNamespace(
    epochs = 3,
    batch_size = 128,
    lr = 1e-3,
    dropout = 0.1,
    slice_size = 10_000,
    valid_pct = 0.2,
)

In [7]:
train_model(config)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcapecape[0m ([33mdeeplearning-ai-temp[0m). Use [1m`wandb login --relogin`[0m to force relogin


sprite shape: (89400, 16, 16, 3)
labels shape: (89400,)
sprite shape: (10000, 16, 16, 3)
labels shape: (10000,)


  0%|          | 0/3 [00:00<?, ?it/s]

0,1
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
train/example_ct,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
train/train_loss,█▄▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/val_accuracy,▁▅█
val/val_loss,█▃▁

0,1
test_accuracy,0.8
train/epoch,3.0
train/example_ct,24000.0
train/train_loss,0.0144
val/val_accuracy,1.0
val/val_loss,0.00669


Let's try with another value of dropout:

In [8]:
config.dropout = 0.5
train_model(config)

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016670229483376413, max=1.0…

sprite shape: (89400, 16, 16, 3)
labels shape: (89400,)
sprite shape: (10000, 16, 16, 3)
labels shape: (10000,)


  0%|          | 0/3 [00:00<?, ?it/s]

0,1
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
train/example_ct,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
train/train_loss,█▅▃▃▃▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/val_accuracy,▁██
val/val_loss,█▂▁

0,1
test_accuracy,0.8
train/epoch,3.0
train/example_ct,24000.0
train/train_loss,0.02836
val/val_accuracy,1.0
val/val_loss,0.00944
