# Imports

In [None]:
import os
import torch
import pandas as pd
import numpy as np
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, random_split, DataLoader
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision.utils import make_grid
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
from collections import OrderedDict
import glob

%matplotlib inline

# Inits

In [None]:
PATH = 'dog-breed-classifier-wideresnet_with_data_aug.pth'

In [None]:
inference_transform = transforms.Compose([
    transforms.Resize((168,168)), 
    transforms.ToTensor(),
])

In [None]:
class ImageClassificationBase(nn.Module):
    # training step
    def training_step(self, batch):
        img, targets = batch
        out = self(img)
        loss = F.nll_loss(out, targets)
        return loss
    
    # validation step
    def validation_step(self, batch):
        img, targets = batch
        out = self(img)
        loss = F.nll_loss(out, targets)
        acc = accuracy(out, targets)
        return {'val_acc':acc.detach(), 'val_loss':loss.detach()}
    
    # validation epoch end
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()
        return {'val_loss':epoch_loss.item(), 'val_acc':epoch_acc.item()}
        
    # print result end epoch
    def epoch_end(self, epoch, result):
        print("Epoch [{}] : train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(epoch, result["train_loss"], result["val_loss"], result["val_acc"]))
        
        
class DogBreedPretrainedWideResnet(ImageClassificationBase):
    def __init__(self):
        super().__init__()
        
        self.network = models.wide_resnet50_2(pretrained=True)
        # Replace last layer
        num_ftrs = self.network.fc.in_features
        self.network.fc = nn.Sequential(
            nn.Linear(num_ftrs, 120),
            nn.LogSoftmax(dim=1)
        )
        
    def forward(self, xb):
        return self.network(xb)

# Load

In [None]:
# load model
model = DogBreedPretrainedWideResnet()
model.load_state_dict(torch.load(PATH))
model.eval()

In [None]:
# load list of breeds
breeds = pd.read_csv('data/breeds.csv')

# Inference

In [None]:
def predict_single(model, breeds, img):
    
    test_img = inference_transform(img)
    output = model(test_img.unsqueeze(0))
    prediction = output[0]
    index = torch.max(prediction, dim=0)[-1].item()
    label = breeds.loc[index].values[0]
    
    to_tensor = transforms.Compose([transforms.ToTensor()])
    plt.imshow(to_tensor(img).permute(1,2,0))    
    plt.show()
    print('Predicted :', label)
    
    return

In [None]:
for f in glob.iglob("data/test/*"):
    test_image = Image.open(f)
    predict_single(model, breeds, test_image)

In [None]:
breeds