#### Imports

In [1]:
%matplotlib inline
import os 
import numpy as np
import matplotlib.pyplot as plt
import scipy
from scipy.io import loadmat

import tensorflow as tf
tf.config.run_functions_eagerly(True)
gpus = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[2], 'GPU')
tf.config.experimental.set_memory_growth(gpus[2], True)
from tensorflow.keras import Model

import matplotlib as mpl
import imp
mrsds = imp.load_source('mrsds', '/mnt/cup/people/orrenk/mrsds/mrsds-iclr/__init__.py')

from mrsds import mrsds_switching_svae as mrsds
from mrsds.mrsds_switching_svae import load_model
from mrsds.utils_analysis import get_msgs_and_norms, get_communication_models

#### Load mrsds model

In [2]:
base_dir = '/mnt/cup/people/orrenk/mrsds/mrsds-iclr/'
model_dir = '/scratch/orrenk/testing/mika-6s_12_20210929-k2-d2-s-svae-test-models'
config_path = base_dir + 'run-configs/mika-test.yaml'
num_regions = 3
num_dims = 2
region_sizes = [121, 159, 220]
trial_length = 184
num_states = 2

In [3]:
mrsds_model, xtran, _ = load_model(model_dir, config_path,
                                   num_regions, num_dims,
                                   region_sizes, trial_length,
                                   num_states, load=True)

# NOTE that in svae case that are two forms of the dynamics model
# both have the same weights, but process inputs of size:
# (batch, 1, latent_dims) and (batch, T, latent_dims) respectively.
# The first is used for single timepoint rollouts used for forming
# the structured posterior. The second is used to evaluate dynamics
# on multiple timepoints in parallel.
(xtran, xtran_time) = xtran

s-svae multistep
build model svae multitep mv
model seed 0


2024-06-03 18:05:58.782503: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-06-03 18:05:59.631075: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 4964 MB memory:  -> device: 2, name: GeForce RTX 2080 Ti, pci bus id: 0000:60:00.0, compute capability: 7.5


xnets [<keras.engine.functional.Functional object at 0x7f54487a8be0>, <keras.engine.functional.Functional object at 0x7f54486a7190>] [<keras.engine.functional.Functional object at 0x7f54487a8e20>, <keras.engine.functional.Functional object at 0x7f5448699a00>]
num_states 2
transformer input shape 10 2
[20, 10]
[(None, 184, 10), (None, 184, 10), (None, 184, 10)] 2
input shape (184, 32) 10 3 2
Model: "model_6"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_15 (InputLayer)           [(None, 184, 32)]    0                                            
__________________________________________________________________________________________________
layer_normalization (LayerNorma (None, 184, 32)      64          input_15[0][0]                   
______________________________________________________________________________________________

2024-06-03 18:06:09.005869: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.




#### Load data from saved latents file

In [4]:
fpath = '/scratch/orrenk/testing/mika-6s_12_20210929-k2-d2-s-svae-test_latents.mat'
dat = loadmat(fpath)

In [5]:
list(dat.keys())

['__header__',
 '__version__',
 '__globals__',
 'xs_train',
 'xs_test',
 'zs_train',
 'zs_test',
 'zs_logprob',
 'zs_logprob_test',
 'us_train',
 'us_test',
 'ys_train',
 'ys_test',
 'ys_recon_train',
 'ys_recon_test',
 'msgs_train',
 'msgs_test',
 'norms_train',
 'norms_test',
 'train_ids',
 'test_ids',
 'train_lengths',
 'test_lengths',
 'latent_region_sizes',
 'loga_train',
 'logb_train',
 'loga_test',
 'logb_test',
 'elbos',
 'lrs',
 'log_pxs',
 'log_pys',
 'log_qxs',
 'train_mses',
 'train_r2s',
 'train_gen1_mses',
 'train_gen1_r2s',
 'test_mses',
 'test_r2s',
 'test_gen1_mses',
 'test_gen1_r2s',
 'cosmooth_mses',
 'cosmooth_r2s',
 'cosmooth_gen1_mses',
 'cosmooth_gen1_r2s']

#### Run inference and get messages

In [6]:
# Inference
result_dict = mrsds_model(dat['ys_train'], dat['us_train'])

input_len = dat['us_train'].shape[-1]

# For padded trials of different lengths
trial_lengths_train = dat['train_lengths'].squeeze()

switching multistep
masks none. null mask: (309, 184, 500) (309, 184, 500)
inf net time 0.9694273471832275
inf time 0.9714515209197998
rollout time 10.333381414413452


2024-06-03 18:06:22.467750: I tensorflow/core/util/cuda_solvers.cc:180] Creating CudaSolver handles for stream 0x5654513a7830


vec map xlik time 0.09164214134216309
us passed to ztransition.
us passed to ztransition.
us passed to ztransition.
us passed to ztransition.
us passed to ztransition.
us passed to ztransition.
us passed to ztransition.
us passed to ztransition.
us passed to ztransition.
us passed to ztransition.
us passed to ztransition.
xlik overshoot time 0.9117112159729004
xlik time 1.1586964130401611
lxmeans 10 11
0
1
2
3
4
5
6
7
8
9
ylik time 0.2512197494506836
us passed to ztransition.
logpx_sum () tf.Tensor(-5980400.5, shape=(), dtype=float32)
logpxy (309,) tf.Tensor(-6033803000000.0, shape=(), dtype=float32)
elbo () tf.Tensor(-19526926000.0, shape=(), dtype=float32)


In [7]:
result_dict.keys()

dict_keys(['ys', 'us', 'zs_logprob', 'reconstructed_ys', 'log_py', 'logpy_sum', 'log_px', 'logpx_sum', 'log_py_cosmooth', 'z_posterior', 'z_posterior_ll', 'x_sampled', 'psi_sampled', 'xsample_prior', 'elbo', 'diffs_norm', 'sequence_likelihood', 'xt_entropy', 'log_a', 'log_b'])

In [8]:
xs = result_dict['x_sampled'].numpy().squeeze()
us = result_dict['us'].numpy()
zs_logprob = result_dict['z_posterior_ll'].numpy()
print(xs.shape, zs_logprob.shape)

(309, 184, 6) (309, 184, 2)


In [9]:
communication_models = get_communication_models(xtran_time,
                                                num_regions=num_regions,
                                                num_states=num_states,
                                                input_len=input_len)
latent_region_sizes = [num_dims]*num_regions
msgs, norms = get_msgs_and_norms(communication_models,
                                 xs, us, zs_logprob,
                                 num_states=num_states,
                                 num_regions=num_regions,
                                 latent_dim=num_dims,
                                 latent_region_sizes=latent_region_sizes,
                                 trial_lengths=trial_lengths_train,
                                 input_len=input_len)

In [10]:
# Messages are of size:
# (trials, time, num_regions, num_regions+num_inputs, latent_dim)
print(msgs.shape)

(309, 184, 3, 5, 2)
