# Session 9: Running Deep Generative AI models in Real-Time 

Agenda:
- Introduction: PyTorch and TorchScript
- 8-bit linear quantization
- ONNX and Graph Optimizations

## Introduction: PyTorch and TorchScript

<center><img src="./assets/torchscript.png" width="50%" /></center>


**Example:**

The `@torch.jit.script` decorator tells PyTorch to compile the function to
TorchScript, capturing its control flow and making it runnable in the C++
runtime.

```python
import torch
from typing import List

@torch.jit.script
def weighted_sum(tensors: List[torch.Tensor], weights: List[float]) -> torch.Tensor:
    """
    Compute a weighted sum over a list of tensors.
    
    Args:
        tensors (List[torch.Tensor]):   N tensors of identical shape.
        weights (List[float]):          N scalar weights (same length as `tensors`).

    Returns:
        torch.Tensor:   Σ_i  weights[i] * tensors[i]
    """

    # Basic runtime check (TorchScript preserves the assert)
    assert len(tensors) == len(weights), "Tensors and weights must be the same length."

    # The loop and control flow are traced by TorchScript
    out: torch.Tensor = torch.zeros_like(tensors[0])
    for t, w in zip(tensors, weights):
        out = out + t * w

    return out
```

<div class="alert alert-info">

**Note:** TorchScript is _statically typed_, meaning that all values should have a
_monomorphic type_. These types should be specified using the `typing` module or
any other module that allows static type checking (e.g. MyPy).

</div>

The function can then be compiled and saved to a `.pt` file with:

```python
torch.jit.save(weighted_sum, "weighted_sum.pt")
```

Finally, the compiled TorchScript module can be run in C++ using the `libtorch`
API.

```cpp
#include <torch/script.h>

int main() {
    torch::jit::script::Module module = torch::jit::load("weighted_sum.pt");

    // Create weights
    at::Tensor a = torch::tensor({1.0, 2.0, 3.0});
    at::Tensor b = torch::tensor({10.0, 20.0, 30.0});
    std::vector<at::Tensor> tensors = {a, b};

    // Create weights
    std::vector<double> weights = {0.3, 0.7};

    // Pack arguments into a list of IValue
    std::vector<torch::IValue> inputs;
    inputs.push_back(tensors);
    inputs.push_back(weights);

    // Run the function
    at::Tensor output = module.forward(inputs).toTensor();

    // Print the output
    std::cout << "Result: " << output << std::endl;

    return 0;
}
```

The execution time of the C++ program should be significantly faster than the
Python one.

## 8-bit Linear Quantization
_Inspired by [https://huggingface.co/docs/optimum/en/concept_guides/quantization](https://huggingface.co/docs/optimum/en/concept_guides/quantization)._

![](./assets/float32_int8.png)

We want to lower the bit width of our weights and activations from 32-bit
floating point numbers (`float32`) to 8-bit integers (`int8`). This effectively
decreases the model size and allows faster inference. This is useful for
real-time use, or deployment on edge devices. Some libraries have been
introduced to perform this kind of quantization, such as
[BitsAndByte](https://github.com/bitsandbytes-foundation/bitsandbytes) or 
[🤗 optimum](https://github.com/huggingface/optimum).

By performing 8-bit quantization, we are representing different `float32` values
as `int8` values (which are comprised between 0-256). To do so, we create a
linear projection from the `float32` space to the `int8` space through the
formula:

$$
x = S \cdot (x_q - Z)
$$

where $x$ is the `float32` value, $x_q$ is the quantized `int8` value, and $S$
and $Z$ are the quantization parameters. Namely:
- $S$ is the scale (a positive `float32`)
- $Z$ is the zero-point, which is the `int8` value corresponding to 0 in the
`float32` space.

To compute these quantization parameters, we need to know the range of `float32`
values that we can deal with. While this is known when quantizing weights, we
do not have access to the range of values that our model can encounter when
quantizing activation values. To estimate a range, we have different strategies:

- (Post-Training) **Dyanmic Quantization**: The range for each activation is
calculated at _runtime_. This increases the cost of inference, but generally
achieves a higher accuracy.

- (Post-Training) **Static Quantization**: The range for each activation is
determined at _quantization time_ using a set of inputs called _calibration data_.

- **Quantization-Aware Training**: The range for each activation is calculated
at _training time_, by estimating the error introduced by quantization during
training.

## ONNX and Graph Optimizations

ONNX (Open Neural Network Exchange) is an open standard for representing Machine
Learning models. It defines a common **set of operators** (_opset_) and a 
common **file format** (`.onnx`) to facilitate _interoperability_ between
different deep learning frameworks.

<center><img src="assets/onnx.png" /></center>

Image from [https://microsoft.github.io/ai-at-edge/docs/onnx/](https://microsoft.github.io/ai-at-edge/docs/onnx/).

### Visualization of the ONNX model for `music-medium-800k`

We use [netron.app](https:?/netron.app) to create the visualization of the
model.

![](./assets/netron1.png)
![](./assets/netron2.png)

### Graph Optimizations

**[ONNX Runtime](https://onnxruntime.ai)** is a library that provides tools for the
optimization and acceleration of inferencing of ONNX models. It supports a wide
range of hardware backends (CUDA, DirectML, CoreML, etc.) and offers APIs for
several programming languages.

On top of allowing **8-bit linear quantization**, ONNX Runtime allows different
kinds of ONNX model [graph optimizations](https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html):

- _Constant Folding_ removes constant nodes from the graph by statically
computing their output.

- _Redundant Node Eliminations_ removes redundant nodes such as _identity_,
_slice_, _unsqueeze_, or _dropout_.

- _Node Fusions_ folds multiples nodes into a single node to improve efficiency.
For example, a _matrix multiplication_ followed by an _addition_ can be merged
into a more efficient **Gemm** (General Matrix Multiplication) node.

**Example:**

<div style="display: flex; align-items: center; justify-content: center; margin: 0 auto; width: fit-content;">
  <img src="assets/graph_fusion_matmul_add.png" style="padding: 2em" width="50%">
  <span style="width:100%">-></span>
  <img src="assets/graph_fusion_gemm.png" style="padding: 2em" width="50%">
</div>

### Note: Other Machine Learning formats and standards

Lots: _GGML_, _CoreML_, _Apache TVM_, _LLVM MLIR_, ...