# To run this notebook you need to clone this repository to be imported in the second cell:

## https://github.com/garrettkatz/rnn-fxpts.git

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import sys
import rnn_fxpts as rfx
import tensorflow as tf
import numpy as np
import random
import matplotlib.pyplot as plt
import pickle
from sklearn.decomposition import PCA


from indl.display import turbo_cmap
from sklearn.model_selection import train_test_split

# Testing RFX

In [None]:
N0 = 2
W0 = 1.25*np.eye(N0) + 0.1*np.random.randn(N0,N0)

In [None]:
fxpts0, fiber0 = rfx.run_solver(W0)

In [None]:
print(fxpts0.shape, fiber0.shape)

In [None]:
plt.plot(fxpts0[0], fxpts0[1], 'o')
plt.plot(fiber0[0], fiber0[1])
plt.plot(fiber0[1], fiber0[2])
plt.plot(fiber0[0], fiber0[2])
plt.plot(fiber0[1], fiber0[0])
plt.plot(fiber0[2], fiber0[1])
plt.plot(fiber0[2], fiber0[0])
plt.xlim((-1.5,1.5))
plt.ylim((-1.5,1.5))

In [None]:
rfx.show_fiber(W0, fxpts0, fiber0)

# Training our model

### Loading data (New version)

In [None]:
# You can change the data_path according to the local setup

In [None]:
data_path = Path.cwd().parent.parent / 'Data' / 'Preprocessed'
if not (data_path).is_dir():
    !kaggle datasets download --unzip --path {str(data_path)} cboulay/macaque-8a-spikes-rates-and-saccades
    print("Finished downloading and extracting data.")
else:
    print("Data directory found. Skipping download.")

In [None]:
from misc.misc import sess_infos, load_macaque_pfc, dec_from_enc

load_kwargs = {
    'valid_outcomes': (0,),  # Use (0, 9) to include trials with incorrect behaviour
    'zscore': True,
    'dprime_range': (1.0, np.inf),  # Use (-np.inf, np.inf) to include all trials.
    'time_range': (-np.inf, 1.45),
    'verbose': False,
    'y_type': 'sacClass',
    'samples_last': True    
    #     'resample_X': 20
}

In [None]:
test_sess_ix = 1
sess_info = sess_infos[test_sess_ix]
sess_id = sess_info['exp_code']
print(f"\nImporting session {sess_id}")
X_rates, Y_new, ax_info = load_macaque_pfc(data_path, sess_id, x_chunk='spikerates', **load_kwargs)
Y_class = tf.keras.utils.to_categorical(Y_new, num_classes=8)

### Loading data (Old version)

In [None]:
sess_id_old = sess_info['exp_code'][:-1] + "_v1+"
print(f"\nImporting session {sess_id_old}")
X_rates_old, Y_old, ax_info_old = load_macaque_pfc(data_path, sess_id_old, x_chunk='spikerates', **load_kwargs)
Y_class_old = tf.keras.utils.to_categorical(Y_old, num_classes=8)

In [None]:
print(X_rates.shape, X_rates_old.shape)

### Making the model

In [None]:
from indl.model import parts
from indl.model.helper import check_inputs
from indl.regularizers import KernelLengthRegularizer

@check_inputs
def make_model(
    _input,
    num_classes,
    filt=8,
    kernLength=25,
    ds_rate=10,
    n_rnn=64,
    n_rnn2=64,
    dropoutRate=0.25,
    activation='relu',
    l1_reg=0.000, l2_reg=0.000,
    norm_rate=0.25,
    latent_dim=16,
    return_model=True
):
    
    inputs = _input
    
    input_shape = _input.shape.as_list()
    # The Conv layers are insensitive to the number of samples in the time dimension.
    # To make it possible for this trained model to be applied to segments of different
    # durations, we need to explicitly state that we don't care about the number of samples.
    # input_shape[2] = -1  # Comment out during debug
    # _y = layers.Reshape(input_shape[1:])(_input)  # Note that Reshape ignores the batch dimension.

    # RNN
    if len(input_shape) < 4:
        input_shape = input_shape + [1]
    # The Conv layers are insensitive to the number of samples in the time dimension.
    # To make it possible for this trained model to be applied to segments of different
    # durations, we need to explicitly state that we don't care about the number of samples.
    _y = tf.keras.layers.Reshape(input_shape[1:])(inputs)
    _y = tf.keras.layers.Conv2D(filt, (1, kernLength), padding='valid', data_format=None,
                                dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer='glorot_uniform',
                                bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None,
                                activity_regularizer=None, kernel_constraint=None, bias_constraint=None)(_y)
    _y = tf.keras.layers.BatchNormalization()(_y)
    _y = tf.keras.layers.DepthwiseConv2D((_y.shape.as_list()[1], 1), padding='valid',
                                      depth_multiplier=1, data_format=None, dilation_rate=(1, 1),
                                      activation=None, use_bias=True, depthwise_initializer='glorot_uniform',
                                      bias_initializer='zeros', depthwise_regularizer=None,
                                      bias_regularizer=None, activity_regularizer=None,
                                      depthwise_constraint=None, bias_constraint=None)(_y)
    _y = tf.keras.layers.BatchNormalization()(_y)
    _y = tf.keras.layers.Activation(activation)(_y)
    _y = tf.keras.layers.AveragePooling2D(pool_size=(1, ds_rate))(_y)
    _y = tf.keras.layers.Dropout(dropoutRate)(_y)
    _y = tf.keras.layers.Reshape(_y.shape.as_list()[2:])(_y)
    _y = tf.keras.layers.LSTM(n_rnn,
                              kernel_regularizer=tf.keras.regularizers.l2(l2_reg),
                              recurrent_regularizer=tf.keras.regularizers.l2(l2_reg),
                              return_sequences=n_rnn2 > 0,
                              stateful=False,
                              name='rnn1')(_y)
    _y = tf.keras.layers.Activation(activation)(_y)
    _y = tf.keras.layers.BatchNormalization()(_y)
    _y = tf.keras.layers.Dropout(dropoutRate)(_y)
    
    
    if n_rnn2 > 0:
        
        _y = tf.keras.layers.LSTM(n_rnn2,
                              kernel_regularizer=tf.keras.regularizers.l2(l2_reg),
                              recurrent_regularizer=tf.keras.regularizers.l2(l2_reg),
                              return_sequences=False,
                              stateful=False,
                              name='rnn2')(_y)
        _y = tf.keras.layers.Activation(activation)(_y)
        _y = tf.keras.layers.BatchNormalization()(_y)
        _y = tf.keras.layers.Dropout(dropoutRate)(_y)
    
    # Dense
    _y = parts.Bottleneck(_y, latent_dim=latent_dim, activation=activation)
    
    # Classify
    outputs = parts.Classify(_y, n_classes=num_classes, norm_rate=norm_rate)
    

    if return_model is False:
        return outputs
    else:
        return tf.keras.models.Model(inputs=inputs, outputs=outputs)

In [None]:
## Model Parameters
LABEL_SMOOTHING = 0.2

model_kwargs = dict(
    filt=8,
    kernLength=25,
    ds_rate=9,
    n_rnn=64,
    n_rnn2=0,
    dropoutRate=0.40,
    activation='relu',
    l1_reg=0.0000, l2_reg=0.001,
    norm_rate=0.25,
    latent_dim=64
)

### Training

In [None]:
from sklearn.model_selection import StratifiedKFold

N_SPLITS = 10
BATCH_SIZE = 16
EPOCHS = 180


