In [1]:
import torch
from torch import optim
from torchsummary import summary
from torch.utils.data import random_split, DataLoader
from torchvision.datasets import CIFAR100
from torchvision.transforms import v2
from pathlib import Path
import wandb
import os
import sys

sys.path.append("../src")

from trainer import Trainer
from module import ResNetModule
from utils import model_size, load_from_checkpoint
from callbacks import OverfitCallback, EarlyStoppingCallback
from logger import WandbLogger
from dataset import MapDataset
import settings as s

In [2]:
data_path = Path("../data")
logs_path = Path("../logs")
logs_path.mkdir(exist_ok=True)

In [3]:
logger = WandbLogger(
    project_name=s.project_name,
    config={
        "model": s.model,
        "dataset": s.dataset,
        "max_epochs": s.max_epochs,
        "optimizer": s.optimizer,
        "lr_scheduler": s.lr_scheduler
    },
    logs_path=logs_path,
    offline=s.wandb_offline
)

In [4]:
cpu_count = os.cpu_count()
# cpu_count = 7

dataset = CIFAR100(data_path, train=True, download=True)

train_dataset, val_dataset = random_split(
    dataset, [s.dataset["train_split"], s.dataset["val_split"]]
)

normalize_transforms = v2.Compose([
    # Normalize
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
])

transforms_list = []
if s.dataset["augumentations"]:
    transforms_list.extend([
        v2.RandomCrop(size=(32, 32), padding=4, padding_mode='reflect'),
        v2.RandomHorizontalFlip(),
        
        # v2.RandomCrop(size=(32, 32), padding=4, padding_mode='reflect'),  # Random cropping
        # v2.RandomHorizontalFlip(),  # Horizontal flip
        # v2.RandomVerticalFlip(p=0.2),  # Vertical flip with 20% probability
        # v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Color adjustments
        # v2.RandomRotation(degrees=15),  # Random rotation within ±15 degrees
        # v2.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # Random translation
        # v2.RandomGrayscale(p=0.1),  # Convert to grayscale with 10% probability
    ])

# Add normalization (always)
transforms_list.append(normalize_transforms)

# Compose the transforms
train_transforms = v2.Compose(transforms_list)
val_transforms = normalize_transforms

train_dataset = MapDataset(train_dataset, transform=train_transforms)
val_dataset = MapDataset(val_dataset, transform=val_transforms)

train_dataloader = DataLoader(
    train_dataset, batch_size=s.dataset["batch_size"], shuffle=True, num_workers=cpu_count, pin_memory=True)
val_dataloader = DataLoader(
    val_dataset, batch_size=s.dataset["batch_size"],  num_workers=cpu_count, pin_memory=True)

Files already downloaded and verified


In [5]:
callbacks = [
    EarlyStoppingCallback(min_val_accuracy=90.0, accuracy_diff=5.0, wait_epochs=5),
    # OverfitCallback(limit_batches=1, batch_size=10, max_epochs=500, augument_data=False)
]

In [6]:
module = ResNetModule(toy_model=False)
optimizer = optim.AdamW(
    params=module.model.parameters(),
    weight_decay=s.optimizer["weight_decay"] if s.optimizer["weight_decay"] else 0.01
)

try:
    if s.lr_scheduler["name"] == "OneCycleLR":
        lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optimizer,
            max_lr=s.lr_scheduler["max_lr"],
            epochs=s.max_epochs,
            steps_per_epoch=len(train_dataloader)
        )
except TypeError:
    lr_scheduler = None
    print("lr_scheduler is None!")

lr_scheduler is None!


In [7]:
# module.model, optimizer = load_from_checkpoint(
#     path="../logs/wandb/offline-run-20241215_132918-77n093vj/checkpoints/best.pt",
#     model=module.model,
#     optimizer=optimizer
# )

