In [None]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

import plotting_functions as p_f
import cell_analyses as c_u
import data_utils as d_u
import model_utils as m_u

COLOR = "black"
plt.rcParams["text.color"] = COLOR
plt.rcParams["axes.labelcolor"] = COLOR
plt.rcParams["xtick.color"] = COLOR
plt.rcParams["ytick.color"] = COLOR

fontsize = 15
linewidth = 3
labelsize = 15
legendsize = 12
legend_loc = (0.45, 0.22)
legend_ncol = 2
window_len = 10
dpi = 300
cs = ['white', 'pink']

fig_path = 'FIG PATH'

save_path_net = fig_path + 'Subspace_Net/'
save_path_ae = fig_path + 'Subspace_AE/'
save_path_vae = fig_path + 'Subspace_VAE/'
save_path_pattern = fig_path + 'Pattern/'

data = {'sup_shal_lin_factors': {},
        'sup_deep_nonlin_factors': {},
        'sup_deep_lin_factors': {},
        'sup_deep_2_subspace': {},
        'ae_lin_factors': {},
        'ae_nonlin_factors': {},
        'vae_shapes3d': {},
        'vae_shapes3d_baselines': {},
        'vae_shapes3d_longer' : {},
        'vae_shapes3d_baselines_longer': {},
        'vae_dsprites_longer' : {},
        'vae_dsprites_baselines_longer': {},
        'ae_categorical': {},
        'sup_shal_lin_factors_sparse': {},
        'sup_shal_lin_factors_sparse_weights': {},
       }

loss_act = r"$\mathcal{L}_{activity}$,  "
loss_weight = r"$\mathcal{L}_{weight}$,  "
loss_nonneg = r"$\mathcal{L}_{nonneg}$,  "
loss_sparse = r"$\mathcal{L}_{sparse}$,  "
relu = "ReLu,  "

# SubspaceNet

In [None]:
model_folder_name = "Subspaces"  # 'Group_embedder' 'Subspaces_vae' 'Pattern_Learning'
path1 = "SOME PATH" + model_folder_name + "/"
path2 = "SOME PATH" + model_folder_name + "/path2/"
paths = [path2, path1]

### Supervised Shallow Net - Linear data -> Factors

In [None]:
# this is one hidden layer net
data['sup_shal_lin_factors']['info'] = {
    relu + loss_act + loss_weight: {
        "2022-08-20": [0, 1, 2],
    },
    loss_nonneg + loss_act + loss_weight: {
        "2022-08-20": [3, 4, 5],
    },
    relu + loss_act: {
        "2022-08-20": [9, 10, 11],
    },
    relu + loss_weight: {
        "2022-08-20": [12, 13, 14],
    },
    loss_act + loss_weight: {
        "2022-08-20": [6, 7, 8],
    },
}

data['sup_shal_lin_factors'] = p_f.get_tensorboard_df(data['sup_shal_lin_factors'], paths=paths)

In [None]:
p_f.plot_train_curve(data['sup_shal_lin_factors'], 'metrics/discrete_mil_0', 'MIR',\
                     save_path_net + 'shal_lin_factors_mil', label_keep=-3, ylim=(0,1.0)) 

In [None]:
p_f.plot_train_curve(data['sup_shal_lin_factors'], 'accuracies/r2', '$r^2$',\
                     save_path_net + 'shal_lin_factors_r2', label_keep=-3, ylim=(0.9999,1.0)) 

### Supervised Shallow Net SPARSE neurons - Linear data -> Factors

In [None]:
# this is one hidden layer net
data['sup_shal_lin_factors_sparse']['info'] = {
    loss_weight + r"$\beta_{sparse}=0.0001$": {
        "2022-09-27": [0, 2, 4],
    },
    loss_weight + r"$\beta_{sparse}=0.001$": {
        "2022-09-27": [1, 3, 6],
    },
    loss_weight + r"$\beta_{sparse}=0.01$": {
        "2022-09-27": [5, 7, 8],
    },
    loss_weight + r"$\beta_{sparse}=0.1$": {
        "2022-09-27": [9, 10, 11],
    },
    r"$\beta_{sparse}=0.0001$" : {
        "2022-09-27": [12, 13, 14],
    },
    r"$\beta_{sparse}=0.001$": {
        "2022-09-27": [15, 16, 17],
    },
    r"$\beta_{sparse}=0.01$": {
        "2022-09-27": [18, 19, 21],
    },
    r"$\beta_{sparse}=0.1$": {
        "2022-09-27": [20, 22, 23],
    },
}

data['sup_shal_lin_factors_sparse'] = p_f.get_tensorboard_df(data['sup_shal_lin_factors_sparse'], paths=paths)

In [None]:
p_f.plot_train_curve(data['sup_shal_lin_factors_sparse'], 'metrics/discrete_mil_0', 'MIR',\
                     save_path_net + 'shal_lin_factors_sparse_mil', label_keep=None, ylim=(0,1.0), legend_loc_=(0.05, 0.8))

### Supervised Shallow Net SPARSE weights - Linear data -> Factors

In [None]:
# this is one hidden layer net_weights
data['sup_shal_lin_factors_sparse_weights']['info'] = {
    r"$\beta_{sparse}=0.0001$" : {
        "2022-11-09": [17, 19, 20, 22, 25],
    },
    r"$\beta_{sparse}=0.001$": {
        "2022-11-09": [15, 16, 18, 21 ,23, 24],
    },
    r"$\beta_{sparse}=0.01$": {
        "2022-11-09": [6, 7, 9, 13, 14],
    },
    r"$\beta_{sparse}=0.1$": {
        "2022-11-09": [5, 8, 10, 11, 12],
    },
    loss_act+ r"$\beta_{sparse}=0.0001$" : {
        "2022-11-10": [0, 2, 4],
    },
    loss_act + r"$\beta_{sparse}=0.001$": {
        "2022-11-10": [1, 3, 9],
    },
    loss_act + r"$\beta_{sparse}=0.01$": {
        "2022-11-10": [5, 6, 11],
    },
    loss_act + r"$\beta_{sparse}=0.1$": {
        "2022-11-10": [7, 8, 10],
    },
    #r"$\beta_{sparse}=1.0$": {
    #    "2022-11-09": [0, 1, 2, 3, 4],
    #},
}

data['sup_shal_lin_factors_sparse_weights'] = p_f.get_tensorboard_df(data['sup_shal_lin_factors_sparse_weights'], paths=paths)

In [None]:
p_f.plot_train_curve(data['sup_shal_lin_factors_sparse_weights'], 'metrics/discrete_mil_0', 'MIR',\
                     save_path_net + 'shal_lin_factors_sparse_weights_mil', label_keep=None, ylim=(0,1.0), legend_loc_=(0.05, 0.8))

### Supervised Deep Net - Nonlinear data -> Factors

In [None]:
# this one has large network size, WITH non-linear function on input
data['sup_deep_nonlin_factors']['info'] = {
    "non-linear data factor dim = 6": {
        "2022-08-20": [15, 16],
        "2022-08-21": [0, 1, 7],
    },
    "non-linear data factor dim = 4": {
        "2022-08-21": [2, 3, 4, 5, 6],
    },
}

data['sup_deep_nonlin_factors'] = p_f.get_tensorboard_df(data['sup_deep_nonlin_factors'], paths=paths)

