In [44]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_probability as tfp
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from tqdm.notebook import tqdm
import seaborn as sns
from numba import njit
from scipy.stats import skew
from scipy.stats import binom
from functools import partial
np.set_printoptions(suppress=True)

# Import our bayesflow lib
from deep_bayes.models import BayesFlow, InvariantNetwork
from deep_bayes.training import train_online
from deep_bayes.losses import maximum_likelihood_loss
from deep_bayes.viz import plot_true_est_scatter, plot_true_est_posterior
import deep_bayes.diagnostics as diag

In [48]:
import tensorflow as tf
if tf.__version__.startswith('1'):
    tf.enable_eager_execution()

In [3]:
%matplotlib inline

# Summary Network Structure
Here, we will define the basic outline of a permutation-invariant neural network which maps raw reaction times data to outcomes.
<br>
See https://arxiv.org/pdf/1901.06082.pdf (p.28) for more details.

In [7]:
class InvariantModule(tf.keras.Model):
    """Implements an invariant nn module as proposed by Bloem-Reddy and Teh (2019)."""

    def __init__(self, meta, pooler=tf.reduce_mean):
        """
        Creates an invariant function with mean pooling.
        ----------

        Arguments:
        meta : dict -- a dictionary with hyperparameter name - values
        """

        super(InvariantModule, self).__init__()


        self.module = tf.keras.Sequential([
            tf.keras.layers.Dense(**meta['dense_inv_args'])
            for _ in range(meta['n_dense_inv'])
        ])
        
        self.pooler = pooler
            

        self.post_pooling_dense = tf.keras.Sequential([
            tf.keras.layers.Dense(**meta['dense_inv_args'])
            for _ in range(meta['n_dense_inv'])
        ])

    def call(self, x):
        """
        Transofrms the input into an invariant representation.
        ----------

        Arguments:
        x : tf.Tensor of shape (batch_size, n, m) - the input where n is the 'time' or 'samples' dimensions
            over which pooling is performed and m is the input dimensionality
        ----------

        Returns:
        out : tf.Tensor of shape (batch_size, h_dim) -- the pooled and invariant representation of the input
        """

        # Embed
        x_emb = self.module(x)

        # Pool representation
        pooled = self.pooler(x_emb, axis=1)
    
        # Increase representational power
        out = self.post_pooling_dense(pooled)
        return out


class EquivariantModule(tf.keras.Model):
    """Implements an equivariant nn module as proposed by Bloem-Reddy and Teh (2019)."""

    def __init__(self, meta):
        """
        Creates an equivariant neural network consisting of a FC network with
        equal number of hidden units in each layer and an invariant module
        with the same FC structure.
        ----------

        Arguments:
        meta : dict -- a dictionary with hyperparameter name - values
        """

        super(EquivariantModule, self).__init__()

        self.module = tf.keras.Sequential([
            tf.keras.layers.Dense(**meta['dense_equiv_args'])
            for _ in range(meta['n_dense_equiv'])
        ])

        self.invariant_module = InvariantModule(meta)

    def call(self, x):
        """
        Transofrms the input into an equivariant representation.
        ----------

        Arguments:
        x : tf.Tensor of shape (batch_size, n, m) - the input where n is the 'time' or 'samples' dimensions
            over which pooling is performed and m is the input dimensionality
        ----------

        Returns:
        out : tf.Tensor of shape (batch_size, h_dim) -- the pooled and invariant representation of the input
        """

        x_inv = self.invariant_module(x)
        x_inv = tf.stack([x_inv] * int(x.shape[1]), axis=1) # Repeat x_inv n times
        x = tf.concat((x_inv, x), axis=-1)
        out = self.module(x)
        return out


class InvariantNetwork(tf.keras.Model):
    """
    Implements a network which parameterizes a
    permutationally invariant function according to Bloem-Reddy and Teh (2019).
    """

    def __init__(self, meta):
        """
        Creates a permutationally invariant network
        consisting of two equivariant modules and one invariant module.
        ----------

        Arguments:
        meta : dict -- hyperparameter settings for the equivariant and invariant modules
        """

        super(InvariantNetwork, self).__init__()

        self.equiv = tf.keras.Sequential([
            EquivariantModule(meta)
            for _ in range(meta['n_equiv'])
        ])
        self.inv = InvariantModule(meta)

    def call(self, x, **kwargs):
        """
        Transofrms the input into a permutationally invariant
        representation by first passing it through multiple equivariant
        modules in order to increase representational power.
        ----------

        Arguments:
        x : tf.Tensor of shape (batch_size, n, m) - the input where n is the
        'samples' dimensions over which pooling is performed and m is the input dimensionality
        ----------

        Returns:
        out : tf.Tensor of shape (batch_size, h_dim) -- the pooled and invariant representation of the input
        """

        x = self.equiv(x)
        out = self.inv(x)
        return out

