In [1]:
import argparse
import os
import sys
import json
import shutil
import pathlib
import random
from itertools import islice
import time
from copy import deepcopy
import math
from pprint import pprint
from qj_global import qj
import logging

try:
    from comet_ml import Experiment
    COMET_AVAIL = True
except:
    COMET_AVAIL = False

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint, DiskSaver, Timer, TerminateOnNan, TimeLimit
from ignite.metrics import RunningAverage
from ignite.contrib.metrics import GpuInfo
from ignite.utils import setup_logger

# adding the folder containing the folder `disentanglement_via_mechanism_sparsity` to sys.path
# sys.path.insert(0, str(pathlib.Path(__file__).parent.parent))
sys.path.insert(0, str('/Users/vitoria/Documents/GitHub/'))
from disentanglement_via_mechanism_sparsity.universal_logger.logger import UniversalLogger
from disentanglement_via_mechanism_sparsity.metrics import MyMetrics, linear_regression_metric, mean_corr_coef, edge_errors
from disentanglement_via_mechanism_sparsity.plot import plot_01_matrix
from disentanglement_via_mechanism_sparsity.data.synthetic import get_ToyManifoldDatasets
from disentanglement_via_mechanism_sparsity.model.ilcm_vae import ILCM_VAE
from disentanglement_via_mechanism_sparsity.model.latent_models_vae import FCGaussianLatentModel


In [2]:
manifold = "nn"
transition_model = "action_sparsity_trivial"
# transition_model = "action_sparsity_non_trivial"

In [3]:
# datasets = get_ToyManifoldDatasets(manifold, transition_model, split=(0.8, 0.1, 0.1),
#                                        z_dim=10, x_dim=20, num_samples=int(1e6),
#                                        no_norm=True, discrete=True)

In [3]:
datasets = get_ToyManifoldDatasets(manifold, transition_model, split=(0.8, 0.1, 0.1),
                                       z_dim=2, x_dim=6, num_samples=100,
                                       no_norm=True, discrete=True)

NotImplementedError: The transition model action_sparsity_trivial is not implemented.

In [4]:
train_loader = data.DataLoader(datasets[4], batch_size=500, shuffle=False, drop_last=True)

In [5]:
train_loader

<torch.utils.data.dataloader.DataLoader at 0x7f8d201ea790>

In [6]:
train_batch = next(iter(train_loader))

In [7]:
train_batch[7].shape

torch.Size([500, 1, 5, 20])

In [8]:
train_batch[5].shape

torch.Size([500, 1, 5, 10])

In [9]:
len(datasets[4].__getitem__(5))

8

In [10]:
datasets[4].disc_x.shape

torch.Size([800000, 5, 20])

In [11]:
datasets[4].p_x.shape

torch.Size([800000, 5, 20])

In [12]:
datasets[4].disc_z.shape

torch.Size([800000, 5, 10])

In [13]:
datasets[4].p_z.shape

torch.Size([800000, 5, 10])

# Generate actions dataset

In [5]:
d_z = 40

In [6]:
# a_init = np.eye(d_z)

In [7]:
# a_init

array([[1., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 1.]])

In [8]:
# a_init.shape

(40, 40)

In [9]:
# zeros = np.zeros((1,d_z))

In [11]:
# a_new = np.concatenate([a_init, zeros],0)

In [18]:
G = np.eye(d_z)

In [15]:
G.shape

(40, 40)

In [16]:
G

array([[1., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 1.]])

In [12]:
k=3

In [13]:
ones_k = np.ones(k)

In [14]:
ones_k.shape

(3,)

In [15]:
lambda_0 = np.random.normal(size=k*d_z)

In [16]:
lambda_0.shape

(120,)

In [22]:
# samples = []
# for l in range(G.shape[1]):
#     extended_mask = np.kron(G[:,l], ones_k)
#     normal_a_l = np.random.normal(size=1)
#     lambda_a = normal_a_l * extended_mask
#     samples.append(lambda_0 + lambda_a)

In [19]:
lambdas = np.zeros((d_z, k*d_z))
for l in range(G.shape[1]):
    extended_mask = np.kron(G[:,l], ones_k)
    normal_a_l = np.random.normal(size=1)
    lambda_a = normal_a_l * extended_mask
    lambdas[l] = lambda_0 + lambda_a

In [20]:
lambdas.shape

(40, 120)

In [25]:
exp_lambdas = np.exp(lambdas)

In [26]:
exp_lambdas.shape

(40, 120)

In [27]:
exp_lambdas