def get_hists_acc(sess_id, new=True, verbose=1):
    print(f"Processing session {sess_id}...")
    X_rates, Y_class, ax_info = load_macaque_pfc(data_path, sess_id, x_chunk='spikerates', **load_kwargs)
    
    splitter = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=0)
    split_ix = 0
    histories = []
    per_fold_eval = []
    per_fold_true = []

    for trn, vld in splitter.split(X_rates, Y_class):
        print(f"\tSplit {split_ix + 1} of {N_SPLITS}")
        _y = tf.keras.utils.to_categorical(Y_class, num_classes=8)
        
        ds_train = tf.data.Dataset.from_tensor_slices((X_rates[trn], _y[trn]))
        ds_valid = tf.data.Dataset.from_tensor_slices((X_rates[vld], _y[vld]))

        # cast data types to GPU-friendly types.
        ds_train = ds_train.map(lambda x, y: (tf.cast(x, tf.float32), tf.cast(y, tf.uint8)))
        ds_valid = ds_valid.map(lambda x, y: (tf.cast(x, tf.float32), tf.cast(y, tf.uint8)))

        # TODO: augmentations (random slicing?)

        ds_train = ds_train.shuffle(len(trn) + 1)
        ds_train = ds_train.batch(BATCH_SIZE, drop_remainder=True)
        ds_valid = ds_valid.batch(BATCH_SIZE, drop_remainder=False)

        tf.keras.backend.clear_session()
        
        randseed = 12345
        random.seed(randseed)
        np.random.seed(randseed)
        tf.random.set_seed(randseed)
        
        model = make_model(
            ds_train.element_spec[0],
            _y.shape[-1],
            **model_kwargs
        )
        optim = tf.keras.optimizers.Nadam(learning_rate=0.001)
        loss_obj = tf.keras.losses.CategoricalCrossentropy(label_smoothing=LABEL_SMOOTHING)
        model.compile(optimizer=optim, loss=loss_obj, metrics=['accuracy'])
        
        if new:
            best_model_path = f'r2c_lstm_{sess_id}new_split{split_ix}.h5'
        else:
            best_model_path = f'r2c_lstm_{sess_id}old_split{split_ix}.h5'
        callbacks = [
            tf.keras.callbacks.ModelCheckpoint(
                filepath=best_model_path,
                # Path where to save the model
                # The two parameters below mean that we will overwrite
                # the current checkpoint if and only if
                # the `val_loss` score has improved.
                save_best_only=True,
                monitor='val_accuracy',
                verbose=verbose)
        ]

        hist = model.fit(x=ds_train, epochs=EPOCHS,
                         verbose=verbose,
                         validation_data=ds_valid,
                         callbacks=callbacks)
        # tf.keras.models.save_model(model, 'model.h5')
        histories.append(hist.history)
        
        model = tf.keras.models.load_model(best_model_path)
        per_fold_eval.append(model(X_rates[vld]).numpy())
        per_fold_true.append(Y_class[vld])
        
        split_ix += 1
        
    # Combine histories into one dictionary.
    history = {}
    for h in histories:
        for k,v in h.items():
            if k not in history:
                history[k] = v
            else:
                history[k].append(np.nan)
                history[k].extend(v)
                
    pred_y = np.concatenate([np.argmax(_, axis=1) for _ in per_fold_eval])
    true_y = np.concatenate(per_fold_true).flatten()
    accuracy = 100 * np.sum(pred_y == true_y) / len(pred_y)
    print(f"Session {sess_id} overall accuracy: {accuracy}%")
    
    return history, accuracy

#### Train for new dataset

In [None]:
from indl.metrics import quickplot_history

history, accuracy = get_hists_acc(sess_id,new=True, verbose=2)
quickplot_history(history)

#### Train for old dataset

In [None]:
from indl.metrics import quickplot_history

history, accuracy = get_hists_acc(sess_id_old,new=False, verbose=2)
quickplot_history(history)

### Finding fixed points in lstm's recurrent cell states
#### Looking into saved best models for ten splits

In [None]:
# WARNING: Running the next cell may take several hours to finish on CPU
# You can jump to the "load pickle" cell (Four cells later) and the next one after that to plot the results

In [None]:
NCOMP = 2
fxpts = []


for i in range(N_SPLITS):
    print(f'Working on Split #{i}, new data')
    model = tf.keras.models.load_model(f'r2c_lstm_sra3_1_j_050_00+new_split{i}.h5')
    lstm = model.layers[10].get_weights()
    cell_states = lstm[1][:, model_kwargs['n_rnn'] * 2: model_kwargs['n_rnn'] * 3]
    W = cell_states * 10
    fxpt, _ = rfx.run_solver(W)
#     pca = PCA(n_components=NCOMP)
#     tmp = fxpt.T
#     fxpt_pc = pca.fit_transform(tmp)
#     fxpt_pc = fxpt_pc.T
    fxpts.append(fxpt)
    
    print(f'Working on Split #{i}, old data')
    model = tf.keras.models.load_model(f'r2c_lstm_sra3_1_j_050_00_v1+old_split{i}.h5')
    lstm = model.layers[10].get_weights()
    cell_states = lstm[1][:, model_kwargs['n_rnn'] * 2: model_kwargs['n_rnn'] * 3]
    W = cell_states * 10
    fxpt, _ = rfx.run_solver(W)
#     pca = PCA(n_components=NCOMP)
#     tmp = fxpt.T
#     fxpt_pc = pca.fit_transform(tmp)
#     fxpt_pc = fxpt_pc.T
    fxpts.append(fxpt)
    


In [None]:
with open('fixed_points.pkl', 'wb') as f:
    pickle.dump(fxpts, f)

In [None]:
# fig, axs = plt.subplots(2, 5, figsize=(20, 10))
# for i in range(N_SPLITS):
#     axs[int(i/5), i%5].plot(fxpts[2*i][0], fxpts[2*i][1], 'bo')
#     axs[int(i/5), i%5].set_title(f'Split #{i}')
    
#     axs[int(i/5), i%5].plot(fxpts[2*i+1][0], fxpts[2*i+1][1], 'go')
    
# for ax in axs.flat:
#     ax.set(xlabel='Fixed Point PC0', ylabel='Fixed Point PC1')

# for ax in axs.flat:
#     ax.label_outer()
    
# fig.legend(["New Data", "Old Data"], loc = (0.4, 0), ncol=5 )

In [None]:
with open('fixed_points.pkl', 'rb') as f:
    fxpts_list = pickle.load(f)

In [None]:
print(len(fxpts_list), fxpts_list[0].shape)

In [None]:
# fig, axs = plt.subplots(2, 5, figsize=(20, 10))
# for i in range(N_SPLITS):
#     axs[int(i/5), i%5].plot(fxpts_list[2*i][0], fxpts_list[2*i][1], 'bo')
#     axs[int(i/5), i%5].set_title(f'Split #{i}')
    
#     axs[int(i/5), i%5].plot(fxpts_list[2*i+1][0], fxpts_list[2*i+1][1], 'go')
    
# for ax in axs.flat:
#     ax.set(xlabel='Fixed Point PC0', ylabel='Fixed Point PC1')

# for ax in axs.flat:
#     ax.label_outer()
    
# fig.legend(["New Data", "Old Data"], loc = (0.4, 0), ncol=5 )

In [None]:
# model = tf.keras.models.load_model(f'r2c_lstm_sra3_1_j_050_00+new_split0.h5')
# lstm = model.layers[10].get_weights()
# output = model.layers[9].output
# factor_model = tf.keras.Model(model.input, output)
# factor_model.summary()

In [None]:
# factors = factor_model(X_rates)
# print(X_rates.shape, factors.shape)

In [None]:
# inputs = tf.keras.layers.Input(shape=factors.shape[1:])
# state_lstm = tf.keras.layers.LSTM(model_kwargs['n_rnn'],
#                                   kernel_regularizer=tf.keras.regularizers.l2(model_kwargs['l2_reg']),
#                                   recurrent_regularizer=tf.keras.regularizers.l2(model_kwargs['l2_reg']),
#                                   return_sequences=True,
#                                   name='state_rnn1')(inputs)
# state_model = tf.keras.Model(inputs, state_lstm)

# state_model.layers[-1].set_weights(lstm)

# state_model.summary()

In [None]:
# state_outputs = state_model(factors)
# print(state_outputs.shape)
# print(state_outputs[0].shape, state_outputs[1].shape, state_outputs[2].shape)

In [None]:
# tmp = np.reshape(state_outputs, (state_outputs.shape[0] * state_outputs.shape[0], state_outputs.shape[2]))
# print(tmp.shape)

In [None]:
# pca = PCA(n_components=2)
# tmp = pca.fit_transform(tmp)
# print(tmp.shape)

In [None]:
# states_pc = np.reshape(tmp, (state_outputs.shape[0], state_outputs.shape[1], tmp.shape[1]))
# print(states_pc.shape)

In [None]:
# fxpts_list[0].shape

### Looping over all splits for the new and old data

In [None]:
hidden_states = []
# fixed_points = []

