In [1]:
import torch
from torchsummary import summary
from torch.utils.data import random_split, DataLoader
from torchvision.datasets import CIFAR10, Food101
from torchvision.transforms import v2
from pathlib import Path
import wandb

import sys

sys.path.append("../src")

from trainer import Trainer
from module import VGGNetModule
from utils import model_size
from callbacks import OverfitCallback, EarlyStoppingCallback
from logger import WandbLogger
from dataset import MapDataset

In [2]:
data_path = Path("../data")
logs_path = Path("../logs")
logs_path.mkdir(exist_ok=True)

In [3]:
logger = WandbLogger(
    project_name="ImageClassification",
    config={
        "model_architecture": "ResNet",
        "num_model_layers": 20,
        "batch_size": 128,
        "max_epochs": 7,
        "optimizer": {
            "name": "Adam",
        },
        "train_split": 0.7,
        "val_split": 0.3
    },
    logs_path=logs_path
)

In [4]:
dataset = CIFAR10(data_path, train=True, download=True)

train_dataset, val_dataset = random_split(
    dataset, [logger.config["train_split"], logger.config["val_split"]]
)

val_transforms = v2.Compose([
    # Normalize
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
])

train_dataset = MapDataset(train_dataset, transform=v2.Compose([
    # Data auguments
    v2.RandomCrop(size=(32, 32), padding=4, padding_mode='reflect'),
    v2.RandomHorizontalFlip(),

    val_transforms
]))

val_dataset = MapDataset(val_dataset, transform=val_transforms)

train_dataloader = DataLoader(
    train_dataset, batch_size=logger.config["batch_size"], shuffle=True, num_workers=6, pin_memory=True)
val_dataloader = DataLoader(
    val_dataset, batch_size=logger.config["batch_size"],  num_workers=6, pin_memory=True)

Files already downloaded and verified


In [None]:
callbacks = [
    # EarlyStoppingCallback(min_val_accuracy=90.0, accuracy_diff=5.0, wait_epochs=5),
    OverfitCallback(limit_batches=5, max_epochs=500, augument_data=True)
]

In [6]:
module = VGGNetModule()

trainer = Trainer(
    module=module,
    logger=logger,
    callbacks=callbacks,
    logs_path=logs_path,
    fast_dev_run=False,
    measure_time=True,
    checkpoint="best_train",
    num_workers=train_dataloader.num_workers
)

summary(module.model, input_size=(train_dataset[0][0].shape), device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 32, 32]             224
       BatchNorm2d-2            [-1, 8, 32, 32]              16
              ReLU-3            [-1, 8, 32, 32]               0
            Conv2d-4           [-1, 16, 32, 32]           1,168
       BatchNorm2d-5           [-1, 16, 32, 32]              32
              ReLU-6           [-1, 16, 32, 32]               0
         MaxPool2d-7           [-1, 16, 16, 16]               0
          VGGBlock-8           [-1, 16, 16, 16]               0
            Conv2d-9           [-1, 32, 16, 16]           4,640
      BatchNorm2d-10           [-1, 32, 16, 16]              64
             ReLU-11           [-1, 32, 16, 16]               0
           Conv2d-12           [-1, 64, 16, 16]          18,496
      BatchNorm2d-13           [-1, 64, 16, 16]             128
             ReLU-14           [-1, 64,

In [7]:
module.device

'cuda'

In [8]:
try:
    trainer.fit(train_dataloader, val_dataloader)
except KeyboardInterrupt as e:
    print("Run stopped!")
finally:
    wandb.finish()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msampath017[0m. Use [1m`wandb login --relogin`[0m to force relogin


Time per epoch: 2.87 seconds
Epoch: 0, train_accuracy: 13.80, val_accuracy: 9.11
Epoch: 1, train_accuracy: 13.80, val_accuracy: 8.59
Epoch: 2, train_accuracy: 19.79, val_accuracy: 8.59
Epoch: 3, train_accuracy: 18.23, val_accuracy: 8.59
Epoch: 4, train_accuracy: 21.09, val_accuracy: 8.59
Epoch: 5, train_accuracy: 22.66, val_accuracy: 8.59
Epoch: 6, train_accuracy: 23.18, val_accuracy: 8.59
Epoch: 7, train_accuracy: 25.26, val_accuracy: 8.59
Epoch: 8, train_accuracy: 23.96, val_accuracy: 8.59
Epoch: 9, train_accuracy: 27.60, val_accuracy: 8.59
Epoch: 10, train_accuracy: 26.82, val_accuracy: 8.59
Epoch: 11, train_accuracy: 25.78, val_accuracy: 8.59
Epoch: 12, train_accuracy: 27.08, val_accuracy: 7.55
Epoch: 13, train_accuracy: 25.78, val_accuracy: 12.50
Epoch: 14, train_accuracy: 27.86, val_accuracy: 13.80
Epoch: 15, train_accuracy: 29.17, val_accuracy: 11.98
Epoch: 16, train_accuracy: 26.56, val_accuracy: 11.20
Epoch: 17, train_accuracy: 29.17, val_accuracy: 11.20
Epoch: 18, train_accur

0,1
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇█
epoch_train_accuracy,▁▂▂▅▅▆▆▆▆▆▇▇▇▇█▇█████▇███▇██████████▇███
epoch_train_loss,██▇▆▆▅▅▄▄▃▃▃▃▃▂▂▂▂▁▁▂▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁
epoch_val_accuracy,▄▅▄▄▃▃▂▅▅▁▄▄▄▅▆▄▃▄▄▅█▆▃▄▆▄▃▃▅▄▅▅▄▆▆▅▄▅█▄
epoch_val_loss,▄▁▁▁▁▂▂▂▃▂█▂▃▃▃▃▃▄▄▃▄▄▄▃▄▃▄▅▅▅▅▃▃▃▄▄▄▆▅▄
step,▁▅█▁▅█▅▁▅▅█▁▁▅▁██▁▁▁█▁██▅▅▅▅██▅▁▁▁██▅█▅█

0,1
epoch,499
epoch_train_accuracy,97.65625
epoch_train_loss,0.11724
epoch_val_accuracy,9.63542
epoch_val_loss,17.22898
model_architecture,VGGNet(  (feature_e...
step,2


[Metrics](https://api.wandb.ai/links/sampath017/iwrrziwg)