In [1]:
import torch
from torch.utils.data import random_split, DataLoader
from torchvision.datasets import CIFAR10
from torchvision.datasets.utils import extract_archive
from torchvision.transforms import v2
from pathlib import Path
import wandb

import sys

sys.path.append("../src")

from dataset import MapDataset
from trainer import Trainer
from module import ResNet
from utils import model_size

In [None]:
wandb.init(
    project="ImageClassification",
    config = {
        "model_architecture": "ResNet",
        "batch_size": 1024,
        "epochs": 20,
        "optimizer": {
            "name": "Adam",
            "weight_decay": 1e-4,
            "max_lr": 1e-2,
        },
        "train_split": 42_000,
        "val_split": 8000
    }
  )

config = wandb.config

In [2]:
data_path = Path("../data")
# extract_archive(data_path / "cifar-10-python.tar.gz")

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [5]:
config = {
    "model_architecture": "ResNet",
    "batch_size": 1024,
    "epochs": 20,
    "optimizer": {
        "name": "Adam",
        "weight_decay": 1e-4,
        "max_lr": 1e-2,
    },
    "train_split": 42_000,
    "val_split": 8000
}

In [6]:
dataset = CIFAR10(data_path, train=True)

train_dataset, val_dataset = random_split(
    dataset, [config["train_split"], config["val_split"]])

val_transforms = v2.Compose([
    # Normalize
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
])

train_dataset = MapDataset(train_dataset, transform=v2.Compose([
    # Data auguments
    v2.RandomCrop(size=(32, 32), padding=4, padding_mode='reflect'),
    v2.RandomHorizontalFlip(), 
    
    val_transforms
]))


val_dataset = MapDataset(val_dataset, transform=val_transforms)

train_dataloader = DataLoader(
    train_dataset, batch_size=config["batch_size"], shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=config["batch_size"])

In [None]:
model = ResNet()
model_size(model)

optimizer = torch.optim.Adam(
    params=model.parameters(),
    weight_decay=config["optimizer"]["weight_decay"]
)

lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    config["optimizer"]["max_lr"], 
    epochs=config["epochs"],
    steps_per_epoch=len(train_dataloader)
)

trainer = Trainer(model, config, optimizer,lr_scheduler, device=device, limit_batches=None)

In [None]:
try:
    trainer.fit(train_dataloader, val_dataloader)
except KeyboardInterrupt as e:
    print("Run stopped!")
finally:
    wandb.finish()

[Metrics](https://api.wandb.ai/links/sampath017/iwrrziwg)