# Using the PyTorch JIT Compiler with Pyro
This tutorial shows how to use the PyTorch jit compiler in Pyro models.

## Summary:
* You can use compiled functions in Pyro models.

* You cannot use pyro primitives inside compiled functions.

* If your model has static structure, you can use a Jit* version of an ELBO algorithm, e.g.

```
- Trace_ELBO()
+ JitTrace_ELBO()

```
* The HMC and NUTS classes accept jit_compile=True kwarg.

* Models should input all tensors as *args and all non-tensors as **kwargs.

* Each different value of **kwargs triggers a separate compilation.

* Use **kwargs to specify all variation in structure (e.g. time series length).

* To ignore jit warnings in safe code blocks, use with pyro.util.ignore_jit_warnings():.

* To ignore all jit warnings in HMC or NUTS, pass ignore_jit_warnings=True.

## Table of contents
* Introduction

* A simple model

* Varying structure

In [5]:
import os
import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
from pyro import poutine
from pyro.distributions.util import broadcast_shape
from pyro.infer import Trace_ELBO, JitTrace_ELBO, TraceEnum_ELBO, JitTraceEnum_ELBO, SVI
from pyro.infer.mcmc import MCMC, NUTS
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.optim import Adam

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.8.4')

## Introduction
PyTorch 1.0 includes a jit compiler to speed up models. You can think of compilation as a “static mode”, whereas PyTorch usually operates in “eager mode”.

Pyro supports the jit compiler in two ways. First you can use compiled functions inside Pyro models (but those functions cannot contain Pyro primitives). Second, you can use Pyro’s jit inference algorithms to compile entire inference steps; in static models this can reduce the Python overhead of Pyro models and speed up inference.

The rest of this tutorial focuses on Pyro’s jitted inference algorithms: JitTrace_ELBO, JitTraceGraph_ELBO, JitTraceEnum_ELBO, JitMeanField_ELBO, HMC(jit_compile=True), and NUTS(jit_compile=True).

## A simple model
Let’s start with a simple Gaussian model and an autoguide.

In [7]:
def model(data):
    loc = pyro.sample("loc", dist.Normal(0., 10.))
    scale = pyro.sample("scale", dist.LogNormal(0., 3.))
    with pyro.plate("data", data.size(0)):
        pyro.sample("obs", dist.Normal(loc, scale), obs=data)

guide = AutoDiagonalNormal(model)

data = dist.Normal(0.5, 2.).sample((100,))

First let’s run as usual with an SVI object and Trace_ELBO.

In [14]:
%%time
pyro.clear_param_store()
elbo = Trace_ELBO()
svi = SVI(model, guide, Adam({'lr': 0.01}), elbo)
for i in range(2 if smoke_test else 1000):
    svi.step(data)

CPU times: user 946 ms, sys: 6.05 ms, total: 952 ms
Wall time: 951 ms


Next to run with a jit compiled inference, we simply replace
```
- elbo = Trace_ELBO()
+ elbo = JitTrace_ELBO()
```
Also note that the AutoDiagonalNormal guide behaves a little differently on its first invocation (it runs the model to produce a prototype trace), and we don’t want to record this warmup behavior when compiling. Thus we call the guide(data) once to initialize, then run the compiled SVI,

In [13]:
%%time
pyro.clear_param_store()

guide(data)   # Do any lazy initialization before compiling

elbo = JitTrace_ELBO()
svi = SVI(model, guide, Adam({'lr': 0.01}), elbo)
for i in range(2 if smoke_test else 1000):
    svi.step(data)

CPU times: user 525 ms, sys: 11.2 ms, total: 536 ms
Wall time: 548 ms


Notice that we have a more than 2x speedup for this small model.

Let us now use the same model, but we will instead use MCMC to generate samples from the model’s posterior. We will use the No-U-Turn(NUTS) sampler.

In [18]:
%%time
nuts_kernel = NUTS(model)
pyro.set_rng_seed(1)
mcmc_run = MCMC(nuts_kernel, num_samples=100).run(data)


Sample: 100%|██████████| 200/200 [00:00, 362.25it/s, step size=1.05e+00, acc. prob=0.839]

CPU times: user 549 ms, sys: 7.31 ms, total: 557 ms
Wall time: 556 ms





We can compile the potential energy computation in NUTS using the jit_compile=True argument to the NUTS kernel. We also silence JIT warnings due to the presence of tensor constants in the model by using ignore_jit_warnings=True.

In [19]:
%%time
nuts_kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True)
pyro.set_rng_seed(1)
mcmc_run = MCMC(nuts_kernel, num_samples=100).run(data)

Sample: 100%|██████████| 200/200 [00:00, 390.15it/s, step size=6.46e-01, acc. prob=0.963]

CPU times: user 509 ms, sys: 8.12 ms, total: 518 ms
Wall time: 516 ms





We notice some increase in sampling throughput when JIT compilation is enabled.

## Varying structure
Time series models often run on datasets of multiple time series with different lengths. To accomodate varying structure like this, Pyro requires models to separate all model inputs into tensors and non-tensors.

* Non-tensor inputs should be passed as **kwargs to the model and guide. These can determine model structure, so that a model is compiled for each value of the passed **kwargs.

* Tensor inputs should be passed as *args. These must not determine model structure. However len(args) may determine model structure (as is used e.g. in semisupervised models).

To illustrate this with a time series model, we will pass in a sequence of observations as a tensor arg and the sequence length as a non-tensor kwarg: