In [7]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from torchvision.transforms import v2
from pathlib import Path
import sys
import wandb
import torch.nn.functional as F
from torchinfo import summary

sys.path.append("../src")

from utils import accuracy
from utils import load_from_checkpoint
from trainer import Trainer
from models import ResNet18

In [2]:
data_path = Path("../data")
logs_path = Path("../logs")
logs_path.mkdir(exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [3]:
test_transforms = v2.Compose([
    # Normalize
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
])

test_dataset = CIFAR100(data_path, train=False, transform=test_transforms, download=True)

test_dataloader = DataLoader(test_dataset, batch_size=1024)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ..\data\cifar-100-python.tar.gz


100.0%


Extracting ..\data\cifar-100-python.tar.gz to ..\data


In [4]:
@torch.no_grad()
def test(model, test_dataloader, device):
    model.to(device)
    step_test_losses = []
    step_test_accuracies = []

    model.eval()
    num_batches = len(test_dataloader)
    for index, batch in enumerate(test_dataloader, start=1):
        x, y = batch
        x = x.to(device)
        y = y.to(device)

        logits = model(x)
        loss = F.cross_entropy(logits, y)
        acc = accuracy(logits, y)

        step_test_loss = loss.item()
        step_test_accuracy = acc.item()
        step_test_losses.append(step_test_loss)
        step_test_accuracies.append(step_test_accuracy)

        print(f"Batch: {index}/{num_batches}, accuracy: {step_test_accuracy:.2f}")

    test_loss = torch.tensor(step_test_losses).mean()
    test_accuracy = torch.tensor(step_test_accuracies).mean()

    return test_loss, test_accuracy

In [5]:
api = wandb.Api()
run = api.run("sampath017/ImageClassification/24z2beff")
artifact = api.artifact('sampath017/ImageClassification/run-24z2beff-best_val_acc_68.66.pt:v0', type='model')
local_path = artifact.download(root=logs_path)
checkpoint = torch.load(Path(local_path)/"best_val_acc_68.66.pt", weights_only=True, map_location=device)

model = ResNet18(num_classes=100)
model.load_state_dict(checkpoint["model"])
summary(model, input_size=(test_dataset[0][0].shape), batch_size=1024, device="cpu")

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Downloading large artifact run-24z2beff-best_val_acc_68.66.pt:v0, 319.00MB. 1 files... 
wandb:   1 of 1 files downloaded.  
Done. 0:0:42.3


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [1024, 64, 32, 32]           1,728
       BatchNorm2d-2         [1024, 64, 32, 32]             128
              ReLU-3         [1024, 64, 32, 32]               0
         ConvBlock-4         [1024, 64, 32, 32]               0
            Conv2d-5        [1024, 128, 32, 32]          73,728
       BatchNorm2d-6        [1024, 128, 32, 32]             256
              ReLU-7        [1024, 128, 32, 32]               0
         ConvBlock-8        [1024, 128, 32, 32]               0
            Conv2d-9        [1024, 128, 32, 32]         147,456
      BatchNorm2d-10        [1024, 128, 32, 32]             256
             ReLU-11        [1024, 128, 32, 32]               0
        ConvBlock-12        [1024, 128, 32, 32]               0
           Conv2d-13        [1024, 128, 32, 32]         147,456
      BatchNorm2d-14        [1024, 128,

In [6]:
loss, accuracy = test(model, test_dataloader, device)
print(f"\n{loss=:.2f}, {accuracy=:.2f}")

Batch: 1/10, accuracy: 66.21
Batch: 2/10, accuracy: 67.77
Batch: 3/10, accuracy: 64.45
Batch: 4/10, accuracy: 67.19
Batch: 5/10, accuracy: 66.21
Batch: 6/10, accuracy: 64.16
Batch: 7/10, accuracy: 66.50
Batch: 8/10, accuracy: 65.72
Batch: 9/10, accuracy: 67.38
Batch: 10/10, accuracy: 66.45

loss=1.51, accuracy=66.21