summary(module.model, input_size=(train_dataset[0][0].shape), batch_size=s.dataset["batch_size"], device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [256, 64, 32, 32]           1,728
       BatchNorm2d-2          [256, 64, 32, 32]             128
              ReLU-3          [256, 64, 32, 32]               0
         ConvBlock-4          [256, 64, 32, 32]               0
            Conv2d-5         [256, 128, 32, 32]          73,728
       BatchNorm2d-6         [256, 128, 32, 32]             256
              ReLU-7         [256, 128, 32, 32]               0
         MaxPool2d-8         [256, 128, 16, 16]               0
         ConvBlock-9         [256, 128, 16, 16]               0
           Conv2d-10         [256, 128, 16, 16]         147,456
      BatchNorm2d-11         [256, 128, 16, 16]             256
             ReLU-12         [256, 128, 16, 16]               0
           Conv2d-13         [256, 128, 16, 16]         147,456
      BatchNorm2d-14         [256, 128,

In [8]:
trainer = Trainer(
    module=module,
    logger=logger,
    optimizer=optimizer,
    callbacks=callbacks,
    logs_path=logs_path,
    fast_dev_run=s.fast_dev_run,
    measure_time=True,
    lr_scheduler=lr_scheduler,
    save_checkpoint_type="best_val",
    num_workers=cpu_count
)

Using device: cuda!


In [9]:
try:
    trainer.fit(train_dataloader, val_dataloader)
except KeyboardInterrupt as e:
    print("Run stopped!")
finally:
    wandb.finish()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msampath017[0m. Use [1m`wandb login --relogin`[0m to force relogin


Time per epoch: 27.39 seconds
Epoch: 0, train_accuracy: 6.96, val_accuracy: 11.45, lr: 0.0010
Epoch: 1, train_accuracy: 14.05, val_accuracy: 15.71, lr: 0.0010
Epoch: 2, train_accuracy: 21.01, val_accuracy: 25.39, lr: 0.0010
Epoch: 3, train_accuracy: 26.41, val_accuracy: 31.39, lr: 0.0010
Epoch: 4, train_accuracy: 31.35, val_accuracy: 35.60, lr: 0.0010
Epoch: 5, train_accuracy: 34.58, val_accuracy: 38.80, lr: 0.0010
Epoch: 6, train_accuracy: 37.80, val_accuracy: 41.93, lr: 0.0010
Epoch: 7, train_accuracy: 40.70, val_accuracy: 44.50, lr: 0.0010
Epoch: 8, train_accuracy: 43.52, val_accuracy: 43.93, lr: 0.0010
Epoch: 9, train_accuracy: 45.05, val_accuracy: 44.09, lr: 0.0010
Epoch: 10, train_accuracy: 47.92, val_accuracy: 49.56, lr: 0.0010
Epoch: 11, train_accuracy: 49.17, val_accuracy: 49.17, lr: 0.0010
Epoch: 12, train_accuracy: 50.67, val_accuracy: 48.16, lr: 0.0010
Epoch: 13, train_accuracy: 52.12, val_accuracy: 55.93, lr: 0.0010
Epoch: 14, train_accuracy: 53.49, val_accuracy: 47.87, lr

Exception ignored in: <function _releaseLock at 0x74cd62f73b00>
Traceback (most recent call last):
  File "/usr/lib/python3.12/logging/__init__.py", line 243, in _releaseLock
    def _releaseLock():
    
KeyboardInterrupt: 


Epoch: 41, train_accuracy: 76.85, val_accuracy: 64.06, lr: 0.0010
Run stopped!


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
epoch_train_accuracy,▁▂▂▃▃▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇████████
epoch_train_loss,█▇▆▅▅▅▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
epoch_val_accuracy,▁▂▃▄▄▅▅▅▅▅▆▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇█████████████
epoch_val_loss,██▆▅▄▄▃▃▃▃▂▃▃▂▃▂▁▁▁▂▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁
step_train_accuracy,▁▂▂▂▂▄▄▅▅▅▆▅▆▆▆▆▆▆▆▆▇▇▇▇▇█▇▇▇█▇███▇▇████
step_train_loss,█▇▇▅▅▅▅▅▅▅▄▄▄▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁▁▁
step_val_accuracy,▁▃▂▂▃▄▅▅▅▅▅▅▆▆▆▆▆▇▆▆▇▇▇▆▇█▇▇▇▇▇▇▇█▇▇▇██▇
step_val_loss,█▆▅▄▄▄▄▄▃▃▃▂▄▂▂▂▂▂▂▂▂▂▁▁▂▂▂▁▃▁▂▁▁▃▃▂▂▂▂▂
training_step,▁▁▁▂▂▂▂▂▂▂▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇████

0,1
epoch,41
epoch_train_accuracy,76.85159
epoch_train_loss,0.79028
epoch_val_accuracy,64.05832
epoch_val_loss,1.46148
model_architecture,ResNet9(  (feature_...
step_train_accuracy,76.5625
step_train_loss,0.80468
step_val_accuracy,63.81579
step_val_loss,1.55095


[Metrics](https://api.wandb.ai/links/sampath017/iwrrziwg)