# The PyTorch JIT
## PyCon 2019 - Berlin
### Tilman Krokotsch

# About Me

### Tilman Krokotsch
#### Deep Learning Engineer @ IAV GmbH automotive engineering
#### PhD Student @ TU Berlin under Prof. Clemens Gühmann

# Imports and Stuff

First of all we need imports for PyTorch itself, torchvision for the pretrained models, numpy for feeding PyTorch, pillow for image manipulation and json for loading JSON files.

In [11]:
import torch
import torchvision
import torch.jit as jit
import numpy as np
import PIL.Image
import json

from IPython.display import Image, display

In [12]:
with open('./imagenet_classes.json', mode='rt') as f:
    CLASS_DICT = json.load(f)
    
IMAGENET_MEAN = np.array([[[0.485, 0.456, 0.406]]], dtype=np.float32)
IMAGENET_STD = np.array([[[0.229, 0.224, 0.225]]], dtype=np.float32)

def predict_imagenet(net, image_file):
    # Load image and resize
    image = PIL.Image.open(image_file).resize((244, 244))
    # Convert to numpy and normalize
    image = np.array(image, dtype=np.float32) / 255.
    image = (image - IMAGENET_MEAN) / IMAGENET_STD
    # Convert to PyTorch and make channel first
    image = torch.as_tensor(image).unsqueeze(0).permute(0, 3, 1, 2)
    # Predict top class
    logits = net(image)
    class_idx = logits.squeeze(0).argmax().item()
    # Output predictions
    print('It is a %s.' % CLASS_DICT[str(class_idx)])
    display(Image(filename=image_file))

# Basic Functionality

We will use a pretrained AlexNet from the torchvision model zoo for this example. Let's load it and have a look at its architecture. Printing the network lets us know pretty much everything: layer types, layer order, kernel sizes and so on.

In [13]:
net = torchvision.models.alexnet(pretrained=True)
print(net)

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

Now we can convert our network into TorchScript, the language used by the JIT. For most feed forward networks this is done by tracing. We set the network into evaluation mode, as we want to deploy it, and define a representative input. The trace() function feeds the input through the forward() function of our network and records all operations. Out comes our desired ScriptedModule.

In [14]:
x = torch.randn(1, 3, 244, 244)
net.eval()
traced_net = jit.trace(net, x)
print(traced_net)

TracedModule[AlexNet](
  (features): TracedModule[Sequential](
    (0): TracedModule[Conv2d]()
    (1): TracedModule[ReLU]()
    (2): TracedModule[MaxPool2d]()
    (3): TracedModule[Conv2d]()
    (4): TracedModule[ReLU]()
    (5): TracedModule[MaxPool2d]()
    (6): TracedModule[Conv2d]()
    (7): TracedModule[ReLU]()
    (8): TracedModule[Conv2d]()
    (9): TracedModule[ReLU]()
    (10): TracedModule[Conv2d]()
    (11): TracedModule[ReLU]()
    (12): TracedModule[MaxPool2d]()
  )
  (avgpool): TracedModule[AdaptiveAvgPool2d]()
  (classifier): TracedModule[Sequential](
    (0): TracedModule[Dropout]()
    (1): TracedModule[Linear]()
    (2): TracedModule[ReLU]()
    (3): TracedModule[Dropout]()
    (4): TracedModule[Linear]()
    (5): TracedModule[ReLU]()
    (6): TracedModule[Linear]()
  )
)


Our traced network is now ready to be written to disk. For that we use the save() function of the jit module. It works the same as the conentional torch.save() function.

In [15]:
jit.save(traced_net, './model.pth')

# That's it folks!

We have saved our network. Now let's head over to a fresh notebook where we can load and test it out.