In [1]:
import torch
from torch import nn
from torch import optim
from torchinfo import summary
from torch.utils.data import random_split, DataLoader
from torchvision.transforms import v2
from torchvision.models import resnet18, ResNet18_Weights
from pathlib import Path
import wandb
import sys
import matplotlib.pyplot as plt
import numpy as np
import os
from quickai.trainer import Trainer
from quickai.utils import model_size, load_from_checkpoint
from quickai.callbacks import OverfitCallback, EarlyStoppingCallback
from quickai.logger import WandbLogger
from quickai.dataset import MapDataset

sys.path.append("../src")

from models import ResNet18
from module import ResNetModule
import settings as s
from dataset import ObjectDetectionDataset

In [2]:
data_path = Path("../data")
logs_path = Path("../logs")
logs_path.mkdir(exist_ok=True)

In [3]:
logger = WandbLogger(
    project_name=s.project_name,
    config={
        "model": s.model,
        "dataset": s.dataset,
        "max_epochs": s.max_epochs,
        "optimizer": s.optimizer,
        "lr_scheduler": s.lr_scheduler,
        "test_run": s.test_run,
        "transfer_learning": s.transfer_learning
    },
    logs_path=logs_path,
    offline=s.wandb_offline
)

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [4]:
# cpu_count = os.cpu_count()
cpu_count = 7

dataset = ObjectDetectionDataset(data_path)

train_dataset, val_dataset = random_split(
    dataset, [s.dataset["train_split"], s.dataset["val_split"]]
)

# val_transforms = ResNet18_Weights.IMAGENET1K_V1.transforms()

normalize_transforms = v2.Compose([
    # Normalize
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
])

train_transforms = v2.Compose([
    # v2.RandomHorizontalFlip(),
    # v2.RandomVerticalFlip(p=0.2),
    # v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    # v2.RandomRotation(degrees=15),
    # v2.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    # v2.RandomGrayscale(p=0.1),
    
    v2.Resize(224),
    v2.CenterCrop(224),
    normalize_transforms
])

train_dataset = MapDataset(train_dataset, transform=train_transforms)
val_dataset = MapDataset(val_dataset, transform=normalize_transforms)

train_dataloader = DataLoader(
    train_dataset, batch_size=s.dataset["batch_size"], shuffle=True, num_workers=cpu_count)
val_dataloader = DataLoader(
    val_dataset, batch_size=s.dataset["batch_size"],  num_workers=cpu_count)

len(dataset)

Loading from disk!


  self.single_object_images = torch.load(data_path)


1865

In [5]:
# inverse_norm = v2.Compose([
#     v2.Normalize(mean=[0., 0., 0.], std=[1 / 0.2023, 1 / 0.1994, 1 / 0.2010]),
#     v2.Normalize(mean=(-0.4914, -0.4822, -0.4465), std=[1., 1., 1.]),
# ])

# for x, y in train_dataset:
#     break

# image = inverse_norm(x)
# image = (image * 255).permute(1, 2, 0)
# image = image.to(torch.int)

# plt.imshow(image)
# plt.show()

In [6]:
callbacks = [
    EarlyStoppingCallback(min_val_accuracy=90.0, accuracy_diff=5.0, wait_epochs=5),
    # OverfitCallback(limit_batches=1, batch_size=10, max_epochs=500, augument_data=False)
]

In [7]:
api = wandb.Api()
run = api.run("sampath017/ImageClassification/24z2beff")
artifact = api.artifact('sampath017/ImageClassification/run-24z2beff-best_val_acc_68.66.pt:v0', type='model')
local_path = artifact.download(root=logs_path)
checkpoint = torch.load(Path(local_path)/"best_val_acc_68.66.pt", weights_only=True, map_location=device)

