# Integrating Simulations into Machine Learning

## About

A question we increasingly ask ourselves with modern algorithms is how we can integrate large simulations with machine learning components. Especially when we consider to do this in an intrusive fashion, then we are faced with a number of tough choices, some of which are often counter-intuitive on first glance. This problem is mostly faced when working with a simulation code from a classical HPC language such as C, C++, or Fortran and then seeking to integrate this with a machine learning system in Python such as PyTorch, TensorFlow, or JAX. While some decide to rewrite their simulations in machine learning DSLs, or programming languages with broader automatic differentiation support such as Julia, we focus on the case of keeping the HPC simulation intact, and instead seeking to highlight the potential interfaces between the two. Taking a look at literature we see the downside of potentially overly relying on machine learning DSLs in particular

![](https://i.imgur.com/K2V1JbC.png)

(Source: [Dr. JIT: A Just-In-Time Compiler for Differentiable Rendering](https://arxiv.org/abs/2202.01284))

> Our performance is in large parts governed by our ability to map our operations on the optimized set of computation primitives offered by the machine learning DSL.

## Outline

* [1. Integrating C into Machine Learning](#c-into-ml)
  * [1.1 CTypes](#ctypes-ml-c)
* [2. Integrating C++ into Machine Learning](#cpp-into-ml)
  * [2.1 PyTorch C++ Extension](#pytorch-cpp-extension)
    * [2.1.1 PyBind11](#pybind11-cpp)
* [3. Integrating Fortran into Machine Learning](#fortran-into-ml)

## 1. C Integrating C into Machine Learning <a name="c-into-ml"></a>

C offers a variety of pathways to approach this problem, the most prominent of which are [CTypes](https://docs.python.org/3/library/ctypes.html), and [Cython](https://cython.org). For the integration of differentiated code into machine learning frameworks we take the example of PyTorch, but exposing the differentiated code as a library is generally applicable beyond only PyTorch and extends to the TensorFlow/JAX ecosystem.

![](https://i.imgur.com/pokz4ge.png)

> All applicable approaches for C are the same for Tapenade, as well as Enzyme.

### 1.1 CTypes <a name="ctypes-ml-c"></a>

CTypes is a foreign function library for Python, which provides C compatible data types, and hence allows for the calling of functions in DLLs or most importantly shared libraries. By wrapping these libraries we can then call them from our pure Python syntax and combine the two languages with as little friction as possible. Beginning by defining ourselves a test library in C

```c
// vjp_enzyme.c
extern double __enzyme_fwddiff(void*, double[100], double[100], double[100], double[100]);
void f(double x[100], double out[100]) {
    int prev = 0;
    for(int i = 0; i < 100; i++) {
        out[i] = x[i] - prev/x[i];
        prev = x[i];
    }
}
void jvpf(double x[100], double v[100], double out[100], double dout[100]) {
    __enzyme_fwddiff((void*)f, x, v, out, dout);
}
```

we then have to compile this vjp library into a library object, which we can then expose to Python through CTypes. Presuming an existing Enzyme installation, the next steps then take the following form:

```bash
clang vjp_enzyme.c -S -emit-llvm -o input.ll -O2 -fno-vectorize -fno-slp-vectorize -fno-unroll-loops
opt input.ll -load=/path/to/Enzye/LLVMEnzyme-<LLVM version number>.so -enzyme -o output.ll -S -enable-new-pm=0
opt output.ll -O2 -o output_opt.ll -S
clang output_opt.ll -lib -o libvjp_enzyme.a
```

Which we can then call from the Python level with

```python
import ctypes
import numpy as np

lib = ctypes.CDLL('libvjp_enzyme.a')

# Initializing the values
x = np.arange(1, 101, dtype='float64') ** 2
y = np.ones(100)

# Setting shadow structures manually
out = np.zeros(100)
dout = np.zeros(100)

for a in [x, v, out, dout]:
  args.append(a.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))
lib.jvpf(*args)

# "Inspect" the gradient
print(dout)
```

There are multiple avenues CTypes hand-off to PyTorch can be made smooth, and almost seamless. Some of these avenues will be shown in the upcoming example _PINN with PyTorch and Tapenade_.

> We will go more in-depth on the use of CTypes in conjunction with PyTorch in the next hands-on example with Tapenade!

> Seeing the above CTypes examples, we will not spell out the Cython syntax at this point, but point to the [Cython documentation](https://cython.readthedocs.io/en/latest/src/tutorial/cython_tutorial.html) for the exact details to replicate the above with Cython.

## 2. Integrating C++ into Machine Learning <a name="cpp-into-ml"></a>

For the integration of differentiated code into machine learning frameworks we take the example of PyTorch, but the integration path of exposing the differentiated code as a library is generally applicable beyond only PyTorch and extends to the TensorFlow/JAX ecosystem.

![](https://i.imgur.com/GncIVOw.png)

While there exist two great binding packages for C++ in 

* [PyBind11](https://pybind11.readthedocs.io/en/stable/)
* [Nanobind](https://nanobind.readthedocs.io/en/latest/)

we first have to define the derivative functions. With the same options from C persisting in CTypes, and Cython, we will not focus further on these options in this section, but instead more deeply look at [PyTorch's C++ Extension](https://pytorch.org/tutorials/advanced/cpp_extension.html) which allows Enzyme to operate within its Just-in-time (JIT) compiler, and compute the gradients as part of the compilation process of individual functions.

### 2.1 PyTorch's C++ Extension <a name="pytorch-cpp-extension"></a>

Based on [PyTorch's C++ example](https://pytorch.org/tutorials/advanced/cpp_extension.html) we will now expand PyTorch's `Function` and `Module` definitions to suit our purposes. Key here is that we are able to include Enzyme, and hence simply attach the gradient information to the existing infrastructure for external functions in PyTorch. Taking the long-long term memory cell example now we would then be looking at the following PyTorch code

```python
class LLTM(torch.nn.Module):
    def __init__(self, input_features, state_size):
        super(LLTM, self).__init__()
        self.input_features = input_features
        self.state_size = state_size
        # 3 * state_size for input gate, output gate and candidate cell gate.
        # input_features + state_size because we will multiply with [input, h].
        self.weights = torch.nn.Parameter(
            torch.empty(3 * state_size, input_features + state_size))
        self.bias = torch.nn.Parameter(torch.empty(3 * state_size))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.state_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, +stdv)

    def forward(self, input, state):
        old_h, old_cell = state
        X = torch.cat([old_h, input], dim=1)

        # Compute the input, output and candidate cell gates with one MM.
        gate_weights = F.linear(X, self.weights, self.bias)
        # Split the combined gate weight matrix into its components.
        gates = gate_weights.chunk(3, dim=1)

        input_gate = torch.sigmoid(gates[0])
        output_gate = torch.sigmoid(gates[1])
        # Here we use an ELU instead of the usual tanh.
        candidate_cell = F.elu(gates[2])

        # Compute the new cell state.
        new_cell = old_cell + candidate_cell * input_gate
        # Compute the new hidden state and output.
        new_h = torch.tanh(new_cell) * output_gate

        return new_h, new_cell
```

Such normal function uses all the optimized kernels PyTorch has implemented in `ATen`, but at the sime time PyTorch enables us to **rewrite parts of our simulations in C++**. We can exploit this interface for our purpose of including differentiated simulations in the machine learning model.


###### Ahead-of-Time Compilation

The ahead-of-time compilation is then performed through `setuptools.Extension`, which the below code tells to use the `C++` backend of extension.

```python
rom setuptools import setup, Extension
from torch.utils import cpp_extension

setup(name='lltm_cpp',
      ext_modules=[cpp_extension.CppExtension('lltm_cpp', ['lltm.cpp'])],
      cmdclass={'build_ext': cpp_extension.BuildExtension})
```

###### Just-in-Time Compilation

For the JIT-compilation we need to have the following lines in setuptools:

```python
from torch.utils.cpp_extension import load

lltm_cpp = load(name="lltm_cpp", sources=["lltm.cpp"])
```

this will lead to the following actions in the backend:

* Create a temporary directory to store build artifacts in
* Store a `Ninja` build file in the temporary directory
* Compile the source files into a shared library
* Import the shared library as a Python module

Integrating Enzyme into such a pipeline then requires takes the following shape. On the highest level we have a calling function:

```python
import lltm

a = torch.from_numpy(np.array([[1,2,3,4.]], dtype=np.float32))
a.requires_grad_(True)
b = lltm.Enzyme("test.cpp", "f").apply(a).sum()
b.backward()
```

which under the hood call the same Enzyme automatic differentiation primitive and integrates Enzyme into the JIT pipeline

```cpp
...
std::function<void(void*, void*, size_t, void*)> diffecompile(std::string filename, std::string function) {
    int res;

    char buffer [L_tmpnam];
    tmpnam (buffer);
    char data[1024];
ENZYME
    sprintf(data, "/usr/bin/clang++-12 -O3 %s -DTF_ENZYME=1 -fno-exceptions -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -Xclang -new-struct-path-tbaa -S -emit-llvm -o %s.ll", filename.c_str(), buffer);
    printf("running compile - %s\n", data);
    res = system(data);
    printf("ran compile - %s\n", data);
    assert(res == 0);

    sprintf(data, "/usr/bin/opt-12 %s.ll -load=%s -S -enzyme -mem2reg -instcombine -simplifycfg -adce -loop-deletion -simplifycfg -o %s.ll", buffer, "/content/Enzyme-0.0.49/build/Enzyme/LLVMEnzyme-12.so", buffer);
    printf("running compile - %s\n", data);
    res = system(data);
    printf("ran compile - %s\n", data);
    assert(res == 0);


    printf("making buffer 2\n");

    char buffer2 [L_tmpnam];
    printf("making tm buffer 2\n");
    tmpnam (buffer2);
    printf("made buffer 2\n");

    sprintf(data, "/usr/bin/clang++-12 -fPIC -shared %s.ll -o %s.so", buffer, buffer2);
    printf("running library - %s\n", data);
    res = system(data);
    printf("ran library - %s\n", data);
    assert(res == 0);

    char buffer3[L_tmpnam];
    sprintf(buffer3, "%s.so", buffer2);

    printf("running dlopen\n");
    void* lib = dlopen(buffer3, RTLD_LAZY);
    assert(lib);
    std::string tofind = "diffe" + function;
    printf("running dlsym %s\n", tofind.c_str());
    void* sym = dlsym(lib, tofind.c_str());
    assert(sym);
    auto diffef = (void(*)(void*, void*, size_t, void*))sym;
    return diffef;
}
```

#### 2.1.1 PyBind11 <a name="pybind11-cpp"></a>

With PyBind11 and its successor nanobind largely working the same way, we will curtail ourselves here to only focus on the example of PyBind11 which we are applying in conjunction with PyTorch's C++ extension.

At the end of our respective C++ file, we then use PyBind11 to expose those primitives to the Python language level, with which they are then available inside of the machine learning framework.

```cpp
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &lltm_forward, "LLTM forward");
  m.def("backward", &lltm_backward, "LLTM backward");
}
```

To recall, what we are providing to PyTorch here are:
* Forward function evaluation (no gradients)
* Backward function evaluation, i.e. reverse-mode differentiated function

> If you want to use PyTorch's newer features such as `vmap` coming out of `torch.fx`, or more advanced automatic differentiation features such as `forward-diff`, then you need to provide more definitions to PyTorch.

> If you work in JAX then you have to provide more transformations, as you also have to provide the behaviour under the vmap transform! See [here](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html) for more information.

## 3. Integrating Fortran into Machine Learning <a name="fortran-into-ml"></a>

For the integration of differentiated code into machine learning frameworks we take the example of PyTorch, but exposing the differentiated code as a library is generally applicable beyond only PyTorch and extends to the TensorFlow/JAX ecosystem.

![](https://i.imgur.com/LNpFhQN.png)

Fortran, as Fortran has to do for many libraries such as IO, has to rely on C-APIs to interface to machine learning frameworks, as such we refer to the CTypes section above, and in the _PINN with PyTorch and Tapenade_ example.