In [32]:
from PIL import Image
import numpy as np
import torch
from torchvision import datasets, transforms, models

# Inference for classification

Now you'll write a function to use a trained network for inference. That is, you'll pass an image into the network and predict the class of the flower in the image. Write a function called `predict` that takes an image and a model, then returns the top $K$ most likely classes along with the probabilities. It should look like 

```python
probs, classes = predict(image_path, model)
print(probs)
print(classes)
> [ 0.01558163  0.01541934  0.01452626  0.01443549  0.01407339]
> ['70', '3', '45', '62', '55']
```

First you'll need to handle processing the input image such that it can be used in your network. 

## Image Preprocessing

You'll want to use `PIL` to load the image ([documentation](https://pillow.readthedocs.io/en/latest/reference/Image.html)). It's best to write a function that preprocesses the image so it can be used as input for the model. This function should process the images in the same manner used for training. 

First, resize the images where the shortest side is 256 pixels, keeping the aspect ratio. This can be done with the [`thumbnail`](http://pillow.readthedocs.io/en/3.1.x/reference/Image.html#PIL.Image.Image.thumbnail) or [`resize`](http://pillow.readthedocs.io/en/3.1.x/reference/Image.html#PIL.Image.Image.thumbnail) methods. Then you'll need to crop out the center 224x224 portion of the image.

Color channels of images are typically encoded as integers 0-255, but the model expected floats 0-1. You'll need to convert the values. It's easiest with a Numpy array, which you can get from a PIL image like so `np_image = np.array(pil_image)`.

As before, the network expects the images to be normalized in a specific way. For the means, it's `[0.485, 0.456, 0.406]` and for the standard deviations `[0.229, 0.224, 0.225]`. You'll want to subtract the means from each color channel, then divide by the standard deviation. 

And finally, PyTorch expects the color channel to be the first dimension but it's the third dimension in the PIL image and Numpy array. You can reorder dimensions using [`ndarray.transpose`](https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.ndarray.transpose.html). The color channel needs to be first and retain the order of the other two dimensions.

To check your work, the function below converts a PyTorch tensor and displays it in the notebook. If your `process_image` function works, running the output through this function should return the original image (except for the cropped out portions).

In [33]:
#for self test

filepath1 = 'flowers/test/1/image_06743.jpg'
image1 = Image.open(filepath1)

print('image1 format: ',image1.format )
print('image1 size: ',image1.size )
print('image1 mode: ',image1.mode )


image1 format:  JPEG
image1 size:  (500, 601)
image1 mode:  RGB


In [34]:
size = (256,256)
image1 = image1.resize(size)

In [35]:
print('image1 format: ',image1.format )
print('image1 size: ',image1.size )
print('image1 mode: ',image1.mode )

image1 format:  None
image1 size:  (256, 256)
image1 mode:  RGB


In [36]:
width, height = image1.size
print('width: ', width)
print('height: ', height)
new_width = 224
new_height =224
left = (width - new_width)/2
top = (height - new_height)/2
right = (width + new_width)/2
bottom = (height + new_height)/2
print('left: ',left)
print('top: ',top)
print('right: ',right)
print('bottom: ',bottom)
image1 = image1.crop((left, top, right, bottom))

width:  256
height:  256
left:  16.0
top:  16.0
right:  240.0
bottom:  240.0


In [37]:
print('image1 format: ',image1.format )
print('image1 size: ',image1.size )
print('image1 mode: ',image1.mode )

image1 format:  None
image1 size:  (224, 224)
image1 mode:  RGB


In [38]:
np_image1 = np.array(image1)

In [44]:
#image1 = np_image1/255
image1 = np_image1
image1

array([[[ 6, 13,  6],
        [13, 20, 13],
        [23, 28, 22],
        ..., 
        [33, 48, 25],
        [58, 78, 51],
        [61, 81, 56]],

       [[17, 24, 16],
        [21, 28, 20],
        [21, 26, 19],
        ..., 
        [12, 19, 11],
        [ 6, 20,  3],
        [14, 28, 11]],

       [[25, 32, 24],
        [23, 30, 22],
        [16, 21, 14],
        ..., 
        [14, 18, 19],
        [20, 30, 19],
        [28, 40, 26]],

       ..., 
       [[47, 63, 27],
        [46, 64, 26],
        [43, 64, 23],
        ..., 
        [37, 61,  3],
        [36, 62,  0],
        [34, 58,  0]],

       [[43, 61, 23],
        [39, 60, 19],
        [44, 67, 23],
        ..., 
        [29, 55,  0],
        [29, 56,  0],
        [30, 56,  0]],

       [[32, 52,  3],
        [29, 48,  2],
        [23, 42,  0],
        ..., 
        [28, 50,  4],
        [29, 49,  0],
        [28, 48,  0]]], dtype=uint8)

In [45]:
print('the dimensions of the array: ', image1.shape)
print('the total number of elements of the array: ', image1.size)
print('the type of the elements in the array: ',image1.dtype.name)

the dimensions of the array:  (224, 224, 3)
the total number of elements of the array:  150528
the type of the elements in the array:  uint8


In [46]:
#normalize the torch tensor
normalize = transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                 std = [0.229, 0.224, 0.225])
transform = transforms.Compose([transforms.ToTensor(), normalize])
torch_image1_norm = transform(image1)

In [47]:
print('tensor size: ',torch_image1_norm.size())
torch_image1_norm

tensor size:  torch.Size([3, 224, 224])


tensor([[[-2.0152, -1.8953, -1.7240,  ..., -1.5528, -1.1247, -1.0733],
         [-1.8268, -1.7583, -1.7583,  ..., -1.9124, -2.0152, -1.8782],
         [-1.6898, -1.7240, -1.8439,  ..., -1.8782, -1.7754, -1.6384],
         ...,
         [-1.3130, -1.3302, -1.3815,  ..., -1.4843, -1.5014, -1.5357],
         [-1.3815, -1.4500, -1.3644,  ..., -1.6213, -1.6213, -1.6042],
         [-1.5699, -1.6213, -1.7240,  ..., -1.6384, -1.6213, -1.6384]],

        [[-1.8081, -1.6856, -1.5455,  ..., -1.1954, -0.6702, -0.6176],
         [-1.6155, -1.5455, -1.5805,  ..., -1.7031, -1.6856, -1.5455],
         [-1.4755, -1.5105, -1.6681,  ..., -1.7206, -1.5105, -1.3354],
         ...,
         [-0.9328, -0.9153, -0.9153,  ..., -0.9678, -0.9503, -1.0203],
         [-0.9678, -0.9853, -0.8627,  ..., -1.0728, -1.0553, -1.0553],
         [-1.1253, -1.1954, -1.3004,  ..., -1.1604, -1.1779, -1.1954]],

        [[-1.6999, -1.5779, -1.4210,  ..., -1.3687, -0.9156, -0.8284],
         [-1.5256, -1.4559, -1.4733,  ..., -1