model = ResNet18(num_classes=100)
model.load_state_dict(checkpoint["model"])

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/codespace/.netrc
[34m[1mwandb[0m: Downloading large artifact run-24z2beff-best_val_acc_68.66.pt:v0, 319.00MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:27.1


<All keys matched successfully>

: 

In [None]:
model.classifier = nn.Sequential(
    nn.Flatten(),
    nn.Linear(512, 10),
)

summary(
    model,
    input_size=(s.dataset["batch_size"],
                *train_dataset[0][0].shape),
    device="cpu",
    mode="train",
    depth=1
)

In [7]:
module = ResNetModule(model)

optimizer = optim.AdamW(
    params=module.model.parameters(),
    weight_decay=s.optimizer["weight_decay"]
)

try:
    if s.lr_scheduler["name"] == "OneCycleLR":
        lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optimizer,
            max_lr=s.lr_scheduler["max_lr"],
            epochs=s.max_epochs,
            steps_per_epoch=len(train_dataloader),
        )

        print(s.lr_scheduler["name"])
except TypeError:
    lr_scheduler = None
    print("lr_scheduler is None!")

OneCycleLR


In [8]:
trainer = Trainer(
    module=module,
    logger=logger,
    optimizer=optimizer,
    callbacks=callbacks,
    logs_path=logs_path,
    fast_dev_run=s.fast_dev_run,
    limit_batches=s.limit_batches,
    lr_scheduler=lr_scheduler,
    save_checkpoint_type="best_val",
    num_workers=cpu_count
)

Using device: cuda!


In [9]:
try:
    trainer.fit(train_dataloader, val_dataloader)
except KeyboardInterrupt as e:
    print("Run stopped!")
finally:
    wandb.finish()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msampath017[0m. Use [1m`wandb login --relogin`[0m to force relogin


Time per epoch: 4.38 seconds
Epoch: 0, train_accuracy: 20.00, val_accuracy: 15.06, lr: 0.0004
Epoch: 1, train_accuracy: 34.53, val_accuracy: 14.66, lr: 0.0005
Epoch: 2, train_accuracy: 33.19, val_accuracy: 14.99, lr: 0.0008
Epoch: 3, train_accuracy: 28.71, val_accuracy: 5.06, lr: 0.0013
Epoch: 4, train_accuracy: 34.70, val_accuracy: 12.86, lr: 0.0020
Epoch: 5, train_accuracy: 35.39, val_accuracy: 12.98, lr: 0.0028
Epoch: 6, train_accuracy: 32.21, val_accuracy: 9.46, lr: 0.0038
Epoch: 7, train_accuracy: 31.72, val_accuracy: 12.13, lr: 0.0048
Epoch: 8, train_accuracy: 30.86, val_accuracy: 10.52, lr: 0.0058
Epoch: 9, train_accuracy: 31.69, val_accuracy: 11.89, lr: 0.0068
Epoch: 10, train_accuracy: 27.62, val_accuracy: 12.47, lr: 0.0077
Epoch: 11, train_accuracy: 32.21, val_accuracy: 8.58, lr: 0.0085
Epoch: 12, train_accuracy: 33.25, val_accuracy: 12.39, lr: 0.0092
Epoch: 13, train_accuracy: 31.68, val_accuracy: 14.56, lr: 0.0096
Epoch: 14, train_accuracy: 31.11, val_accuracy: 15.71, lr: 0

VBox(children=(Label(value='40.232 MB of 128.165 MB uploaded\r'), FloatProgress(value=0.3139062845193023, max=…

0,1
epoch,▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▇▇▇▇▇▇████
epoch_train_accuracy,▁▂▂▂▂▂▂▂▂▂▂▂▂▂▃▃▃▃▃▃▄▃▄▄▄▄▅▅▅▅▆▆▇▇▇▇▇███
epoch_train_loss,█▇▆▇▆▇▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▄▄▄▄▄▄▃▃▃▂▂▂▂▁▁▁▁▁
epoch_val_accuracy,▃▃▃▁▃▂▃▂▃▂▃▄▃▄▅▆▇▅▅▇▆▆▆▆▆▆▇▇▇█▇██████▇▇▇
epoch_val_loss,▁▁▁▂▃█▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,▁▂▂▄▄▆▆▇▇████▇▇▇▇▇▆▆▆▅▅▅▄▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁
step_train_accuracy,▁▂▁▂▂▂▂▂▂▂▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▄▄▅▆▄▄▆▆█▇███
step_train_loss,█▇▇▆▇▇▇▇▇▇▆▇▆▆▆▅▆▆▅▄▅▄▄▄▄▄▃▂▂▂▂▁▂▁▁▁▁▁▁▁
step_val_accuracy,▂▃▂▂▂▃▂▃▁▃▃▃▃▅▃▄▄▅▅▅▅▅▃▆▅▆▆▆█▆▇▇▇▇▆▆▆▆▆▆
step_val_loss,▁▁▁▃▄█▃▅▃▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,49
epoch_train_accuracy,94.53125
epoch_train_loss,0.28311
epoch_val_accuracy,31.08101
epoch_val_loss,2.86014
lr,1e-05
model_architecture,ResNet(  (conv1): C...
step_train_accuracy,100
step_train_loss,0.18841
step_val_accuracy,31.91489


[Metrics](https://api.wandb.ai/links/sampath017/iwrrziwg)