In [1]:
import torch
from torchsummary import summary
from torch.utils.data import random_split, DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import v2
from pathlib import Path
import wandb

import sys

sys.path.append("../src")

from trainer import Trainer
from module import VGGNetModule
from utils import model_size
from callbacks import OverfitCallback, EarlyStoppingCallback
from logger import WandbLogger

In [2]:
data_path = Path("../data")
logs_path = Path("../logs")
logs_path.mkdir(exist_ok=True)

In [3]:
logger = WandbLogger(
    project_name="ImageClassification",
    config={
        "model_architecture": "VGGNet",
        "batch_size": 64,
        "max_epochs": 100,
        "optimizer": {
            "name": "Adam",
        },
        "train_split": 42_000,
        "val_split": 8000
    },
    logs_path=logs_path
)

In [4]:
dataset = CIFAR10(data_path, train=True, download=True, transform=v2.Compose([
    # Normalize
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
]))

train_dataset, val_dataset = random_split(
    dataset, [logger.config["train_split"], logger.config["val_split"]])

train_dataloader = DataLoader(
    train_dataset, batch_size=logger.config["batch_size"], shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=logger.config["batch_size"])

Files already downloaded and verified


In [5]:
callbacks = [
    EarlyStoppingCallback(min_val_accuracy=80.0, accuracy_diff=4.0, wait_epochs=5),
    OverfitCallback(limit_batches=2, max_epochs=200),
]

In [6]:
module = VGGNetModule()

trainer = Trainer(
    module=module,
    logger=logger,
    callbacks=callbacks,
    logs_path=logs_path,
    fast_dev_run=False,
    measure_time=True
)

model_size(module.model)

model size: 0.02 MB


In [7]:
module.device

'cuda'

In [8]:
try:
    trainer.fit(train_dataloader, val_dataloader)
except KeyboardInterrupt as e:
    print("Run stopped!")
finally:
    wandb.finish()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Time per epoch: 12.76 seconds
Epoch: 0, train_accuracy: 33.64, val_accuracy: 44.71
Epoch: 1, train_accuracy: 48.77, val_accuracy: 48.03
Epoch: 2, train_accuracy: 53.86, val_accuracy: 52.78
Epoch: 3, train_accuracy: 56.48, val_accuracy: 53.49
Epoch: 4, train_accuracy: 58.79, val_accuracy: 56.42
Epoch: 5, train_accuracy: 60.25, val_accuracy: 59.51
Epoch: 6, train_accuracy: 61.39, val_accuracy: 59.04
Epoch: 7, train_accuracy: 62.40, val_accuracy: 58.67
Epoch: 8, train_accuracy: 63.24, val_accuracy: 61.10
Epoch: 9, train_accuracy: 64.08, val_accuracy: 59.44
Epoch: 10, train_accuracy: 64.74, val_accuracy: 62.84
Epoch: 11, train_accuracy: 65.48, val_accuracy: 62.69
Epoch: 12, train_accuracy: 65.67, val_accuracy: 62.79
Epoch: 13, train_accuracy: 66.35, val_accuracy: 63.34
Epoch: 14, train_accuracy: 66.71, val_accuracy: 63.62
Epoch: 15, train_accuracy: 67.05, val_accuracy: 63.78
Epoch: 16, train_accuracy: 67.09, val_accuracy: 64.28
Epoch: 17, train_accuracy: 67.43, val_accuracy: 65.04
Epoch: 1

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇██
epoch_train_accuracy,▁▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇█████████████████████
epoch_train_loss,█▅▄▄▄▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch_val_accuracy,▁▃▄▅▅▆▆▆▇▇▇▇▇▇▇▇███▇▇█████▇█████████████
epoch_val_loss,█▆▅▄▃▂▃▂▂▂▂▂▁▁▁▂▂▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁
step,▅▁▇▂▆▃█▆▇▂▂▅▅▄██▄▅▅▃▇▇▅▁█▅▅▄▅▅▁▆▂▇▂▄▄▁▅▆

0,1
epoch,99
epoch_train_accuracy,73.48269
epoch_train_loss,0.74564
epoch_val_accuracy,67.3375
epoch_val_loss,0.93405
model_architecture,VGGNet(  (feature_e...
step,656


[Metrics](https://api.wandb.ai/links/sampath017/iwrrziwg)