In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
from torchtnt.utils.device import copy_data_to_device
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR

from tqdm.auto import tqdm

from matplotlib import cm
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()

jet = cm.get_cmap('jet')

from resnet import *
from convnet import *
from dataset import *

device = torch.device("cuda:1")

In [None]:
train_loader = get_loader(
    dataset="CIFAR100",
    split="train",
    img_size=40,
    batch_size=512,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True,
)
    
val_loader = get_loader(
    dataset="CIFAR100",
    split="val",
    img_size=40,
    batch_size=512,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True,
)

# Xent configs

- MNIST ConvNet:  Adam@1e-2, max_epochs=20, StepLR(optimizer, 1, 0.7) **VAL ACC = 0.992**

- MNIST ResNet18: Adam@1e-2, max_epochs=20, StepLR(optimizer, 1, 0.7)  **VAL ACC = 0.995**

- FashionMNIST ConvNet:  Adam@1e-2, max_epochs=50, StepLR(optimizer, 4, 0.5) **VAL ACC = 0.930**

- FashionMNIST ResNet18: Adam@1e-2, max_epochs=50, StepLR(optimizer, 4, 0.5) **VAL ACC = 0.935**

- CIFAR10 ResNet18: SGD@1e-1, momentum=0.9, max_epochs=200, CosineAnnealingLR(T_max=200) **VAL ACC = 0.929**

In [None]:
module = resnet18(in_dim=3, out_dim=100, activation=torch.nn.functional.leaky_relu)
#module = ConvNet(in_dim=1, out_dim=10, activation=torch.nn.functional.leaky_relu)
module = module.to(device)

max_epochs = 200

optimizer = optim.SGD(
    module.parameters(),
    lr=1e-1,
    momentum=0.9,
)

lr_scheduler = CosineAnnealingLR(optimizer, T_max=200)
#lr_scheduler = StepLR(optimizer, T_max=200)

for epoch in range(max_epochs):
    module.train()
    for batch in train_loader:
        optimizer.zero_grad()
        batch = copy_data_to_device(batch, device=device)
        logits = module(batch.images)
        loss = F.cross_entropy(logits,
                               torch.argmax(batch.labels, dim=-1))
        loss.backward()
        optimizer.step()
    lr_scheduler.step()
    
    module.eval()
    num_correct = 0
    total = 0
    for batch in val_loader:
        batch = copy_data_to_device(batch, device=device)
        logits = module(batch.images)
        pred_idx = torch.argmax(logits, dim=-1)
        target_idx = torch.argmax(batch.labels, dim=-1)
        num_correct += (pred_idx == target_idx).sum()
        total += len(pred_idx)

    print(f"Epoch {epoch} val acc: {num_correct / total} | lr = {lr_scheduler.get_last_lr()[0]}")