In [None]:
# inverse_norm = v2.Compose([
#     v2.Normalize(mean=[0., 0., 0.], std=[1 / 0.2023, 1 / 0.1994, 1 / 0.2010]),
#     v2.Normalize(mean=(-0.4914, -0.4822, -0.4465), std=[1., 1., 1.]),
# ])

# for x, y in train_dataset:
#     break

# image = inverse_norm(x)
# image = (image * 255).permute(1, 2, 0)
# image = image.to(torch.int)

# plt.imshow(image)
# plt.show()

In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import v2
from pathlib import Path
import sys
import wandb
import torch.nn.functional as F
from torchinfo import summary

sys.path.append("../src")

from utils import accuracy
from utils import load_from_checkpoint
from trainer import Trainer
from models import ResNet18
import settings as s

In [2]:
data_path = Path("../data")
logs_path = Path("../logs")
logs_path.mkdir(exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [3]:
test_transforms = v2.Compose([
    # Normalize
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
])

test_dataset = CIFAR10(data_path, train=False, transform=test_transforms, download=True)

test_dataloader = DataLoader(test_dataset, batch_size=1024)

Files already downloaded and verified


In [4]:
@torch.no_grad()
def test(model, test_dataloader, device):
    model.to(device)
    step_test_losses = []
    step_test_accuracies = []

    model.eval()
    num_batches = len(test_dataloader)
    for index, batch in enumerate(test_dataloader, start=1):
        x, y = batch
        x = x.to(device)
        y = y.to(device)

        logits = model(x)
        loss = F.cross_entropy(logits, y)
        acc = accuracy(logits, y)

        step_test_loss = loss.item()
        step_test_accuracy = acc.item()
        step_test_losses.append(step_test_loss)
        step_test_accuracies.append(step_test_accuracy)

        print(f"Batch: {index}/{num_batches}, accuracy: {step_test_accuracy:.2f}")

    test_loss = torch.tensor(step_test_losses).mean()
    test_accuracy = torch.tensor(step_test_accuracies).mean()

    return test_loss, test_accuracy

In [5]:
api = wandb.Api()
run = api.run("sampath017/ImageClassification/24z2beff")
artifact = api.artifact('sampath017/ImageClassification/run-5y973ba9-best_val_acc_93.97.pt:v0', type='model')
local_path = artifact.download(root=logs_path)
checkpoint = torch.load(Path(local_path)/"best_val_acc_93.97.pt", weights_only=True, map_location=device)

model = ResNet18(num_classes=10)
model.load_state_dict(checkpoint["model"])
summary(
    model,
    input_size=(1024, *test_dataset[0][0].shape),
    device="cpu",
    mode="train",
    depth=1
)

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Downloading large artifact run-5y973ba9-best_val_acc_93.97.pt:v0, 318.47MB. 1 files... 
wandb:   1 of 1 files downloaded.  
Done. 0:0:1.5


Layer (type:depth-idx)                             Output Shape              Param #
ResNet18                                           [1024, 10]                --
├─Sequential: 1-1                                  [1024, 512, 1, 1]         27,811,392
├─Sequential: 1-2                                  [1024, 10]                5,130
Total params: 27,816,522
Trainable params: 27,816,522
Non-trainable params: 0
Total mult-adds (Units.TERABYTES): 2.82
Input size (MB): 12.58
Forward/backward pass size (MB): 17851.04
Params size (MB): 111.27
Estimated Total Size (MB): 17974.89

In [6]:
loss, accuracy = test(model, test_dataloader, device)
print(f"\n{loss=:.2f}, {accuracy=:.2f}")

Batch: 1/10, accuracy: 92.68
Batch: 2/10, accuracy: 91.99
Batch: 3/10, accuracy: 94.24
Batch: 4/10, accuracy: 90.92
Batch: 5/10, accuracy: 92.38
Batch: 6/10, accuracy: 92.58
Batch: 7/10, accuracy: 92.87
Batch: 8/10, accuracy: 93.16
Batch: 9/10, accuracy: 92.68
Batch: 10/10, accuracy: 93.62

loss=0.30, accuracy=92.71
