# `torch.jit`

Eager execution is great for development and debugging. but it can be hard to (automatically) optimize the code and deploy it.

Now there is`torch.jit` with two flavours:

- `torch.jit.trace` does not record control flow.
- `torch.jit.script` records control flow and creates an intermediate representation that can be optimized; only supports a subset of Python.

Note: don't forget `model.eval()` and `model.train()`.


## Ref and More:
- https://pytorch.org/docs/stable/jit.html
- https://speakerdeck.com/perone/pytorch-under-the-hood
- https://lernapparat.de/fast-lstm-pytorch/

## Init, helpers, utils, ...

In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

In [None]:
from pprint import pprint

import matplotlib.pyplot as plt
import numpy as np
from IPython.core.debugger import set_trace

import utils  # little helpers
from utils import attr

# `torch.jit.trace`

In [None]:
def f(x):
    if x.item() < 0:
        return torch.tensor(0)
    else:
        return x

In [None]:
f(torch.tensor(-1))

In [None]:
f(torch.tensor(3))

In [None]:
X = torch.tensor(1)
traced = torch.jit.trace(f, X)

In [None]:
type(traced)

In [None]:
traced(torch.tensor(1))

In [None]:
traced.graph

In [None]:
traced(torch.tensor(-1))

## Storing and restoring

In [None]:
traced.save("traced.pt")

In [None]:
!file scripted.pt

In [None]:
g = torch.jit.load("traced.pt")

In [None]:
g(torch.tensor(1))

In [None]:
g(torch.tensor(-1))

# `torch.jit.script`

In [None]:
bool(torch.tensor(1) < 2)

In [None]:
@torch.jit.script
def f(x):
    if bool(x < 0):
        result = torch.zeros(1)
    else:
        result = x
    return result

This is `torchscript` which is a only a supset of python.

In [None]:
f(torch.tensor(-1))

In [None]:
f(torch.tensor(1))

In [None]:
type(f)

In [None]:
f.graph

## Storing and restoring

In [None]:
torch.jit.save(f, "scripted.pt")

In [None]:
!file scripted.pt

In [None]:
g = torch.jit.load("scripted.pt")

In [None]:
g(torch.tensor(-1))

In [None]:
g(torch.tensor(1))

## Subclassing `torch.jit.ScriptModule`
If you work with `nn.Module` replace it by `torch.jit.ScriptModule` (see [[tutorial]](https://pytorch.org/tutorials/beginner/deploy_seq2seq_hybrid_frontend_tutorial.html) for more).

```python
class MyModule(torch.jit.ScriptModule):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        # ...
        return x
```

# PyTorch and C++

PyTorch offers a very nice(!) C++ interface which is very close to Python.

## Loading traced models from C++

```c++
#include <torch/script.h>

int main(int(argc, const char* argv[]) {
    auto module = torch::jit::load("scrpted.pt");
    // data ...
    module->forward(data);
}
```