## Transform spike rates to behaviour using RNN and CNN ensemble
### Setup
Environment Setup

Configure the local or Google Colab environments.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import os
import sys
try:
    # Only on works on Google Colab
    from google.colab import files
    %tensorflow_version 2.x
    os.chdir('..')
    
    # Configure kaggle if necessary
    if not (Path.home() / '.kaggle').is_dir():
        uploaded = files.upload()  # Find the kaggle.json file in your ~/.kaggle directory.
        if 'kaggle.json' in uploaded.keys():
            !mkdir -p ~/.kaggle
            !mv kaggle.json ~/.kaggle/
            !chmod 600 ~/.kaggle/kaggle.json
    
    !pip install git+https://github.com/SachsLab/indl.git
    
    if Path.cwd().stem == 'MonkeyPFCSaccadeStudies':
        os.chdir(Path.cwd().parent)
    
    if not (Path.cwd() / 'MonkeyPFCSaccadeStudies').is_dir():
        !git clone --single-branch --recursive https://github.com/SachsLab/MonkeyPFCSaccadeStudies.git
        sys.path.append(str(Path.cwd() / 'MonkeyPFCSaccadeStudies'))
    
    os.chdir('MonkeyPFCSaccadeStudies')
        
    !pip install -q kaggle
    
    # Latest version of SKLearn
    !pip install -U scikit-learn
    
    IN_COLAB = True
    
except ModuleNotFoundError:    
    # chdir to MonkeyPFCSaccadeStudies
    if Path.cwd().stem == 'Analysis':
        os.chdir(Path.cwd().parent.parent)
        
    # Add indl repository to path.
    # Eventually this should already be pip installed, but it's still under heavy development so this is easier for now.
    check_dir = Path.cwd()
    while not (check_dir / 'Tools').is_dir():
        check_dir = check_dir / '..'
    indl_path = check_dir / 'Tools' / 'Neurophys' / 'indl'
    sys.path.append(str(indl_path))
    
    # Make sure the kaggle executable is on the PATH
    os.environ['PATH'] = os.environ['PATH'] + ';' + str(Path(sys.executable).parent / 'Scripts')
    
    IN_COLAB = False

# Try to clear any logs from previous runs
if (Path.cwd() / 'logs').is_dir():
    import shutil
    try:
        shutil.rmtree(str(Path.cwd() / 'logs'))
    except PermissionError:
        print("Unable to remove logs directory.")

In [None]:
# Additional imports
import tensorflow as tf
import numpy as np
import random
import matplotlib.pyplot as plt
from indl.display import turbo_cmap
from sklearn.model_selection import train_test_split

In [None]:
plt.rcParams.update({
    'axes.titlesize': 24,
    'axes.labelsize': 20,
    'lines.linewidth': 2,
    'lines.markersize': 5,
    'xtick.labelsize': 16,
    'ytick.labelsize': 16,
    'legend.fontsize': 18,
    'figure.figsize': (6.4, 6.4)
})

### Download Data (if necessary)

In [None]:
if IN_COLAB:
    data_path = Path.cwd() / 'data' / 'monkey_pfc' / 'converted'
else:
    data_path = Path.cwd() / 'StudyLocationRule' / 'Data' / 'Preprocessed'

if not (data_path).is_dir():
    !kaggle datasets download --unzip --path {str(data_path)} cboulay/macaque-8a-spikes-rates-and-saccades
    print("Finished downloading and extracting data.")
else:
    print("Data directory found. Skipping download.")

### (Prepare to) Load Data

We will use a custom function `load_macaque_pfc` to load the data into memory.

There are 4 different strings to be passed to the import `x_chunk` argument:
* 'analogsignals' - if present. Returns 1 kHz LFPs
* 'gaze'          - Returns 2-channel gaze data.
* 'spikerates'    - Returns smoothed spikerates
* 'spiketrains'

The `y_type` argument can be
* 'pair and choice' - returns Y as np.array of (target_pair, choice_within_pair)
* 'encoded input' - returns Y as np.array of shape (n_samples, 10) (explained below)
* 'replace with column name' - returns Y as a vector of per-trial values. e.g., 'sacClass'

The actual data we load depends on the particular analysis below.

In [None]:
from misc.misc import sess_infos, load_macaque_pfc, dec_from_enc