In [None]:
p_f.plot_training_curves_layers(data['sup_deep_nonlin_factors'], 'metrics/discrete_mil_', 'MIR', \
                            save_path_net + 'deep_nonlin_factors_mil_', label_keep=None, ylim=(0,1), legend_loc_=(0.40, 0.16))

### Supervised Deep Net - Linear data -> Factors

In [None]:
# this one has large network size, WITHOUT non-linear function on input
data['sup_deep_lin_factors']['info'] = {
    "linear data factor dim = 6": {
        "2022-08-21": [23, 24, 25, 26, 27],
    },
    "linear data factor dim = 4": {
        "2022-08-21": [28, 29, 30, 31, 32],
    },
}

data['sup_deep_lin_factors'] = p_f.get_tensorboard_df(data['sup_deep_lin_factors'], paths=paths)

In [None]:
p_f.plot_training_curves_layers(data['sup_deep_lin_factors'], 'metrics/discrete_mil_', 'MIR', \
                            save_path_net + 'deep_lin_factors_mil_', label_keep=None, ylim=(0,1), legend_loc_=(0.40, 0.16))

# Autoencoders

In [None]:
model_folder_name = "Subspaces_vae"  # 'Group_embedder' 'Subspaces_vae' 'Pattern_Learning'
path1 = "SOME PATH" + model_folder_name + "/"
path2 = "OTHER PATH" + model_folder_name + "/path2/"
paths = [path2, path1]

### Autoencoder - Linear synthetic data - factors

In [None]:
# autoencoder linear synthetic data
data['ae_lin_factors']['info'] = {
    relu + loss_act + loss_weight: {
        "2022-08-21": [ 9, 10, 11, 12, 13 ],
    },
    loss_nonneg + loss_act + loss_weight: {
        "2022-08-21": [ 24, 25, 26, 27, 28 ],
    },
    relu + loss_act: {
        "2022-08-21": [ 14, 15, 16, 17, 18 ],
    },
    relu + loss_weight: {
        "2022-08-21": [ 19, 20, 21, 22, 23 ],
    },
    loss_act + loss_weight: {
        "2022-08-21": [ 29, 30, 31, 32],
    },
}

data['ae_lin_factors'] = p_f.get_tensorboard_df(data['ae_lin_factors'], paths=paths)

In [None]:
p_f.plot_train_curve(data['ae_lin_factors'], 'metrics/discrete_mig', 'MIG',\
                 save_path_ae + 'lin_factors_mig', label_keep=-3, ylim=(0,0.5))                  

In [None]:
p_f.plot_train_curve(data['ae_lin_factors'], 'metrics/discrete_mil', 'MIR',\
                 save_path_ae + 'lin_factors_mil', label_keep=-3, ylim=(0,1.0))  

In [None]:
p_f.plot_train_curve(data['ae_lin_factors'], 'accuracies/r2', '$r^2$',\
                 save_path_ae + 'lin_factors_r2', label_keep=-3, ylim=(0.95,1.0))

### Autoencoder - Nonlinear synthetic data

In [None]:
# autoencoder nonlinear synthetic data
data['ae_nonlin_factors']['info'] = {
    relu + loss_act + loss_weight: {
        "2022-08-21": [ 6, 7, 8, 34, 35],
    },
    loss_nonneg + loss_act + loss_weight: {
        "2022-08-21": [ 40, 41, 42, 43, 44],
    },
    relu + loss_act: {
        "2022-08-21": [ 50, 51, 52, 53, 54],
    },
    relu + loss_weight: {
        "2022-08-21": [ 55, 56, 57, 58, 59],
    },
    loss_act + loss_weight: {
        "2022-08-21": [ 45, 46, 47, 48, 49],
    },
}

# autoencoder nonlinear synthetic data
data['ae_nonlin_factors']['info'] = {
    relu + loss_act + loss_weight: {
        "2022-11-11": [2, 5, 6],
        "2022-11-12": [0, 4],
    },
    loss_nonneg + loss_act + loss_weight: {
        "2022-11-11": [0, 1, 3],
        "2022-11-12": [7, 8],
    },
    relu + loss_act: {
        "2022-11-11": [ 9, 11, 12 ],
        "2022-11-12": [1, 6],
    },
    relu + loss_weight: {
        "2022-11-11": [ 10, 13, 17],
        "2022-11-12": [2, 5],
    },
    loss_act + loss_weight: {
        "2022-11-11": [ 14, 15, 16],
        "2022-11-12": [3, 9],
    },
}

data['ae_nonlin_factors'] = p_f.get_tensorboard_df(data['ae_nonlin_factors'], paths=paths)

In [None]:
p_f.plot_train_curve(data['ae_nonlin_factors'], 'metrics/discrete_mig', 'MIG',\
                 save_path_ae + 'nonlin_factors_mig', label_keep=-3, ylim=(0,0.6)) 

In [None]:
p_f.plot_train_curve(data['ae_nonlin_factors'], 'metrics/discrete_mil', 'MIR',\
                 save_path_ae + 'nonlin_factors_mil', label_keep=-3, ylim=(0,1.0)) 

In [None]:
p_f.plot_train_curve(data['ae_nonlin_factors'], 'accuracies/r2', '$r^2$',\
                 save_path_ae + 'nonlin_factors_r2', label_keep=-3, ylim=(0.95,1.0)) 

### VAE - Shapes3D

In [None]:
# 500000 iterations:
data['vae_shapes3d_longer']['info'] = {}
data['vae_shapes3d_longer']['info'] = {
        relu + r"$\beta_{VAE}=1$,  " + r"$\beta_{weight}=1.0$": {
            "2022-08-25": [0, 1, 2, 3, 4], #checked
    },
        relu + r"$\beta_{VAE}=1$,  " + r"$\beta_{weight}=0.1$": {
            "2022-08-25": [5, 6, 7, 8, 9], #checked
    },
        relu + r"$\beta_{VAE}=1$,  " + r"$\beta_{weight}=0.01$": {
            "2022-08-26": [10, 11, 12, 13, 14], #checked
    },
}

data['vae_shapes3d_longer'] = p_f.get_tensorboard_df(data['vae_shapes3d_longer'], min_steps=400, paths=paths)

In [None]:
p_f.plot_train_curve(data['vae_shapes3d_longer'], 'metrics/discrete_mig', 'MIG',\
                 save_path_vae + 'shapes3d_mig', label_keep=None, ylim=(0,0.7), legend_ncol_=1, legend_loc_=(0.3, 0.2)) 

In [None]:
p_f.plot_train_curve(data['vae_shapes3d_longer'], 'metrics/discrete_mil', 'MIR',\
                 save_path_vae + 'shapes3d_mil', label_keep=None, ylim=(0,1.0), legend_ncol_=1, legend_loc_=(0.3, 0.2)) 

In [None]:
p_f.plot_train_curve(data['vae_shapes3d_longer'], 'accuracies/r2', '$r^2$',\
                 save_path_vae + 'shapes3d_r2', label_keep=None, ylim=(0.95,1.0), legend_ncol_=1, legend_loc_=(0.3, 0.2)) 

