In [1]:
import torch
import sys

sys.path.append("..")
from src.models import MyrtleNet
from src.image_data import get_cifar10_loaders

In [2]:
device = "cuda:1"

In [3]:
# model
model_cfg = {
    "architecture": "myrtle_net",
    "n_layers": 3,
    "residual_blocks": [0, 2],
}
model = MyrtleNet(**model_cfg).to(device)

In [4]:
# data
batch_size = 512
root = "/mnt/ssd/ronak/datasets/"

train_loader, val_loader = get_cifar10_loaders(batch_size, root)

Files already downloaded and verified
Files already downloaded and verified
50,000 training samples.
10,000 test samples.


In [5]:
# optim
optim_cfg = {
    "optimizer": "adam",
    "lr": 0.003,
}

optimizer = torch.optim.AdamW(
    model.parameters(), lr=optim_cfg["lr"], weight_decay=5e-4
)

In [6]:
# Run experiment.
max_iters = 100
grad_accumulation_steps = 1

model.train()
iter_num = 0
print("Training Loss")
print("-------------")
while iter_num < max_iters:
    for X, Y in train_loader:
        loss, logits = model(X.to(device), Y.to(device))
        loss = loss / grad_accumulation_steps
        print(f"{iter_num:03d}: {loss.item():0.4f}")
        loss.backward()
        if iter_num % grad_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
        iter_num += 1
        if iter_num > max_iters:
            break

Training Loss
-------------
000: 3.0742
001: 9.7745
002: 11.2783
003: 7.6125
004: 9.0296
005: 8.1771
006: 6.2588
007: 6.4001
008: 3.6303
009: 4.9894
010: 4.4825
011: 4.5366
012: 5.3237
013: 4.3536
014: 4.0704
015: 3.2507
016: 4.1695
017: 3.7187
018: 3.3210
019: 3.4755
020: 3.2531
021: 3.8041
022: 3.6164
023: 2.9680
024: 2.7723
025: 2.8352
026: 2.6664
027: 2.5693
028: 2.6800
029: 2.9628
030: 2.9264
031: 2.5191
032: 2.7009
033: 2.3546
034: 2.6264
035: 2.3360
036: 2.2495
037: 2.5527
038: 2.5076
039: 2.5318
040: 2.3609
041: 2.2626
042: 2.2117
043: 2.3234
044: 2.2248
045: 2.2021
046: 2.0719
047: 2.0554
048: 2.2171
049: 2.0895
050: 2.1538
051: 2.1404
052: 2.0683
053: 2.0458
054: 2.1160
055: 1.9735
056: 1.9837
057: 2.0055
058: 1.9490
059: 1.9669
060: 1.9862
061: 1.9782
062: 2.0079
063: 1.9301
064: 2.0060
065: 1.9790
066: 1.8536
067: 1.9174
068: 1.9157
069: 1.8783
070: 1.8009
071: 1.8293
072: 1.8390
073: 1.8228
074: 1.8456
075: 1.9386
076: 1.7733
077: 1.8697
078: 1.8332
079: 1.8824
080: 1.8055