for i in range(N_SPLITS):
#     print(f'Working on Split #{i}, new data')
#     model = tf.keras.models.load_model(f'r2c_lstm_sra3_1_j_050_00+_split{i}.h5')
#     lstm = model.layers[10].get_weights()
#     output = model.layers[9].output
#     factor_model = tf.keras.Model(model.input, output)
#     factors = factor_model(X_rates)
#     inputs = tf.keras.layers.Input(shape=factors.shape[1:])
#     state_lstm = tf.keras.layers.LSTM(model_kwargs['n_rnn'],
#                                       kernel_regularizer=tf.keras.regularizers.l2(model_kwargs['l2_reg']),
#                                       recurrent_regularizer=tf.keras.regularizers.l2(model_kwargs['l2_reg']),
#                                       return_sequences=True,
#                                       name='state_rnn1')(inputs)
#     state_model = tf.keras.Model(inputs, state_lstm)
#     state_model.layers[-1].set_weights(lstm)
#     state_outputs = state_model(factors)
    
#     pca = PCA(n_components=2)
#     tmp = np.reshape(state_outputs, (state_outputs.shape[0] * state_outputs.shape[1], state_outputs.shape[2]))
#     st_out_pc = pca.fit_transform(tmp)
#     tmp = np.reshape(st_out_pc, (state_outputs.shape[0], state_outputs.shape[1], st_out_pc.shape[1]))
#     hidden_states.append(tmp)
    
#     tmp = fxpts_list[2*i].T
#     tmp = pca.transform(tmp)
#     fixed_points.append(tmp.T)
    
    print(f'Working on Split #{i}, old data')
    model = tf.keras.models.load_model(f'r2c_lstm_sra3_1_j_050_00_v1+old_split{i}.h5')
    lstm = model.layers[10].get_weights()
    output = model.layers[9].output
    factor_model = tf.keras.Model(model.input, output)
    factors = factor_model(X_rates_old)
    inputs = tf.keras.layers.Input(shape=factors.shape[1:])
    state_lstm = tf.keras.layers.LSTM(model_kwargs['n_rnn'],
                                      kernel_regularizer=tf.keras.regularizers.l2(model_kwargs['l2_reg']),
                                      recurrent_regularizer=tf.keras.regularizers.l2(model_kwargs['l2_reg']),
                                      return_sequences=True,
                                      name='state_rnn1')(inputs)
    state_model = tf.keras.Model(inputs, state_lstm)

    state_model.layers[-1].set_weights(lstm)
    state_outputs = state_model(factors)
    
    pca = PCA(n_components=2)
    tmp = np.reshape(state_outputs, (state_outputs.shape[0] * state_outputs.shape[1], state_outputs.shape[2]))
    st_out_pc = pca.fit_transform(tmp)
    tmp = np.reshape(st_out_pc, (state_outputs.shape[0], state_outputs.shape[1], st_out_pc.shape[1]))
    hidden_states.append(tmp)
    
#     tmp = fxpts_list[2*i+1].T
#     tmp = pca.transform(tmp)
#     fixed_points.append(tmp.T)

In [None]:
print(len(hidden_states), hidden_states[0].shape)#, len(fixed_points), fixed_points[0].shape)

In [None]:
# hidden_states_norm = hidden_states
# fxpts_list_norm = fxpts_list
# for i in range(2*N_SPLITS):
#     hidden_states_norm[i] =  2 * ((hidden_states_norm[i] - np.min(hidden_states_norm[i]))/(np.max(hidden_states_norm[i]) - np.min(hidden_states_norm[i]))) - 1
#     fxpts_list_norm[i] =  2 * ((fxpts_list_norm[i] - np.min(fxpts_list_norm[i]))/(np.max(fxpts_list_norm[i]) - np.min(fxpts_list_norm[i]))) - 1

In [None]:
color_map = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:pink', 'tab:olive', 'tab:cyan', 'tab:purple']
int(Y_new[100])

In [None]:
# for j in range(285):
#     lbl = f'States Class {int(Y_new[j])}'
#     plt.plot(hidden_states[0][j, :, 0], hidden_states[0][j, :, 1], color_map[int(Y_new[j])], label=lbl)
    
# plt.legend()

In [None]:
for j in range(285):
    plt.plot(hidden_states[0][j, :, 0], hidden_states[0][j, :, 1], color_map[int(Y_old[j])])

In [None]:
fig, axs = plt.subplots(2, 5, figsize=(20, 10))
fig.suptitle("New Data")
for i in range(N_SPLITS):
    for j in range(X_rates.shape[0]):
        axs[int(i/5), i%5].plot(hidden_states[2*i][j, :, 0], hidden_states[2*i][j, :, 1], color_map[int(Y_new[j])])
    axs[int(i/5), i%5].set_title(f'Split #{i}')
    
    axs[int(i/5), i%5].plot(fixed_points[2*i][0], fixed_points[2*i][1], 'ko')
    
for ax in axs.flat:
    ax.set(xlabel='PC0', ylabel='PC1')

for ax in axs.flat:
    ax.label_outer()
    
fig.legend(["States Class 0", "States Class 1", "States Class 2", "States Class 3", "States Class 4", "States Class 5",
            "States Class 6", "States Class 7", "Fixed Points"], loc = (0.4, 0), ncol=5 )


fig, axs = plt.subplots(2, 5, figsize=(20, 10))
fig.suptitle("Old Data")
for i in range(N_SPLITS):
    for j in range(X_rates_old.shape[0]):
        axs[int(i/5), i%5].plot(hidden_states[2*i+1][j, :, 0], hidden_states[2*i+1][j, :, 1], color_map[int(Y_old[j])])
    axs[int(i/5), i%5].set_title(f'Split #{i}')
    
    axs[int(i/5), i%5].plot(fixed_points[2*i+1][0], fixed_points[2*i+1][1], 'ko')
    
for ax in axs.flat:
    ax.set(xlabel='PC0', ylabel='PC1')

for ax in axs.flat:
    ax.label_outer()
    
fig.legend(["States Class 0", "States Class 1", "States Class 2", "States Class 3", "States Class 4", "States Class 5",
            "States Class 6", "States Class 7", "Fixed Points"], loc = (0.4, 0), ncol=5 )

# Training a Rule Model

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
import sys
import tensorflow as tf
import numpy as np
import random
import matplotlib
import matplotlib.pyplot as plt
import pickle
from sklearn.decomposition import PCA
from scipy import signal
from scipy import stats
from sklearn.model_selection import train_test_split
from indl.fileio import from_neuropype_h5
from sklearn.model_selection import StratifiedKFold
from sklearn.svm import SVC
from sklearn.model_selection import KFold
from sklearn.manifold import TSNE
from itertools import cycle

import os

if Path.cwd().stem == 'Analysis':
    os.chdir(Path.cwd().parent.parent)
    
    
data_path = Path.cwd() / 'StudyLocationRule'/ 'Data' / 'Preprocessed'
if not (data_path).is_dir():
    !kaggle datasets download --unzip --path {str(data_path)} cboulay/macaque-8a-spikes-rates-and-saccades
    print("Finished downloading and extracting data.")
else:
    print("Data directory found. Skipping download.")
    
from misc.misc import sess_infos, load_macaque_pfc, dec_from_enc

load_kwargs = {
    'valid_outcomes': (0,),  # Use (0, 9) to include trials with incorrect behaviour
    'zscore': False,
    'dprime_range': (1.0, np.inf),  # Use (-np.inf, np.inf) to include all trials.
    'time_range': (-np.inf, np.inf),
    'verbose': False,
    'y_type': 'sacClass',
    'samples_last': True    
    #     'resample_X': 20
}
load_kwargs_ul = {
    'valid_outcomes': (9,),  # Use (0, 9) to include trials with incorrect behaviour
    'zscore': False,
    'dprime_range': (1.0, np.inf),  # Use (-np.inf, np.inf) to include all trials.
    'time_range': (-np.inf, 1.45),
    'verbose': False,
    'y_type': 'sacClass',
    'samples_last': True    
    #     'resample_X': 20
}
load_kwargs_all = {
    'valid_outcomes': (0, 9),  # Use (0, 9) to include trials with incorrect behaviour
    'zscore': False,
    'dprime_range': (1.0, np.inf),  # Use (-np.inf, np.inf) to include all trials.
    'time_range': (-np.inf, 1.45),
    'verbose': False,
    'y_type': 'sacClass',
    'samples_last': True    
    #     'resample_X': 20
}
## Model Parameters
BATCH_SIZE = 16
EPOCHS = 150
LABEL_SMOOTHING = 0.2