In [None]:
p_f.plot_train_curve(data['vae_shapes3d_longer'], 'metrics/num used latents', 'Number Used Latents',\
                 save_path_vae + 'shapes3d_num_latents', label_keep=None, ylim=(0,10.0), legend_ncol_=1, legend_loc_=(0.3, 0.2))

### VAE - Shapes3D baselines

In [None]:
data['vae_shapes3d_baselines_longer']['info'] = {}
data['vae_shapes3d_baselines_longer']['info'] = {
    r"$\beta_{VAE}=16$": {
        "2022-08-25": [10, 11, 12, 13, 14],
    },
    relu + r"$\beta_{VAE}=16$": {
    },
    r"$\beta_{VAE}=4$": {
        "2022-08-25": [20, 21, 22, 23],
        "2022-08-26": [0],
    },
    r"$\beta_{VAE}=1$": {
        "2022-08-26": [5, 6, 7, 8, 9],
    },
    relu + r"$\beta_{VAE}=4$": {
        "2022-08-25": [15, 16, 17, 18, 19],
    },
    relu + r"$\beta_{VAE}=1$": {
        "2022-08-26": [1, 2, 3, 4, 15],  
    },  
}

data['vae_shapes3d_baselines_longer'] = p_f.get_tensorboard_df(data['vae_shapes3d_baselines_longer'], min_steps=400, paths=paths)

In [None]:
p_f.plot_train_curve(data['vae_shapes3d_baselines_longer'], 'metrics/discrete_mig', 'MIG',\
                 save_path_vae + 'shapes3d_baselines_mig', label_keep=None, ylim=(0,0.7), legend_loc_=(0.0, 0.86)) 

In [None]:
p_f.plot_train_curve(data['vae_shapes3d_baselines_longer'], 'metrics/discrete_mil', 'MIR',\
                 save_path_vae + 'shapes3d_baselines_mil', label_keep=None, ylim=(0,1.0), legend_loc_=(0.3, 0.2)) 

In [None]:
p_f.plot_train_curve(data['vae_shapes3d_baselines_longer'], 'accuracies/r2', '$r^2$',\
                 save_path_vae + 'shapes3d_baselines_r2', label_keep=None, ylim=(0.95,1.0), legend_loc_=(0.3, 0.2)) 

In [None]:
p_f.plot_train_curve(data['vae_shapes3d_baselines_longer'], 'metrics/num used latents', 'Number Used Latents',\
                 save_path_vae + 'shapes3d_baselines_num_latents', label_keep=None, ylim=(0,10.0), legend_loc_=(0.3, 0.2)) 

### MIG vs Accuracy - for all models

In [None]:
# get final value then plot r2 vs mig for all models. and colour models
data_1 = data['vae_shapes3d_longer']
data_2 = data['vae_shapes3d_baselines_longer']
dfs_all_mig_1 = [[a['metrics/discrete_mig'][:data_1['cutoff']] for a in x] for x in data_1['dfs']]
dfs_all_mig_2 = [[a['metrics/discrete_mig'][:data_2['cutoff']] for a in x] for x in data_2['dfs']]
final_mig_1 = [pd.concat(a, axis=1).iloc[-1:] for a in dfs_all_mig_1]
final_mig_2 = [pd.concat(a, axis=1).iloc[-1:] for a in dfs_all_mig_2]
dfs_all_r2_1 = [[a['accuracies/r2'][:data_1['cutoff']] for a in x] for x in data_1['dfs']]
dfs_all_r2_2 = [[a['accuracies/r2'][:data_2['cutoff']] for a in x] for x in data_2['dfs']]
final_r2_1 = [pd.concat(a, axis=1).iloc[-1:] for a in dfs_all_r2_1]
final_r2_2 = [pd.concat(a, axis=1).iloc[-1:] for a in dfs_all_r2_2]

for lab, mig, r2 in zip(data_1['labels'] + data_2['labels'], final_mig_1 + final_mig_2, final_r2_1 + final_r2_2):
    plt.scatter(mig, r2, label=lab, s=80)
plt.legend(fontsize=legendsize, ncol=2, loc='center left', bbox_to_anchor=(-0.05, 0.28))
plt.xlabel('MIG', fontsize=fontsize)
plt.ylabel('$r^2$', fontsize=fontsize)

