## Results for the toy experiment with synthetic datasets and analytical solutions

To familiarize yourself with the two syntetic datasets, take a look at the ```VisualizeDatasets``` notebook.

In [1]:
import sys
sys.path.append("..")
import os
os.chdir("..")

In [2]:
import numpy as np
import matplotlib.pyplot as plt

In [4]:
from data.four_bars import FourBars, ColorBar 

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


A helper function to sample suitable distortions.

In [10]:
def sample_scale_distort(ndim=3, max_val=10):
    """ Sample a distortion matrix A that keeps the factors in the valid range [0, 10] """
    while True:
        ent_matrix = torch.randn(ndim,ndim)
        ent_matrix_norm = ent_matrix / ent_matrix.norm(dim=0, keepdim=True)
        A = torch.norm(torch.eig(ent_matrix_norm)[0], dim=1)
        if torch.min(A) > 0.3 and torch.max(A) < 3.0:
            print(A)
            break
    return ent_matrix_norm

A helper function to compute the gradients of the faithful encoder distorted by the matrix $M$ via the decoder. 
According to the transfer lemma stated in the paper, we know that

$$J_f(g(z))^\top J_g(z) = I$$

so we can compute $J_f$ from our knowledge of $J_g(z)$. $J_g$ is computed via torch, as the generator is differentiable in pytorch.

In [102]:
from common.attributions import generator_jacobian
class GTGenWrapper():
    def __init__(self, gt_gen):
        self.my_gt = gt_gen
    def decode(self, latent):
        #print(latent.dtype)
        return self.my_gt.sample_observations_from_factors(latent, ret_torch=True)
    
def compute_mde_gradients(facts_org, m, syn_dataset):
    """ Compute the gradients of the faithful encoder. 
        facts: latent variables values [num_samples, num_facts]
        As f = M f* they are M*grad(f*)
    """
    batch = 64
    grad_list = []
    for i in range(0, len(facts_org), batch):
        grad_matrix = generator_jacobian(GTGenWrapper(syn_dataset), facts_org[i:i+64]) # [B, Z, H, W, C]
        grad_list.append(grad_matrix)

    grad_matrix = torch.cat(grad_list, dim=0)
    grad_matrix_size = grad_matrix.shape
    #print(grad_matrix_size)
    ## Multipy with m.
    jg = grad_matrix.reshape(grad_matrix_size[0], grad_matrix_size[1], -1) # [B, Z, N]
    jf = jg/jg.pow(2).sum(dim=2, keepdim=True) # Denormalize to have jg*jf=I
    # Find min.
    min_index = torch.argmin(torch.min(jg.pow(2).sum(dim=2), dim=1)[0])
    #print(jg.pow(2).sum(dim=2)[min_index])
    #print(facts_org[min_index])
    grad_matrix_multiply = jf.transpose(0,1).reshape(grad_matrix_size[1], -1) #[Z, N]
    # Invert norming.
    #print(grad_matrix_multiply.shape)
    grad_matrix_multiply = m.matmul(grad_matrix_multiply).reshape([grad_matrix_size[1], grad_matrix_size[0]] + list(grad_matrix_size[2:])).transpose(0,1)
    return grad_matrix_multiply

After stating some helpers, we can now start implementing the analytical solutions

In [79]:
""" We implement the analytical solutions here in a batch-wise fashion. """

### HELPER FUNCTIONS
def build_pair_indices(n):
    list1 = []
    list2 = []
    for i in range(n):
        list1.append(i*torch.ones(i,dtype=torch.int64))
        list2.append(torch.arange(i, dtype=torch.int64))
    return torch.cat(list1), torch.cat(list2)

def find_non_sing_submatrix(f):
    """ Return a non singular, square submatrix of M."""
    #print(f.shape)
    # 1st step: Find non singular submatrix.
    ind_set = []
    clm_count = 0
    for i in range(0, len(f)):
        ind_set.append(clm_count)
        while torch.svd(f[:,ind_set], compute_uv=False)[1][-1] < 1e-5: # linear independence check. Continue if independent.
            clm_count += 1
            ind_set[-1] = clm_count
        #print("Appending column", clm_count, "to set.")
        clm_count += 1
        if len(ind_set) == f.size(0):
            break
    #print(ind_set)
    if len(ind_set) < f.size(0):
        return None
    else:
        return f[:, ind_set]
    
