In [1]:
import torch
from torch import optim
from torchsummary import summary
from torch.utils.data import random_split, DataLoader
from torchvision.datasets import CIFAR10, Food101
from torchvision.transforms import v2
from pathlib import Path
import wandb
import os
import sys

sys.path.append("../src")

from trainer import Trainer
from module import ResNetModule
from utils import model_size, load_from_checkpoint
from callbacks import OverfitCallback, EarlyStoppingCallback
from logger import WandbLogger
from dataset import MapDataset

In [2]:
data_path = Path("../data")
logs_path = Path("../logs")
logs_path.mkdir(exist_ok=True)

In [3]:
logger = WandbLogger(
    project_name="ImageClassification",
    config={
        "model_architecture": "ResNet18",
        "num_model_layers": 18,
        "batch_size": 1024,
        "max_epochs": 20,
        "optimizer": {
            "name": "Adam",
            "weight_decay": 1e-3
        },
        "lr_scheduler": {
            "max_lr": 0.01
        },
        "train_split": 0.7,
        "val_split": 0.3
    },
    logs_path=logs_path,
    offline=False
)

In [4]:
cpu_count = os.cpu_count()
dataset = CIFAR10(data_path, train=True, download=True)

train_dataset, val_dataset = random_split(
    dataset, [logger.config["train_split"], logger.config["val_split"]]
)

val_transforms = v2.Compose([
    # Normalize
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
])

train_dataset = MapDataset(train_dataset, transform=v2.Compose([
    # Data auguments
    v2.RandomCrop(size=(32, 32), padding=4, padding_mode='reflect'),
    v2.RandomHorizontalFlip(),

    val_transforms
]))

val_dataset = MapDataset(val_dataset, transform=val_transforms)

train_dataloader = DataLoader(
    train_dataset, batch_size=logger.config["batch_size"], shuffle=True, num_workers=cpu_count, pin_memory=True)
val_dataloader = DataLoader(
    val_dataset, batch_size=logger.config["batch_size"],  num_workers=cpu_count, pin_memory=True)

Files already downloaded and verified


In [5]:
callbacks = [
    EarlyStoppingCallback(min_val_accuracy=95.0, accuracy_diff=5.0, wait_epochs=5),
    # OverfitCallback(limit_batches=1, batch_size=10, max_epochs=500, augument_data=False)
]

In [6]:
module = ResNetModule(toy_model=False)
optimizer = optim.AdamW(
    params=module.model.parameters(),
    weight_decay=logger.config["optimizer"]["weight_decay"]
)

lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer,
    max_lr=logger.config["lr_scheduler"]["max_lr"],
    epochs=logger.config["max_epochs"],
    steps_per_epoch=len(train_dataloader)
)

In [7]:
# module.model, optimizer = load_from_checkpoint(
#     path="../logs/wandb/offline-run-20241215_132918-77n093vj/checkpoints/best.pt",
#     model=module.model,
#     optimizer=optimizer
# )

summary(module.model, input_size=(train_dataset[0][0].shape), batch_size=logger.config["batch_size"], device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [1024, 64, 32, 32]           1,728
       BatchNorm2d-2         [1024, 64, 32, 32]             128
              ReLU-3         [1024, 64, 32, 32]               0
         ConvBlock-4         [1024, 64, 32, 32]               0
            Conv2d-5        [1024, 128, 32, 32]          73,728
       BatchNorm2d-6        [1024, 128, 32, 32]             256
              ReLU-7        [1024, 128, 32, 32]               0
         MaxPool2d-8        [1024, 128, 16, 16]               0
         ConvBlock-9        [1024, 128, 16, 16]               0
           Conv2d-10        [1024, 128, 16, 16]         147,456
      BatchNorm2d-11        [1024, 128, 16, 16]             256
             ReLU-12        [1024, 128, 16, 16]               0
           Conv2d-13        [1024, 128, 16, 16]         147,456
      BatchNorm2d-14        [1024, 128,

In [8]:
len(train_dataloader)

35

In [9]:
trainer = Trainer(
    module=module,
    logger=logger,
    optimizer=optimizer,
    callbacks=callbacks,
    logs_path=logs_path,
    fast_dev_run=False,
    measure_time=True,
    lr_scheduler=lr_scheduler,
    lr_scheduler_on_epoch=False,
    checkpoint="best_val",
    num_workers=cpu_count
)

In [10]:
module.device

'cuda'

In [11]:
try:
    trainer.fit(train_dataloader, val_dataloader)
except KeyboardInterrupt as e:
    print("Run stopped!")
finally:
    wandb.finish()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msampath017[0m. Use [1m`wandb login --relogin`[0m to force relogin


Time per epoch: 19.25 seconds
Epoch: 0, train_accuracy: 27.70, val_accuracy: 24.28, lr: 0.0010
Epoch: 1, train_accuracy: 42.71, val_accuracy: 26.46, lr: 0.0028
Epoch: 2, train_accuracy: 54.25, val_accuracy: 33.68, lr: 0.0052
Epoch: 3, train_accuracy: 60.74, val_accuracy: 49.09, lr: 0.0076
Epoch: 4, train_accuracy: 68.22, val_accuracy: 61.09, lr: 0.0094
Epoch: 5, train_accuracy: 72.12, val_accuracy: 55.12, lr: 0.0100
Epoch: 6, train_accuracy: 75.89, val_accuracy: 62.25, lr: 0.0099
Epoch: 7, train_accuracy: 79.12, val_accuracy: 75.70, lr: 0.0095
Epoch: 8, train_accuracy: 81.61, val_accuracy: 77.95, lr: 0.0089
Epoch: 9, train_accuracy: 83.49, val_accuracy: 77.05, lr: 0.0081
Epoch: 10, train_accuracy: 85.32, val_accuracy: 82.16, lr: 0.0072
Epoch: 11, train_accuracy: 86.90, val_accuracy: 84.03, lr: 0.0061
Epoch: 12, train_accuracy: 88.96, val_accuracy: 86.26, lr: 0.0050
Epoch: 13, train_accuracy: 90.60, val_accuracy: 87.20, lr: 0.0039
Epoch: 14, train_accuracy: 91.93, val_accuracy: 87.40, l

VBox(children=(Label(value='59.511 MB of 292.931 MB uploaded\r'), FloatProgress(value=0.20315719185251385, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇████
epoch_train_accuracy,▁▃▄▄▅▆▆▆▇▇▇▇▇▇██████
epoch_train_loss,█▆▅▅▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁
epoch_val_accuracy,▁▁▂▄▅▄▅▆▇▇▇▇████████
epoch_val_loss,▆█▆▄▂▃▂▂▁▂▁▁▁▁▁▁▁▁▁▁
lr,▁▁▁▁▂▂▂▂▃▃▆▆████▇▇▇▇▆▅▅▅▄▃▃▃▃▃▃▂▂▂▂▁▁▁▁▁
step_train_accuracy,▁▃▄▄▄▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█████████
step_train_loss,█▇▇▆▆▅▅▅▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
step_val_accuracy,▁▁▂▂▂▂▂▄▄▅▄▅▅▇▇▇▇▇▇▇▇▇▇▇▇███████████████
step_val_loss,▅▅▅▆███▅▆▅▂▂▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,19
epoch_train_accuracy,96.15162
epoch_train_loss,0.11261
epoch_val_accuracy,90.80384
epoch_val_loss,0.31369
lr,0.0
model_architecture,ResNet18(  (feature...
step_train_accuracy,97.82609
step_train_loss,0.06194
step_val_accuracy,90.96385


[Metrics](https://api.wandb.ai/links/sampath017/iwrrziwg)