In [1]:
from abc import abstractmethod

import numpy as np
import torch
import torch.nn as nn

from bcnf.model.feature_network import FeatureNetwork


class ConditionalInvertibleLayer(nn.Module):
    log_det_J: float | torch.Tensor | None
    n_conditions: int

    @abstractmethod
    def forward(self, x: torch.Tensor, y: torch.Tensor, log_det_J: bool = False) -> torch.Tensor:
        pass

    @abstractmethod
    def inverse(self, z: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        pass


class ConditionalNestedNeuralNetwork(nn.Module):
    def __init__(self, input_size: int, output_size: int, hidden_size: int, n_conditions: int, dropout: float = 0.2) -> None:
        super(ConditionalNestedNeuralNetwork, self).__init__()

        self.n_conditions = n_conditions

        self.layers = nn.Sequential(
            nn.Linear(input_size + n_conditions, hidden_size),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, output_size * 2)
        )

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        if self.n_conditions > 0:
            # Concatenate the input with the condition
            x = torch.cat([x, y], dim=1)

        # Get the translation coefficients t and the scale coefficients s from the neural network
        t, s = self.layers(x).chunk(2, dim=1)

        # Return the coefficients
        return t, torch.tanh(s)


class ConditionalAffineCouplingLayer(ConditionalInvertibleLayer):
    def __init__(self, input_size: int, hidden_size: int, n_conditions: int, dropout: float = 0.2) -> None:
        super(ConditionalAffineCouplingLayer, self).__init__()

        self.n_conditions = n_conditions
        self.log_det_J = None

        # Create the nested neural network
        self.nn = ConditionalNestedNeuralNetwork(
            input_size=np.ceil(input_size / 2).astype(int),
            output_size=np.floor(input_size / 2).astype(int),
            hidden_size=hidden_size,
            n_conditions=n_conditions,
            dropout=dropout)

    def forward(self, x: torch.Tensor, y: torch.Tensor, log_det_J: bool = False) -> torch.Tensor:
        # Split the input into two halves
        x_a, x_b = x.chunk(2, dim=1)

        # Get the coefficients from the neural network
        t, log_s = self.nn.forward(x_a, y)

        # Apply the transformation
        z_a = x_a  # skip connection
        z_b = torch.exp(log_s) * x_b + t  # affine transformation

        # Calculate the log determinant of the Jacobian
        if log_det_J:
            self.log_det_J = log_s.sum(dim=1)

        # Return the output
        return torch.cat([z_a, z_b], dim=1)

    def inverse(self, z: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        # Split the input into two halves
        z_a, z_b = z.chunk(2, dim=1)

        # Get the coefficients from the neural network
        t, log_s = self.nn.forward(z_a, y)

        # Apply the inverse transformation
        x_a = z_a
        x_b = (z_b - t) * torch.exp(- log_s)

        # Return the output
        return torch.cat([x_a, x_b], dim=1)


class OrthonormalTransformation(ConditionalInvertibleLayer):
    def __init__(self, input_size: int) -> None:
        super(OrthonormalTransformation, self).__init__()

        self.log_det_J = 0

        # Create the random orthonormal matrix via QR decomposition
        self.orthonormal_matrix = nn.Parameter(torch.linalg.qr(torch.randn(input_size, input_size))[0], requires_grad=False)

    def forward(self, x: torch.Tensor, y: torch.Tensor, log_det_J: bool = False) -> torch.Tensor:
        # Apply the transformation
        return x @ self.orthonormal_matrix

    def inverse(self, z: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        # Apply the inverse transformation
        return z @ self.orthonormal_matrix.T


class CondRealNVP(ConditionalInvertibleLayer):
    def __init__(self, input_size: int, hidden_size: int, n_epochs: int, n_conditions: int, feature_network: FeatureNetwork | None, dropout: float = 0.2, device: str = "cpu"):
        super(CondRealNVP, self).__init__()

        if n_conditions == 0 or feature_network is None:
            self.h = nn.Identity()
        else:
            self.h = feature_network

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_epochs = n_epochs
        self.n_conditions = n_conditions
        self.device = device

        self.layers = self._build_network()

        self.log_det_J: float | None = None

    def _build_network(self) -> nn.ModuleList:
        # Create the network
        layers: list[ConditionalInvertibleLayer] = []
        for _ in range(self.n_epochs - 1):
            layers.append(ConditionalAffineCouplingLayer(self.input_size, self.hidden_size, self.n_conditions, dropout=0.2))
            layers.append(OrthonormalTransformation(self.input_size))

        # Add the final affine coupling layer
        layers.append(ConditionalAffineCouplingLayer(self.input_size, self.hidden_size, self.n_conditions, dropout=0.2))

        return nn.ModuleList(layers)

    def forward(self, x: torch.Tensor, y: torch.Tensor, log_det_J: bool = False) -> torch.Tensor:
        # Apply the feature network to y
        y = self.h(y)

        # Apply the network

        if log_det_J:
            self.log_det_J = 0

            for layer in self.layers:
                x = layer(x, y, log_det_J)
                self.log_det_J += layer.log_det_J

            return x

        for layer in self.layers:
            x = layer(x, y, log_det_J)

        return x

    def inverse(self, z: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        # Apply the feature network to y
        y = self.h(y)

        # Apply the network in reverse
        for layer in reversed(self.layers):
            z = layer.inverse(z, y)

        return z

    def sample(self, n_samples: int, y: torch.Tensor, sigma: float = 1, outer: bool = False, verbose: bool = False) -> torch.Tensor:
        """
        Sample from the model.

        Parameters
        ----------
        n_samples : int
            The number of samples to generate.
        y : torch.Tensor
            The conditions used for sampling.
            If 1st order tensor and len(y) == n_conditions, the same conditions are used for all samples.
            If 2nd order tensor, y.shape must be (n_samples, n_conditions), and each row is used as the conditions for each sample.
        sigma : float
            The standard deviation of the normal distribution to sample from.
        outer : bool
            If True, the conditions are broadcasted to match the shape of the samples.
            If False, the conditions are matched to the shape of the samples.
        verbose : bool
            If True, print debug information.

        Returns
        -------
        torch.Tensor
            The generated samples.
        """

        print(f'{y.shape=}')

        y = y.to(self.device)

        if y.ndim == 1:
            if verbose:
                print('Broadcasting')
            if len(y) != self.n_conditions:
                raise ValueError(f"y must have length {self.n_conditions}, but got len(y) = {len(y)}")

            # Generate n_samples for each condition in y
            z = sigma * torch.randn(n_samples, self.input_size).to(self.device)
            y = y.repeat(n_samples, 1)


            # Apply the inverse network
            return self.inverse(z, y).view(n_samples, self.input_size)
        elif y.ndim == 2:
            if outer:
                if verbose:
                    print('Outer')
                if y.shape[1] != self.n_conditions:
                    raise ValueError(f"y must have shape (n_samples_per_condition, {self.n_conditions}), but got y.shape = {y.shape}")

                n_samples_per_condition = y.shape[0]

                # Generate n_samples for each condition in y
                z = sigma * torch.randn(n_samples * n_samples_per_condition, self.input_size).to(self.device)
                y = y.repeat(n_samples, 1)

                # Apply the inverse network
                return self.inverse(z, y).view(n_samples, n_samples_per_condition, self.input_size)
            else:
                if verbose:
                    print('Matching')
                if y.shape[0] != n_samples or y.shape[1] != self.n_conditions:
                    raise ValueError(f"y must have shape (n_samples, {self.n_conditions}), but got y.shape = {y.shape}")
        
                # Use y_i as the condition for the i-th sample
                z = sigma * torch.randn(n_samples, self.input_size).to(self.device)

                # Apply the inverse network
                return self.inverse(z, y).view(n_samples, self.input_size)
        else:
            raise ValueError(f"y must be a 1st or 2nd order tensor, but got y.shape = {y.shape}")


In [2]:
cnf = CondRealNVP(
    input_size=2,
    hidden_size=256,
    n_epochs=8,
    n_conditions=7,
    feature_network=None,
    dropout=0.2,
    device="cpu"
)

In [3]:
broadcast_sample = cnf.sample(
    n_samples=13,
    y=torch.randn(7),
    verbose=True
)

print(f'{broadcast_sample.shape=}')

y.shape=torch.Size([7])
Broadcasting
broadcast_sample.shape=torch.Size([13, 2])


In [4]:
matching_sample = cnf.sample(
    n_samples=13,
    y=torch.randn(13, 7),
    verbose=True
)

print(f'{matching_sample.shape=}')

y.shape=torch.Size([13, 7])
Matching
matching_sample.shape=torch.Size([13, 2])


In [5]:
outer_sample = cnf.sample(
    n_samples=3,
    y=torch.randn(5, 7),
    outer=True,
    verbose=True
)

print(f'{outer_sample.shape=}')

y.shape=torch.Size([5, 7])
Outer
outer_sample.shape=torch.Size([3, 5, 2])