load_kwargs = {
    'valid_outcomes': (0,),  # Use (0, 9) to include trials with incorrect behaviour
    'zscore': True,
    'dprime_range': (1.0, np.inf),  # Use (-np.inf, np.inf) to include all trials.
    'time_range': (-np.inf, 1.45),
    'verbose': False,
    'y_type': 'sacClass',
    'samples_last': True    
    #     'resample_X': 20
}

In [None]:
test_sess_ix = 1
sess_info = sess_infos[test_sess_ix]
sess_id = sess_info['exp_code']
print(f"\nImporting session {sess_id}")
X_rates, Y_class, ax_info = load_macaque_pfc(data_path, sess_id, x_chunk='spikerates', **load_kwargs)
Y_class = tf.keras.utils.to_categorical(Y_class, num_classes=8)

## Decoding trial class (0:7) from spike rates

### Making the model

In [None]:
from indl.model import parts
from indl.model.helper import check_inputs
from indl.regularizers import KernelLengthRegularizer
from tensorflow.keras import layers

@check_inputs
def make_model_RNN(
    _input,
    filt=8,
    kernLength=25,
    ds_rate=10,
    n_rnn=64,
    n_rnn2=64,
    dropoutRate=0.25,
    activation='relu',
    l1_reg=0.000, l2_reg=0.000,
    norm_rate=0.25,
    latent_dim=16,
    return_model=True
):
    _y = _input
    
    input_shape = _input.shape.as_list()
    # The Conv layers are insensitive to the number of samples in the time dimension.
    # To make it possible for this trained model to be applied to segments of different
    # durations, we need to explicitly state that we don't care about the number of samples.
    # input_shape[2] = -1  # Comment out during debug
#     _y = layers.Reshape(input_shape[1:])(_input)  # Note that Reshape ignores the batch dimension.

    # RNN
    if len(input_shape) < 4:
        input_shape = input_shape + [1]
    # The Conv layers are insensitive to the number of samples in the time dimension.
    # To make it possible for this trained model to be applied to segments of different
    # durations, we need to explicitly state that we don't care about the number of samples.
    _y = layers.Reshape(input_shape[1:])(_y)
    _y = tf.keras.layers.Conv2D(filt, (1, kernLength), padding='valid', data_format=None,
                                dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer='glorot_uniform',
                                bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None,
                                activity_regularizer=None, kernel_constraint=None, bias_constraint=None)(_y)
    _y = tf.keras.layers.BatchNormalization()(_y)
    _y = tf.keras.layers.DepthwiseConv2D((_y.shape.as_list()[1], 1), padding='valid',
                                      depth_multiplier=1, data_format=None, dilation_rate=(1, 1),
                                      activation=None, use_bias=True, depthwise_initializer='glorot_uniform',
                                      bias_initializer='zeros', depthwise_regularizer=None,
                                      bias_regularizer=None, activity_regularizer=None,
                                      depthwise_constraint=None, bias_constraint=None)(_y)
    _y = tf.keras.layers.BatchNormalization()(_y)
    _y = tf.keras.layers.Activation(activation)(_y)
    _y = tf.keras.layers.AveragePooling2D(pool_size=(1, ds_rate))(_y)
    _y = tf.keras.layers.Dropout(dropoutRate)(_y)
    _y = layers.Reshape(_y.shape.as_list()[2:])(_y)
    _y = tf.keras.layers.LSTM(n_rnn,
                              kernel_regularizer=tf.keras.regularizers.l2(l2_reg),
                              recurrent_regularizer=tf.keras.regularizers.l2(l2_reg),
                              return_sequences=n_rnn2 > 0,
                              stateful=False,
                              name='rnn1')(_y)
    _y = tf.keras.layers.Activation(activation)(_y)
    _y = tf.keras.layers.BatchNormalization()(_y)
    _y = tf.keras.layers.Dropout(dropoutRate)(_y)
    
    
    if n_rnn2 > 0:
        
        _y = tf.keras.layers.LSTM(n_rnn2,
                              kernel_regularizer=tf.keras.regularizers.l2(l2_reg),
                              recurrent_regularizer=tf.keras.regularizers.l2(l2_reg),
                              return_sequences=False,
                              stateful=False,
                              name='rnn2')(_y)
        _y = tf.keras.layers.Activation(activation)(_y)
        _y = tf.keras.layers.BatchNormalization()(_y)
        _y = tf.keras.layers.Dropout(dropoutRate)(_y)
    
    # Dense
    _y = parts.Bottleneck(_y, latent_dim=latent_dim, activation=activation)
    
    # Classify
    _y = parts.Classify(_y, n_classes=8, norm_rate=norm_rate)
    

    if return_model is False:
        return _y
    else:
        return tf.keras.models.Model(inputs=_input, outputs=_y)

