In [None]:
import numpy as np
from torch.utils.data import Dataset

In [4]:
class CustomDataset(Dataset):
    def __init__(self, sprites, slabels, transform=default_tfms, null_context=False, argmax=False):
        self.sprites = sprites
        if argmax:
            self.slabels = np.argmax(slabels, axis=1)
        else:
            self.slabels = slabels
        self.transform = transform
        self.null_context = null_context

    @classmethod
    def from_np(cls, 
                path, 
                sfilename="sprites_1788_16x16.npy", lfilename="sprite_labels_nc_1788_16x16.npy", transform=default_tfms, null_context=False, argmax=False):
        sprites = np.load(Path(path)/sfilename)
        slabels = np.load(Path(path)/lfilename)
        return cls(sprites, slabels, transform, null_context, argmax)

    # Return the number of images in the dataset
    def __len__(self):
        return len(self.sprites)
    
    # Get the image and label at a given index
    def __getitem__(self, idx):
        # Return the image and label as a tuple
        if self.transform:
            image = self.transform(self.sprites[idx])
            if self.null_context:
                label = torch.tensor(0).to(torch.int64)
            else:
                label = torch.tensor(self.slabels[idx]).to(torch.int64)
        return (image, label)
    

    def subset(self, slice_size=1000):
        # return a subset of the dataset
        indices = random.sample(range(len(self)), slice_size)
        return CustomDataset(self.sprites[indices], self.slabels[indices], self.transform, self.null_context)

    def split(self, pct=0.2):
        "split dataset into train and test"
        train_size = int((1-pct)*len(self))
        test_size = len(self) - train_size
        train_dataset, test_dataset = torch.utils.data.random_split(self, [train_size, test_size])
        return train_dataset, test_dataset

NameError: name 'default_tfms' is not defined

In [None]:
def get_dataloaders(data_dir, batch_size, slice_size=None, valid_pct=0.2):
    "Get train/val dataloaders for classification on sprites dataset"
    dataset = CustomDataset.from_np(Path(data_dir), argmax=True)
    if slice_size:
        dataset = dataset.subset(slice_size)

    train_ds, valid_ds = dataset.split(valid_pct)

    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=1)    
    valid_dl = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, num_workers=1)

    return train_dl, valid_dl

In [4]:
import math
from pathlib import Path
from types import SimpleNamespace
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from utilities import get_dataloaders

import wandb

ModuleNotFoundError: No module named 'utilities'

In [6]:
### Sprite Classification

In [7]:
INPUT_SIZE = 3 * 16 * 16
OUTPUT_SIZE = 5
HIDDEN_SIZE = 256
NUM_WORKERS = 3
CLASSES = ["hero", "non-hero", "food", "spell", "side-facing"]
DATA_DIR = Path("./data/")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
### Defining a config to store the hyperparameters
config = SimpleNamespace(
    epochs=2,
    batch_size=128,
    lr=1e-5,
    dropout=0.5,
    slice_size=10_000,
    valid_pct=0.2
)

In [None]:
### Training the model

def train_model(config):
    "Train a model with a given config"
    
    wandb.init(
        project="dlai_intro",
        config=config,
    )

    # Get the data
    train_dl, valid_dl = get_dataloaders(DATA_DIR, 
                                         config.batch_size, 
                                         config.slice_size, 
                                         config.valid_pct)
    n_steps_per_epoch = math.ceil(len(train_dl.dataset) / config.batch_size)

    # A simple MLP model
    model = get_model(config.dropout)

    # Make the loss and optimizer
    loss_func = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=config.lr)

    example_ct = 0

    for epoch in tqdm(range(config.epochs), total=config.epochs):
        model.train()

        for step, (images, labels) in enumerate(train_dl):
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            outputs = model(images)
            train_loss = loss_func(outputs, labels)
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

            example_ct += len(images)
            metrics = {
                "train/train_loss": train_loss,
                "train/epoch": epoch + 1,
                "train/example_ct": example_ct
            }
            wandb.log(metrics)
            
        # Compute validation metrics, log images on last epoch
        val_loss, accuracy = validate_model(model, valid_dl, loss_func)
        # Compute train and validation metrics
        val_metrics = {
            "val/val_loss": val_loss,
            "val/val_accuracy": accuracy
        }
        wandb.log(val_metrics)
    
    wandb.finish()