# Hyperparameter settings and model definition

In [8]:
# Neural network structure
summary_meta = {
    'dense_inv_args'   :  dict(units=64, activation='elu', kernel_initializer='glorot_normal'),
    'dense_equiv_args' :  dict(units=64, activation='elu', kernel_initializer='glorot_normal'),
    'dense_post_args'  :  dict(units=64, activation='elu', kernel_initializer='glorot_normal'),
    'n_equiv'          :  2,
    'n_dense_inv'      :  3,
    'n_dense_equiv'    :  3,
}

# Network hyperparameters
inv_meta = {
    'n_units': [128, 128, 128],
    'activation': 'elu',
    'w_decay': 0.00000,
    'initializer': 'glorot_uniform'
}
n_inv_blocks = 4

# Forward model hyperparameters
param_names = [r'$v_1$', r'$v_2$', r'$a_1$', r'$a_2$', r'$\tau_{c}$', r'$\tau_{w}$']
theta_dim = len(param_names)
n_test = 300
n_obs_max = 60
n_obs_min = 60
n_obs_test = (60, 60)

# Utility for online learning
data_gen = partial(data_generator, n_obs_min=n_obs_min, n_obs_max=n_obs_max)


# Training and optimizer hyperparameters
ckpt_file = "iat_bayesflow"
batch_size = 64
epochs = 50
iterations_per_epoch = 1000
n_samples_posterior = 2000
clip_value = 5.

learning_rate = 0.001
if tf.__version__.startswith('1'):
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
else:
    optimizer = tf.keras.optimizers.Adam(lr=learning_rate)               

In [9]:
summary_net = InvariantNetwork(summary_meta)
model = BayesFlow(inv_meta, n_inv_blocks, theta_dim, summary_net=summary_net, permute=True)

# Checkpoint manager
Used for saving/loading the model.

In [13]:
checkpoint = tf.train.Checkpoint(optimizer=optimizer, net=model)
manager = tf.train.CheckpointManager(checkpoint, './checkpoints/iat_bayesflow', max_to_keep=5)
checkpoint.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
else:
    print("Initializing from scratch.")

Restored from ./checkpoints/iat_bayesflow\ckpt-57


# Inference on all real data

In [None]:
import pickle 
# 1. Store all data-set chunk names in a list os.listdir()

# 2. For each chunk

    # 2.1 Load chunk
    
    
    # 2.2 Preprocess chunk 
        # 2.2.1 Apply IAT exclusion criteria (conservative)
        # 2.2.2 Format data for NN, negative coding
        # 2.2.3 Add 0s for <0.3, and >10
        
        
    # 2.3 Estimate chunk
        # theta_samples = np.concatenate([model.sample(x, n_post_samples_sbc, to_numpy=True) 
        # for x in tf.split(sbc_data['x'].astype(np.float32), 10, axis=0)], axis=1)
        
        
    # 2.5 Compute summaries: means, medians, stds, Q0.025, Q0.0975, post_corr
        
        
    # 2.4 Post-processing 
        # 2.4.1 find inices of implausible (out of prior) parameter means
        # 2.4.2 remove from estimates
        # 2.4.3 remove from datasets
    
    
    # 2.6 Store everything together (serialized, pickle.dump) as a dict with keys 
        # dict_to_store = {'data_array': data_chunk_clean, 'est_array': estimates_clean}
        
# 3. Celebrate

In [82]:
# # Get X_test into the correct format:

# rts = np.where(X_test[:, :, 1], -X_test[:, :, 0], X_test[:, :, 0])
# comps = X_test[:, :, 2]
# X_test = np.stack((rts, comps), axis=2)

# # Exclusion criterion (< 0.3)
# idx_300 = (np.abs(X_test[:, :, 0]) < 0.3).sum(axis=1) <= 12
# X_test = X_test[idx_300, :, :]

# # Exclusion criterion (> 10)
# idx_10000 = (np.abs(X_test)[:, :, 0] > 10.0).sum(axis=1) == 0
# X_test = X_test[idx_10000, :, :]

# # Keep only corresponding y
# y_test = y_test[idx_300, :]
# y_test = y_test[idx_10000, :]