In [1]:
import torch
import numpy as np

from data import get_cifar10_loaders
from models import MyrtleNet
from helper import ExperimentHelper, TRAIN, VAL

In [2]:
# TODO: Logging.
# saving model iterates

# TODO: Learning rate schedule.

In [3]:
experiment_name = "loss_curve"
model_name = "myrtle_net"
experiment_id = "001"
logs_dir = "logs/"
output_dir = "/mnt/hdd/ronak/cifar10_resnets"


In [4]:
epochs = 2
batch_size = 512
lr = 3e-4
device = "cuda:0"

In [5]:
train_dataloader, test_dataloader = get_cifar10_loaders(batch_size)

Files already downloaded and verified
Files already downloaded and verified
50,000 training samples.
10,000 test samples.


In [6]:
model = MyrtleNet().float().to(device)

In [7]:
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

In [8]:
helper = ExperimentHelper(
    epochs, 
    len(train_dataloader), 
    len(test_dataloader), 
    ["loss", "accuracy"],
    experiment_name,
    model_name,
    experiment_id,
    logs_dir,
    output_dir
)

In [9]:
def evaluate(logits, labels):
    return {
        "loss": loss_func(logits, labels).item(),
        "accuracy": (torch.sum(torch.argmax(logits, dim=1) == labels) / len(labels)).item()
    }

In [10]:
helper.start_experiment(model)
for epoch in range(epochs):
    helper.start_epoch(epoch, TRAIN)
    for i, (x_batch, y_batch) in enumerate(train_dataloader):
        helper.start_step(i)
        model.zero_grad()
        logits = model(x_batch.to(device))
        loss = loss_func(logits, y_batch.to(device))
        loss.backward()
        optimizer.step()
        helper.end_step(evaluate(logits, y_batch.to(device)))
    helper.end_epoch(epoch)

    helper.start_epoch(epoch, VAL)
    with torch.no_grad():
        for i, (x_batch, y_batch) in enumerate(test_dataloader):
            helper.start_step(i)
            logits = model(x_batch.to(device))
            helper.end_step(evaluate(logits, y_batch.to(device)))
    helper.end_epoch(epoch, model=model)
helper.end_experiment()
            


2023-02-05 13:05:49,777 [INFO] Training...
2023-02-05 13:06:00,150 [INFO]   batch    40 /    98.    elapsed: 0:00:10.
2023-02-05 13:06:10,081 [INFO]   batch    80 /    98.    elapsed: 0:00:20.
2023-02-05 13:06:14,468 [INFO]   train loss: 1.298
2023-02-05 13:06:14,469 [INFO]   train accuracy: 0.552

2023-02-05 13:06:14,470 [INFO]   train epoch 0 took: 0:00:25

2023-02-05 13:06:14,471 [INFO] Running validation...
2023-02-05 13:06:16,123 [INFO]   validation loss: 0.943
2023-02-05 13:06:16,124 [INFO]   validation accuracy: 0.667

2023-02-05 13:06:16,154 [INFO]   validation epoch 0 took: 0:00:02

2023-02-05 13:06:16,155 [INFO] Training...
2023-02-05 13:06:26,137 [INFO]   batch    40 /    98.    elapsed: 0:00:10.
2023-02-05 13:06:36,127 [INFO]   batch    80 /    98.    elapsed: 0:00:20.
2023-02-05 13:06:40,543 [INFO]   train loss: 0.745
2023-02-05 13:06:40,544 [INFO]   train accuracy: 0.743

2023-02-05 13:06:40,545 [INFO]   train epoch 1 took: 0:00:24

2023-02-05 13:06:40,546 [INFO] Running