model_kwargs = dict(
    filt=8,
    kernLength=25,
    ds_rate=9,
    n_rnn=64,
    n_rnn2=0,
    dropoutRate=0.40,
    activation='relu',
    l1_reg=0.0001, l2_reg=0.001,
    norm_rate=0.25,
    latent_dim=64
)

from indl.model import parts
from indl.regularizers import KernelLengthRegularizer

def make_model(
    _input,
    num_classes,
    filt=8,
    kernLength=25,
    ds_rate=10,
    n_rnn=64,
    n_rnn2=64,
    dropoutRate=0.25,
    activation='relu',
    l1_reg=0.000, l2_reg=0.000,
    norm_rate=0.25,
    latent_dim=16,
    return_model=True
):
    
    inputs = tf.keras.layers.Input(shape=_input.shape[1:])
    
    if _input.shape[2] < 10:
        kernLength = 4
        filt = 4
        ds_rate = 4
    elif _input.shape[2] < 20:
        kernLength = 8
        ds_rate = 8
    elif _input.shape[2] < 30:
        kernLength = 16
    
    input_shape = list(_input.shape)
    # The Conv layers are insensitive to the number of samples in the time dimension.
    # To make it possible for this trained model to be applied to segments of different
    # durations, we need to explicitly state that we don't care about the number of samples.
    # input_shape[2] = -1  # Comment out during debug
    # _y = layers.Reshape(input_shape[1:])(_input)  # Note that Reshape ignores the batch dimension.

    # RNN
    if len(input_shape) < 4:
        input_shape = input_shape + [1]
    # The Conv layers are insensitive to the number of samples in the time dimension.
    # To make it possible for this trained model to be applied to segments of different
    # durations, we need to explicitly state that we don't care about the number of samples.
    _y = tf.keras.layers.Reshape(input_shape[1:])(inputs)
    _y = tf.keras.layers.Conv2D(filt, (1, kernLength), padding='valid', data_format=None,
                                dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer='glorot_uniform',
                                bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None,
                                activity_regularizer=None, kernel_constraint=None, bias_constraint=None)(_y)
    _y = tf.keras.layers.BatchNormalization()(_y)
    _y = tf.keras.layers.DepthwiseConv2D((_y.shape.as_list()[1], 1), padding='valid',
                                      depth_multiplier=1, data_format=None, dilation_rate=(1, 1),
                                      activation=None, use_bias=True, depthwise_initializer='glorot_uniform',
                                      bias_initializer='zeros', depthwise_regularizer=None,
                                      bias_regularizer=None, activity_regularizer=None,
                                      depthwise_constraint=None, bias_constraint=None)(_y)
    _y = tf.keras.layers.BatchNormalization()(_y)
    _y = tf.keras.layers.Activation(activation)(_y)
    _y = tf.keras.layers.AveragePooling2D(pool_size=(1, ds_rate))(_y)
    _y = tf.keras.layers.Dropout(dropoutRate)(_y)
    _y = tf.keras.layers.Reshape(_y.shape.as_list()[2:])(_y)
    _y = tf.keras.layers.LSTM(n_rnn,
                              kernel_regularizer=tf.keras.regularizers.l2(l2_reg),
                              recurrent_regularizer=tf.keras.regularizers.l2(l2_reg),
                              return_sequences=n_rnn2 > 0,
                              stateful=False,
                              name='rnn1')(_y)
    _y = tf.keras.layers.Activation(activation)(_y)
    _y = tf.keras.layers.BatchNormalization()(_y)
    _y = tf.keras.layers.Dropout(dropoutRate)(_y)
    
    
    if n_rnn2 > 0:
        
        _y = tf.keras.layers.LSTM(n_rnn2,
                              kernel_regularizer=tf.keras.regularizers.l2(l2_reg),
                              recurrent_regularizer=tf.keras.regularizers.l2(l2_reg),
                              return_sequences=False,
                              stateful=False,
                              name='rnn2')(_y)
        _y = tf.keras.layers.Activation(activation)(_y)
        _y = tf.keras.layers.BatchNormalization()(_y)
        _y = tf.keras.layers.Dropout(dropoutRate)(_y)
    
    # Dense
    _y = tf.keras.layers.Dense(latent_dim, activation=activation)(_y)
#     _y = parts.Bottleneck(_y, latent_dim=latent_dim, activation=activation)
    
    # Classify
    outputs = tf.keras.layers.Dense(num_classes, activation='softmax')(_y)
#     outputs = parts.Classify(_y, n_classes=num_classes, norm_rate=norm_rate)
    

    if return_model is False:
        return outputs
    else:
        return tf.keras.models.Model(inputs=inputs, outputs=outputs)

In [None]:
test_sess_ix = 0
sess_info = sess_infos[test_sess_ix]
sess_id = sess_info['exp_code']
print(f"\nImporting session {sess_id}")
X_rates, Y, ax_info = load_macaque_pfc(data_path, sess_id, x_chunk='spiketrains', **load_kwargs)
print(X_rates.shape, Y.shape)

In [None]:
test_sess_ix = 1
sess_info = sess_infos[test_sess_ix]
sess_id = sess_info['exp_code']
print(f"\nImporting session {sess_id}")
X_rates, Y_class, ax_info = load_macaque_pfc(data_path, sess_id, x_chunk='spikerates', **load_kwargs)
classes, _y = np.unique(Y_class, return_inverse=True)
# Y_class = tf.keras.utils.to_categorical(Y_class, num_classes=8)

In [None]:
_y.ravel()

In [None]:
classes

In [None]:
Y_class

In [None]:
ax_info

In [None]:
target = np.array(ax_info['instance_data']['TargetRule'])
color = np.array(ax_info['instance_data']['CueColour'])
trial = np.array(ax_info['instance_data']['TrialIndex'])

In [None]:
label = np.zeros(len(target))
for i in range(len(label)):
    if (target[i]=='DD' and color[i]=='g'):
        label[i] = 12
    elif (target[i]=='UU' and color[i]=='b'):
        label[i] = 1
    elif (target[i]=='DR' and color[i]=='r'):
        label[i] = 2
    elif (target[i]=='UL' and color[i]=='b'):
        label[i] = 3
    elif (target[i]=='DL' and color[i]=='b'):
        label[i] = 4
    elif (target[i]=='UR' and color[i]=='g'):
        label[i] = 5
    elif (target[i]=='LL' and color[i]=='r'):
        label[i] = 6
    elif (target[i]=='RR' and color[i]=='g'):
        label[i] = 7
    elif (target[i]=='UU' and color[i]=='r'):
        label[i] = 8
    elif (target[i]=='DR' and color[i]=='g'):
        label[i] = 9
    elif (target[i]=='DL' and color[i]=='r'):
        label[i] = 10
    elif (target[i]=='UR' and color[i]=='r'):
        label[i] = 11
        
keep_idx = np.argwhere(label>0).flatten()

new_label = label[keep_idx].flatten().astype(int)
new_X = X_rates[keep_idx]
new_Y = Y[keep_idx].flatten()
new_target = target[keep_idx]
new_color = color[keep_idx]
new_trial = trial[keep_idx]

zer_idx = np.argwhere(new_label==12).flatten()
new_label[zer_idx] = 0

print(new_X.shape, new_label.shape, new_Y.shape, new_target.shape, new_color.shape, new_trial.shape)

