# Custom C++ extension


References: 
- https://pytorch.org/tutorials/advanced/cpp_extension.html
- https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html



In [None]:
from torch.utils.cpp_extension import load

lltm_cpp = load(name="lltm_cpp", sources=["lltm.cpp"], verbose=True)
help(lltm_cpp)

In [None]:
import torch
from torch import Tensor, nn, jit

In [None]:
lltm_cpp.forward(
    torch.randn(3, 3),
    torch.randn(3, 6),
    torch.randn(3, 3),
    torch.randn(3, 3),
    torch.randn(3, 3),
)

In [None]:
class MyModule(nn.Module):
    def forward(
        self, a0: Tensor, a1: Tensor, a2: Tensor, a3: Tensor, a4: Tensor
    ) -> Tensor:
        return lltm_cpp.forward(a0, a1, a2, a3, a4)

In [None]:
a0 = torch.randn(3, 3)
a1 = torch.randn(3, 6)
a2 = torch.randn(3, 3)
a3 = torch.randn(3, 3)
a4 = torch.randn(3, 3)

model = MyModule()
model(a0, a1, a2, a3, a4)

In [None]:
jit.script(model)

In [None]:
torch.randn(2, 3)

In [None]:
dir(lltm_cpp)

In [None]:
import math
import torch
import time

import os

# C++ ops compiled into lltm_ops.so
torch.ops.load_library("lltm_ops.so")


class LLTMFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weights, bias, old_h, old_cell):
        outputs = torch.ops.lltm_ops.lltm_forward(input, weights, bias, old_h, old_cell)
        new_h, new_cell = outputs[:2]
        variables = outputs[1:] + [weights]
        ctx.save_for_backward(*variables)

        return new_h, new_cell

    @staticmethod
    def backward(ctx, grad_h, grad_cell):
        outputs = torch.ops.lltm_ops.lltm_backward(
            grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_variables
        )
        d_old_h, d_input, d_weights, d_bias, d_old_cell = outputs
        return d_input, d_weights, d_bias, d_old_h, d_old_cell


class LLTM(torch.nn.Module):
    def __init__(self, input_features, state_size):
        super(LLTM, self).__init__()
        self.input_features = input_features
        self.state_size = state_size
        self.weights = torch.nn.Parameter(
            torch.empty(3 * state_size, input_features + state_size)
        )
        self.bias = torch.nn.Parameter(torch.empty(3 * state_size))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.state_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, +stdv)

    def forward(self, input, state):
        return LLTMFunction.apply(input, self.weights, self.bias, *state)