def find_solution_ima(orth_1, orth_2):
    """ The solution according to IMA theory. 
        Let orth_1 be Sigma(z_1) and orth_2 be Sigma(z_2) as stated in App. B 7.2. in our paper.
    """
    U = torch.inverse(torch.cholesky(orth_1)) 
    Gamma = U @ orth_2 @ U.transpose(-2, -1) 
    # Compute eigentvectors of Gamma, batchwise computation due to old version of torch.
    ds_list, q_list = [], []
    for i in range(len(orth_1)):
        Dslash, Q = torch.symeig(Gamma[i], eigenvectors=True)
        ds_list.append(Dslash), q_list.append(Q)
    Q = torch.stack(q_list)
    return Q.transpose(-2, -1) @ U

def find_solution_dma(jf):
    """ Find solution with DMA, using the jacobians. """
    sol_list = []
    for i in range(len(jf)):
        sol_list.append(torch.inverse(find_non_sing_submatrix(jf[i])))
    Q = torch.stack(sol_list)
    return Q

In [80]:
""" Compute the DCI scores via their entropy formulation: Normalize rows and then compute column-wise entropies."""
def dci_ent(prod, n_dims = 4):
    if len(prod.shape) == 2:
        prod = prod.unsqueeze(0)
    score_list = []
    for i in range(len(prod)):
        prodi = torch.abs(prod[i]) / torch.sum(torch.abs(prod[i]), dim=0)
        #print(prod)
        score = torch.sum(-torch.log(prodi+1e-8)*prodi, dim=0)
        score_list.append(1.0-(torch.mean(score)/np.log(n_dims)))
    return np.array(score_list)

In [108]:
## Use this flag to switch between datasets.
use_colorbar = False
n_intervals = 11
if use_colorbar:
    syn_dataset = ColorBar(n_intervals, nonlin_colors=True)
    num_factors = 3
else:
    syn_dataset = FourBars(n_intervals)
    num_factors = 4

In [109]:
import torch
n_runs = 5
n_samples = 20 # How many samples to draw

In [110]:
res_ima_list = []
res_dma_list = []
for run in range(n_runs):
    print("Starting run ", run)
    facts= torch.rand(n_samples, num_factors)*10
    A = sample_scale_distort(num_factors)
    grad_matrix = compute_mde_gradients(facts.float(), A, syn_dataset)
    grad_matrix = grad_matrix.reshape(len(grad_matrix), grad_matrix.size(1), -1).clone()
    orthogonality = torch.bmm(grad_matrix, grad_matrix.transpose(1,2))
    
    ## IMA (requires gradients at two points, test all combinations of samples against each other)
    res1, res2 = build_pair_indices(n_runs)
    res =  find_solution_ima(orthogonality[res1], orthogonality[res2])
    prod = res@A
    res_ima_list.append(dci_ent(prod, num_factors).mean())
    
    ## DMA (test only one point)
    res_dma = find_solution_dma(grad_matrix)
    prod = res_dma@A
    res_dma_list.append(dci_ent(prod, num_factors).mean())

res_dma_list = np.array(res_dma_list)
res_ima_list = np.array(res_ima_list)

Starting run  0
tensor([0.8417, 0.8740, 0.8740, 0.4231])
Starting run  1
tensor([0.9611, 1.2589, 0.5646, 0.5646])
Starting run  2
tensor([0.7611, 0.7877, 0.7877, 1.1722])
Starting run  3
tensor([1.1274, 0.4057, 0.4057, 1.0892])
Starting run  4
tensor([0.5171, 0.5171, 1.2037, 0.6642])


In [111]:
print("IMA-DCI:", res_ima_list.mean(), "+-", res_ima_list.std())
print("DMA-DCI:", res_dma_list.mean(), "+-", res_dma_list.std())

IMA-DCI: 0.14445782 +- 0.030627714
DMA-DCI: 0.9999982 +- 9.344822e-07


We see that, when using ```ColorBar```, we obtain the perfect score with the IMA method, whereas for ```FourBars```, the DMA method yields DCI=1.

Note: Scores for non-working methods are highly volatile depending on the matrices sampled, and therefore do not exactly match the numbers given in the paper.