In [12]:
import torch
import torch.nn as nn
import torchvision
from torchvision import models
import torchvision.transforms as T
from torch.utils.data import DataLoader
import os
import matplotlib.pyplot as plt
import numpy as np

## Reference: 
https://github.com/Muhammad-MujtabaSaeed/Stanford-Dogs-Classification/blob/master/Stanford_Dogs_Classification.ipynb

In [5]:
working_dir = "./data"
pets_path_val = os.path.join(working_dir, "OxfordPets", "validation")

In [8]:
data_transforms = {
    "Validation": T.Compose(
        [
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
}

In [9]:
batch_size = 8

In [10]:
pets_test = torchvision.datasets.OxfordIIITPet(
    root=pets_path_val,
    split="test",
    target_types="category",
    download=True,
    transform=data_transforms["Validation"],
)
test_dataloader = DataLoader(pets_test, batch_size=batch_size, shuffle=True)
class_names = pets_test.classes

Downloading https://thor.robots.ox.ac.uk/datasets/pets/images.tar.gz to data/OxfordPets/Validation/oxford-iiit-pet/images.tar.gz


  4%|▍         | 32571392/791918971 [00:06<02:22, 5331759.31it/s]


KeyboardInterrupt: 

In [3]:
model_ft = models.resnet18()  # loading a pre-trained(trained on image net) resnet18 model
num_ftrs = model_ft.fc.in_features  # number of features
model_ft.fc = nn.Linear(num_ftrs, 120)

In [4]:
checkpoint = torch.load('./resnet_best_model.pth')
model_ft.load_state_dict(checkpoint['model'])

<All keys matched successfully>

In [None]:
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


In [None]:
def visualize_model(model, data_loader, num_images=6):
    images_so_far = 0
    fig = plt.figure()
    
    for i, data in enumerate(data_loader):
        inputs, labels = data

        inputs, labels = inputs, labels

        outputs = model(inputs)
        _, preds = torch.max(outputs.data, 1)

        for j in range(inputs.size()[0]):
            images_so_far += 1
            ax = plt.subplot(num_images//2, 2, images_so_far)
            ax.axis('off')
            ax.set_title('class: {} predicted: {}'.format(class_names[labels.data[j]], class_names[preds[j]]))
            imshow(inputs.cpu().data[j])

            if images_so_far == num_images:
                return
     