In [1]:
import torch
from torch import exp, log, tanh, abs
import torch.nn as nn
import numpy as np

In [2]:
class FFNN(nn.Module):
    def __init__(self, in_size:int, layer_sizes:list[int], out_size:int) -> None:
        super().__init__()

        sizes = [in_size] + list(layer_sizes) + [out_size]
        layers = [nn.Linear(sizes[i], sizes[i+1]) for i in range(len(sizes) - 1)]

        # splice in ReLU layers
        for i in reversed(range(len(layers))):
            layers.insert(i, nn.ReLU())
        
        self.layers = nn.Sequential(*layers)


    def forward(self, X:torch.Tensor):
        return self.layers(X)

In [3]:
class Layer(nn.Module):
    def forward(self, X:torch.Tensor, return_likelihood:bool=True):
        pass

In [4]:
class BijectiveLayer(Layer):
    def __init__(self, size:int, hidden_sizes:list[int]) -> None:
        assert size > 1, "Layer size must be at least 2!"
        super().__init__()

        self.skip_size = size // 2
        self.other_size = size - self.skip_size # TODO: rename 'other' everywhere
        self.ffnn = FFNN(self.skip_size, hidden_sizes, self.other_size + 1) # returns t & s
    

    def forward(self, X:torch.Tensor, return_likelihood:bool=True):
        skip_connection  = X[:, :self.skip_size]
        other_connection = X[:, self.skip_size:]

        coeffs = self.ffnn(skip_connection)
        t = coeffs[:, :-1]
        pre_s = coeffs[:, -1]
        s_log = tanh(pre_s).unsqueeze(1)
        s = exp(s_log)

        new_connection = s * other_connection + t
        Z = torch.cat((skip_connection, new_connection), dim=1)

        if return_likelihood:
            return Z, abs(s_log)
        else:
            return Z
    

    def backward(self, Z:torch.Tensor):
        skip_connection  = Z[:, :self.skip_size]
        other_connection = Z[:, self.skip_size:]

        coeffs = self.ffnn(skip_connection)
        t = coeffs[:, :-1]
        pre_s = coeffs[:, -1]
        s = exp(tanh(pre_s)).unsqueeze(1)

        new_connection = (other_connection - t) / s
        X = torch.cat((skip_connection, new_connection), dim=1)
        return X

In [6]:
# testing
b = BijectiveLayer(23, (25, 50))

for _ in range(1000):
    a = torch.rand(300, 23)
    assert torch.allclose(a, b.backward(b.forward(a, return_likelihood=False)), atol=1e-5)