In [2]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import v2
from pathlib import Path
import sys
import torch.nn.functional as F

sys.path.append("../src")

from utils import accuracy
from utils import load_from_checkpoint
from trainer import Trainer
from models import ResNet9

In [3]:
data_path = Path("../data")

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [4]:
model = ResNet9()
checkpoint_path = "/workspace/ImageClassification/logs/wandb/run-20241216_180718-0g5r3xwq/checkpoints/best_val_acc_91.24.pt"
model, _, _ = load_from_checkpoint(checkpoint_path, model=model, device=device)
model

ResNet9(
  (feature_extractor): Sequential(
    (0): ConvBlock(
      (block): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
    )
    (1): ConvBlock(
      (block): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
      )
    )
    (2): ResBlock(
      (block): Sequential(
        (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)

In [5]:
val_transforms = v2.Compose([
    # Normalize
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
])

test_dataset = CIFAR10(data_path, train=False, transform=val_transforms)

test_dataloader = DataLoader(test_dataset, batch_size=1024)

In [6]:
@torch.no_grad()
def test(test_dataloader):
    step_test_losses = []
    step_test_accuracies = []

    model.eval()
    for batch in test_dataloader:
        x, y = batch
        x = x.to(device)
        y = y.to(device)

        logits = model(x)
        loss = F.cross_entropy(logits, y)
        acc = accuracy(logits, y)

        step_test_loss = loss.item()
        step_test_accuracy = acc.item()
        step_test_losses.append(step_test_loss)
        step_test_accuracies.append(step_test_accuracy)

    val_loss = torch.tensor(step_test_losses).mean()
    val_accuracy = torch.tensor(step_test_accuracies).mean()

    return val_loss, val_accuracy

In [7]:
loss, accuracy = test(test_dataloader)
print(f"{loss=:.2f}, {accuracy=:.2f}")

loss=0.29, accuracy=90.42


In [19]:
# api = wandb.Api()
# run = api.run("sampath017/ImageClassification/ofrc6h0p")
# config = run.config

# artifact = api.artifact("sampath017/ImageClassification/run-ofrc6h0p-model_19.pt:v0")
# local_path = artifact.download()
# model = ResNet()
# model.load_state_dict(torch.load(Path(local_path)/"model_19.pt", map_location=device))
# model

{'epochs': 20,
 'optimizer': {'name': 'Adam', 'max_lr': 0.01, 'weight_decay': 0.0001},
 'val_split': 8000,
 'batch_size': 1024,
 'train_split': 42000,
 'model_architecture': 'ResNet'}