# Transformer Conversion

In [1]:
import sys
import torch
from botorch.models.transforms.input import AffineInputTransform

sys.path.append("../")
from lume_model.models import TorchModel, TorchModule

In [2]:
# load exemplary model
torch_model = TorchModel("../tests/test_files/california_regression/torch_model.yml")
torch_module = TorchModule(model=torch_model)

In [3]:
# conversion
def convert_torch_transformer(t: torch.nn.Linear) -> AffineInputTransform:
    """Creates an AffineInputTransform module which mirrors the behavior of the given torch.nn.Linear module.

    Args:
        t: The torch transformer to convert.

    Returns:
        AffineInputTransform module which mirrors the behavior of the given torch.nn.Linear module.
    """
    m = AffineInputTransform(
        d=t.bias.size(-1),
        coefficient=1 / t.weight.diagonal(),
        offset=-t.bias / t.weight.diagonal(),
    ).to(t.bias.dtype)
    m.offset.requires_grad = t.bias.requires_grad
    m.coefficient.requires_grad = t.weight.requires_grad
    if not t.training:
        m.eval()
    return m


def convert_botorch_transformer(t: AffineInputTransform) -> torch.nn.Linear:
    """Creates a torch.nn.Linear module which mirrors the behavior of the given AffineInputTransform module.

    Args:
        t: The botorch transformer to convert.

    Returns:
        torch.nn.Linear module which mirrors the behavior of the given AffineInputTransform module.
    """
    d = t.offset.size(-1)
    m = torch.nn.Linear(in_features=d, out_features=d).to(t.offset.dtype)
    m.bias = torch.nn.Parameter(-t.offset / t.coefficient)
    weight_matrix = torch.zeros((d, d))
    weight_matrix = weight_matrix.fill_diagonal_(1.0) / t.coefficient
    m.weight = torch.nn.Parameter(weight_matrix)
    m.bias.requires_grad = t.offset.requires_grad
    m.weight.requires_grad = t.coefficient.requires_grad
    if not t.training:
        m.eval()
    return m

In [4]:
# test on exemplary input
input_dict = torch_model.random_input(n_samples=1)
x = torch.tensor([input_dict[k] for k in torch_module.input_order]).unsqueeze(0)
botorch_transformer = torch_model.input_transformers[0].to(x.dtype)
torch_transformer = convert_botorch_transformer(botorch_transformer)
converted_botorch_transformer = convert_torch_transformer(torch_transformer)

print(torch.all(torch.isclose(botorch_transformer(x), torch_transformer(x), atol=1e-6)).item())
print(torch.all(torch.isclose(torch_transformer(x), converted_botorch_transformer(x), atol=1e-6)).item())

True
True
