In [17]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
import torch
from torch.utils.data import random_split, DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import v2
from pathlib import Path
import wandb

import sys

sys.path.append("../src")

from trainer import Trainer
from module import VGGNetModule
from utils import model_size
from callbacks import OverfitCallback
from logger import WandbLogger

In [19]:
data_path = Path("../data")
logs_path = Path("../logs")

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [20]:
logger = WandbLogger(
    project_name="ImageClassification",
    config={
        "model_architecture": "ConvNet",
        "batch_size": 64,
        "epochs": 50,
        "optimizer": {
            "name": "Adam",
        },
        "train_split": 42_000,
        "val_split": 8000
    },
    logs_path=logs_path
)

In [21]:
dataset = CIFAR10(data_path, train=True, download=True, transform=v2.Compose([
    # Normalize
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
]))

train_dataset, val_dataset = random_split(
    dataset, [logger.config["train_split"], logger.config["val_split"]])

train_dataloader = DataLoader(
    train_dataset, batch_size=logger.config["batch_size"], shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=logger.config["batch_size"])

Files already downloaded and verified


In [22]:
callbacks = [
    # OverfitCallback(limit_batches=2, max_epochs=200)
]

In [23]:
module = VGGNetModule()

trainer = Trainer(
    module=module,
    logger=logger,
    callbacks=callbacks,
    logs_path=logs_path,
    device=device,
    fast_dev_run=False
)

In [24]:
try:
    trainer.fit(train_dataloader, val_dataloader)
except KeyboardInterrupt as e:
    print("Run stopped!")
finally:
    wandb.finish()

Epoch: 0, train_accuracy: 46.52, val_accuracy: 54.33
Epoch: 1, train_accuracy: 58.75, val_accuracy: 59.14
Epoch: 2, train_accuracy: 62.00, val_accuracy: 61.65
Epoch: 3, train_accuracy: 63.72, val_accuracy: 61.75
Epoch: 4, train_accuracy: 65.33, val_accuracy: 62.49
Epoch: 5, train_accuracy: 66.33, val_accuracy: 64.35
Epoch: 6, train_accuracy: 66.94, val_accuracy: 64.62
Epoch: 7, train_accuracy: 67.73, val_accuracy: 65.45
Epoch: 8, train_accuracy: 68.51, val_accuracy: 62.79
Epoch: 9, train_accuracy: 68.79, val_accuracy: 63.38
Epoch: 10, train_accuracy: 69.37, val_accuracy: 65.90
Epoch: 11, train_accuracy: 69.52, val_accuracy: 65.78
Epoch: 12, train_accuracy: 69.84, val_accuracy: 65.66
Epoch: 13, train_accuracy: 70.12, val_accuracy: 66.30
Epoch: 14, train_accuracy: 70.33, val_accuracy: 66.34
Epoch: 15, train_accuracy: 70.70, val_accuracy: 66.20
Epoch: 16, train_accuracy: 70.99, val_accuracy: 65.26
Epoch: 17, train_accuracy: 71.02, val_accuracy: 66.94
Epoch: 18, train_accuracy: 71.08, val_

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇███
epoch_train_accuracy,▁▄▅▆▆▆▇▇▇▇▇▇▇▇███████████
epoch_train_loss,█▅▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
epoch_val_accuracy,▁▄▅▅▅▆▇▇▆▆▇▇▇▇▇▇▇█▇█▇█▇██
epoch_val_loss,█▆▄▄▄▃▃▂▄▄▂▂▂▂▂▂▂▁▃▁▂▂▂▁▁
step,▇▇▅▇▁▅▁▃▆█▇▇▇▇█▂▂▅▃▇▃▅▇█▂█▅▃▆▄▆▆▁▂▃▅▅█▁▆

0,1
epoch,24.0
epoch_train_accuracy,72.11758
epoch_train_loss,0.80032
epoch_val_accuracy,67.425
epoch_val_loss,0.94173
step,478.0


[Metrics](https://api.wandb.ai/links/sampath017/iwrrziwg)