In [11]:
import numpy as np
import torch
from torch import nn

from crc.eval import compute_multiview_r2
from crc.methods.shared import FCEncoder

In [9]:
def construct_invertible_mlp(
    n: int = 20,
    n_layers: int = 2,
    n_iter_cond_thresh: int = 10000,
    cond_thresh_ratio: float = 0.25,
    weight_matrix_init = "pcl",
    act_fct = "leaky_relu",
):
    """
    Create an (approximately) invertible mixing network based on an MLP.
    Based on the mixing code by Hyvarinen et al.

    Args:
        n: Dimensionality of the input and output data
        n_layers: Number of layers in the MLP.
        n_iter_cond_thresh: How many random matrices to use as a pool to find weights.
        cond_thresh_ratio: Relative threshold how much the invertibility
            (based on the condition number) can be violated in each layer.
        weight_matrix_init: How to initialize the weight matrices.
        act_fct: Activation function for hidden layers.
    """

    class SmoothLeakyReLU(nn.Module):
        def __init__(self, alpha=0.2):
            super().__init__()
            self.alpha = alpha

        def forward(self, x):
            return self.alpha * x + (1 - self.alpha) * torch.log(1 + torch.exp(x))

    def get_act_fct(act_fct):
        if act_fct == "relu":
            return torch.nn.ReLU, {}, 1
        if act_fct == "leaky_relu":
            return torch.nn.LeakyReLU, {"negative_slope": 0.2}, 1
        elif act_fct == "elu":
            return torch.nn.ELU, {"alpha": 1.0}, 1
        elif act_fct == "max_out":
            raise NotImplementedError
        elif act_fct == "smooth_leaky_relu":
            return SmoothLeakyReLU, {"alpha": 0.2}, 1
        elif act_fct == "softplus":
            return torch.nn.Softplus, {"beta": 1}, 1
        else:
            raise Exception(f"activation function {act_fct} not defined.")

    layers = []
    act_fct, act_kwargs, act_fac = get_act_fct(act_fct)

    # Subfuction to normalize mixing matrix
    def l2_normalize(Amat, axis=0):
        # axis: 0=column-normalization, 1=row-normalization
        l2norm = np.sqrt(np.sum(Amat * Amat, axis))
        Amat = Amat / l2norm
        return Amat

    condList = np.zeros([n_iter_cond_thresh])
    if weight_matrix_init == "pcl":
        for i in range(n_iter_cond_thresh):
            A = np.random.uniform(-1, 1, [n, n])
            A = l2_normalize(A, axis=0)
            condList[i] = np.linalg.cond(A)
        condList.sort()  # Ascending order
    condThresh = condList[int(n_iter_cond_thresh * cond_thresh_ratio)]
    # print("condition number threshold: {0:f}".format(condThresh))

    for i in range(n_layers):
        lin_layer = nn.Linear(n, n, bias=False)

        if weight_matrix_init == "pcl":
            condA = condThresh + 1
            while condA > condThresh:
                weight_matrix = np.random.uniform(-1, 1, (n, n))
                weight_matrix = l2_normalize(weight_matrix, axis=0)

                condA = np.linalg.cond(weight_matrix)
                # print("    L{0:d}: cond={1:f}".format(i, condA))
            # print(f"layer {i+1}/{n_layers},  condition number: {np.linalg.cond(weight_matrix)}")
            lin_layer.weight.data = torch.tensor(weight_matrix, dtype=torch.float32)

        elif weight_matrix_init == "rvs":
            weight_matrix = ortho_group.rvs(n)
            lin_layer.weight.data = torch.tensor(weight_matrix, dtype=torch.float32)
        elif weight_matrix_init == "expand":
            pass
        else:
            raise Exception(f"weight matrix {weight_matrix_init} not implemented")

        layers.append(lin_layer)

        if i < n_layers - 1:
            layers.append(act_fct(**act_kwargs))

    mixing_net = nn.Sequential(*layers)

    # fix parameters
    for p in mixing_net.parameters():
        p.requires_grad = False

    return mixing_net

In [2]:
rs = np.random.RandomState(42)

In [3]:
# Generate synthetic samples
z = rs.normal(size=(5000, 5))
content_indices = ((0, 1, 2), (2, 4), (0, 3))
subsets = ((0, 1), (0, 2), (0, 3))

In [12]:
# Sample encoders
# enc_view_0 = FCEncoder(in_dim=5, latent_dim=5, hidden_dims=[5, 5, 5])
# enc_view_1 = FCEncoder(in_dim=3, latent_dim=3, hidden_dims=[3, 3, 3])
# enc_view_2 = FCEncoder(in_dim=2, latent_dim=2, hidden_dims=[2, 2, 2])
# enc_view_3 = FCEncoder(in_dim=2, latent_dim=2, hidden_dims=[2, 2, 2])

enc_view_0 = construct_invertible_mlp(n=5, n_layers=3, n_iter_cond_thresh=25000, cond_thresh_ratio=0.001)
enc_view_1 = construct_invertible_mlp(n=3, n_layers=3, n_iter_cond_thresh=25000, cond_thresh_ratio=0.001)
enc_view_2 = construct_invertible_mlp(n=2, n_layers=3, n_iter_cond_thresh=25000, cond_thresh_ratio=0.001)
enc_view_3 = construct_invertible_mlp(n=2, n_layers=3, n_iter_cond_thresh=25000, cond_thresh_ratio=0.001)

In [13]:
# Apply known bijective mixing function
z_0 = enc_view_0(torch.as_tensor(z[:, [0, 1, 2, 3, 4]], dtype=torch.float32)).detach().numpy()
z_1 = enc_view_1(torch.as_tensor(z[:, [0, 1, 2]], dtype=torch.float32)).detach().numpy()
z_2 = enc_view_2(torch.as_tensor(z[:, [2, 4]], dtype=torch.float32)).detach().numpy()
z_3 = enc_view_3(torch.as_tensor(z[:, [0, 3]], dtype=torch.float32)).detach().numpy()

# z_hat = np.stack((z_0, z_1, z_2, z_3))
z_hat = [z_0, z_1, z_2, z_3]

In [6]:
z_hat.shape

(4, 5000, 5)

In [14]:
# Evaluate
r2_dict = compute_multiview_r2(z, z_hat, content_indices, subsets)

In [15]:
r2_dict['avg_r2_nonlin']

array([[0.98486829, 0.49067374, 0.98795771],
       [0.98537678, 0.48947362, 0.49063802],
       [0.98200491, 0.98644503, 0.48718567],
       [0.48453301, 0.47914021, 0.98828493],
       [0.48647658, 0.98535909, 0.48682604]])

In [16]:
r2_dict['avg_r2_lin']

array([[0.82736308, 0.39745582, 0.89062949],
       [0.77091651, 0.34900836, 0.36456917],
       [0.77459848, 0.76888531, 0.36426642],
       [0.37297095, 0.37443435, 0.75907766],
       [0.39454757, 0.76619026, 0.39784204]])

In [27]:
r2_dict['r2_nonlin'][..., 1]

array([[ 0.99314519,         nan,         nan],
       [ 0.98885679,         nan,         nan],
       [ 0.98827947,         nan,         nan],
       [-0.0036829 ,         nan,         nan],
       [-0.00570299,         nan,         nan]])