# Torch vision example

Source: https://www.learnopencv.com/pytorch-for-beginners-image-classification-using-pre-trained-models/
with minor modifications. Errors are likely mine.

In [None]:
from torchvision import models
import torch

# Let us look at the Deep learning architectures implemented in the torch vision library.
dir(models)

Notice that there is one entry called **AlexNet** and one called **alexnet**. The capitalised name refers to the Python class (AlexNet) whereas alexnet is a convenience function that returns the model instantiated from the AlexNet class.

In [None]:
from torchvision.models.alexnet import AlexNet_Weights

alexnet = models.alexnet(weights=AlexNet_Weights.DEFAULT)
print(alexnet)

In [None]:
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt


In [None]:
from torchvision import transforms
import json

def imagenet_class(idx: int, cache=None):
    if cache is None:
        cache = json.load(open(Path('./imagenet_classes.json')))["imagenet_classes"]
        assert len(cache) == 1000 and cache[11] == 'goldfinch'

    return cache[idx]

transform = transforms.Compose([        # Defining a variable transforms
   transforms.Resize(256),                # Resize the image to 256×256 pixels
   transforms.CenterCrop(224),            # Crop the image to 224×224 pixels about the center
   transforms.ToTensor(),                 # Convert the image to PyTorch Tensor data type
   transforms.Normalize(                  # Normalize the image
   mean=[0.485, 0.456, 0.406],            # Mean and std of image as also used when training the network
   std=[0.229, 0.224, 0.225]
 )])

def evaluate_image(img):
    img_t = transform(img)                # Apply the transformations on the image
    batch_t = torch.unsqueeze(img_t, 0)   # Add a batch dimension to the image
    alexnet.eval()                        # Set the network to evaluation mode
    out = alexnet(batch_t)                # Forward propagate the image

    _, indices = torch.sort(out, descending=True)
    percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
    return [(imagenet_class(idx), percentage[idx].item()) for idx in indices[0][:5]]

In [None]:
img = Image.open(Path('./dog.jpg'))
best = evaluate_image(img)
best

# running testing alexnet with stl10 torchvision dataset

As it happens, the INaturalist dataset was good for our purposes but way too big.
As [the STL-10 page notes](https://cs.stanford.edu/~acoates/stl10/) these
images were drawn from ImageNet so presumably are readily recognized.

In [None]:
from charmory.evaluation import SysConfig
from torchvision.datasets import STL10
from pprint import pprint

root = SysConfig().dataset_cache / 'stl10'
stl = STL10(root=root, split='test', download=True)

stl_classes = 'airplane, bird, car, cat, deer, dog, horse, monkey, ship, truck'.split(', ')


for i in range(2):
    img, label = stl[i]
    target = stl_classes[label]
    plt.imshow(img)
    plt.show()
    print(f"{label=}, {target=}")
    best = evaluate_image(img)
    pprint(best)

Well that didn't work. There are only 10 classes in STL-10 and the model outputs
1000 classes. As shown in the first output from the cell above, the first STL item
is a horse, but imagenet has no horse class. So we get the most likely as "oxcart"
which isn't a bad guess, but it is not a horse.


# using Imagenet-1k dataset

The Imagenet-1k dataset in torchvision requires tarfiles downloaded from image-net.org,
but I've applied to get access and it takes "up to 5 days". So I've downloaded parquet
files for the same dataset and have stashed them on our 
s3://armory-library-data/datasets/huggingface/imagenet/ so at least I don't have to
hunt them down again.

This means that I'm going to use a huggingface dataloader for now.

Or perhaps not, let's try to keep to the same framework, so I'll use the torchvision
VisionDataset class and see if I can get it to work.

# using Pyarrow to create a torchvision.VisionDataset

So what I've got is the HuggingFace parquet files, but I'd like them to look like a
torchvision dataset. Thus I get make a pyarrow table from the parquet and wrap it in the
`__init__` and `__getitem__` protocol of a `torchvision.VisionDataset`.

This has all been encapsulated in the `imagenet_tst.ImageNetTST` class.

In [1]:
from matplotlib import pyplot as plt
from imagenet_tst import get_local_imagenettst

imtst = get_local_imagenettst("val")


In [2]:
image, label = imtst[0]
print(f"{image=}, {label=} {imtst.label(label)=}")
plt.imshow(image)
plt.show()

TypeError: 'NoneType' object is not callable

In [None]:
for i in range(90, 100):
    img, label = imtst[i]
    print(f"{img=}, {label=}")
    target = imtst.label(label)
    plt.imshow(img)
    plt.show()
    print(f"{label=}, {target=}")
    best = evaluate_image(img)
    pprint(best)