# Image Classification with PyTorch

Before we explore the ways that we can use Dask's power to accelerate an image classification problem, we should discuss a little bit about the image classification process we'll be using.

### ImageNet and ResNet

ImageNet is a powerful database of over 14 million labeled images created by academics, which lets us train neural nets for computer vision tasks. Thanks to this, scholars have been able to pretrain lots of models that are useful for general computer vision tasks, such as the one we'll use in Chapter 4. 

The ResNet model originates from [a 2015 publication](https://arxiv.org/abs/1512.03385), which introduces the residual learning framework for neural networks. In a residual learning neural network subsequent layers train against the residual, instead of against a completely new function - the resulting effect is that deeper layered networks are able to gain performance better than equally deep networks that are not training against the residual. If you're interested in learning more, the paper is a great resource.

We won't dig in to the model much more than that, but it's interesting to know. As we are using ResNet50, our pretrained model is 50 layers.

If you care to learn more about the way that computer vision and neural nets work, there are numerous great courses and books on the subject!

## How We'll Use It

We are going to use PyTorch for the demonstration in Chapter 4, so it's good to take a quick look at the infrastructure and how we'll be using it.


### Model

The PyTorch ecosystem handily offers computer vision datasets, transformation tools, and prebuilt models in the `torchvision` library, which we'll use here to load ResNet50.

In [None]:
from torchvision import datasets, transforms, models
resnet = models.resnet50(pretrained=True)

### Datasets

You can load the images you want directly from S3 or another cloud storage system - we're using a public S3 bucket on AWS, where the Stanford Dogs dataset has been placed.

In [5]:
import s3fs
from PIL import Image

s3 = s3fs.S3FileSystem(anon=True)

with s3.open("s3://saturn-public-data/dogs/2-dog.jpg", 'rb') as f:
    img = Image.open(f).convert("RGB")
    
transform = transforms.Compose([
    transforms.Resize(256), 
    transforms.CenterCrop(250), 
    transforms.ToTensor()])

### Inference Task

Finally, we create a function that runs through the inference task. 

In [6]:
import torch
to_pil = transforms.ToPILImage()

def classify_img(transform, img, model):
    img_t = transform(img)
    batch_t = torch.unsqueeze(img_t, 0)

    resnet.eval()
    out = model(batch_t)
    
    _, indices = torch.sort(out, descending=True)
    percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
    labelset = [(classes[idx], percentage[idx].item()) for idx in indices[0][:5]]
    return to_pil(img_t), labelset

Key aspects of the function to pay attention to include:

* `img_t = transform(img)` : we must run the transformation we defined above on every image before we try to classify it.  
* `batch_t = torch.unsqueeze(img_t, 0)` : this step reshapes our image tensors to allow the model to accept it.
* `resnet.eval()` : When we download the model, it can either be in training or in evaluation mode. We need it in evaluation mode here, so that it can return the predicted labels to us without changing itself.
* `out = model(batch_t)` : This step actually evaluates the images. We are using batches of images here, so many can be classified at once.

### Results Processing

* `_, indices = torch.sort(out, descending=True)` : Sorts the results, high score to low (gives us the most likely labels at the top).
* `percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100` : Rescales the scores from the model to probabilities (returns probabilities of each label) .
* `labelset = [(classes[idx], percentage[idx].item()) for idx in indices[0][:5]]` : Interprets the top five labels in human readable form.

In [6]:
def classify_img(transform, img, model):
    img_t = transform(img)
    batch_t = torch.unsqueeze(img_t, 0)

    resnet.eval()
    out = model(batch_t)
    
    _, indices = torch.sort(out, descending=True)
    percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
    labelset = [(classes[idx], percentage[idx].item()) for idx in indices[0][:5]]
    return to_pil(img_t), labelset

## Running the Classification

We just have to run the function at this point, and we'll get human readable results as well as an image we can look at.

In [None]:
%%time

dogpic, labels = classify_img(transform, img, resnet)

In [None]:
dogpic

In [None]:
labels

We have proved our image classification can run on a single image and is effective! This sets us up to complete our case study, translating this use case to a highly parallel job on a Dask cluster.