@check_inputs
def make_model_CNN(
        _input,
        F1=8, kernLength=25, F1_kern_reg=None,
        D=2, D_pooling=4,
        F2=8, F2_kernLength=16,
        F2_pooling=8,
        dropoutRate=0.25,
        activation='relu',
        l1_reg=0.000, l2_reg=0.000,
        norm_rate=0.25,
        latent_dim=16,
        return_model=True
    ):
    
    if F1_kern_reg is None:
        F1_kern_reg = tf.keras.regularizers.l1_l2(l1=l1_reg, l2=l2_reg)
    elif isinstance(F1_kern_reg, str) and F1_kern_reg == 'kern_length_regu':
        F1_kern_reg = KernelLengthRegularizer((1, kernLength),
                                              window_scale=1e-4,
                                              window_func='poly',
                                              poly_exp=2,
                                              threshold=0.0015)
        
    # EEGNetEnc 
    _y = parts.EEGNetEnc(_input,
                         F1=F1,
                         F1_kernLength=kernLength,
                         F1_kern_reg=F1_kern_reg,
                         D=D,
                         D_pooling=D_pooling,
                         F2=F2,
                         F2_pooling=F2_pooling,
                         F2_kernLength=F2_kernLength,
                         dropoutRate=dropoutRate)
    
    # Restore time-dimension that was stripped out by EEGNetEnc
    _y = layers.Reshape((1, _input.shape.as_list()[2] // D_pooling // F2_pooling, F2))(_y)
    
    # Dense
    _y = parts.Bottleneck(_y, latent_dim=latent_dim, activation=activation)
    
    # Classify
    _y = parts.Classify(_y, n_classes=8, norm_rate=norm_rate)
    
    if return_model:
        return tf.keras.models.Model(inputs=[_input], outputs=[_y])
    else:
        return _y

### Training the ensemble

In [None]:
from sklearn.model_selection import StratifiedKFold

N_SPLITS = 10
BATCH_SIZE = 16
EPOCHS = 180
LABEL_SMOOTHING = 0.2


def model_kfold_train(sess_id, branch, verbose=1):
#     models = []
#     print(f"Processing session {sess_id}...")
    X_rates, Y_class, ax_info = load_macaque_pfc(data_path, sess_id, x_chunk='spikerates', **load_kwargs)
    
    splitter = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=0)
    split_ix = 0
    histories = []
    per_fold_eval = []
    per_fold_true = []

    for trn, vld in splitter.split(X_rates, Y_class):
        print(f"\tSplit {split_ix + 1} of {N_SPLITS}")
        _y = tf.keras.utils.to_categorical(Y_class, num_classes=8)
        
        ds_train = tf.data.Dataset.from_tensor_slices((X_rates[trn], _y[trn]))
        ds_valid = tf.data.Dataset.from_tensor_slices((X_rates[vld], _y[vld]))

        # cast data types to GPU-friendly types.
        ds_train = ds_train.map(lambda x, y: (tf.cast(x, tf.float32), tf.cast(y, tf.uint8)))
        ds_valid = ds_valid.map(lambda x, y: (tf.cast(x, tf.float32), tf.cast(y, tf.uint8)))

        # TODO: augmentations (random slicing?)

        ds_train = ds_train.shuffle(len(trn) + 1)
        ds_train = ds_train.batch(BATCH_SIZE, drop_remainder=True)
        ds_valid = ds_valid.batch(BATCH_SIZE, drop_remainder=False)

        tf.keras.backend.clear_session()
        
        randseed = 12345
        random.seed(randseed)
        np.random.seed(randseed)
        tf.random.set_seed(randseed)
        
        if branch == 'RNN':
            model = make_model_RNN(
                ds_train.element_spec[0],
                **model_kwargs
            )
        else:
            model = make_model_CNN(
                ds_train.element_spec[0],
                **model_kwargs
            )
        
        optim = tf.keras.optimizers.Nadam(learning_rate=0.001)
        loss_obj = tf.keras.losses.CategoricalCrossentropy(label_smoothing=LABEL_SMOOTHING)
        model.compile(optimizer=optim, loss=loss_obj, metrics=['accuracy'])
        
        best_model_path = f'r2c_lstm_{sess_id}_split{split_ix}.h5'
        callbacks = [
            tf.keras.callbacks.ModelCheckpoint(
                filepath=best_model_path,
                # Path where to save the model
                # The two parameters below mean that we will overwrite
                # the current checkpoint if and only if
                # the `val_loss` score has improved.
                save_best_only=True,
                monitor='val_accuracy',
                verbose=verbose)
        ]

        model.fit(x=ds_train, epochs=EPOCHS,
                         verbose=verbose,
                         validation_data=ds_valid,
                         callbacks=callbacks)
        # tf.keras.models.save_model(model, 'model.h5')
        
        model = tf.keras.models.load_model(best_model_path)
        per_fold_eval.append(model(X_rates[vld]).numpy())
        per_fold_true.append(Y_class[vld])
        
        split_ix += 1
        
    pred_y = np.concatenate([np.argmax(_, axis=1) for _ in per_fold_eval])
    true_y = np.concatenate(per_fold_true).flatten()
    accuracy = 100 * np.sum(pred_y == true_y) / len(pred_y)
    print(f"Session {sess_id} overall accuracy: {accuracy}%")
    
    return per_fold_eval, per_fold_true


In [None]:
pfe = [np.zeros((29,8)), np.zeros((29,8)), np.zeros((29,8)), np.zeros((29,8)), np.zeros((29,8)),
       np.zeros((28,8)), np.zeros((28,8)), np.zeros((28,8)), np.zeros((28,8)), np.zeros((28,8))]
# rnn1 = [32, 32, 32, 32, 32, 32, 32, 32, 32, 64, 64, 64, 64, 64, 64, 64, 64, 64]
# rnn2 = [0, 32, 64, 0, 32, 64, 0, 32, 64, 0, 32, 64, 0, 32, 64, 0, 32, 64]
# dense = [0, 0, 0, 32, 32, 32, 64, 64, 64, 0, 0, 0, 32, 32, 32, 64, 64, 64]
rnn1 = [32, 32, 64,64]
rnn2 = [0, 32, 32, 64]
dense = [32, 32, 32, 64]
for i in range(len(rnn1)):
    print(f"Training RNN Model # {i + 1}")
    model_kwargs=dict(
        filt=8,
        kernLength=25,
        ds_rate=9,
        n_rnn=rnn1[i],
        n_rnn2=rnn2[i],
        dropoutRate=0.30,
        activation='relu',
        l1_reg=0.0000, l2_reg=0.003,
        norm_rate=0.25,
        latent_dim=dense[i]
    )
    pfe_model, _ = model_kfold_train(sess_id, 'RNN', verbose=0)
    pfe = np.add(pfe, pfe_model)

print("Training the CNN Model")
model_kwargs = dict(
    F1=8, kernLength=25, F1_kern_reg=None,
    D=2, D_pooling=4,
    F2=8, F2_kernLength=16,
    F2_pooling=8,
    dropoutRate=0.30,
    activation='relu',
    l1_reg=0.000, l2_reg=0.003,
    norm_rate=0.25,
    latent_dim=16
)
pfe_model, pft = model_kfold_train(sess_id, 'CNN', verbose=0)
pfe = np.add(pfe, pfe_model)

pred_y = np.concatenate([np.argmax(_, axis=1) for _ in pfe])
true_y = np.concatenate(pft).flatten()
accuracy = 100 * np.sum(pred_y == true_y) / len(pred_y)
print(f"Session {sess_id} overall ensemble accuracy: {accuracy}%")

### Testing on all sessions

In [None]:
# TODO
# hists = []
# accs = []
# for sess_info in sess_infos:
#     _hist, _acc = get_hists_acc(sess_info['exp_code'], verbose=0)
#     hists.append(_hist)
#     accs.append(_acc)
    
# print(accs)