# Disentangled representation of intention in macaque prefrontal cortex

Herein we train a model to yield a low-dimensional latent vector that encodes only the intended saccade target, and is disentangled from another low-dimensional timeseries that encodes simple dynamics related to target-agnostic task events.

## Resources

[Notes on beta-VAE in general](https://colab.research.google.com/github/SachsLab/IntracranialNeurophysDL/blob/master/notebooks/05_04_betaVAE_TFP.ipynb).

[Disentangled sequential autoencoders paper, by Li and Mandt, ICML 2018](https://arxiv.org/pdf/1803.02991.pdf), with an implementation in [TF Probability by google](https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/disentangled_vae.py) and in [pytorch by yatindandi/Disentangled-Sequential-Autoencoder](https://github.com/yatindandi/Disentangled-Sequential-Autoencoder).

Further extensions of this concept can be found [Swapping Autoencoder for Deep Image Manipulation by Park et al., 2020](https://arxiv.org/pdf/2007.00653.pdf) with [pytorch implementation](https://github.com/rosinality/swapping-autoencoder-pytorch).
* discriminator for real vs fake when keeping dynamic latent but swapping static latent from another trial.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import os
import sys


try:
    # See if we are running on google.colab
    from google.colab import files
    %tensorflow_version 2.x
    os.chdir('..')
    
    if not (Path.home() / '.kaggle').is_dir():
        # Configure kaggle
        uploaded = files.upload()  # Find the kaggle.json file in your ~/.kaggle directory.
        if 'kaggle.json' in uploaded.keys():
            !mkdir -p ~/.kaggle
            !mv kaggle.json ~/.kaggle/
            !chmod 600 ~/.kaggle/kaggle.json
            
    if Path.cwd().stem == 'MonkeyPFCSaccadeStudies':
        os.chdir(Path.cwd().parent)
    if not (Path.cwd() / 'MonkeyPFCSaccadeStudies').is_dir():
        !git clone --single-branch --recursive https://github.com/SachsLab/MonkeyPFCSaccadeStudies.git
        sys.path.append(str(Path.cwd() / 'MonkeyPFCSaccadeStudies'))
    os.chdir('MonkeyPFCSaccadeStudies')
        
    !pip install git+https://github.com/SachsLab/indl.git
    !pip install -q kaggle
    !pip install --upgrade tensorflow-probability
    IN_COLAB = True

except ModuleNotFoundError:
    IN_COLAB = False
    import sys
    if Path.cwd().stem == 'Analysis':
        os.chdir(Path.cwd().parent.parent)
    # Make sure the kaggle executable is on the PATH
    os.environ['PATH'] = os.environ['PATH'] + ';' + str(Path(sys.executable).parent / 'Scripts')

# Try to clear any logs from previous runs
if (Path.cwd() / 'logs').is_dir():
    import shutil
    try:
        shutil.rmtree(str(Path.cwd() / 'logs'))
    except PermissionError:
        print("Unable to remove logs directory.")

In [None]:
# Additional imports
from functools import partial
import math
import numpy as np
import random
import tensorflow as tf
import tensorflow.keras.layers as tfkl
import tensorflow_addons as tfa
from tensorflow.keras import backend as K
import tensorflow_probability as tfp
tfd = tfp.distributions
tfpl = tfp.layers

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from indl.display import turbo_cmap
if True:
    plt.style.use('dark_background')
else:
    plt.style.use('seaborn-poster')
plt.rcParams.update({
    'axes.titlesize': 24,
    'axes.labelsize': 20,
    'lines.linewidth': 1,
    'lines.markersize': 5,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 18,
    'figure.figsize': (8, 6.4)
})


## Download Data (if necessary)

In [None]:
if IN_COLAB:
    data_path = Path.cwd() / 'data' / 'monkey_pfc' / 'converted'
else:
    data_path = Path.cwd() / 'StudyLocationRule' / 'Data' / 'Preprocessed'

if not (data_path).is_dir():
    !kaggle datasets download --unzip --path {str(data_path)} cboulay/macaque-8a-spikes-rates-and-saccades
    print("Finished downloading and extracting data.")
else:
    print("Data directory found. Skipping download.")

## Get Data

We will use a custom function `load_macaque_pfc` to load the data into memory.

There are 4 different strings to be passed to the import `x_chunk` argument:
* 'analogsignals' - if present. Returns 1 kHz LFPs
* 'gaze'          - Returns 2-channel gaze data.
* 'spikerates'    - Returns smoothed spikerates
* 'spiketrains'

The `y_type` argument can be
* 'pair and choice' - returns Y as np.array of (target_pair, choice_within_pair)
* 'encoded input' - returns Y as np.array of shape (n_samples, 10) (explained below)
* 'replace with column name' - returns Y as a vector of per-trial values. e.g., 'sacClass'

The actual data we load depends on the particular analysis below.

In [None]:
from misc.misc import sess_infos, load_macaque_pfc

load_kwargs = {
    'valid_outcomes': (0, 9),  # Use (0, 9) to include trials with incorrect behaviour
    'zscore': False,
    'dprime_range': (-np.inf, np.inf),  # Use (-np.inf, np.inf) to include all trials.
    'time_range': (-np.inf, 1.35),  # np.inf),
    'verbose': True,
    'y_type': 'sacClass',
    'samples_last': False,
    'resample_X': 10
}

### Load Data

In [None]:
test_sess_ix = 1
sess_info = sess_infos[test_sess_ix]
sess_id = sess_info['exp_code']
print(f"\nImporting session {sess_id}")

# Different x_chunk values: 'analogsignals' (i.e. LFPs), 'spikerates', 'spiketrains', 'gaze'
# Rates...
#X_rates, Y_class, ax_info = load_macaque_pfc(data_path, sess_id, x_chunk='spikerates', **load_kwargs)
# Set baseline to 0. Makes reconstruction with relu easier.
#X_rates = X_rates - np.min(np.min(X_rates, axis=0, keepdims=True), axis=1, keepdims=True)
#X_rates = X_rates.astype(np.float32)

X_spikes, Y_class, ax_info = load_macaque_pfc(data_path, sess_id, x_chunk='spiketrains',
                                              **load_kwargs)
X_spikes = X_spikes * load_kwargs['resample_X']  # spks/msec -> spks/bin
X_spikes = X_spikes.astype(np.float32)
Y_class = tf.keras.utils.to_categorical(Y_class, num_classes=8)

In [None]:
plt.plot(X_spikes[1, :, ::6])

## Model

### Hyperparameters

In [None]:
USE_READIN = False
N_HIDDEN_STATIC = 128                # Number of RNN cells in static encoder network
LATENT_SIZE_STATIC = 64              # Size of static latent vector f | g_0
STATIC_LATENT_OFF_DIAG = False
DYNAMIC_GRAPH = 'full'               # 'none', 'factorized', 'full', 'controller'
N_HIDDEN_DYNAMIC = 12                # Number of RNN cells in dynamic encoder
LATENT_SIZE_DYNAMIC = 2              # Size of dynamic latent vector z_t | u_t
DYNAMIC_LATENT_OFF_DIAG = False
N_HIDDEN_GEN = 256                   # Number of RNN cells in generator
N_FACTORS = 36                       # Number of latent factors ?? | f_t
# NUM_RECONSTRUCTION_SAMPLES = 1
BATCH_SIZE = 16
NUM_SAMPLES = 4
RANDOM_SEED = 1337
N_EPOCHS = 150
MAX_GRAD_NORM = 200.0
DROPOUT_RATE = 0.025
L2_REG = 2e-5

### Helper Imports

Much of our model will be created by creating model blocks comprising multiple layers/transformations. Here we import some helper functions and classes for simplifying model creation.

The AutoShape classes are necessary because otherwise a block that is a tf.keras.Model subclass does not report its shape properly.

In [None]:
from indl.model.tfp import scale_shift
from indl.model import parts
from indl.model.autoshape_mixin import AutoShapeMixin


class BidirectionalAutoShape(AutoShapeMixin, tfkl.Bidirectional):
    pass
class GRUAutoShape(AutoShapeMixin, tfkl.GRU):
    pass
class DenseAutoShape(AutoShapeMixin, tfkl.Dense):
    pass
class DistLambdaAutoShape(AutoShapeMixin, tfpl.DistributionLambda):
    pass
class GRUCellAutoShape(AutoShapeMixin, tfkl.GRUCell):
    pass
class DropoutAutoShape(AutoShapeMixin, tfkl.Dropout):
    pass
class WeightNormAutoShape(AutoShapeMixin, tfa.layers.WeightNormalization):
    pass

In [None]:
# from indl.data.augmentations import random_slice
ds = tf.data.Dataset.from_tensor_slices((X_spikes, Y_class))
# Any augmentations. e.g., random slicing.
# p_random_slice = partial(random_slice, max_offset=3, axis=0)
# ds = ds.map(p_random_slice)
ds = ds.shuffle(X_spikes.shape[0] + 1)
ds = ds.batch(BATCH_SIZE, drop_remainder=True)
print(ds.element_spec)

input_shape = ds.element_spec[0].shape.as_list()
input_shape[0] = None  # Batch dim
input_shape = tuple(input_shape)
n_times, n_sensors = input_shape[-2:]

### Read-In Feature Extraction (Optional)

This optional part of the model does some mild feature extraction on the input data. The intention is to transform the data into a common dimensionality and space.

Input shape: `(batch, samples, channels)`

Output shape: `(batch, samples//pooling, n_kernels*depth_multiplier)`

In [None]:
def ReadIn(input_shape, n_kernels=6, kern_length=25, depth_multiplier=2, activation=tf.nn.leaky_relu, pooling=5, dropout_rate=0.25):
    n_times, n_sensors = input_shape[-2:]
    return tf.keras.Sequential([
        tfkl.Input(shape=input_shape[-2:]),
        tfkl.Reshape(input_shape[-2:] + (1,)),
        tfkl.Conv2D(n_kernels, (kern_length, 1), padding='same', use_bias=False, name="temporal_filter"),
        #     tfkl.BatchNormalization(name="temporal_filter_bnorm"),
        tfkl.DepthwiseConv2D((1, n_sensors), padding='valid',
                             depth_multiplier=depth_multiplier, use_bias=False, name="spatial_filter"),
        #     tfkl.BatchNormalization(name="spatial_filter_bnorm"),
        tfkl.Activation(activation),
        tfkl.AveragePooling2D((pooling, 1), name="temporal_smoothing"),
        #     tfkl.Dropout(dropout_rate),
        tfkl.Reshape((n_times // pooling, n_kernels * depth_multiplier))],
        name="read_in")

K.clear_session()
read_in = ReadIn(input_shape)
read_in.summary()

# temp_input = tf.random.uniform(ds.element_spec[0].shape)
# tmp = read_in(temp_input)
# print(tmp)

### Encoder
Posterior distributions `q` of latents $f$ (static) and $z_t$ (dynamic) are conditioned on input sequence $x_t$.

Latent postierior, Static only:
$$q(f | x_{1:T})$$

Latent posterior, Static and Dynamic Factorized:
$$q(z_{1:T}, f | x_{1:T}) = q(f | x_{1:T}) \prod_{t=1}^T q(z_t | x_t)$$

Latent posterior, Static and Dynamic Full:
$$q(z_{1:T}, f | x_{1:T}) = q(f | x_{1:T}) q(z_{1:T} | f, x_{1:T})$$
Note that _q(z)_ depends on _f_

#### Static Encoder

Transform full sequence of "features" (`inputs` or `ReadIn(inputs)`) through (1) bidirectional LSTM then (2) affine to yield parameters of static latent posterior distribution:
$$q(f | x_{1:T})$$
This distribution is a multivariate normal, optionally with off-diagonal elements allowed.

Model loss will include the KL divergence between the static latent posterior and a prior; the prior is a learnable multivariate normal diagonal. The prior is initialized with a mean of 0 and a stddev of 1 but these are trainable by default.

See [this notebook](https://colab.research.google.com/github/SachsLab/IntracranialNeurophysDL/blob/master/notebooks/05_04_betaVAE_TFP.ipynb) under the section "**Define the latent prior**" for a discussion on the merits of allowing off-diagonal elements on the prior.

In [None]:
from indl.model.tfp import LearnableMultivariateNormalDiag  # For prior
from indl.model.tfp import make_mvn_prior  # , make_mvn_dist_fn


scale_shift = np.log(np.exp(1) - 1).astype(np.float32)


class StaticEncoder(AutoShapeMixin, tf.keras.Model):
    def __init__(self, units=64, latent_size=32, dropout_rate=DROPOUT_RATE, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.latent_size = latent_size
        
        # Model layers parameterizations do not depend on input shape so we can initialize
        #  them here instead of .build()
        self.dropout = DropoutAutoShape(dropout_rate)
        # input --> bidirectional GRU --> loc & scale --> MVN distribution
        self.static_latent_rnn = BidirectionalAutoShape(
            GRUAutoShape(self.units, return_sequences=False),
            merge_mode="concat", name="static_latent_rnn")
        self.static_latent_loc = DenseAutoShape(self.latent_size, name="static_latent_loc")
        if STATIC_LATENT_OFF_DIAG:
            self.static_latent_scale = DenseAutoShape(
                tfpl.MultivariateNormalTriL.params_size(self.latent_size) - self.latent_size,
                name="static_latent_scale")
            self.shift_scale = tfp.bijectors.FillScaleTriL()
            self.static_latent_posterior = DistLambdaAutoShape(
                make_distribution_fn=lambda t: tfd.MultivariateNormalTriL(loc=t[0], scale_tril=t[1]),
#                 convert_to_tensor_fn=lambda s: s.sample(n_samples),
                name="static_latent_posterior")
        else:
            self.static_latent_scale = DenseAutoShape(
                tfpl.IndependentNormal.params_size(self.latent_size) - self.latent_size,
                name="static_latent_scale")
            self.shift_scale = lambda x: tf.math.softplus(x + scale_shift) + 1e-5
            self.static_latent_posterior = DistLambdaAutoShape(
                make_distribution_fn=lambda t: tfd.MultivariateNormalDiag(loc=t[0], scale_diag=t[1]),
#                 convert_to_tensor_fn=lambda s: s.sample(n_samples),
                name="static_latent_posterior")
            
        # Define the static prior
        #if STATIC_PRIOR_OFF_DIAG: self.prior = make_mvn_prior(latent_size, trainable=False)
        # TODO: prior variance constant kappa=0.1
        self.static_prior_factory = LearnableMultivariateNormalDiag(self.latent_size)
        self.static_prior_factory.build(input_shape=(0,))
        
    def call(self, inputs, training=None):
        _inputs = self.dropout(inputs, training=training)
        _hidden_state = self.static_latent_rnn(inputs)
        _loc = self.static_latent_loc(_hidden_state)
        _scale = self.static_latent_scale(_loc)
        _shifted_scale = self.shift_scale(_scale)
        _q_f = self.static_latent_posterior([_loc, _shifted_scale])
        return _q_f   
    
K.clear_session()

static_encoder = StaticEncoder(units=N_HIDDEN_STATIC, latent_size=LATENT_SIZE_STATIC)

features_shape = (BATCH_SIZE,) + (read_in.output_shape[1:] if USE_READIN else input_shape[1:])
print(f"Input shape: {(None,) + features_shape[1:]}")
dummy_latent = static_encoder(tf.random.uniform(features_shape))
static_encoder.summary()

#### Dynamic Encoder ($z_t$)

Input features are transformed through (1) a bidirectional LSTM (`return_sequences=True`), (2) then RNN, and (3) a pair of affines to yield the parameters of the dynamic posterior latent distribution:

$q(z_t | x_{1:T})$

This is a multivariate normal distribution **at each timestep**, conditioned on features from $x_t$ **and optionally concatenated with static latent factors $f$** in the full not-factorized model.

##### Compared to LFADS

The [LFADS model](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6380887/pdf/nihms-1500948.pdf) 'initial condition' encoder is similar to this model's StaticEncoder.
The LFADS model also has an 'input controller' module, whereas here we have a DynamicEncoder. They are similar in some respects, but also different.

We have a diagram that outlines LFADS architecture as well as the present model, highlighting the differences, found in this study folder `/Output/Figures/DEAE_vs_LFADS.svg`.

* [LFADS source in tensorflow](https://github.com/tensorflow/models/tree/master/research/lfads)
* [LFADS in JAX](https://github.com/google-research/computation-thru-dynamics/tree/master/lfads_tutorial)
* [hierarchical LFADS in pytorch](https://github.com/lyprince/hierarchical_lfads)

Other differences:

* LFADS uses in-cell value clipping in the bidirectional GRU
* LFADS' l2-regularization coefficient is on a scheduler


$$y = \hat{w} x$$, where $$\hat{w}_ij = w_ij / |w_{i:}$$

##### Dynamic Encoder - Factorized

The full factorized encoder model is thus:
$$q(z_{1:T}, f | x_{1:T}) = q(f | x_{1:T}) \prod_{t=1}^T q(z_t | x_t)$$

Model loss will include the KL divergence between dynamic latent posterior distribution and the prior distribution; the prior is a multivariate normal diagonal with the same shape as the dynamic posterior -- i.e. one multivariate (LATENT_SIZE_DYNAMIC) distribution per timestep (n_timesteps).

In [None]:
from indl.model.tfp import LearnableMultivariateNormalDiagCell  # For prior


class DynamicEncoder(tf.keras.Model):
    """
    Probabilistic encoder for the time-variant latent variable `z_t`.

    The conditional distribution `q(z_t | x_t)` is a multivariate normal
    distribution on `R^{latent_size}` at each timestep `t`, conditioned on
    a representation of `x_t` (optionally processed from ReadIn).
    The parameters are computed by a one-hidden layer neural net.

    In this formulation, we posit that the dynamic latent variable `z_t`
    is independent of static latent variable `f`.
    """
    def __init__(self, hidden_size=16, latent_size=4, factorized=True,
                 l2_reg=L2_REG, name="dynamic_encoder"):
        super().__init__(name=name)
        self.hidden_size = hidden_size  # Number of units hidden layer and in dynamic_prior_cell
        self.latent_size = latent_size  # Dimensionality of latent posterior
        self.factorized = factorized
        # The dynamic prior - an LSTMCell with learnable params, and learnable Dense layers
        #  to generate a MVNDiag for each sequence timestep.
        self.dynamic_prior_cell = LearnableMultivariateNormalDiagCell(self.hidden_size, self.latent_size,
                                                                      cell_type='gru')
        if self.factorized:
            self.hidden_layer = DenseAutoShape(self.hidden_size, activation=tf.nn.leaky_relu,
                                               name="dyn_hidden")
            self.dynamic_latent_rnn1 = self.dynamic_latent_rnn2 = None
        else:
            self.hidden_layer = None
            self.dynamic_latent_rnn1 = BidirectionalAutoShape(
                GRUAutoShape(self.hidden_size, return_sequences=True,
                             recurrent_regularizer=tf.keras.regularizers.l2(l=l2_reg)),
                merge_mode="sum", name="dyn_hidden1")
            self.dynamic_latent_rnn2 = tfkl.GRU(self.hidden_size, return_sequences=True,
                                                recurrent_regularizer=tf.keras.regularizers.l2(l=l2_reg),
                                                name="dyn_hidden2")
            # can't use GRUAutoShape here because mismatch between input_spec being list / single item.
        
        self.loc = DenseAutoShape(self.latent_size, name="dyn_loc")
        self.unxf_scale = DenseAutoShape(self.latent_size, name="dyn_scale")
        # if DYNAMIC_LATENT_OFF_DIAG: ??? else:
        self.q_z_layer = DistLambdaAutoShape(
            make_distribution_fn=lambda t: tfd.MultivariateNormalDiag(loc=t[0], scale_diag=t[1]),
#             convert_to_tensor_fn=lambda s: s.sample(n_samples),
            name="q_z"
        )
        
    def build(self, input_shapes):
        static_shape, features_shape = input_shapes
        self.n_times = features_shape[-2]
        # We can't .build our prior because its .call requires 2 inputs (sample, state)
        # so instead we call the cell with its zero-state, effectively forcing it to build.
        sample_batch_shape = (1,) + features_shape[1:-2]
        sample0, state0 = self.dynamic_prior_cell.zero_state(sample_batch_shape)
        self.dynamic_prior_cell(sample0, state0)
#         super().build(input_shapes)
        
    def call(self, inputs):
        static_sample, features = inputs
        if self.factorized:
            _hidden = self.hidden_layer(features)
        else:
            # We explicitly broadcast `x` and `f` to the same shape other than the final
            # dimension, because `tf.concat` can't automatically do this. This will
            # entail adding a `timesteps` dimension to `f` to give the shape `(...,
            # batch, timesteps, latent)`, and then broadcasting the sample shapes of
            # both tensors to the same shape.
            timesteps = tf.shape(input=features)[-2]
            static_sample = static_sample[..., tf.newaxis, :] + tf.zeros([timesteps, 1])
            sample_shape_static = tf.shape(input=static_sample)[:-3]
            sample_shape_features = tf.shape(input=features)[:-3]
            broadcast_shape_features = tf.concat((sample_shape_static, [1, 1, 1]), 0)
            broadcast_shape_static = tf.concat((sample_shape_features, [1, 1, 1]), 0)
            features = features + tf.zeros(broadcast_shape_features)
            static_sample = static_sample + tf.zeros(broadcast_shape_static)
            # `combined` will have shape (..., batch, timesteps, hidden+latent).
            combined = tf.concat((features, static_sample), axis=-1)
            collapsed_shape = tf.concat(([-1], tf.shape(input=combined)[-2:]), axis=0)
            combined = tf.reshape(combined, collapsed_shape)
            _hidden = self.dynamic_latent_rnn1(combined)
            _hidden = self.dynamic_latent_rnn2(_hidden)
            expanded_shape = tf.concat((tf.shape(input=combined)[:-2],
                                        tf.shape(input=_hidden)[1:]), axis=0)
            _hidden = tf.reshape(_hidden, expanded_shape)  # (sample, batch, T, hidden_size)
        loc = self.loc(_hidden)
        unxf_scale = self.unxf_scale(_hidden)
        scale = tf.math.softplus(unxf_scale + scale_shift) + 1e-5
        q_z = self.q_z_layer([loc, scale])
        return q_z
    
    def call_full(self, inputs):
        raise NotImplementedError("Just keeping this code here for later reference. Ignore for now.")
        # _features needs to be repeated NUM_SAMPLES on a new samples axis at axis=0.
        _x2 = _features[tf.newaxis, ...] + tf.zeros([NUM_SAMPLES, 1, 1, 1])
        # Concatenate _x2 (features) and _static_sample
        _x2 = tfkl.Concatenate()([_x2, _static_sample])  # (samples, batch, timesteps, feat_dim+latent_static)
        # Collapse samples + batch dims  -- required by LSTM
        _x2 = tf.reshape(_x2, [-1] + _x2.shape.as_list()[-2:])  # (samples*batch, T, feat+lat_stat)
        # Run _x2 through bidirectional lstm then a simple RNN,
        # then use output to parameterize distribution over latent variable z_t.
        _x2 = tfkl.Bidirectional(
            tfkl.GRU(self.hidden_size, return_sequences=True),
            merge_mode="sum")(_x2)
        _x2 = tfkl.GRU(self.hidden_size, return_sequences=True)(_x2)
        # Restore samples dim?
        _x2 = tf.reshape(_x2, [NUM_SAMPLES, -1, n_timesteps, self.hidden_size])
    
    def sample_dynamic_prior(self, timesteps, samples=1, batches=1, fixed=False):
        """
        Samples from self.dynamic_prior_cell `timesteps` times.
        On each step, the previous (sample, state) is fed back into the cell
        (zero_state used for 0th step).
        
        The cell returns a multivariate normal diagonal distribution for each timestep.
        We collect each timestep-dist's params (loc and scale), then use them to create
        the return value: a single MVN diag dist that has a dimension for timesteps.
        
        The cell returns a full dist for each timestep so that we can 'sample' it.
        If our sample size is 1, and our cell is an RNN cell, then this is roughly equivalent
        to doing a generative RNN (init state = zeros, return_sequences=True) then passing
        those values through a pair of Dense layers to parameterize a single MVNDiag.
        
        :param timesteps: Number of timesteps to sample for each sequence.
        :param samples: Number of samples to draw from the latent distribution.
        :param batches: Number of sequences to sample.
        :param fixed: Boolean for whether or not to share the same random
            sample across all sequences in batch.
        """
        if fixed:
            sample_batch_size = 1
        else:
            sample_batch_size = batches

        sample, state = self.dynamic_prior_cell.zero_state([samples, sample_batch_size])
        locs = []
        scale_diags = []
        sample_list = []
        for _ in range(timesteps):
            dist, state = self.dynamic_prior_cell(sample, state)
            sample = dist.sample()
            locs.append(dist.parameters["loc"])
            scale_diags.append(dist.parameters["scale_diag"])
            sample_list.append(sample)

        sample = tf.stack(sample_list, axis=2)
        loc = tf.stack(locs, axis=2)
        scale_diag = tf.stack(scale_diags, axis=2)

        if fixed:  # tile along the batch axis
            sample = sample + tf.zeros([batches, 1, 1])

        return sample, tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale_diag)
        # TODO: Move 1 of the batch dims into event dims
    

K.clear_session()
dynamic_encoder = DynamicEncoder(hidden_size=N_HIDDEN_DYNAMIC, latent_size=LATENT_SIZE_DYNAMIC,
                                 factorized=DYNAMIC_GRAPH == 'factorized')

dummy_static = tf.random.uniform((BATCH_SIZE, LATENT_SIZE_STATIC))  # or None if factorized
dummy_features_shape = (BATCH_SIZE,) + (read_in.output_shape[1:] if USE_READIN else input_shape[1:])
dummy_features = tf.random.uniform(features_shape)
dynamic_encoder((dummy_static, dummy_features))
dynamic_encoder.summary()


dyn_prior_samp, dyn_prior = dynamic_encoder.sample_dynamic_prior(
    dummy_features.shape[-2], samples=1, batches=1)
print("dynamic prior: ", dyn_prior)

In [None]:
dynamic_prior_cell = LearnableMultivariateNormalDiagCell(3, 4, cell_type='gru')
sample, state = dynamic_prior_cell.zero_state([1, 1])
locs = []
scale_diags = []
sample_list = []
for _ in range(161):
    dist, state = dynamic_prior_cell(sample, state)
    sample = dist.sample()
    locs.append(dist.parameters["loc"])
    scale_diags.append(dist.parameters["scale_diag"])
    sample_list.append(sample)

In [None]:
tf.stack(sample_list, axis=2)

### Latents to Factors

From latent distributions (f and z_t) to factors.
Both latent distributions are assumed sampled before inputting samples to this module.
The static sample (f) goes through an affine then gives the initial condition for a generative RNN; the dynamic sample (z_t) provides the input to the generative RNN. z_t is generally much lower dimension than f.

The RNN evolves (feeding its own output to its input on the next step). In the end it returns the sequence of states. The states are then transformed through a linear layer to give latent factors.

In [None]:
from indl.model.recurrent import GenerativeRNN


class GenerateFactors(AutoShapeMixin, tf.keras.Model):
    """
    Probabilistic decoder for `p(x_t | z_t, f)`.

    The decoder generates a sequence of multi-sensor frames `x_{1:T}` from
    dynamic and static latent variables `z_{1:T}` and `f`, respectively,
    for timesteps `1:T`.
    """
    def __init__(self, factors=8, units=64, combine_latents='state_and_input',
                 dropout_rate=DROPOUT_RATE, l2_reg=L2_REG, **kwargs):
        """
        :param combine_latents: 'tile_concat' or 'state_and_input'.
        """
        super().__init__(**kwargs)
        assert combine_latents in ['state_and_input', 'tile_concat']
        self.combine_latents = combine_latents
        if self.combine_latents == 'state_and_input':
            self.static_to_state0 = DenseAutoShape(units, name="static_to_state0")
        else:
            self.static_to_state0 = None
        #self.gen_rnn = GenerativeRNN(GRUCellAutoShape(units, use_bias=False),
        #                             timesteps=timesteps,
        #                             return_sequences=True,
        #                             name="gen_rnn")
        self.gen_rnn = tfkl.GRU(units, return_sequences=True,
                                recurrent_regularizer=tf.keras.regularizers.l2(l=l2_reg),
                                name="gen_rnn")
        self.dropout = DropoutAutoShape(dropout_rate)
#         self.gen_factors = WeightNormAutoShape(
#             tfkl.Dense(factors), name="gen_factors_normed")
        self.gen_factors = DenseAutoShape(factors, name="gen_factors")
        
    def call(self, inputs, training=None):
        """
        inputs is a tuple (static, dynamic)
        If dynamic is not used, then pass in tf.zeros, with dims (batch_size, timesteps, 1)
        """
        static, dynamic = inputs
        if self.combine_latents == 'state_and_input':
            # Option 1 - static is init state, dynamic is inputs
            _init_state = self.static_to_state0(static)
            _gen_seq = self.gen_rnn(inputs=dynamic, initial_state=_init_state)
        else:  # tile_concat
            # Option 2 - tile static, concat with dynamic, feed both as inputs to GRU
            dyn_steps = tf.shape(input=dynamic)[-2]
            static = static[..., tf.newaxis, :] + tf.zeros([dyn_steps, 1])
            latents = tf.concat([dynamic, static], axis=-1)
            _gen_seq = self.gen_rnn(inputs=latents, initial_state=None)
        _gen_seq = self.dropout(_gen_seq, training=training)
        _factors = self.gen_factors(_gen_seq)
        return _factors
        

K.clear_session()
factor_times = n_times // read_in.get_layer("temporal_smoothing").pool_size[0] if USE_READIN else n_times
n_factors = N_FACTORS if not USE_READIN else\
    read_in.get_layer("temporal_filter").filters * read_in.get_layer("spatial_filter").depth_multiplier

gen_fac = GenerateFactors(factors=n_factors, units=N_HIDDEN_GEN,
                          combine_latents='state_and_input')

dummy_static = tf.random.uniform((BATCH_SIZE, LATENT_SIZE_STATIC))
if DYNAMIC_GRAPH != 'none':
    dummy_dynamic = tf.random.uniform((BATCH_SIZE, factor_times, LATENT_SIZE_DYNAMIC))
else:
    dummy_dynamic = tf.random.uniform((BATCH_SIZE, factor_times, 1))
generated_factors = gen_fac((dummy_static, dummy_dynamic))
gen_fac.summary()

## Read-Out: Factors to Reconstructed Features

Optional. To be used when ReadIn is used.

TODO: Currently broken.

In [None]:
def ReadOut(input_shape, out_time, out_space,
            n_kernels=6, kern_length=25, pooling=5, factors=10, units=32, name="readout"):
    
    n_samps = n_times * pooling - kern_length + 1
    req_padding = max(0, out_time - n_samps)
    req_padding = (int(math.ceil(req_padding / 2)), req_padding // 2)  # left, right
    return tf.keras.Sequential([
        tfkl.Input(shape=input_shape),
        tfkl.Reshape(input_shape[:-1] + (1,) + input_shape[-1:]),
        tfkl.UpSampling2D(size=(1, out_space)),
        tfkl.DepthwiseConv2D(kernel_size=(1, out_space), padding='same', depth_multiplier=1),
        tfkl.UpSampling2D(size=(pooling, 1)),
        tfkl.SeparableConv2D(n_kernels, (kern_length, 1)),
        tfkl.ZeroPadding2D(padding=(req_padding, 0)),
        tfkl.Conv2D(1, (out_time, 1), padding='same'),
        tfkl.Lambda(lambda x: x[..., 0])],
        name=name)


K.clear_session()
read_out = ReadOut((factor_times, n_factors), input_shape[-2], input_shape[-1])
read_out.summary()

## Reconstruct Spike Trains

From factors or reconstructed features to rates.
The rates then parameterize log_rates of poisson distributions.

In [None]:
class OutDist(tf.keras.Model):
    """
    Simple inputs --> Dense --> Poisson distribution
    """
    # Because the output is a tfd object and not a Tensor,
    # we cannot use tf.keras.Sequential(list_of_layers) nor tf.keras.Model(inputs, outputs).
    # The output must be a distribution so we can use logprob for cost.
    
    def __init__(self, out_space,
                 name='outdist', **kwargs):
        super().__init__(name=name, **kwargs)
        self.out_space = out_space
        self.to_log_rates = DenseAutoShape(self.out_space, name="log_rate")
        self.q_z_layer = DistLambdaAutoShape(
            make_distribution_fn=lambda t: tfd.Poisson(rate=tf.exp(t)),
            name="p_out"
        )
        
    def call(self, inputs):
        # Generate output distribution
        log_rates = self.to_log_rates(inputs)
        p_out = self.q_z_layer(log_rates)
        
        # Move the time dimension from batch to event.
        p_out = tfd.Independent(p_out, reinterpreted_batch_ndims=2)
        
        return p_out
    
K.clear_session()
factor_times = n_times // read_in.get_layer("temporal_smoothing").pool_size[0] if USE_READIN else n_times
n_factors = N_FACTORS if not USE_READIN else\
    read_in.get_layer("temporal_filter").filters * read_in.get_layer("spatial_filter").depth_multiplier
out_dist = OutDist(n_sensors)
tmp_fac = tf.random.uniform((BATCH_SIZE, factor_times, n_factors))
tmp_out = out_dist(tmp_fac)
out_dist.summary()

### Full AutoEncoder

In [None]:
class AutoEncoder(tf.keras.Model):
    def __init__(self,
                 readin_kernels=6 if USE_READIN else 0, readin_kern_length=25,
                 readin_depth_multiplier=2, readin_pooling=5,
                 static_units=N_HIDDEN_STATIC, static_latent_size=LATENT_SIZE_STATIC,
                 dynamic_graph=DYNAMIC_GRAPH,
                 dynamic_units=N_HIDDEN_DYNAMIC, dynamic_latent_size=LATENT_SIZE_DYNAMIC,
                 gen_units=N_HIDDEN_GEN, gen_combine_latents='state_and_input',
                 n_factors=N_FACTORS,
                 dropout_rate=DROPOUT_RATE, l2_reg=L2_REG,
                 name='autoencoder', **kwargs):
        super().__init__(name=name, **kwargs)
        self.readin_kernels = readin_kernels
        self.readin_kern_length = readin_kern_length
        self.readin_depth_multiplier = readin_depth_multiplier
        self.readin_pooling = readin_pooling
        
        self.input_dropout = DropoutAutoShape(dropout_rate)
        self.static_encoder = StaticEncoder(units=static_units,
                                            latent_size=static_latent_size,
                                            dropout_rate=dropout_rate)
        self.dynamic_graph = dynamic_graph
        if self.dynamic_graph in ['factorized', 'full']:
            self.dynamic_encoder = DynamicEncoder(hidden_size=dynamic_units,
                                                  latent_size=dynamic_latent_size,
                                                  factorized=dynamic_graph == 'factorized',
                                                  l2_reg=l2_reg)
        else:
            self.dynamic_encoder = None
            
        self.gen_units = gen_units
        self.n_factors = self.readin_n_kernels * self.readin_depth_multiplier if USE_READIN else n_factors
        self.gen_combine_latents = gen_combine_latents
        self.gen_facs = GenerateFactors(factors=self.n_factors, units=self.gen_units,
                                        combine_latents=self.gen_combine_latents,
                                        dropout_rate=dropout_rate, l2_reg=l2_reg)
        
        if self.readin_kernels > 0:
            # TODO: Currently broken
            self.read_out = ReadOut(n_kernels=self.readin_kernels,
                                    kern_length=self.readin_kern_length,
                                    pooling=self.readin_pooling)
            
    def build(self, input_shape):
        n_times, n_sensors = input_shape[-2:]
        
        if self.readin_kernels > 0:
            self.readin = ReadIn(input_shape,
                                 n_kernels=self.readin_kernels,
                                 pooling=self.readin_pooling,
                                 depth_multiplier=self.readin_depth_multiplier)
            
        self.out_dist = OutDist(n_sensors)
        super().build(input_shape)
        
    def call(self, inputs, return_intermediates=False, training=None):
        inputs = self.input_dropout(inputs, training=training)
        # TODO: randomly set 30% of inputs to 0. Save the mask.
        
        features = self.readin(inputs) if self.readin_kernels > 0 else inputs
        q_f = self.static_encoder(features, training=training)
        f_sample = tf.convert_to_tensor(q_f)
        if self.dynamic_graph != 'none':
            q_z = self.dynamic_encoder([f_sample, features])
            z_sample = tf.convert_to_tensor(q_z)  # Might create sample dim
        else:
            q_z = None
            dummy_dynamic_sample_shape = f_sample.shape[:-1] + (features.shape[-2], 1)
            z_sample = tf.zeros(dummy_dynamic_sample_shape)
        facs = self.gen_facs((f_sample, z_sample), training=training)
        log_rates = self.read_out(facs) if self.readin_kernels > 0 else facs
        p_full = self.out_dist(log_rates)
        
        if not return_intermediates:
            return p_full
        return p_full, features, q_f, q_z, facs, log_rates, p_full
    
    def train_step(self, data):
        print(data)
        inputs, preds = data
        with tf.GradientTape() as tape:
            # TODO: coordinated dropout mask cd_mask on 30% of samples; set to 0
            # Probably unnecessary while generated rates are so smooth.
            
            p_full, features, q_f, q_z, facs, log_rates, p_full = self(inputs,
                                                                       return_intermediates=True,
                                                                       training=True)
            
            # TODO: Do not allow BPTT through ~cd_mask samples.

            # Reconstruction log-likelihood: p(output|input).
            recon_post_log_prob = p_full.log_prob(inputs)

            # Not necessary to sum over time axis because event shape is (time, space)
            # recon_post_log_prob = tf.reduce_sum(recon_post_log_prob, axis=-1)

            # KL Divergence - analytical
            # Static
            static_prior = self.static_encoder.static_prior_factory()
            stat_kl = tfd.kl_divergence(q_f, static_prior)

            # Dynamic
            if self.dynamic_graph != 'none':
                _, dynamic_prior = self.dynamic_encoder.sample_dynamic_prior(
                    inputs.shape[-2], samples=1, batches=1
                )
                dyn_kl = tfd.kl_divergence(q_z, dynamic_prior)
                # TODO: Check if necessary (maybe q_z needs batch dim reinterp)
                dyn_kl = tf.reduce_mean(dyn_kl, axis=-1)
                dyn_kl = tf.squeeze(dyn_kl)
            else:
                dyn_kl = tf.zeros(stat_kl.shape)

            elbo = recon_post_log_prob - kl_beta * (stat_kl + dyn_kl)
            elbo = tf.reduce_mean(input_tensor=elbo)
            l2_loss = tf.reduce_sum(self.losses)
            loss = -elbo + l2_loss
            
        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        
        # clip gradients
#         gradients, _ = tf.clip_by_global_norm(gradients, MAX_GRAD_NORM)
        
        # TODO: more l2?
        #  -with scheduler!
        
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Return dictionary of metrics:
        return {
            'neg.log.like': tf.reduce_mean(-recon_post_log_prob),
            'kl_beta': kl_beta,
            'static KL': tf.reduce_mean(stat_kl),
            'dynamic KL': tf.reduce_mean(dyn_kl),
            'l2 loss': l2_loss
        }

In [None]:
K.clear_session()
K.set_floatx('float32')
tf.random.set_seed(RANDOM_SEED)
ae_model = AutoEncoder()
dummy_input = tf.random.uniform((BATCH_SIZE,) + input_shape[1:])
print(f"dummy_input.shape: {dummy_input.shape}")
dummy_output = ae_model(dummy_input)
print(f"dummy_output: {dummy_output}")
ae_model.static_encoder.summary()
if ae_model.dynamic_graph != 'none':
    ae_model.dynamic_encoder.summary()
ae_model.gen_facs.summary()
ae_model.out_dist.summary()
ae_model.summary()
# Visualize: https://github.com/lutzroeder/netron

## Training

### KL Beta Cycling

In [None]:
K.clear_session()
kl_beta = K.variable(value=0.0)
kl_beta._trainable = False  # It isn't trained. We set it explicitly with the callback.

def kl_beta_update(epoch_ix, N_epochs=N_EPOCHS, M_cycles=3, R_increasing=0.8):
    T = N_epochs // M_cycles
    tau = (epoch_ix % T) / T
    new_beta_value = tf.minimum(1.0, tau/R_increasing)
#     new_beta_value = new_beta_value * BATCH_SIZE  #  / N_TRIALS
    new_beta_value = 1.0 * new_beta_value
    K.set_value(kl_beta, new_beta_value)

### Hyperparameter Tuning

https://github.com/optuna/optuna/blob/master/examples/keras_simple.py

https://neptune.ai/blog/optuna-vs-hyperopt

In [None]:
import optuna
from sklearn.linear_model import LogisticRegressionCV


CLASSIFIER_OBJECTIVE = True
TUNE_STATIC = True
TUNE_DYNAMIC = True
TUNE_REGU = False
TUNE_LR = False


def create_model(trial):
    if TUNE_STATIC:
        static_units = trial.suggest_int("static_units", 4, 256, log=True)  # RNN cells in static encoder: 128
        static_latent_size = trial.suggest_int("static_latent_size", 2, 128, log=True) # f | g_0: 64
    else:
        static_units = 128
        static_latent_size = 8
        
    if TUNE_DYNAMIC:
        dynamic_graph = 'full'
        # dynamic_graph = trial.suggest_categorical("dynamic_graph", ['none', 'factorized', 'full', 'controller'])
        dynamic_units = trial.suggest_int("dynamic_units", 2, 32)  # RNN cells in dynamic encoder: 12
        dynamic_latent_size = trial.suggest_int("dynamic_latent_size", 2, 8)  # z_t | u_t: 2
    else:
        dynamic_graph = 'none'
        dynamic_units = None
        dynamic_latent_size = None

    gen_units = trial.suggest_int("gen_units", 8, 128, log=True)  # RNN cells in generator: 256
    # gen_combine_latents = trial.suggest_categorical("gen_combine_latents", ['state_and_input', 'tile_concat'])
    gen_combine_latents = 'state_and_input'
    gen_factors = trial.suggest_int("gen_factors", 4, 32)

    if TUNE_REGU:
        dropout_rate = trial.suggest_float("dropout_rate", 0.0, 0.5)  # 0.025
        l2_reg = trial.suggest_float("l2_reg", 1e-6, 0.1, log=True)  # 2e-5
    else:
        dropout_rate=0.3
        l2_reg=2e-5
    
    ae_model = AutoEncoder(static_units=static_units, static_latent_size=static_latent_size,
                           dynamic_graph=dynamic_graph,
                           dynamic_units=dynamic_units, dynamic_latent_size=dynamic_latent_size,
                           gen_units=gen_units,
                           n_factors=gen_factors,
                           dropout_rate=dropout_rate, l2_reg=l2_reg)
    return ae_model


def objective(trial):
    n_epochs = 50
    K.clear_session()
    tf.random.set_seed(RANDOM_SEED)
    model = create_model(trial)
    lr = trial.suggest_float("adam_learning_rate", 1e-5, 1e-1, log=True) if TUNE_LR else 2e-3
    
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr))
    print(trial.params)
    history = model.fit(ds, epochs=n_epochs,
                        verbose=2,
                        callbacks=[
                            optuna.integration.TFKerasPruningCallback(trial, 'neg.log.like'),
                            tf.keras.callbacks.LambdaCallback(on_epoch_begin=lambda epoch, logs:
                                                              kl_beta_update(epoch, N_epochs=n_epochs, M_cycles=1))
                        ])
    if not CLASSIFIER_OBJECTIVE:
        return np.min(history.history['neg.log.like'][-n_epochs//5:])
    else:
        # Collect data
        t_vec = ax_info['timestamps']
        feat_t_vec = t_vec[model.readin.pooling//2::model.readin.pooling]\
                        if USE_READIN else t_vec
        class_ids = np.zeros((0,), dtype=int)
        static_latents = np.zeros((0, model.static_encoder.latent_size))
        dynamic_latents = np.zeros((0, len(feat_t_vec),
                                    model.dynamic_encoder.latent_size if DYNAMIC_GRAPH != 'none' else 1))
        for batch in ds:
            class_ids = np.hstack((class_ids, np.argmax(batch[1].numpy(), axis=1)))
            p_full, features, q_f, q_z, facs, log_rates, p_full = model(batch[0], return_intermediates=True)
            static_latents = np.vstack((static_latents, q_f.mean().numpy()))
            
        clf = make_pipeline(StandardScaler(), LogisticRegressionCV(cv=5, random_state=0, max_iter=5000))
        clf.fit(static_latents, class_ids)
        return clf.score(static_latents, class_ids)


study = optuna.create_study(direction="maximize" if CLASSIFIER_OBJECTIVE else "minimize",
                            sampler=optuna.samplers.TPESampler(),
                            pruner=optuna.pruners.HyperbandPruner(min_resource=2)
                           )
study.optimize(objective, n_trials=120)

In [None]:
print("Number of finished trials: ", len(study.trials))

print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

In [None]:
optuna.visualization.plot_contour(study)

In [None]:
optuna.visualization.plot_parallel_coordinate(study)

## Train the model

In [None]:
K.clear_session()
tf.random.set_seed(RANDOM_SEED)
ae_model = AutoEncoder(static_units=148, static_latent_size=24,
                       dynamic_graph='full',
                       dynamic_units=8, dynamic_latent_size=2,
                       gen_units=10,
                       n_factors=10,
                       dropout_rate=0.3, l2_reg=2e-6)

ae_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=2e-3))
history = ae_model.fit(ds, epochs=N_EPOCHS,
                       verbose=2,
                       callbacks=[
                           tf.keras.callbacks.LambdaCallback(on_epoch_begin=lambda epoch,
                                                             logs: kl_beta_update(epoch, N_epochs=N_EPOCHS))
                           # https://keras.io/api/callbacks/lambda_callback/
                       ])

In [None]:
print(history.history.keys())
# print(plt.style.available)
# plt.style.use('fivethirtyeight')
fig, ax1 = plt.subplots()
ax1.plot(history.history['neg.log.like'], linewidth=3, color='C0', label='AE loss')
ax1.set_ylabel('Recon. Neg.Log.Likelihoood', color='C0')
ax1.set_ylim([800, 2000])
ax1.set_xlabel('Epochs')
ax1.tick_params(axis='y', labelcolor='C0')
ax2 = ax1.twinx()
ax2.plot(history.history['static KL'], color='C1', label='static KL')
ax2.plot(history.history['dynamic KL'], color='C2', label='dynamic KL')
ax2.set_ylabel('KL Divergence', color='C1')
ax2.set_ylim([0, 20])
ax2.tick_params(axis='y', labelcolor='C1')
ax2.legend(facecolor='k')

## Visualize Latents

* Visualize priors to confirm their locations were non-zero indicating they learned.
* Calculate 2-comp t-SNE on static latents
* Plot 1 trial: input (spike counts), static latents in t-SNE space, dynamic latents, generated factors, recon rates.
* Plot all trials, colour-coded by target: static latents in t-SNE space, dynamic latents as 2-D trajectories; plot per-target averages

In [None]:
# Create a colour code cycler e.g. 'C0', 'C1', etc.
# from itertools import cycle
# colour_codes = map('C{}'.format, cycle(range(10)))
# class_colors = np.array([next(colour_codes) for _ in range(10)])
class_colors = np.array(['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w'])
turbo_cmap = plt.cm.get_cmap('turbo')
c_categ = turbo_cmap(np.linspace(0, 1, 8))

In [None]:
# Collect data
t_vec = ax_info['timestamps']
feat_t_vec = t_vec[ae_model.readin.pooling//2::ae_model.readin.pooling]\
                if USE_READIN else t_vec
class_ids = np.zeros((0,), dtype=int)
static_latents = np.zeros((0, ae_model.static_encoder.latent_size))
dynamic_latents = np.zeros((0, len(feat_t_vec),
                            ae_model.dynamic_encoder.latent_size if DYNAMIC_GRAPH != 'none' else 1))
for batch in ds:
    class_ids = np.hstack((class_ids, np.argmax(batch[1].numpy(), axis=1)))
    p_full, features, q_f, q_z, facs, log_rates, p_full = ae_model(batch[0], return_intermediates=True)
    static_latents = np.vstack((static_latents, q_f.mean().numpy()))
    if DYNAMIC_GRAPH != 'none':
        dynamic_latents = np.vstack((dynamic_latents, q_z.mean().numpy()))
static_prior = ae_model.static_encoder.static_prior_factory().mean().numpy()
if DYNAMIC_GRAPH != 'none':
    dynamic_prior_samp, dynamic_prior_dist = ae_model.dynamic_encoder.sample_dynamic_prior(
        len(feat_t_vec), samples=1, batches=1)
    dynamic_prior = dynamic_prior_dist.mean().numpy().squeeze()

In [None]:
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA


TEST_PERPLEXITY = 30  # 10, 30

# Precede TSNE with a PCA.
pca = PCA(n_components=static_latents.shape[-1])
static_latents_pca = pca.fit_transform(static_latents)

# Calculate t-SNE
tsne_model = TSNE(n_components=2, perplexity=TEST_PERPLEXITY)
static_latents_tsne = tsne_model.fit_transform(static_latents_pca)

In [None]:
# Find least squares solution relating static latents to tsne embedding.
#  Use this to identify best latent dims to plot
W = np.linalg.lstsq(np.hstack([static_latents, np.ones((len(static_latents), 1))]),
                    static_latents_tsne, rcond=None)[0]
b = W[-1, :]
W = W[:-1, :]
important_latent_dims = np.argsort(np.sum(np.abs(W), axis=1))[::-1]

In [None]:
# Scatter plot per class
plot_range = [-5.0, 5.0]
# print(plt.style.available)
plt.style.use('dark_background')
fig = plt.figure(figsize=[8, 6], tight_layout=True)
axes = fig.subplots(2, 2)

for pair_ix, dim_pair in enumerate(important_latent_dims[:8].reshape((4, 2))):
    row_ix = pair_ix // 2
    col_ix = pair_ix % 2
    dat = static_latents[:, dim_pair]
    plot_lim = int(np.ceil(np.max(np.abs(dat))))
    axes[row_ix, col_ix].scatter(dat[:, 0], dat[:, 1],
                                 c=c_categ[class_ids])
    axes[row_ix, col_ix].scatter(static_prior[dim_pair[0]], static_prior[dim_pair[1]],
                                 s=200, c='k', marker='+')
#     axes[row_ix, col_ix].set_xlim([-plot_lim, plot_lim])
#     axes[row_ix, col_ix].set_ylim([-plot_lim, plot_lim])
    axes[row_ix, col_ix].set_xlabel(f"dim {dim_pair[0]}")
    axes[row_ix, col_ix].set_ylabel(f"dim {dim_pair[1]}")
#     axes[row_ix, col_ix].set_xticks([-plot_lim, 0, plot_lim])
#     axes[row_ix, col_ix].set_yticks([-plot_lim, 0, plot_lim])
    

In [None]:
if DYNAMIC_GRAPH != 'none':
    plt.subplot(1, 2, 1)
    plt.plot(feat_t_vec, dynamic_prior + 0.2*np.arange(dynamic_prior.shape[-1]).reshape((1, -1)))
    plt.subplot(1, 2, 2)
    plt.plot(dynamic_prior[:, 0], dynamic_prior[:, 1])

In [None]:
# batch (inputs) p_full, features, q_f, q_z, facs, log_rates, p_full
tr_ix = 1  # in batch

fig = plt.figure(figsize=[10, 6], tight_layout=True)

plt.subplot(2, 3, 1)
plt.plot(t_vec, batch[0][tr_ix])
plt.title('Binned Spike Counts')
#plt.ylim([0, 6])

plt.subplot(2, 3, 2)
# TODO: better way to represent per-trial static latent? t-sne space?
plt.plot(q_f.mean()[tr_ix], '.')
plt.title('Static Latent')

if DYNAMIC_GRAPH != 'none':
    plt.subplot(2, 3, 3)
    plt.plot(feat_t_vec, q_z.mean()[tr_ix])
    plt.title('Dynamic Latent')

plt.subplot(2, 3, 4)
plt.plot(feat_t_vec, facs[tr_ix])
plt.title('Factors')

plt.subplot(2, 3, 5)
plt.plot(t_vec, p_full.mean()[tr_ix])
plt.title('Recon. Rates')
#plt.ylim([0, 6])
plt.tight_layout()
plt.show()

In [None]:
def plot_tsne(x_vals, y_vals, perplexity, title='Model Output'):
    plt.scatter(x=x_vals[:, 0], y=x_vals[:, 1], c=c_categ[y_vals])
    plt.xlabel('t-SNE D-1')
    plt.ylabel('t-SNE D-2')
    plt.title(title + ' (Ppx: {})'.format(perplexity))
    ax = plt.gca()

In [None]:
fig = plt.figure(figsize=(8, 8))
plt.subplot(1, 1, 1)
plot_tsne(static_latents_tsne, class_ids, TEST_PERPLEXITY, title='Latents')

In [None]:
from sklearn.linear_model import LogisticRegressionCV

clf = make_pipeline(StandardScaler(), LogisticRegressionCV(cv=5, random_state=0, 
                                                           max_iter=4000))
clf.fit(static_latents, class_ids)
print(clf.score(static_latents, class_ids))

In [None]:
X_rates, Y_class, _ = load_macaque_pfc(data_path, sess_id, x_chunk='spikerates',
                                        **{**load_kwargs, 'resample_X': 1})
X_rates = X_rates.reshape(X_rates.shape[0], -1)
Y_class = Y_class.ravel()
clf.fit(X_rates, Y_class)
print(clf.score(X_rates, Y_class))