plt.savefig(save_path_vae + 'shapes3d_mig_r2' + ".png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# get final value then plot r2 vs mig for all models. and colour models
data_1 = data['vae_shapes3d_longer']
data_2 = data['vae_shapes3d_baselines_longer']
dfs_all_mig_1 = [[a['metrics/discrete_mig'][:data_1['cutoff']] for a in x] for x in data_1['dfs']]
dfs_all_mig_2 = [[a['metrics/discrete_mig'][:data_2['cutoff']] for a in x] for x in data_2['dfs']]
final_mig_1 = [pd.concat(a, axis=1).iloc[-10:] for a in dfs_all_mig_1]
final_mig_2 = [pd.concat(a, axis=1).iloc[-10:] for a in dfs_all_mig_2]
dfs_all_ll_1 = [[a['losses/rec'][:data_1['cutoff']] for a in x] for x in data_1['dfs']]
dfs_all_ll_2 = [[a['losses/rec'][:data_2['cutoff']] for a in x] for x in data_2['dfs']]
final_ll_1 = [pd.concat(a, axis=1).iloc[-10:] for a in dfs_all_ll_1]
final_ll_2 = [pd.concat(a, axis=1).iloc[-10:] for a in dfs_all_ll_2]

for lab, mig, l_l in zip(data_1['labels'] + data_2['labels'], final_mig_1 + final_mig_2, final_ll_1 + final_ll_2):
    plt.scatter(mig.mean(), -l_l.mean(), label=lab, s=80)
plt.legend(fontsize=legendsize, ncol=2, loc='center left', bbox_to_anchor=(-0.05, 0.25))
plt.xlabel('MIG', fontsize=fontsize)
plt.ylabel('Log Liklihood', fontsize=fontsize)
plt.savefig(save_path_vae + 'shapes3d_mig_ll' + ".png", dpi=300, bbox_inches='tight')
plt.show()

### dsprites

In [None]:
data['vae_dsprites_longer']['info'] = {}
data['vae_dsprites_longer']['info'] = {
        relu + r"$\beta_{VAE}=1$,  " + r"$\beta_{weight}=10.0$": {
           "2022-08-28": [20, 21, 22, 23, 24]
    },
        relu + r"$\beta_{VAE}=1$,  " + r"$\beta_{weight}=3.0$": {
           "2022-08-28": [25, 26, 27, 28, 28]
    },
        relu + r"$\beta_{VAE}=1$,  " + r"$\beta_{weight}=1.0$": {
           "2022-08-27": [9, 10, 18, 19, 20]
    },
        relu + r"$\beta_{VAE}=1$,  " + r"$\beta_{weight}=0.3$": {
            "2022-08-28": [6, 7, 8, 9, 10],
    },
}
data['vae_dsprites_longer'] = p_f.get_tensorboard_df(data['vae_dsprites_longer'], min_steps=300, paths=paths)

In [None]:
p_f.plot_train_curve(data['vae_dsprites_longer'], 'metrics/discrete_mig', 'MIG',\
                 save_path_vae + 'dsprites_mig', label_keep=None, ylim=(0,0.4), legend_ncol_=1, legend_loc_=(0.01, 0.8)) 

In [None]:
data['vae_dsprites_baselines_longer']['info'] = {}
data['vae_dsprites_baselines_longer']['info'] = {
    r"$\beta_{VAE}=16$": {
        "2022-08-28": [0,1,2,3,5],
    },
    relu + r"$\beta_{VAE}=16$": {
    },
    r"$\beta_{VAE}=4$": {
        "2022-08-27": [15,16,17],
        "2022-08-28": [12,14],
    },
    relu + r"$\beta_{VAE}=4$": {
        "2022-08-27": [12,13,14],
        "2022-08-28": [11,13],
    },
    r"$\beta_{VAE}=1$": {
    },
    relu + r"$\beta_{VAE}=16$": {
    },  
}
data['vae_dsprites_baselines_longer'] = p_f.get_tensorboard_df(data['vae_dsprites_baselines_longer'], min_steps=400, paths=paths)

In [None]:
p_f.plot_train_curve(data['vae_dsprites_baselines_longer'], 'metrics/discrete_mig', 'MIG',\
                 save_path_vae + 'dsprites_baselines_mig', label_keep=None, ylim=(0,0.4), legend_ncol_=1, legend_loc_=(0.5, 0.5)) 

In [None]:
# get final value then plot r2 vs mig for all models. and colour models
data_1 = data['vae_dsprites_longer']
data_2 = data['vae_dsprites_baselines_longer']
dfs_all_mig_1 = [[a['metrics/discrete_mig'][:data_1['cutoff']] for a in x] for x in data_1['dfs']]
dfs_all_mig_2 = [[a['metrics/discrete_mig'][:data_2['cutoff']] for a in x] for x in data_2['dfs']]
final_mig_1 = [pd.concat(a, axis=1).iloc[-1:] for a in dfs_all_mig_1]
final_mig_2 = [pd.concat(a, axis=1).iloc[-1:] for a in dfs_all_mig_2]
dfs_all_r2_1 = [[a['accuracies/r2'][:data_1['cutoff']] for a in x] for x in data_1['dfs']]
dfs_all_r2_2 = [[a['accuracies/r2'][:data_2['cutoff']] for a in x] for x in data_2['dfs']]
final_r2_1 = [pd.concat(a, axis=1).iloc[-1:] for a in dfs_all_r2_1]
final_r2_2 = [pd.concat(a, axis=1).iloc[-1:] for a in dfs_all_r2_2]

legend_loc = (0.01, 0.06)

for lab, mig, r2 in zip(data_1['labels'] + data_2['labels'], final_mig_1 + final_mig_2, final_r2_1 + final_r2_2):
    plt.scatter(mig, r2, label=lab, s=80)
plt.legend(fontsize=legendsize, loc=(-0.03, 0.17), ncol=2)
plt.xlabel('MIG', fontsize=fontsize)
plt.ylabel('$r^2$', fontsize=fontsize)

plt.savefig(save_path_vae + 'dsprites_mig_r2' + ".png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# get final value then plot r2 vs mig for all models. and colour models
data_1 = data['vae_dsprites_longer']
data_2 = data['vae_dsprites_baselines_longer']
dfs_all_mig_1 = [[a['metrics/discrete_mig'][:data_1['cutoff']] for a in x] for x in data_1['dfs']]
dfs_all_mig_2 = [[a['metrics/discrete_mig'][:data_2['cutoff']] for a in x] for x in data_2['dfs']]
final_mig_1 = [pd.concat(a, axis=1).iloc[-10:] for a in dfs_all_mig_1]
final_mig_2 = [pd.concat(a, axis=1).iloc[-10:] for a in dfs_all_mig_2]
dfs_all_ll_1 = [[a['losses/rec'][:data_1['cutoff']] for a in x] for x in data_1['dfs']]
dfs_all_ll_2 = [[a['losses/rec'][:data_2['cutoff']] for a in x] for x in data_2['dfs']]
final_ll_1 = [pd.concat(a, axis=1).iloc[-10:] for a in dfs_all_ll_1]
final_ll_2 = [pd.concat(a, axis=1).iloc[-10:] for a in dfs_all_ll_2]

for lab, mig, l_l in zip(data_1['labels'] + data_2['labels'], final_mig_1 + final_mig_2, final_ll_1 + final_ll_2):
    plt.scatter(mig.mean(), -l_l.mean(), label=lab, s=80)
plt.legend(fontsize=legendsize, ncol=2, loc='center left', bbox_to_anchor=(-0.05, 0.37))
plt.xlabel('MIG', fontsize=fontsize)
plt.ylabel('Log Liklihood', fontsize=fontsize)
plt.savefig(save_path_vae + 'dsprites_mig_ll' + ".png", dpi=300, bbox_inches='tight')
plt.show()

### Categorical

In [None]:
# autoencoder linear synthetic data
data['ae_categorical']['info'] = {}
data['ae_categorical']['info'] = {
    relu + loss_act + loss_weight: {
        "2022-09-09": [ 5, 6, 7, 8, 9 ],
    },
    loss_nonneg + loss_act + loss_weight: {
        "2022-09-09": [ 0, 1, 2, 3, 4 ],
    },
    relu + loss_act: {
        "2022-09-09": [ 20, 21, 22, 23, 24 ],
    },
    relu + loss_weight: {
        "2022-09-09": [ 15, 16, 17, 18, 19 ],
    },
    loss_act + loss_weight: {
        "2022-09-09": [ 10, 11, 12, 13, 14],
    },
    loss_sparse: {
        "2022-09-09": [ 40, 41, 42, 43, 44],
    },
    loss_sparse + loss_weight: {
        "2022-09-09": [ 35, 36, 37, 38, 39],
    },
}

data['ae_categorical'] = p_f.get_tensorboard_df(data['ae_categorical'], paths=paths, min_steps=50)

In [None]:
p_f.plot_train_curve(data['ae_categorical'], 'metrics/discrete_mig', 'MIG',\
                 save_path_vae + 'ae_categorical_mig', label_keep=-3, ylim=(0,0.5), cutoff=300000, legend_loc_=(0.05, 0.22)) 

In [None]:
p_f.plot_train_curve(data['ae_categorical'], 'metrics/discrete_mil', 'MIL',\
                 save_path_vae + 'ae_categorical_mil', label_keep=-3, ylim=(0,0.7), cutoff=300000, legend_loc_=(0.05, 0.22)) 

## SUBSPACE NETWORK - MUTUAL INFO FIGS

In [None]:
model_folder_name = 'subspaces'
path1 = 'SOME PATH' + model_folder_name + '/'
path2 = 'ANOTHER PATH' + model_folder_name + '/path2/'
model_type = 'subspaces'
base_path = path1
cmap ='binary'

### Linear Net - Linear data

In [None]:
date = '2022-08-20'
run = 1
index = None

# Get directories for the requested run
run_path, train_path, model_path, save_path, script_path = p_f.set_directories(date, run, base_path=base_path)
if index == None:
    index = max([int(x.split('subspaces_')[1].split('.index')[0]) for x in os.listdir(model_path) if 'index' in x])
    print(index)
# Load model from file
model, params, stored_mu, stored_du = p_f.get_model(model_path + '/subspaces_' + str(index), script_path, save_path, use_old_scripts=True, model_type=model_type)

# collect data
params.graph_mode = False
ds_batch, (ds_metric, ds_metric_2) = d_u.data_subspaces(params)

In [None]:
begin_layer = 1
neuron_used = m_u.important_neuroms(model, ds_metric['input'], ds_metric['image'], params, begin_layer=1)

xs = model(ds_metric['input'])
prob = xs[-1].numpy()
xs = [x.numpy().T for x in xs]
xs_2 = model(ds_metric_2['input'])
xs_2 = [x.numpy().T for x in xs_2]

metrics = [c_u.compute_mig(x, ds_metric, x_2, ds_metric_2, dataset=params.dataset, neuron_used=n_u) \
           for (x, x_2, n_u) in zip(xs[begin_layer:], xs_2[begin_layer:], neuron_used)]
mi_matrices = [a[1][0] for a in metrics]
entropies = [a[1][1] for a in metrics]
metrics = [a[0] for a in metrics]

print(metrics)

tick_names = [x[6:] if 'value' in x else x for x in ds_metric.keys() if x not in ['image', 'input']]
plt.imshow(mi_matrices[-1], cmap=cmap)
#plt.colorbar()
ax = plt.gca()
plt.xticks(np.arange(len(tick_names)))
ax.set_xticklabels(tick_names, rotation = 60, ha="right")
plt.xlabel('Factors', fontsize=fontsize)
plt.ylabel('Latents', fontsize=fontsize)
plt.savefig(save_path_net + 'lin_lin_MI' + ".png", dpi=300, bbox_inches='tight')
plt.show()

### Non-linear Net - nonlinear data

In [None]:
date = '2022-08-21'
run = 0
index = None
layer = -2

# Get directories for the requested run
run_path, train_path, model_path, save_path, script_path = p_f.set_directories(date, run, base_path=base_path)
if index == None:
    index = max([int(x.split('subspaces_')[1].split('.index')[0]) for x in os.listdir(model_path) if 'index' in x])
    print(index)
# Load model from file
model, params, stored_mu, stored_du = p_f.get_model(model_path + '/subspaces_' + str(index), script_path, save_path, use_old_scripts=True, model_type=model_type)

# collect data
params.graph_mode = False
ds_batch, (ds_metric, ds_metric_2) = d_u.data_subspaces(params)

In [None]:
begin_layer = 1
neuron_used = m_u.important_neuroms(model, ds_metric['input'], ds_metric['image'], params, begin_layer=1)

xs = model(ds_metric['input'])
prob = xs[-1].numpy()
xs = [x.numpy().T for x in xs]
xs_2 = model(ds_metric_2['input'])
xs_2 = [x.numpy().T for x in xs_2]

metrics = [c_u.compute_mig(x, ds_metric, x_2, ds_metric_2, dataset=params.dataset, neuron_used=n_u) \
           for (x, x_2, n_u) in zip(xs[begin_layer:], xs_2[begin_layer:], neuron_used)]
mi_matrices = [a[1][0] for a in metrics]
entropies = [a[1][1] for a in metrics]
metrics = [a[0] for a in metrics]

print(metrics[layer])

tick_names = [x[6:] if 'value' in x else x for x in ds_metric.keys() if x not in ['image', 'input']]
plt.imshow(mi_matrices[layer], cmap=cmap)
#plt.colorbar()
ax = plt.gca()
plt.xticks(np.arange(len(tick_names)))
ax.set_xticklabels(tick_names, rotation = 60, ha="right")
plt.xlabel('Factors', fontsize=fontsize)
plt.ylabel('Latents', fontsize=fontsize)
plt.savefig(save_path_net + 'nonlin_nonlin_MI' + ".png", dpi=300, bbox_inches='tight')
plt.show()

### Linear Net - no constraints

In [None]:
base_path = path1

date = '2022-08-20'
run = 6
index = None

# Get directories for the requested run
run_path, train_path, model_path, save_path, script_path = p_f.set_directories(date, run, base_path=base_path)
if index == None:
    index = max([int(x.split('subspaces_')[1].split('.index')[0]) for x in os.listdir(model_path) if 'index' in x])
    print(index)
# Load model from file
model, params, stored_mu, stored_du = p_f.get_model(model_path + '/subspaces_' + str(index), script_path, save_path, use_old_scripts=True, model_type=model_type)

# collect data
params.graph_mode = False
ds_batch, (ds_metric, ds_metric_2) = d_u.data_subspaces(params)

In [None]:
begin_layer = 1
neuron_used = m_u.important_neuroms(model, ds_metric['input'], ds_metric['image'], params, begin_layer=1)

xs = model(ds_metric['input'])
prob = xs[-1].numpy()
xs = [x.numpy().T for x in xs]
xs_2 = model(ds_metric_2['input'])
xs_2 = [x.numpy().T for x in xs_2]

metrics = [c_u.compute_mig(x, ds_metric, x_2, ds_metric_2, dataset=params.dataset, neuron_used=n_u) \
           for (x, x_2, n_u) in zip(xs[begin_layer:], xs_2[begin_layer:], neuron_used)]
mi_matrices = [a[1][0] for a in metrics]
entropies = [a[1][1] for a in metrics]
metrics = [a[0] for a in metrics]

print(metrics)

tick_names = [x[6:] if 'value' in x else x for x in ds_metric.keys() if x not in ['image', 'input']]
plt.imshow(mi_matrices[-1], cmap=cmap)
#plt.colorbar()
ax = plt.gca()
plt.xticks(np.arange(len(tick_names)))
ax.set_xticklabels(tick_names, rotation = 60, ha="right")
plt.xlabel('Factors', fontsize=fontsize)
plt.ylabel('Latents', fontsize=fontsize)
plt.savefig(save_path_net + 'lin_lin_no_constraints_MI' + ".png", dpi=300, bbox_inches='tight')
plt.show()

# Subspace VAE - MUTUAL INFO

In [None]:
model_folder_name = 'Subspaces_vae'
path1 = 'SOME PATH' + model_folder_name + '/'
path2 = 'ANOTHER PATH' + model_folder_name + '/path2/'
model_type = 'subspaces_vae'
cmap ='binary'

### Linear DeepNet - Linear Data

In [None]:
base_path = path1

date = '2022-08-21'
run = 9
index = None

# Get directories for the requested run
run_path, train_path, model_path, save_path, script_path = p_f.set_directories(date, run, base_path=base_path)
if index == None:
    index = max([int(x.split('subspaces_')[1].split('.index')[0]) for x in os.listdir(model_path) if 'index' in x])
    print(index)
# Load model from file
model, params, stored_mu, stored_du = p_f.get_model(model_path + '/subspaces_' + str(index), script_path, save_path, use_old_scripts=True, model_type=model_type)

import parameters
par_new = parameters.default_params_subspaces_vae()
for key in par_new.keys():
    try:
        params[key]
    except:
        params[key] = par_new[key]
        
# get data
params.graph_mode = False
# params.dataset = 'dsprites'  #'shapes3d'
ds_batch, (ds_metric, ds_metric_2) = d_u.data_subspaces(params)
# run model
(logits, rec), latents = model(ds_metric['image'])
(_, mu, logvar) = [x.numpy().T for x in latents]
(logits_2, rec_2), latents_2 = model(ds_metric_2['image'])
(_, mu_2, logvar_2) = [x.numpy().T for x in latents_2]

neuron_used = np.mean(np.exp(logvar), axis=1) <= 0.5 if params.sample else None
metrics, mi_mat_ = c_u.compute_mig(mu, ds_metric, mu_2, ds_metric_2, neuron_used=neuron_used,\
                                   dataset=params.dataset, remove_unused=True, compute_dci=False)
print(metrics)

In [None]:
tick_names = [x[6:] if 'value' in x else x for x in ds_metric.keys() if x not in ['image', 'input']]
mi_mat, entropy = mi_mat_
mi_mat_scaled = mi_mat / entropy[None, :]
plt.imshow(mi_mat_scaled, vmin=0.0, vmax=mi_mat_scaled.max(), cmap=cmap)
#plt.colorbar()
ax = plt.gca()
plt.xticks(np.arange(len(tick_names)))
ax.set_xticklabels(tick_names, rotation = 60, ha="right")
plt.xlabel('Factors', fontsize=fontsize)
plt.ylabel('Latents', fontsize=fontsize)
plt.savefig(save_path_vae + 'lin_lin_MI' + ".png", dpi=300, bbox_inches='tight')
plt.show()

### NonLinear DeepNet - NonLinear Data

In [None]:
base_path = path2

date = '2022-11-12'
run = 4
index = None

# Get directories for the requested run
run_path, train_path, model_path, save_path, script_path = p_f.set_directories(date, run, base_path=base_path)
if index == None:
    index = max([int(x.split('subspaces_')[1].split('.index')[0]) for x in os.listdir(model_path) if 'index' in x])
    print(index)
# Load model from file
model, params, stored_mu, stored_du = p_f.get_model(model_path + '/subspaces_' + str(index), script_path, save_path, use_old_scripts=True, model_type=model_type)

import parameters
par_new = parameters.default_params_subspaces_vae()
for key in par_new.keys():
    try:
        params[key]
    except:
        params[key] = par_new[key]
        
# get data
params.graph_mode = False
# params.dataset = 'dsprites'  #'shapes3d'
ds_batch, (ds_metric, ds_metric_2) = d_u.data_subspaces(params)
# run model
(logits, rec), latents = model(ds_metric['image'])
(_, mu, logvar) = [x.numpy().T for x in latents]
(logits_2, rec_2), latents_2 = model(ds_metric_2['image'])
(_, mu_2, logvar_2) = [x.numpy().T for x in latents_2]

neuron_used = np.mean(np.exp(logvar), axis=1) <= 0.5 if params.sample else None
metrics, mi_mat_ = c_u.compute_mig(mu, ds_metric, mu_2, ds_metric_2, neuron_used=neuron_used,\
                                   dataset=params.dataset, remove_unused=True, compute_dci=False)
print(metrics)

In [None]:
tick_names = [x[6:] if 'value' in x else x for x in ds_metric.keys() if x not in ['image', 'input']]
mi_mat, entropy = mi_mat_
mi_mat_scaled = mi_mat / entropy[None, :]
plt.imshow(mi_mat_scaled, vmin=0.0, vmax=mi_mat_scaled.max(), cmap=cmap)
ax = plt.gca()
plt.xticks(np.arange(len(tick_names)))
ax.set_xticklabels(tick_names, rotation = 60, ha="right")
plt.xlabel('Factors', fontsize=fontsize)
plt.ylabel('Latents', fontsize=fontsize)
plt.savefig(save_path_vae + 'nonlin_nonlin_MI' + ".png", dpi=300, bbox_inches='tight')
plt.show()

### Shapes 3d

In [None]:
base_path = path1

date = '2022-08-25'
run = 3 #33
index = None

# Get directories for the requested run
run_path, train_path, model_path, save_path, script_path = p_f.set_directories(date, run, base_path=base_path)
if index == None:
    index = max([int(x.split('subspaces_')[1].split('.index')[0]) for x in os.listdir(model_path) if 'index' in x])
    print(index)
# Load model from file
model, params, stored_mu, stored_du = p_f.get_model(model_path + '/subspaces_' + str(index), script_path, save_path, use_old_scripts=True, model_type=model_type)

import parameters
par_new = parameters.default_params_subspaces_vae()
for key in par_new.keys():
    try:
        params[key]
    except:
        params[key] = par_new[key]
        
params.graph_mode = False
# params.dataset = 'dsprites'  #'shapes3d'
ds_batch, (ds_metric, ds_metric_2) = d_u.data_subspaces(params)
# run model
(logits, rec), latents = model(ds_metric['image'])
(_, mu, logvar) = [x.numpy().T for x in latents]
(logits_2, rec_2), latents_2 = model(ds_metric_2['image'])
(_, mu_2, logvar_2) = [x.numpy().T for x in latents_2]

neuron_used = np.mean(np.exp(logvar), axis=1) <= 0.5 if params.sample else None
metrics, mi_mat_ = c_u.compute_mig(mu, ds_metric, mu_2, ds_metric_2, neuron_used=neuron_used,\
                                   dataset=params.dataset, remove_unused=False, compute_dci=False)
print(metrics)

In [None]:
tick_names = [x[6:] if ('value' in x or 'label' in x) else x for x in ds_metric.keys() if x not in ['image', 'target']]
mi_mat, entropy = mi_mat_
mi_mat_scaled = mi_mat / entropy[None, :]
plt.imshow(mi_mat_scaled, vmin=0.0, vmax=mi_mat_scaled.max(), cmap=cmap)
ax = plt.gca()
plt.xticks(np.arange(len(tick_names)))
ax.set_xticklabels(tick_names, rotation = 60, ha="right")
plt.xlabel('Factors', fontsize=fontsize)
plt.ylabel('Latents', fontsize=fontsize)
plt.savefig(save_path_vae + 'shapes3d_MI' + ".png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
sample_id = 2
latent = mu[:, sample_id]

plt.imshow(ds_metric['image'][sample_id,...].numpy())
plt.axis('off')
plt.savefig(save_path_vae + 'shapes3d_example_1' + ".png", dpi=300, bbox_inches='tight')

plt.figure()
plt.imshow(ds_metric['image'][sample_id+2,...].numpy())
plt.axis('off')
plt.savefig(save_path_vae + 'shapes3d_example_2' + ".png", dpi=300, bbox_inches='tight')

plt.show()

In [None]:
# latent traversals

import imageio

grid_points = 15
grid = np.linspace(-2 if not params.relu_latent_mu else 0,2,grid_points)
latent_all = []
for l_dim in range(params.latent_dim):
    mask = np.zeros((1,params.latent_dim))
    mask[0, l_dim] = 1.0
    for g_ in grid:
        latent_ = tf.identity(latent - mask * latent) + mask * g_
        latent_all.append(latent_)
latent_ = tf.concat(latent_all, axis=0)
_, image_pred = model.decode(latent_, apply_sigmoid=params.sigmoid_output)

plt.figure(figsize=(grid_points+1, params.latent_dim))
i=1
for l_dim in range(params.latent_dim):
    for g_point in range(grid_points+1):
        plt.subplot(params.latent_dim, grid_points+1, g_point + l_dim * (grid_points+1) + 1)
        if g_point == grid_points:
            _ = plt.hist(mu[l_dim,:], bins=np.linspace(-3,3,num=20)) 
        else:
            plt.imshow(image_pred[i-1,...])
            i+=1
        plt.axis('off')
        
plt.savefig(save_path_vae + 'shapes3d_latent_traversal' + ".png", dpi=300, bbox_inches='tight')

plt.show()

In [None]:
filenames = []
for g_point in range(grid_points):
    plt.figure(figsize=(params.latent_dim, 1))
    for l_dim in range(params.latent_dim):
        plt.subplot(1, params.latent_dim, l_dim + 1)
        plt.imshow(image_pred[l_dim * grid_points + g_point,...])
        plt.axis('off')
        
       # create file name and append it to a list
    filename = save_path_vae + f'/{g_point}.png'
    filenames.append(filename)

    # save frame
    plt.savefig(filename)
    plt.close()
        
# build gif
with imageio.get_writer(save_path_vae + '/latent_traversal.gif', mode='I') as writer:
    for filename in filenames:
        image = imageio.imread(filename)
        writer.append_data(image)

# Remove files
for filename in set(filenames):
    os.remove(filename)
    
# show gif
from IPython.display import Image, display
with open(save_path_vae + '/latent_traversal.gif','rb') as file:
    display(Image(file.read()))

## dsprites

In [None]:
base_path = path1

date = '2022-08-27'#'2022-08-24'
run = 9 #33
index = None

# Get directories for the requested run
run_path, train_path, model_path, save_path, script_path = p_f.set_directories(date, run, base_path=base_path)
if index == None:
    index = max([int(x.split('subspaces_')[1].split('.index')[0]) for x in os.listdir(model_path) if 'index' in x])
    print(index)
# Load model from file
model, params, stored_mu, stored_du = p_f.get_model(model_path + '/subspaces_' + str(index), script_path, save_path, use_old_scripts=True, model_type=model_type)

import parameters
par_new = parameters.default_params_subspaces_vae()
for key in par_new.keys():
    try:
        params[key]
    except:
        params[key] = par_new[key]
        
params.graph_mode = False
# params.dataset = 'dsprites'  #'shapes3d'
ds_batch, (ds_metric, ds_metric_2) = d_u.data_subspaces(params)
# run model
(logits, rec), latents = model(ds_metric['image'])
(_, mu, logvar) = [x.numpy().T for x in latents]
(logits_2, rec_2), latents_2 = model(ds_metric_2['image'])
(_, mu_2, logvar_2) = [x.numpy().T for x in latents_2]

neuron_used = np.mean(np.exp(logvar), axis=1) <= 0.5 if params.sample else None
metrics, mi_mat_ = c_u.compute_mig(mu, ds_metric, mu_2, ds_metric_2, neuron_used=neuron_used,\
                                   dataset=params.dataset, remove_unused=False, compute_dci=False)
print(metrics)

In [None]:
tick_names = [x[6:] if ('value' in x or 'label' in x) else x for x in ds_metric.keys() if x not in ['image', 'target']]
mi_mat, entropy = mi_mat_
mi_mat_scaled = mi_mat / entropy[None, :]
plt.imshow(mi_mat_scaled, vmin=0.0, vmax=mi_mat_scaled.max(), cmap=cmap)
#plt.colorbar()
ax = plt.gca()
plt.xticks(np.arange(len(tick_names)))
ax.set_xticklabels(tick_names, rotation = 60, ha="right")
plt.xlabel('Factors', fontsize=fontsize)
plt.ylabel('Latents', fontsize=fontsize)
plt.savefig(save_path_vae + 'dsprites_MI' + ".png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
sample_id = 7
latent = mu[:, sample_id]

plt.imshow(ds_metric['image'][sample_id,...].numpy(), cmap=cmap)
plt.axis('off')
plt.savefig(save_path_vae + 'dsprites_example_1' + ".png", dpi=300, bbox_inches='tight')

plt.figure()
plt.imshow(ds_metric['image'][sample_id+2,...].numpy(), cmap=cmap)
plt.axis('off')
plt.savefig(save_path_vae + 'dsprites_example_2' + ".png", dpi=300, bbox_inches='tight')

plt.show()

In [None]:
# latent traversals

import imageio

grid_points = 15
grid = np.linspace(-2 if not params.relu_latent_mu else 0,2,grid_points)
latent_all = []
for l_dim in range(params.latent_dim):
    mask = np.zeros((1,params.latent_dim))
    mask[0, l_dim] = 1.0
    for g_ in grid:
        latent_ = tf.identity(latent - mask * latent) + mask * g_
        latent_all.append(latent_)
latent_ = tf.concat(latent_all, axis=0)
_, image_pred = model.decode(latent_, apply_sigmoid=params.sigmoid_output)

plt.figure(figsize=(grid_points+1, params.latent_dim))
i=1
for l_dim in range(params.latent_dim):
    for g_point in range(grid_points+1):
        plt.subplot(params.latent_dim, grid_points+1, g_point + l_dim * (grid_points+1) + 1)
        if g_point == grid_points:
            _ = plt.hist(mu[l_dim,:], bins=np.linspace(-3,3,num=20)) 
        else:
            plt.imshow(image_pred[i-1,...])
            i+=1
        plt.axis('off')
        
plt.savefig(save_path_vae + 'dsprites_latent_traversal' + ".png", dpi=300, bbox_inches='tight')

plt.show()

In [None]:
filenames = []
for g_point in range(grid_points):
    plt.figure(figsize=(1, params.latent_dim))
    for l_dim in range(params.latent_dim):
        plt.subplot(params.latent_dim, 1, l_dim + 1)
        plt.imshow(image_pred[l_dim * grid_points + g_point,...])
        plt.axis('off')
        
       # create file name and append it to a list
    filename = save_path_vae + f'/{g_point}.png'
    filenames.append(filename)

    # save frame
    plt.savefig(filename)
    plt.close()
        
# build gif
with imageio.get_writer(save_path_vae + '/latent_traversal.gif', mode='I') as writer:
    for filename in filenames:
        image = imageio.imread(filename)
        writer.append_data(image)

# Remove files
for filename in set(filenames):
    os.remove(filename)
    
# show gif
from IPython.display import Image, display
with open(save_path_vae + '/latent_traversal.gif','rb') as file:
    display(Image(file.read()))

### Categories

In [None]:
base_path = path1

date = '2022-09-09'
run = 5
index = None

# Get directories for the requested run
run_path, train_path, model_path, save_path, script_path = p_f.set_directories(date, run, base_path=base_path)
if index == None:
    index = max([int(x.split('subspaces_')[1].split('.index')[0]) for x in os.listdir(model_path) if 'index' in x])
    print(index)
# Load model from file
model, params, stored_mu, stored_du = p_f.get_model(model_path + '/subspaces_' + str(index), script_path, save_path, use_old_scripts=True, model_type=model_type)

import parameters
par_new = parameters.default_params_subspaces_vae()
for key in par_new.keys():
    try:
        params[key]
    except:
        params[key] = par_new[key]
        
        # get data
params.graph_mode = False
# params.dataset = 'dsprites'  #'shapes3d'
ds_batch, (ds_metric, ds_metric_2) = d_u.data_subspaces(params)
# run model
(logits, rec), latents = model(ds_metric['image'])
(_, mu, logvar) = [x.numpy().T for x in latents]
(logits_2, rec_2), latents_2 = model(ds_metric_2['image'])
(_, mu_2, logvar_2) = [x.numpy().T for x in latents_2]

neuron_used = np.mean(np.exp(logvar), axis=1) <= 0.5 if params.sample else None
metrics, mi_mat_ = c_u.compute_mig(mu, ds_metric, mu_2, ds_metric_2, neuron_used=neuron_used,\
                                   dataset=params.dataset, remove_unused=False, compute_dci=False)
print(metrics)

In [None]:
tick_names = [x[6:] if ('value' in x or 'label' in x) else x for x in ds_metric.keys() if x not in ['image', 'target', 'input']]
mi_mat, entropy = mi_mat_
mi_mat_scaled = mi_mat / entropy[None, :]
plt.imshow(mi_mat_scaled, vmin=0.0, vmax=mi_mat_scaled.max(), cmap=cmap)
ax = plt.gca()
plt.xticks(np.arange(len(tick_names)))
ax.set_xticklabels(tick_names, rotation = 60, ha="right")
plt.xlabel('Factors', fontsize=fontsize)
plt.ylabel('Latents', fontsize=fontsize)
plt.savefig(save_path_vae + 'categories_MI' + ".png", dpi=300, bbox_inches='tight')
plt.show()

# Pattern Learning

In [None]:
model_folder_name = 'Pattern_Learning'
path1 = 'SOME PATH' + model_folder_name + '/'
path2 = 'OTHER PATH' + model_folder_name + '/path2/'
paths = [path2, path1]

In [None]:
data = {'pattern_learning_old': {},
        'pattern_learning': {}
       }

### Cosine curves

In [None]:
data['pattern_learning']['info'] = {
    'ReLu': {
        "2022-08-30": [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
        "2022-08-31": [13, 14, 15, 16, 17]
    },
    'No ReLu': {
        "2022-08-30": [5, 6, 7, 8, 9],
        "2022-08-31": [7, 9, 10, 11, 12],
    },
}
data['pattern_learning'] = p_f.get_tensorboard_df(data['pattern_learning'], paths=paths, min_steps=2000)

In [None]:
names = p_f.show_df_names(data['pattern_learning']['dfs'], to_print=False)

In [None]:
p_f.plot_train_curve(data['pattern_learning'], 'metrics/cosine', 'Cosine Distance',\
                 save_path_pattern + 'cosine', label_keep=None, ylim=(0,1.0), figsize=(4,4))

### Compare Firing maps around object locations

In [None]:
base_path = path2
index = None
results = {}
for experiment_type, experiment_info in data['pattern_learning']['info'] .items():
    print(experiment_type)
    results[experiment_type] = []
    for date, runs in experiment_info.items():
        for run in runs:
            params, g, inputs, model, info = p_f.get_pattern_data(date, run, index, base_path)
            object_all_type = p_f.get_object_surround(model, params)
            mean_corr, mean_corr_cells = p_f.get_mean_spatial_corrs(object_all_type, params)
            results[experiment_type].append(mean_corr)

In [None]:
for experiment_type, corrs in results.items():
    plt.scatter([experiment_type for _ in corrs], corrs, label=experiment_type)
    
ax = plt.gca()
plt.xticks('off')
plt.xlabel('Model / Data type', fontsize=fontsize)
plt.ylabel('Spatial Correlation Around Objects', fontsize=fontsize)
plt.legend(loc='center left', bbox_to_anchor=(0.15,0.3), fontsize=legendsize)
plt.savefig(save_path_pattern + 'object_correlations' + ".png", dpi=300, bbox_inches='tight')
plt.show()

### Relu + Factorised data (1)

In [None]:
date = '2022-08-31'
run = 1
index = None
base_path = path2

params, g, inputs, model, info = p_f.get_pattern_data(date, run, index, base_path)
info.save_path = save_path_pattern + 'ReLu_FactoredData_1_'

In [None]:
g_to_plot = g.numpy()

object_all_type = p_f.get_object_surround(model, params)
mean_corr, mean_corr_cells = p_f.get_mean_spatial_corrs(object_all_type, params)

In [None]:
"""
Grid score and scale
"""
import cell_analyses as ca
fit_ellipse = True
ring = True
torus = False

width = params.width
ent_or_mem = params.ent_dim

scores_all = []
for g_env in g_to_plot:
    module_analysis = []
    for i in range(ent_or_mem):
        if i%10 ==0:
            print(str(i), end=' ')

        # get cell
        cell = g_env[:, i]
        rate_map = np.reshape(cell, (params.height, params.width))
        auto = p_f.autocorr2d_no_nans(rate_map, torus=torus)
        auto[np.isnan(auto)]=0
        score, scale, theta = ca.grid_score_scale_analysis(auto, fit_ellipse=fit_ellipse, ring=ring)
        norm_firing = np.mean(cell**2)
        module_analysis.append([i, score, scale, theta, norm_firing])

    scores, scales, thetas, norm_firings = [], [], [], []
    for x in module_analysis:
        scores.append(x[1])
        scales.append(x[2])
        thetas.append(x[3])
        norm_firings.append(x[4])

    scores = np.asarray(scores)
    scales = np.asarray(scales)
    thetas = np.asarray(thetas)
    norm_firings = np.asarray(norm_firings)

    scores_all.append(scores)
scores_mean = np.nanmean(scores_all, axis=0)

In [None]:
p_f.plot_pattern_all_cells(inputs, g, params, info)       

In [None]:
change, spat_corr = p_f.plot_pattern_metrics(model, g, info)

In [None]:
fig, ax1 = plt.subplots(figsize=(4, 4))

ax2 = ax1.twinx()
ax1.scatter(spat_corr, scores_mean, c='b', marker="^", label='Grid score', alpha=0.5)
ax2.scatter(spat_corr, mean_corr_cells, c='g', marker='o', label='Object patch correlation', alpha=0.5)

ax1.set_xlabel('Spatial correlation', fontsize=fontsize)
ax1.set_ylabel('Grid score', color='b', fontsize=fontsize)
ax2.set_ylabel('Object patch correlation', color='g', fontsize=fontsize)

plt.savefig(save_path_pattern + 'spat_corr_vs_grid_score_and_patch_corr' + '_' + str(index), dpi=300, bbox_inches='tight')
plt.show()

### NoRelu + Factorised data (1)

In [None]:
date = '2022-07-21'
run = 6
index = None
base_path = path1

params, g, inputs, model, info = p_f.get_pattern_data(date, run, index, base_path)
info.save_path = save_path_pattern + 'NoReLu_FactoredData_1_'

In [None]:
p_f.plot_pattern_all_cells(inputs, g, params, info)      

In [None]:
p_f.plot_pattern_metrics(model, g, info)    

### Relu + Entangled Data

In [None]:
date = '2022-08-30'
run = 14
index = None
base_path = path2

params, g, inputs, model, info = p_f.get_pattern_data(date, run, index, base_path)
info.save_path = save_path_pattern + 'ReLu_EnatngledData_1_'

In [None]:
p_f.plot_pattern_all_cells(inputs, g, params, info)      

In [None]:
p_f.plot_pattern_metrics(model, g, info)