# `torch.jit`

Eager execution is great for development and debugging. but it can be hard to (automatically) optimize the code and deploy it.

Now there is`torch.jit` with two flavours:

- `torch.jit.trace` does not record control flow.
- `torch.jit.script` records control flow and creates an intermediate representation that can be optimized; only supports a subset of Python.

Note: don't forget `model.eval()` and `model.train()`.


## Ref and More:
- https://pytorch.org/docs/stable/jit.html
- https://speakerdeck.com/perone/pytorch-under-the-hood
- https://lernapparat.de/fast-lstm-pytorch/

## Init, helpers, utils, ...

In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

In [3]:
from pprint import pprint

import matplotlib.pyplot as plt
import numpy as np
from IPython.core.debugger import set_trace

import utils  # little helpers
from utils import attr

# `torch.jit.trace`

In [4]:
def f(x):
    if x.item() < 0:
        return torch.tensor(0)
    else:
        return x

In [5]:
f(torch.tensor(-1))

tensor(0)

In [6]:
f(torch.tensor(3))

tensor(3)

In [7]:
X = torch.tensor(1)
traced = torch.jit.trace(f, X)

  if x.item() < 0:


In [8]:
type(traced)

torch.jit.ScriptFunction

In [9]:
traced(torch.tensor(1))

tensor(1)

In [10]:
traced.graph

graph(%0 : Long(requires_grad=0, device=cpu)):
  return (%0)

In [11]:
traced(torch.tensor(-1))

tensor(-1)

## Storing and restoring

In [12]:
traced.save("traced.pt")

In [13]:
!file scripted.pt

scripted.pt: Zip archive data, at least v?[0] to extract


In [14]:
g = torch.jit.load("traced.pt")

In [15]:
g(torch.tensor(1))

tensor(1)

In [16]:
g(torch.tensor(-1))

tensor(-1)

# `torch.jit.script`

In [17]:
bool(torch.tensor(1) < 2)

True

In [18]:
@torch.jit.script
def f(x):
    if bool(x < 0):
        result = torch.zeros(1)
    else:
        result = x
    return result

This is `torchscript` which is a only a supset of python.

In [19]:
f(torch.tensor(-1))

tensor([0.])

In [20]:
f(torch.tensor(1))

tensor(1)

In [21]:
type(f)

torch.jit.ScriptFunction

In [22]:
f.graph

graph(%x.1 : Tensor):
  %8 : None = prim::Constant()
  %2 : int = prim::Constant[value=0]() # <ipython-input-1-5b977b5b82b7>:3:16
  %5 : int = prim::Constant[value=1]() # <ipython-input-1-5b977b5b82b7>:4:29
  %3 : Tensor = aten::lt(%x.1, %2) # <ipython-input-1-5b977b5b82b7>:3:12
  %4 : bool = aten::Bool(%3) # <ipython-input-1-5b977b5b82b7>:3:7
  %result : Tensor = prim::If(%4) # <ipython-input-1-5b977b5b82b7>:3:4
    block0():
      %7 : int[] = prim::ListConstruct(%5)
      %result.1 : Tensor = aten::zeros(%7, %8, %8, %8, %8) # <ipython-input-1-5b977b5b82b7>:4:17
      -> (%result.1)
    block1():
      -> (%x.1)
  return (%result)

## Storing and restoring

In [23]:
torch.jit.save(f, "scripted.pt")

In [24]:
!file scripted.pt

scripted.pt: Zip archive data, at least v?[0] to extract


In [25]:
g = torch.jit.load("scripted.pt")

In [26]:
g(torch.tensor(-1))

tensor([0.])

In [27]:
g(torch.tensor(1))

tensor(1)

## Subclassing `torch.jit.ScriptModule`
If you work with `nn.Module` replace it by `torch.jit.ScriptModule` (see [[tutorial]](https://pytorch.org/tutorials/beginner/deploy_seq2seq_hybrid_frontend_tutorial.html) for more).

```python
class MyModule(torch.jit.ScriptModule):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        # ...
        return x
```

# PyTorch and C++

PyTorch offers a very nice(!) C++ interface which is very close to Python.

## Loading traced models from C++

```c++
#include <torch/script.h>

int main(int(argc, const char* argv[]) {
    auto module = torch::jit::load("scrpted.pt");
    // data ...
    module->forward(data);
}
```