In [None]:
spikes = np.zeros((np.size(new_X,0), np.size(new_X,1), np.size(new_X,2)//5 + 1))
for tr in range(np.size(new_X,0)):
    for ch in range(np.size(new_X,1)):
        spk_idx = np.argwhere(new_X[tr,ch,:]==1).flatten()
        spk_idx = spk_idx // 5
        spikes[tr,ch,spk_idx] = 1

In [None]:
for ch in range(np.size(spikes,1)):
    for t in range(np.size(spikes,2)):
        if spikes[0,ch,t]==1:
            plt.plot(t, ch, '|')
plt.vlines(120,-1,36,'grey')
plt.vlines(170, -1, 36, 'grey')
plt.vlines(370, -1, 36, 'grey')
plt.show()

In [None]:
Y_class = tf.keras.utils.to_categorical(new_Y, num_classes=8)

ds_train = tf.data.Dataset.from_tensor_slices((spikes, Y_class))

# cast data types to GPU-friendly types.
ds_train = ds_train.map(lambda x, y: (tf.cast(x, tf.float32), tf.cast(y, tf.uint8)))

# TODO: augmentations (random slicing?)

ds_train = ds_train.shuffle(len(new_Y) + 1)
ds_train = ds_train.batch(BATCH_SIZE, drop_remainder=True)

tf.keras.backend.clear_session()

randseed = 12345
random.seed(randseed)
np.random.seed(randseed)
tf.random.set_seed(randseed)

model = make_model(
    spikes,
    Y_class.shape[-1],
    **model_kwargs
)
optim = tf.keras.optimizers.Nadam(learning_rate=0.001)
loss_obj = tf.keras.losses.CategoricalCrossentropy(label_smoothing=LABEL_SMOOTHING)
model.compile(optimizer=optim, loss=loss_obj, metrics=['accuracy'])


hist = model.fit(x=ds_train, epochs=EPOCHS, verbose=1)

In [None]:
model.summary()

In [None]:
lstm = model.layers[10].get_weights()
output = model.layers[9].output
factor_model = tf.keras.Model(model.input, output)
factors = factor_model(spikes)
inputs = tf.keras.layers.Input(shape=factors.shape[1:])
state_lstm = tf.keras.layers.LSTM(model_kwargs['n_rnn'],
                                  kernel_regularizer=tf.keras.regularizers.l2(model_kwargs['l2_reg']),
                                  recurrent_regularizer=tf.keras.regularizers.l2(model_kwargs['l2_reg']),
                                  return_sequences=True,
                                  name='state_rnn1')(inputs)
state_model = tf.keras.Model(inputs, state_lstm)
state_model.layers[-1].set_weights(lstm)
state_outputs = state_model(factors)

In [None]:
print(np.array(state_outputs).shape,np.array(factors).shape)

In [None]:
plt.plot(state_outputs[0],'|')
plt.show()

In [None]:
plt.plot(factors[0],'|')
plt.show()

In [None]:
pca = PCA(n_components=32)
tmp = np.reshape(state_outputs, (state_outputs.shape[0] * state_outputs.shape[1], state_outputs.shape[2]))
pca_values = pca.fit_transform(tmp)
tsne_model = TSNE(n_components=2, perplexity=10)
tsne_values = tsne_model.fit_transform(pca_values)
hidden_states = np.reshape(tsne_values, (state_outputs.shape[0], state_outputs.shape[1], tsne_values.shape[1]))

In [None]:
print(hidden_states.shape)

In [None]:
color_map = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:pink', 'tab:olive', 'tab:cyan', 'tab:purple']
for tr in range(np.size(hidden_states,0)):
    plt.plot(hidden_states[tr,:,0], hidden_states[tr,:,1], '.-', color=color_map[new_Y[tr]])
plt.show()

In [None]:
max(new_Y)

In [None]:
lstm = tf.keras.models.Model(inputs=model.inputs, outputs=model.layers[10].output)(new_X)
print(lstm.shape)

In [None]:
plt.plot(lstm[100,:,1])

In [None]:
lstm = np.array(lstm)

In [None]:
from scipy.io import savemat
data = {'lstm': lstm,
        'label': label,
        'spikes': new_X,
        'y': new_Y,
        'trial': new_trial,
        'target': new_target,
        'color': new_color}
name = 'lstm_output.mat'
savemat(name, data)

In [None]:
from scipy.io import savemat

for test_sess_ix in range(8):
    sess_info = sess_infos[test_sess_ix]
    sess_id = sess_info['exp_code']
    print(f"\nImporting session {sess_id}")
    X, Y, ax_info = load_macaque_pfc(data_path, sess_id, x_chunk='spiketrains', **load_kwargs)
    target = np.array(ax_info['instance_data']['TargetRule'])
    color = np.array(ax_info['instance_data']['CueColour'])
    trial = np.array(ax_info['instance_data']['TrialIndex'])
    targets = np.zeros(len(target))
    for i in range(len(targets)):
        if (target[i]=='UU'):
            targets[i] = 0
        elif (target[i]=='UR'):
            targets[i] = 1
        elif (target[i]=='RR'):
            targets[i] = 2
        elif (target[i]=='DR'):
            targets[i] = 3
        elif (target[i]=='DD'):
            targets[i] = 4
        elif (target[i]=='DL'):
            targets[i] = 5
        elif (target[i]=='LL'):
            targets[i] = 6
        elif (target[i]=='UL'):
            targets[i] = 7
    spikes = np.zeros((np.size(X,0), np.size(X,1), np.size(X,2)//5 + 1))
    for tr in range(np.size(X,0)):
        for ch in range(np.size(X,1)):
            spk_idx = np.argwhere(X[tr,ch,:]==1).flatten()
            spk_idx = spk_idx // 5
            spikes[tr,ch,spk_idx] = 1
    data = {'spikes': spikes,
            'saccades': Y,
            'targets': targets,
            'colors': color,
           'trial': trial}
    name = f'{sess_id.replace("+", "")}_cor.mat'
    savemat(name, data)

In [None]:
sess_infos

## Manual Loading

In [None]:
import tensorflow as tf

In [None]:
gpu_devices = tf.config.experimental.list_physical_devices("GPU")
for device in gpu_devices:
    tf.config.experimental.set_memory_growth(device, True)

In [None]:
# from tensorflow.compat.v1 import ConfigProto
# from tensorflow.compat.v1 import InteractiveSession
# config = ConfigProto()
# config.gpu_options.allow_growth = True
# session = InteractiveSession(config=config)

In [None]:
import indl

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
import sys
import tensorflow as tf
import numpy as np
import random
import matplotlib
import matplotlib.pyplot as plt
import pickle
from sklearn.decomposition import PCA
from scipy import signal
from scipy import stats
from sklearn.model_selection import train_test_split
from indl.fileio import from_neuropype_h5
from sklearn.model_selection import StratifiedKFold
from sklearn.svm import SVC
from sklearn.model_selection import KFold
from sklearn.manifold import TSNE
from sklearn.decomposition import FactorAnalysis
from itertools import cycle

import os

if Path.cwd().stem == 'Analysis':
    os.chdir(Path.cwd().parent.parent)
    
    
data_path = Path.cwd() / 'StudyLocationRule'/ 'Data' / 'Preprocessed'
if not (data_path).is_dir():
    !kaggle datasets download --unzip --path {str(data_path)} cboulay/macaque-8a-spikes-rates-and-saccades
    print("Finished downloading and extracting data.")
else:
    print("Data directory found. Skipping download.")
    
from misc.misc import sess_infos, load_macaque_pfc, dec_from_enc

load_kwargs = {
    'valid_outcomes': (0,),  # Use (0, 9) to include trials with incorrect behaviour
    'zscore': True,
    'dprime_range': (1.0, np.inf),  # Use (-np.inf, np.inf) to include all trials.
    'time_range': (-np.inf, 1.25),
    'verbose': False,
    'y_type': 'sacClass',
    'samples_last': True    
    #     'resample_X': 20
}

load_kwargs_error = {
    'valid_outcomes': (9,),  # Use (0, 9) to include trials with incorrect behaviour
    'zscore': True,
    'dprime_range': (1.0, np.inf),  # Use (-np.inf, np.inf) to include all trials.
    'time_range': (-np.inf, 1.25),
    'verbose': False,
    'y_type': 'sacClass',
    'samples_last': True    
    #     'resample_X': 20
}

load_kwargs_all = {
    'valid_outcomes': (0,9),  # Use (0, 9) to include trials with incorrect behaviour
    'zscore': True,
    'dprime_range': (1.0, np.inf),  # Use (-np.inf, np.inf) to include all trials.
    'time_range': (-np.inf, 1.25),
    'verbose': False,
    'y_type': 'sacClass',
    'samples_last': True    
    #     'resample_X': 20
}

model_kwargs = dict(
    filt=8,
    kernLength=20,
    ds_rate=5,
    n_rnn=64,
    n_rnn2=0,
    dropoutRate=0.40,
    activation='relu',
    l1_reg=0.0000, l2_reg=0.001,
    norm_rate=0.25,
    latent_dim=64
)
model_kwargs1 = dict(
    filt=16,
    kernLength=30,
    ds_rate=5,
    n_rnn=64,
    n_rnn2=64,
    dropoutRate=0.40,
    activation='relu',
    l1_reg=0.0000, l2_reg=0.001,
    norm_rate=0.25,
    latent_dim=64
)
model_kwargs2 = dict(
    filt=32,
    kernLength=30,
    ds_rate=5,
    n_rnn=64,
    n_rnn2=64,
    dropoutRate=0.40,
    activation='relu',
    l1_reg=0.0000, l2_reg=0.001,
    norm_rate=0.25,
    latent_dim=64
)

N_SPLITS = 10
BATCH_SIZE = 16
EPOCHS = 150
EPOCHS2 = 100
LABEL_SMOOTHING = 0.2

In [None]:
tf.test.is_gpu_available(
    cuda_only=False, min_cuda_compute_capability=None
)

In [None]:
# physical_devices = tf.config.experimental.list_physical_devices('GPU')
# print("Num GPUs Available: ", len(physical_devices))
# tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [None]:
test_sess_ix = 4
sess_info = sess_infos[test_sess_ix]
sess_id = sess_info['exp_code']
segmented_path = Path.cwd() / 'StudyLocationRule' / 'Data' / 'Preprocessed' / 'sra3_1_m_077_0001_segmented.h5'

segmented_data = from_neuropype_h5(segmented_path)
outcome = np.array(segmented_data[2][1]['axes'][0]['data']['OutcomeCode'])
flag = np.argwhere(outcome>-1).flatten()
outcome = outcome[flag]
Y = np.array(segmented_data[2][1]['axes'][0]['data']['TargetClass']).flatten()[flag]
Y_class = tf.keras.utils.to_categorical(Y, num_classes=8)
X = segmented_data[2][1]['data'][flag]
X = np.nan_to_num(X)
X = np.transpose(X, (0, 2, 1))
block = np.array(segmented_data[2][1]['axes'][0]['data']['Block']).flatten()[flag]
b=np.diff(block, axis=0)
border=np.array(np.where(b>0)).flatten()
to_keep = [0]
for i in range(len(border)-2):
    if (border[i+1] - border[i]) > 30:
        to_keep.append(i+1)
if (len(outcome)-border[-1] > 30):
    to_keep.append(-1)
border = border[to_keep]
color = np.array(segmented_data[2][1]['axes'][0]['data']['CueColour']).flatten()[flag]
target = np.array(segmented_data[2][1]['axes'][0]['data']['TargetRule']).flatten()[flag]
classes = np.array(segmented_data[2][1]['axes'][0]['data']['TargetClass']).flatten()[flag]

print(border)
print(sess_id)
print(outcome.shape)
print(X.shape, Y.shape, np.unique(Y, return_counts=True))

In [None]:
rule = np.zeros(np.size(X,0))
for i in range(len(rule)):
    if (target[i]=='UU'):
        if (color[i] == 'r'):
            rule[i]=0
        elif(color[i] == 'g'):
            rule[i]=1
        else:
            rule[i]=2
    elif (target[i]=='UR'):
        if (color[i] == 'r'):
            rule[i]=3
        elif(color[i] == 'g'):
            rule[i]=4
        else:
            rule[i]=5
    elif (target[i]=='RR'):
        if (color[i] == 'r'):
            rule[i]=6
        elif(color[i] == 'g'):
            rule[i]=7
        else:
            rule[i]=8
    elif (target[i]=='DR'):
        if (color[i] == 'r'):
            rule[i]=9
        elif(color[i] == 'g'):
            rule[i]=10
        else:
            rule[i]=11
    elif (target[i]=='DD'):
        if (color[i] == 'r'):
            rule[i]=12
        elif(color[i] == 'g'):
            rule[i]=13
        else:
            rule[i]=14
    elif (target[i]=='DL'):
        if (color[i] == 'r'):
            rule[i]=15
        elif(color[i] == 'g'):
            rule[i]=16
        else:
            rule[i]=17
    elif (target[i]=='LL'):
        if (color[i] == 'r'):
            rule[i]=18
        elif(color[i] == 'g'):
            rule[i]=19
        else:
            rule[i]=20
    elif (target[i]=='UL'):
        if (color[i] == 'r'):
            rule[i]=21
        elif(color[i] == 'g'):
            rule[i]=22
        else:
            rule[i]=23
rule = rule.astype(int)

In [None]:
np.unique(rule, return_counts=True)

In [None]:
rules = np.zeros(len(rule))
unique = np.unique(rule)
for i in range(len(rules)):
    for j in range(len(unique)):
        if rule[i] == unique[j]:
            rules[i] = j
rules = rules.astype(int)
np.unique(rules, return_counts=True)

In [None]:
m_performance = np.zeros(len(outcome))
cor = 0
b=0
tot = 25
for i in range(tot):
    if outcome[i]==0:
        cor += 1

m_performance[:tot] = 100 * (cor / tot)
# for i in range(tot, len(outcome)):
i = tot
while i<len(outcome):
    if i == border[b]:
        cor = 0
        for j in range(tot):
            if outcome[i+j]==0:
                cor += 1
        m_performance[i:i+tot] = 100 * (cor / tot)
        i += tot
        b = (b+1)%len(border)
    elif outcome[i] == outcome[i-tot]:
        m_performance[i] = m_performance[i-1]
        i += 1
    elif outcome[i]==0:
        cor += 1
        m_performance[i] = 100 * (cor / tot)
        i += 1
    else:
        cor -= 1
        m_performance[i] = 100 * (cor / tot)
        i +=1

plt.plot(m_performance)
plt.hlines(70,-1,1300,'grey','dashed')
plt.hlines(60,-1,1300,'grey','dashed')
learned = np.argwhere(m_performance>69).flatten()
unlearned = np.argwhere(m_performance<61).flatten()
print(len(learned),len(unlearned))

In [None]:
def make_model(
    _input,
    num_classes,
    filt=8,
    kernLength=25,
    ds_rate=10,
    n_rnn=64,
    n_rnn2=64,
    dropoutRate=0.25,
    activation='relu',
    l1_reg=0.000, l2_reg=0.000,
    norm_rate=0.25,
    latent_dim=16,
    return_model=True
):
    
    inputs = tf.keras.layers.Input(shape=_input.shape[1:])
    
    if _input.shape[2] < 10:
        kernLength = 4
        filt = 4
        ds_rate = 4
    elif _input.shape[2] < 20:
        kernLength = 8
        ds_rate = 8
    elif _input.shape[2] < 30:
        kernLength = 16
    
    input_shape = list(_input.shape)
    # The Conv layers are insensitive to the number of samples in the time dimension.
    # To make it possible for this trained model to be applied to segments of different
    # durations, we need to explicitly state that we don't care about the number of samples.
    # input_shape[2] = -1  # Comment out during debug
    # _y = layers.Reshape(input_shape[1:])(_input)  # Note that Reshape ignores the batch dimension.

    # RNN
    if len(input_shape) < 4:
        input_shape = input_shape + [1]
    # The Conv layers are insensitive to the number of samples in the time dimension.
    # To make it possible for this trained model to be applied to segments of different
    # durations, we need to explicitly state that we don't care about the number of samples.
    _y = tf.keras.layers.Reshape(input_shape[1:])(inputs)
    _y = tf.keras.layers.Conv2D(filt, (1, kernLength), padding='valid', data_format=None,
                                dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer='glorot_uniform',
                                bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None,
                                activity_regularizer=None, kernel_constraint=None, bias_constraint=None)(_y)
    _y = tf.keras.layers.BatchNormalization()(_y)
    _y = tf.keras.layers.DepthwiseConv2D((_y.shape.as_list()[1], 1), padding='valid',
                                      depth_multiplier=1, data_format=None, dilation_rate=(1, 1),
                                      activation=None, use_bias=True, depthwise_initializer='glorot_uniform',
                                      bias_initializer='zeros', depthwise_regularizer=None,
                                      bias_regularizer=None, activity_regularizer=None,
                                      depthwise_constraint=None, bias_constraint=None)(_y)
    _y = tf.keras.layers.BatchNormalization()(_y)
    _y = tf.keras.layers.Activation(activation)(_y)
    _y = tf.keras.layers.AveragePooling2D(pool_size=(1, ds_rate))(_y)
    _y = tf.keras.layers.Dropout(dropoutRate)(_y)
    _y = tf.keras.layers.Reshape(_y.shape.as_list()[2:])(_y)
    _y = tf.keras.layers.LSTM(n_rnn,
                              kernel_regularizer=tf.keras.regularizers.l2(l2_reg),
                              recurrent_regularizer=tf.keras.regularizers.l2(l2_reg),
                              return_sequences=n_rnn2 > 0,
                              stateful=False,
                              name='rnn1')(_y)
    _y = tf.keras.layers.Activation(activation)(_y)
    _y = tf.keras.layers.BatchNormalization()(_y)
    _y = tf.keras.layers.Dropout(dropoutRate)(_y)
    
    
    if n_rnn2 > 0:
        
        _y = tf.keras.layers.LSTM(n_rnn2,
                              kernel_regularizer=tf.keras.regularizers.l2(l2_reg),
                              recurrent_regularizer=tf.keras.regularizers.l2(l2_reg),
                              return_sequences=False,
                              stateful=False,
                              name='rnn2')(_y)
        _y = tf.keras.layers.Activation(activation)(_y)
        _y = tf.keras.layers.BatchNormalization()(_y)
        _y = tf.keras.layers.Dropout(dropoutRate)(_y)
    
    # Dense
    _y = tf.keras.layers.Dense(latent_dim, activation=activation)(_y)
#     _y = parts.Bottleneck(_y, latent_dim=latent_dim, activation=activation)
    
    # Classify
    outputs = tf.keras.layers.Dense(num_classes, activation='softmax')(_y)
#     outputs = parts.Classify(_y, n_classes=num_classes, norm_rate=norm_rate)
    

    if return_model is False:
        return outputs
    else:
        return tf.keras.models.Model(inputs=inputs, outputs=outputs)

def make_model2(
    _input,
    num_classes,
    filt=32,
    kernLength=16,
    n_rnn=32,
    n_rnn2=0,
    dropoutRate=0.1,
    activation='tanh',
    l1_reg=0.010, l2_reg=0.010,
    norm_rate=0.25,
    latent_dim=32,
    return_model=True
):
    
    inputs = tf.keras.layers.Input(shape=_input.shape[1:])
    
#     if _input.shape[2] < 10:
#         kernLength = 4
#         filt = 4
#         ds_rate = 4
#     elif _input.shape[2] < 20:
#         kernLength = 8
#         ds_rate = 8
#     elif _input.shape[2] < 30:
#         kernLength = 16
    
#     input_shape = list(_input.shape)
    # The Conv layers are insensitive to the number of samples in the time dimension.
    # To make it possible for this trained model to be applied to segments of different
    # durations, we need to explicitly state that we don't care about the number of samples.
    # input_shape[2] = -1  # Comment out during debug
    # _y = layers.Reshape(input_shape[1:])(_input)  # Note that Reshape ignores the batch dimension.

    # RNN
#     if len(input_shape) < 4:
#         input_shape = input_shape + [1]
    # The Conv layers are insensitive to the number of samples in the time dimension.
    # To make it possible for this trained model to be applied to segments of different
    # durations, we need to explicitly state that we don't care about the number of samples.
#     _y = tf.keras.layers.Reshape(input_shape[1:])(inputs)
    _y = tf.keras.layers.Conv1D(filt, kernLength, strides=1, padding='valid',
                                data_format='channels_last', dilation_rate=1, groups=1,
                                activation=None, use_bias=True, kernel_initializer='glorot_uniform',
                                bias_initializer='zeros', kernel_regularizer=None,
                                bias_regularizer=None, activity_regularizer=None, kernel_constraint=None,
                                bias_constraint=None)(inputs)
    _y = tf.keras.layers.BatchNormalization()(_y)
    _y = tf.keras.layers.Dropout(dropoutRate)(_y)
    _y = tf.keras.layers.LSTM(n_rnn,
                              kernel_regularizer=tf.keras.regularizers.l2(l2_reg),
                              recurrent_regularizer=tf.keras.regularizers.l2(l2_reg),
                              return_sequences=n_rnn2 > 0,
                              stateful=False,
                              name='rnn1')(_y)
    _y = tf.keras.layers.Activation(activation)(_y)
    _y = tf.keras.layers.BatchNormalization()(_y)
    _y = tf.keras.layers.Dropout(dropoutRate)(_y)
    
    
    if n_rnn2 > 0:
        
        _y = tf.keras.layers.LSTM(n_rnn2,
                              kernel_regularizer=tf.keras.regularizers.l2(l2_reg),
                              recurrent_regularizer=tf.keras.regularizers.l2(l2_reg),
                              return_sequences=False,
                              stateful=False,
                              name='rnn2')(_y)
        _y = tf.keras.layers.Activation(activation)(_y)
        _y = tf.keras.layers.BatchNormalization()(_y)
        _y = tf.keras.layers.Dropout(dropoutRate)(_y)
    
    # Dense
    _y = tf.keras.layers.Dense(latent_dim, activation=activation)(_y)
#     _y = parts.Bottleneck(_y, latent_dim=latent_dim, activation=activation)
    
    # Classify
    outputs = tf.keras.layers.Dense(num_classes, activation='softmax')(_y)
#     outputs = parts.Classify(_y, n_classes=num_classes, norm_rate=norm_rate)
    

    if return_model is False:
        return outputs
    else:
        return tf.keras.models.Model(inputs=inputs, outputs=outputs)


def kfold_pred(sess_id,X_rates,Y_class,name, verbose=1):
    splitter = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=0)
    split_ix = 0
    histories = []
    per_fold_eval = []
    per_fold_true = []

    for trn, vld in splitter.split(X_rates, Y_class):
        print(f"\tSplit {split_ix + 1} of {N_SPLITS}")
        _y = tf.keras.utils.to_categorical(Y_class, num_classes=np.max(Y_class)+1)
        
        ds_train = tf.data.Dataset.from_tensor_slices((X_rates[trn], _y[trn]))
        ds_valid = tf.data.Dataset.from_tensor_slices((X_rates[vld], _y[vld]))

        # cast data types to GPU-friendly types.
        ds_train = ds_train.map(lambda x, y: (tf.cast(x, tf.float32), tf.cast(y, tf.uint8)))
        ds_valid = ds_valid.map(lambda x, y: (tf.cast(x, tf.float32), tf.cast(y, tf.uint8)))

        # TODO: augmentations (random slicing?)

        ds_train = ds_train.shuffle(len(trn) + 1)
        ds_train = ds_train.batch(BATCH_SIZE, drop_remainder=True)
        ds_valid = ds_valid.batch(BATCH_SIZE, drop_remainder=False)

        tf.keras.backend.clear_session()
        
        randseed = 12345
        random.seed(randseed)
        np.random.seed(randseed)
        tf.random.set_seed(randseed)
        
        model = make_model2(X_rates, _y.shape[-1])
        optim = tf.keras.optimizers.Nadam(learning_rate=0.001)
        loss_obj = tf.keras.losses.CategoricalCrossentropy(label_smoothing=LABEL_SMOOTHING)
        model.compile(optimizer=optim, loss=loss_obj, metrics=['accuracy'])
        
        best_model_path = f'{name}_{sess_id}_split{split_ix}.h5'
        callbacks = [
            tf.keras.callbacks.ModelCheckpoint(
                filepath=best_model_path,
                # Path where to save the model
                # The two parameters below mean that we will overwrite
                # the current checkpoint if and only if
                # the `val_loss` score has improved.
                save_best_only=True,
                monitor='val_accuracy',
                verbose=verbose)
        ]

        hist = model.fit(x=ds_train, epochs=EPOCHS,
                         verbose=verbose,
                         validation_data=ds_valid,
                         callbacks=callbacks)
        # tf.keras.models.save_model(model, 'model.h5')
        histories.append(hist.history)
        
        model = tf.keras.models.load_model(best_model_path)
        per_fold_eval.append(model(X_rates[vld]).numpy())
        per_fold_true.append(Y_class[vld])
        
        split_ix += 1
        
    # Combine histories into one dictionary.
    history = {}
    for h in histories:
        for k,v in h.items():
            if k not in history:
                history[k] = v
            else:
                history[k].append(np.nan)
                history[k].extend(v)
                
    pred_y = np.concatenate([np.argmax(_, axis=1) for _ in per_fold_eval])
    true_y = np.concatenate(per_fold_true).flatten()
    accuracy = 100 * np.sum(pred_y == true_y) / len(pred_y)
    print(f"\n\nSession {sess_id} overall accuracy with CNN/LSTM Model: {accuracy}%")
    
    return history, accuracy, pred_y, true_y

In [None]:
X[learned].shape

In [None]:
np.unique(rules, return_counts=True)

In [None]:
_X = np.transpose(X,(0,2,1))
model = make_model2(_X,10)
model.summary()

In [None]:
N_SPLITS = 10
hist, acc, y_pred, y_true = kfold_pred(sess_id,_X,rules,name='rd_all' ,verbose=1)

In [None]:
hidden_states = []

for i in range(N_SPLITS):
    print(f'Working on Split #{i}')
    model = tf.keras.models.load_model(f'rd_all_sra3_1_m_077_00+01_split{i}.h5')
    lstm = model.layers[4].get_weights()
    output = model.layers[3].output
    factor_model = tf.keras.Model(model.input, output)
    factors = factor_model.predict(_X)
    inputs = tf.keras.layers.Input(shape=factors.shape[1:])
    state_lstm = tf.keras.layers.LSTM(32,
                                      kernel_regularizer=tf.keras.regularizers.l2(0.01),
                                      recurrent_regularizer=tf.keras.regularizers.l2(0.01),
                                      return_sequences=True,
                                      name='state_rnn1')(inputs)
    state_model = tf.keras.Model(inputs, state_lstm)
    state_model.layers[-1].set_weights(lstm)
    state_outputs = state_model(factors)
    
    pca = PCA(n_components=2)
    tmp = np.reshape(state_outputs, (state_outputs.shape[0] * state_outputs.shape[1], state_outputs.shape[2]))
    st_out_pc = pca.fit_transform(tmp)
    tmp = np.reshape(st_out_pc, (state_outputs.shape[0], state_outputs.shape[1], st_out_pc.shape[1]))
    hidden_states.append(tmp)

In [None]:
print(len(hidden_states), hidden_states[0].shape)

In [None]:
hidden_states_norm = hidden_states
for i in range(N_SPLITS):
    hidden_states_norm[i] =  2 * ((hidden_states_norm[i] - np.min(hidden_states_norm[i]))/(np.max(hidden_states_norm[i]) - np.min(hidden_states_norm[i]))) - 1

In [None]:
np.unique(Y,return_counts=True)

In [None]:
color_map = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:pink', 'tab:olive', 'tab:cyan', 'tab:purple']
for j in range(len(hidden_states_norm[0])):
#     lbl = f'States Class {int(Y_new[j])}'
    plt.plot(hidden_states[0][j, ::15, 0], hidden_states[0][j, :, 1], color_map[int(Y[j])])
plt.show()    
# plt.legend()

In [None]:
model_kwargs = dict(
    filt=32,
    kernLength=16,
    ds_rate=8,
    n_rnn=200,
    n_rnn2=200,
    dropoutRate=0.40,
    activation='relu',
    l1_reg=0.0001, l2_reg=0.001,
    norm_rate=0.25,
    latent_dim=64
)
model = make_model(X, Y_class.shape[-1], **model_kwargs)
model.summary()

In [None]:
N_SPLITS = 5
hist, acc, y_pred, y_true = kfold_pred(sess_id,X[learned],rule[learned],name='rd_l' ,verbose=1)

In [None]:
hist, acc, y_pred, y_true = kfold_pred(sess_id,X[unlearned],rule[unlearned],name='rd_ul' ,verbose=1)

In [None]:
model = tf.keras.models.load_model(f'rd_l_sra3_1_j_050_00+_split0.h5')
output = model.layers[10].output
factor_model = tf.keras.Model(model.input, output)
inpt = np.transpose(X[learned],(0,2,1))
XL = factor_model(X[learned])
YL = Y[learned]
pca_values = np.zeros((XL.shape[0],XL.shape[1], 32))
pca =PCA(n_components=32)
# FAs = [FactorAnalysis(n_components=10, random_state=0),FactorAnalysis(n_components=10, random_state=0),FactorAnalysis(n_components=10, random_state=0),FactorAnalysis(n_components=10, random_state=0),FactorAnalysis(n_components=10, random_state=0),FactorAnalysis(n_components=10, random_state=0),FactorAnalysis(n_components=10, random_state=0),FactorAnalysis(n_components=10, random_state=0)]
# TSNEs = [TSNE(n_components=3,perplexity=10),TSNE(n_components=3,perplexity=10),TSNE(n_components=3,perplexity=10),TSNE(n_components=3,perplexity=10),TSNE(n_components=3,perplexity=10),TSNE(n_components=3,perplexity=10),TSNE(n_components=3,perplexity=10),TSNE(n_components=3,perplexity=10)]
for i in range(XL.shape[0]):
    print(f'Trial {i}')
    trial = np.squeeze(XL[i])
    if i==0:
        pca_values[i] = pca.fit_transform(trial)
#         print('Factor Analysis')
#         pfa_values[i] = FAs[YL[i]].fit_transform(pca_values)
    else:
        pca_values[i] = pca.transform(trial)

In [None]:
factor_model.summary()
print(XL.shape)
a = np.transpose(X[learned],(0,2,1))
print(a.shape)

In [None]:
hidden_states = []

for i in range(N_SPLITS):
    print(f'Working on Split #{i}, Learned Trials')
    model = tf.keras.models.load_model(f'rd_l_sra3_1_j_050_00+_split{i}.h5')
    output = model.layers[10].output
    factor_model = tf.keras.Model(model.input, output)
    factors = factor_model(X[learned])
    
    pca = PCA(n_components=50)
    tmp = np.reshape(factors, (factors.shape[0] * factors.shape[1], factors.shape[2]))
    pca_values = pca.fit_transform(tmp)
    fa_values = FactorAnalysis(n_components=10, random_state=0).fit_transform(pca_values)
    tsne_values = TSNE(n_components=3,perplexity=10).fit_transform(fa_values)
    tmp = np.reshape(tsne_values, (factors.shape[0], factors.shape[1], fa_values.shape[1]))
    hidden_states.append(tmp)

In [None]:
XL = np.transpose(X[learned], (0, 2, 1))
YL = Y[learned]
pca_values = np.zeros((XL.shape[0],XL.shape[1], 32))
PCAs = [PCA(n_components=32),PCA(n_components=32),PCA(n_components=32),PCA(n_components=32),PCA(n_components=32),PCA(n_components=32),PCA(n_components=32),PCA(n_components=32)]
# FAs = [FactorAnalysis(n_components=10, random_state=0),FactorAnalysis(n_components=10, random_state=0),FactorAnalysis(n_components=10, random_state=0),FactorAnalysis(n_components=10, random_state=0),FactorAnalysis(n_components=10, random_state=0),FactorAnalysis(n_components=10, random_state=0),FactorAnalysis(n_components=10, random_state=0),FactorAnalysis(n_components=10, random_state=0)]
# TSNEs = [TSNE(n_components=3,perplexity=10),TSNE(n_components=3,perplexity=10),TSNE(n_components=3,perplexity=10),TSNE(n_components=3,perplexity=10),TSNE(n_components=3,perplexity=10),TSNE(n_components=3,perplexity=10),TSNE(n_components=3,perplexity=10),TSNE(n_components=3,perplexity=10)]
for i in range(XL.shape[0]):
    print(f'Trial {i}')
    trial = np.squeeze(XL[i])
    if i==np.argmax(YL==YL[i]):
        print(f'Found first trial of class {YL[i]}')
        print('PCA')
        pca_values[i] = PCAs[YL[i]].fit_transform(trial)
#         print('Factor Analysis')
#         pfa_values[i] = FAs[YL[i]].fit_transform(pca_values)
    else:
        print('PCA')
        pca_values[i] = PCAs[YL[i]].transform(trial)
#         print('Factor Analysis')
#         pfa_values[i] = FAs[YL[i]].transform(pca_values)
#     print('tSNE')
#     ptfa_values[i] = TSNEs[YL[i]].fit_transform(fa_values)

In [None]:
i==np.argmax(YL==YL[i])

In [None]:
YL

In [None]:
color_map = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:pink', 'tab:olive', 'tab:cyan', 'tab:purple']
for tr in range(285):
    plt.plot(pca_values[tr,:,0], pca_values[tr,:,1], '.-', color=color_map[YL[tr]])
plt.show()

In [None]:
i

In [None]:
model.summary()