# End-to-End TorchScript Pipeline

## About

In this notebook we will have a hands-on look at the end-to-end TorchScript pipeline where we first refresh some of the idionsyncracies of TorchScript before diving into the customary ResNet-50 example, inspecting how it the pieces connect, and then writing our own model in PyTorch and compiling it with TorchScript.

## Outline

* 1. [Refresher on TorchScript](#torchscript)
* 2. [ResNet-50 Example](#resnet-50)
* 3. [Writing our own PyTorch Model](#own-model)

To get set up we need to install PyTorch and make sure, that we have the right version available to us (1.13.1+cu116)

In [1]:
!pip install torch >=1.2.0
%matplotlib inline

In [None]:
import torch
print(torch.__version__)

## 1. Refresher on TorchScript <a name="torchscript"></a>

To begin with we first a few simple kernels to inspect the "rough edges" of TorchScript which we have to navigate in saving, and exporting our machine learning models.

In [4]:
x, h = torch.rand(3, 4), torch.rand(3, 4)


In [None]:
class MyDecisionGate(torch.nn.Module):
  def forward(self, x):
    if x.sum() > 0:
      return x
    else:
      return -x

class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h
      
my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))

With the model traced, we now have access to two main representations of our model. The graph representation, and the representation in code with which we can inspect whether PyTorch actually traced what we thought it traced.

In [None]:
print(traced_cell.graph)

In [None]:
print(traced_cell.code)

No control flow so far. For control flow we need to utilize the **script compiler** to run a direct analysis of the Python source code, and transform it into TorchScript.

> If we do not use the script compiler, then TorchScript will only trace the code execution path, but not the entirety of the code with the included control flow.

In [None]:
scripted_gate = torch.jit.script(MyDecisionGate())

my_cell = MyCell(scripted_gate)
traced_cell = torch.jit.script(my_cell)
print(traced_cell.code)

In [None]:
# New inputs
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell(x, h)

The traced model can then be saved with the `save` attribute in PyTorch's own traced format `.pt`.

In [42]:
traced_cell.save('Stored_simple_cell.pt')

## 2. ResNet-50 Example <a name="resnet-50"></a>

To now build up to the ResNet-50 example we seek to utilize the torchvision building blocks.

In [12]:
import torchvision

With which we can use the predefined ResNet model, and then trace it through the JIT-compiler.

In [14]:
# An instance of our model
model = torchvision.models.resnet50()

# Providing an example input to our model
example_input = torch.rand(1, 3, 224, 224)

# Tracing the machine learning model
traced_script_module = torch.jit.trace(model, example_input)

Traced model can be evaluated just as a regular PyTorch model

In [None]:
output = traced_script_module(torch.ones(1, 3, 224, 224))
output[0, :5]

In [17]:
traced_script_module.save("traced_resnet_model.pt")

Seeking to connect this traced model to the C++ layer, we then have to initialize a model loader on the C++ level. For this we define ourselves the following model loader in C++ which includes the libtorch header file

```cpp
#include <torch/script.h> // One-stop header.

#include <iostream>
#include <memory>

int main(int argc, const char* argv[]) {
  if (argc != 2) {
    std::cerr << "usage: example-app <path-to-exported-script-module>\n";
    return -1;
  }


  torch::jit::script::Module module;
  try {
    // Deserialize the ScriptModule from a file using torch::jit::load().
    module = torch::jit::load(argv[1]);
  }
  catch (const c10::Error& e) {
    std::cerr << "error loading the model\n";
    return -1;
  }

  std::cout << "ok\n";
}
```

We then use this simple `CMakeLists.txt` file to build our model loader, and connect the individual components

```cmake
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(resnet_test)

find_package(Torch REQUIRED)

add_executable(model_loader model_loader.cpp)
target_link_libraries(model_loader "${TORCH_LIBRARIES}")
set_property(TARGET model_loader PROPERTY CXX_STANDARD 14)
```

The source code for the model loader, as well as the `CMakeLists.txt` are then available from the GitHub repository below

In [None]:
!git clone https://github.com/ludgerpaehler/TorchScriptTutorial.git

After which we check the version of the compiler in Google Colab, download the correspondig libtorch, and unzip it

In [None]:
!nvcc -V

In [None]:
!wget https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-latest.zip

In [None]:
!unzip libtorch-shared-with-deps-latest.zip

Following a CMake-typical build workflow, we then navigate to the respective folder, and then set up CMake to build 

In [None]:
%cd TorchScriptTutorial/

In [49]:
!mkdir build

In [None]:
%cd build

In [None]:
!cmake -DCMAKE_PREFIX_PATH=/content/libtorch ..

In [None]:
!cmake --build . --config Release

Going back to the root folder, where we previously stored the traced model

In [None]:
%cd ../..

Inspecting that we are in the right directory, and the model have been saved correctly

In [55]:
!ls -l

total 4548536
-rw-r--r-- 1 root root        305 Mar  2 16:36 '=1.2.0'
drwxr-xr-x 2 root root       4096 Mar  2 17:08  build
drwxr-xr-x 6 root root       4096 Mar  2 09:20  libtorch
-rw-r--r-- 1 root root 2305344701 Mar  2 10:29  libtorch-shared-with-deps-latest.zip
-rw-r--r-- 1 root root 2305344701 Mar  2 10:29  libtorch-shared-with-deps-latest.zip.1
drwxr-xr-x 1 root root       4096 Feb 28 14:45  sample_data
-rw-r--r-- 1 root root       4679 Mar  2 17:20  Stored_simple_cell.pt
-rw-r--r-- 1 root root       4679 Mar  2 16:44  Stored_simple_cell.zip
drwxr-xr-x 4 root root       4096 Mar  2 17:32  TorchScriptTutorial
-rw-r--r-- 1 root root   46959061 Mar  2 16:52  traced_resnet_model.pt


We can now execute the model loader to make sure that our model has been traced correctly, and can be loaded.

In [None]:
!./TorchScriptTutorial/build/model_loader traced_resnet_model.pt

With which we can conclude that we have traced our PyTorch model correctly, and can now use the traced model in any number of further backends such as [TVM](https://tvm.apache.org), and [ONNX](https://onnx.ai). [IREE](https://openxla.github.io/iree/#importing-models-from-ml-frameworks) requires our model to be legalized to MLIR's `linalg` dialect which we are unable to test in the same notebook.

## 3. Writing our own PyTorch Model <a name="own-model"></a>

To now define our own PyTorch model we only have to write down the model, save it, and test its correctness with the compiled model loader

In [None]:
class OurMLModel(torch.nn.Module):

  def __init__(self):
    super(OurMLModel, self).__init__()
    ...

  def forward(self):
    ...

We then have to trace the model through the JIT

In [None]:
own_ml_model = OurMLModel()
traced_own_model = torch.jit.trace(own_ml_model, example_input)

And then save the model

In [None]:
traced_own_model.save('traced_own_model.pt')

After which we can test it with the model loader

In [None]:
!./TorchScriptTutorial/build/model_loader traced_own_model.pt