## Imports

In [78]:
import numpy as np
from torch.nn import Module
from torchvision import models, transforms
import torch
from PIL import Image

## Initialisation

In [22]:

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])
with open('image_net_classes.txt') as file:
    classes = [line.strip().split(', ')[1] for line in file.readlines()]

## Inference

In [None]:
def infer(model: Module, images, use_gpu=True):
    model.eval()
    with torch.no_grad():
        if use_gpu:
            model.cuda()
        images_t = [transform(im) for im in images]
        batch = torch.cat([tensor for tensor in [torch.unsqueeze(im_t, 0) for im_t in images_t]])
        if use_gpu:
            out = model(batch.cuda())
        else:
            out = model(batch)

    for prediction in out:
        prediction = prediction.cpu()
        _, indices = torch.sort(prediction, descending=True)
        percentages = [(torch.nn.functional.softmax(prediction, dim=0)[class_index] * 100).item() for class_index in indices[:5]]

        print(f'Rank\tInferred class\tProbability(%)')
        for idx, class_index in enumerate(indices[:5]):
            print(f'#{idx}\t\t{classes[class_index]}\t{percentages[idx]}')
        print('-----------------------------------------')

img = Image.open('img/dog.jpg')
img2 = Image.open('img/strawberries.jpg')
img3 = Image.open('img/bald_eagle.jpg')
infer(models.alexnet(pretrained=True), [img, img2, img3], use_gpu=True)