# Torch vision example

Source: https://www.learnopencv.com/pytorch-for-beginners-image-classification-using-pre-trained-models/
with minor modifications. Errors are likely mine.

In [None]:
from torchvision import models
import torch

# Let us look at the Deep learning architectures implemented in the torch vision library.
dir(models)

Notice that there is one entry called **AlexNet** and one called **alexnet**. The capitalised name refers to the Python class (AlexNet) whereas alexnet is a convenience function that returns the model instantiated from the AlexNet class.

In [None]:
from torchvision.models.alexnet import AlexNet_Weights

alexnet = models.alexnet(weights=AlexNet_Weights.DEFAULT)
print(alexnet)

In [None]:
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt

img = Image.open(Path('./dog.jpg'))

plt.imshow(img)
img

In [None]:
img.resize((256,256))

In [None]:
from torchvision import transforms
transform = transforms.Compose([        # Defining a variable transforms
 transforms.Resize(256),                # Resize the image to 256×256 pixels
 transforms.CenterCrop(224),            # Crop the image to 224×224 pixels about the center
 transforms.ToTensor(),                 # Convert the image to PyTorch Tensor data type
 transforms.Normalize(                  # Normalize the image
 mean=[0.485, 0.456, 0.406],            # Mean and std of image as also used when training the network
 std=[0.229, 0.224, 0.225]
 )])

In [None]:
img_t = transform(img)
img_t.shape

In [None]:
batch_t = torch.unsqueeze(img_t, 0)

In [None]:
alexnet.eval()

In [None]:
out = alexnet(batch_t)
out.shape

In [None]:
import json
raw_classes = json.load(open(Path('./imagenet_classes.json')))["imagenet_classes"]
assert len(classes) == 1000 and classes[11] == 'goldfinch'
classes[:5]

In [None]:
_, indices = torch.sort(out, descending=True)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
[(classes[idx], percentage[idx].item()) for idx in indices[0][:10]]

# running testing alexnet with stl10 torchvision dataset

As it happens, the INaturalist dataset was good for our purposes but way too big.
As [the STL-10 page notes](https://cs.stanford.edu/~acoates/stl10/) these
images were drawn from ImageNet so presumably are readily recognized.

In [None]:
from charmory.evaluation import SysConfig
from torchvision.datasets import STL10

root = SysConfig().dataset_cache / 'stl10'
stl = STL10(root=root, split='test', download=True)

print(stl[0])
img, label = stl[0]
plt.imshow(img)