array([[9.1306335 , 7.87568937, 0.31118157, ..., 0.43172241, 2.23359479,
        2.89165596],
       [4.93208104, 4.2541997 , 0.1680905 , ..., 0.43172241, 2.23359479,
        2.89165596],
       [4.93208104, 4.2541997 , 0.1680905 , ..., 0.43172241, 2.23359479,
        2.89165596],
       ...,
       [4.93208104, 4.2541997 , 0.1680905 , ..., 0.43172241, 2.23359479,
        2.89165596],
       [4.93208104, 4.2541997 , 0.1680905 , ..., 0.43172241, 2.23359479,
        2.89165596],
       [4.93208104, 4.2541997 , 0.1680905 , ..., 1.47439088, 7.62803062,
        9.87539919]])

In [28]:
reshaped = exp_lambdas.reshape((40,40,3))

In [29]:
reshaped.shape

(40, 40, 3)

In [30]:
marg = reshaped.sum(axis=-1)

In [32]:
marg.shape

(40, 40)

In [34]:
marg = np.expand_dims(marg, -1)

In [35]:
marg.shape

(40, 40, 1)

In [42]:
# append a row of ones for the probability of the last class
ones_c = np.ones((40,40,1))

In [48]:
logits = np.concatenate([reshaped, ones_c], -1)

In [49]:
logits.shape

(40, 40, 4)

In [50]:
p = logits / (1 + marg)

In [51]:
p.shape

(40, 40, 4)

In [52]:
p[0,0].sum()

1.0

In [53]:

p[0,0]

array([0.49846493, 0.42995428, 0.01698821, 0.05459259])

In [66]:
np.where(np.isclose(p.sum(axis=-1), 1)==False)

(array([], dtype=int64), array([], dtype=int64))

In [56]:
p.sum(axis=-1)

array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]])

In [57]:
p[0,1]

array([0.68796253, 0.21364317, 0.01710583, 0.08128847])

In [58]:
np.sum(p[0,1])

0.9999999999999999

In [67]:
from scipy.special import logsumexp

In [68]:
lambdas_reshaped = lambdas.reshape((40,40,3))

In [69]:
lambdas_reshaped.shape

(40, 40, 3)

In [71]:
logmarg = logsumexp(lambdas_reshaped, axis=-1)

In [73]:
marg_exp = np.exp(logmarg)

In [75]:
marg_exp

array([[17.31750444, 11.30186746,  2.17658522, ...,  1.61735887,
         1.41160597,  5.55697316],
       [ 9.35437124,  4.03999389,  2.17658522, ...,  1.61735887,
         1.41160597,  5.55697316],
       [ 9.35437124, 11.30186746,  0.17899144, ...,  1.61735887,
         1.41160597,  5.55697316],
       ...,
       [ 9.35437124, 11.30186746,  2.17658522, ...,  0.99113459,
         1.41160597,  5.55697316],
       [ 9.35437124, 11.30186746,  2.17658522, ...,  1.61735887,
         2.97447547,  5.55697316],
       [ 9.35437124, 11.30186746,  2.17658522, ...,  1.61735887,
         1.41160597, 18.97782069]])

array([[[17.31750444],
        [11.30186746],
        [ 2.17658522],
        ...,
        [ 1.61735887],
        [ 1.41160597],
        [ 5.55697316]],

       [[ 9.35437124],
        [ 4.03999389],
        [ 2.17658522],
        ...,
        [ 1.61735887],
        [ 1.41160597],
        [ 5.55697316]],

       [[ 9.35437124],
        [11.30186746],
        [ 0.17899144],
        ...,
        [ 1.61735887],
        [ 1.41160597],
        [ 5.55697316]],

       ...,

       [[ 9.35437124],
        [11.30186746],
        [ 2.17658522],
        ...,
        [ 0.99113459],
        [ 1.41160597],
        [ 5.55697316]],

       [[ 9.35437124],
        [11.30186746],
        [ 2.17658522],
        ...,
        [ 1.61735887],
        [ 2.97447547],
        [ 5.55697316]],

       [[ 9.35437124],
        [11.30186746],
        [ 2.17658522],
        ...,
        [ 1.61735887],
        [ 1.41160597],
        [18.97782069]]])

In [21]:
###
lambdas_reshaped = lambdas.reshape((40,40,3))

# append a row of ones for the probability of the last class
zeros_c = np.zeros((d_z,d_z,1))

In [22]:
lambdas_c = np.concatenate([lambdas_reshaped, zeros_c], -1)

In [23]:
lambdas_c.shape

(40, 40, 4)

In [25]:
from scipy.special import softmax

In [33]:
p = softmax(lambdas_c, axis=-1)

In [34]:
p.shape

(40, 40, 4)

In [35]:
np.sum(p[0,0])

1.0

In [36]:
p[0,0]

array([0.27764123, 0.4868433 , 0.02337692, 0.21213855])

In [37]:
np.sum(p, axis=-1)

array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]])