# Transformer Conversion

In [1]:
import sys
import torch
from botorch.models.transforms.input import AffineInputTransform

sys.path.append("../")
from lume_model.models import TorchModel, TorchModule

In [2]:
# load exemplary model
torch_model = TorchModel("../../tests/test_files/california_regression/torch_model.yml")
torch_module = TorchModule(model=torch_model)

Loaded PyTorch model from file: Sequential(
  (0): Linear(in_features=8, out_features=24, bias=True)
  (1): ReLU()
  (2): Linear(in_features=24, out_features=12, bias=True)
  (3): ReLU()
  (4): Linear(in_features=12, out_features=6, bias=True)
  (5): ReLU()
  (6): Linear(in_features=6, out_features=1, bias=True)
)


In [3]:
# conversion
def convert_torch_transformer(t: torch.nn.Linear) -> AffineInputTransform:
    """Creates an AffineInputTransform module which mirrors the behavior of the given torch.nn.Linear module.

    Args:
        t: The torch transformer to convert.

    Returns:
        AffineInputTransform module which mirrors the behavior of the given torch.nn.Linear module.
    """
    m = AffineInputTransform(
        d=t.bias.size(-1),
        coefficient=1 / t.weight.diagonal(),
        offset=-t.bias / t.weight.diagonal(),
    ).to(t.bias.dtype)
    m.offset.requires_grad = t.bias.requires_grad
    m.coefficient.requires_grad = t.weight.requires_grad
    if not t.training:
        m.eval()
    return m


def convert_botorch_transformer(t: AffineInputTransform) -> torch.nn.Linear:
    """Creates a torch.nn.Linear module which mirrors the behavior of the given AffineInputTransform module.

    Args:
        t: The botorch transformer to convert.

    Returns:
        torch.nn.Linear module which mirrors the behavior of the given AffineInputTransform module.
    """
    d = t.offset.size(-1)
    m = torch.nn.Linear(in_features=d, out_features=d).to(t.offset.dtype)
    m.bias = torch.nn.Parameter(-t.offset / t.coefficient)
    weight_matrix = torch.zeros((d, d))
    weight_matrix = weight_matrix.fill_diagonal_(1.0) / t.coefficient
    m.weight = torch.nn.Parameter(weight_matrix)
    m.bias.requires_grad = t.offset.requires_grad
    m.weight.requires_grad = t.coefficient.requires_grad
    if not t.training:
        m.eval()
    return m

In [4]:
# test on exemplary input
input_dict = torch_model.random_input(n_samples=1)
x = torch.tensor([input_dict[k] for k in torch_module.input_order]).unsqueeze(0)

torch_input_transformers = [
    convert_botorch_transformer(t) for t in torch_model.input_transformers
]
torch_output_transformers = [
    convert_botorch_transformer(t) for t in torch_model.output_transformers
]
new_torch_model = TorchModel(
    input_variables=torch_model.input_variables,
    output_variables=torch_model.output_variables,
    model=torch_model.model,
    input_transformers=torch_input_transformers,
    output_transformers=torch_output_transformers,
)
new_torch_module = TorchModule(model=new_torch_model)

print(torch.isclose(torch_module(x), new_torch_module(x)).item())

True
