# Imports, Paths, Variables, and Functions

In [None]:
import tensorflow as tf

In [None]:
gpu_devices = tf.config.experimental.list_physical_devices("GPU")
for device in gpu_devices:
    tf.config.experimental.set_memory_growth(device, True)
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

In [None]:
import indl
from pathlib import Path
import sys
import numpy as np
import random
import matplotlib
import matplotlib.pyplot as plt
import pickle
from sklearn.decomposition import PCA
from scipy import signal
from scipy import stats
from sklearn.model_selection import train_test_split
from indl.fileio import from_neuropype_h5
from sklearn.model_selection import StratifiedKFold
from sklearn.svm import SVC
from sklearn.model_selection import KFold
from sklearn.manifold import TSNE
from sklearn.decomposition import FactorAnalysis
from itertools import cycle

import os

if Path.cwd().stem == 'Analysis':
    os.chdir(Path.cwd().parent.parent)

from misc.misc import sess_infos, load_macaque_pfc, dec_from_enc    
    
data_path = Path.cwd() / 'StudyLocationRule'/ 'Data' / 'Preprocessed'
if not (data_path).is_dir():
    !kaggle datasets download --unzip --path {str(data_path)} cboulay/macaque-8a-spikes-rates-and-saccades
    print("Finished downloading and extracting data.")
else:
    print("Data directory found. Skipping download.")
    

load_kwargs = {
    'valid_outcomes': (0, ),  # Use (0, 9) to include trials with incorrect behaviour
    'zscore': True,
    'dprime_range': (-np.inf, np.inf),  # Use (-np.inf, np.inf) to include all trials.
    'time_range': (-np.inf, np.inf),
    'verbose': False,
    'y_type': 'sacClass',
    'samples_last': True    
    #     'resample_X': 20
}
load_kwargs_ul = {
    'valid_outcomes': (0, 9),  # Use (0, 9) to include trials with incorrect behaviour
    'zscore': True,
    'dprime_range': (-np.inf, -1),  # Use (-np.inf, np.inf) to include all trials.
    'time_range': (-np.inf, np.inf),
    'verbose': False,
    'y_type': 'sacClass',
    'samples_last': True    
    #     'resample_X': 20
}

load_kwargs_error = {
    'valid_outcomes': (9,),  # Use (0, 9) to include trials with incorrect behaviour
    'zscore': True,
    'dprime_range': (1.0, np.inf),  # Use (-np.inf, np.inf) to include all trials.
    'time_range': (-np.inf, 1.45),
    'verbose': False,
    'y_type': 'sacClass',
    'samples_last': True    
    #     'resample_X': 20
}

load_kwargs_all = {
    'valid_outcomes': (0,9),  # Use (0, 9) to include trials with incorrect behaviour
    'zscore': True,
    'dprime_range': (-np.inf, np.inf),  # Use (-np.inf, np.inf) to include all trials.
    'time_range': (-np.inf, 1.45),
    'verbose': False,
    'y_type': 'sacClass',
    'samples_last': True    
    #     'resample_X': 20
}

model_kwargs = dict(
    filt=8,
    kernLength=20,
    ds_rate=5,
    n_rnn=32,
    n_rnn2=0,
    dropoutRate=0.40,
    activation='relu',
    l1_reg=0.0000, l2_reg=0.001,
    norm_rate=0.25,
    latent_dim=32
)
model_kwargs1 = dict(
    filt=16,
    kernLength=30,
    ds_rate=5,
    n_rnn=64,
    n_rnn2=64,
    dropoutRate=0.40,
    activation='relu',
    l1_reg=0.0000, l2_reg=0.001,
    norm_rate=0.25,
    latent_dim=64
)
model_kwargs2 = dict(
    filt=32,
    kernLength=30,
    ds_rate=5,
    n_rnn=64,
    n_rnn2=64,
    dropoutRate=0.40,
    activation='relu',
    l1_reg=0.0000, l2_reg=0.001,
    norm_rate=0.25,
    latent_dim=64
)

N_SPLITS = 10
BATCH_SIZE = 16
EPOCHS = 150
EPOCHS2 = 100
LABEL_SMOOTHING = 0.2

from indl.model import parts
from indl.model.helper import check_inputs
from indl.regularizers import KernelLengthRegularizer

def make_model(
    _input,
    num_classes,
    filt=32,
    kernLength=16,
    n_rnn=32,
    n_rnn2=0,
    dropoutRate=0.1,
    activation='tanh',
    l1_reg=0.010, l2_reg=0.010,
    norm_rate=0.25,
    latent_dim=32,
    return_model=True
):
    
    inputs = tf.keras.layers.Input(shape=_input.shape[1:])
    
#     if _input.shape[2] < 10:
#         kernLength = 4
#         filt = 4
#         ds_rate = 4
#     elif _input.shape[2] < 20:
#         kernLength = 8
#         ds_rate = 8
#     elif _input.shape[2] < 30:
#         kernLength = 16
    
#     input_shape = list(_input.shape)
    # The Conv layers are insensitive to the number of samples in the time dimension.
    # To make it possible for this trained model to be applied to segments of different
    # durations, we need to explicitly state that we don't care about the number of samples.
    # input_shape[2] = -1  # Comment out during debug
    # _y = layers.Reshape(input_shape[1:])(_input)  # Note that Reshape ignores the batch dimension.

    # RNN
#     if len(input_shape) < 4:
#         input_shape = input_shape + [1]
    # The Conv layers are insensitive to the number of samples in the time dimension.
    # To make it possible for this trained model to be applied to segments of different
    # durations, we need to explicitly state that we don't care about the number of samples.
#     _y = tf.keras.layers.Reshape(input_shape[1:])(inputs)
    _y = tf.keras.layers.Conv1D(filt, kernLength, strides=1, padding='valid',
                                data_format='channels_last', dilation_rate=1, groups=1,
                                activation=None, use_bias=True, kernel_initializer='glorot_uniform',
                                bias_initializer='zeros', kernel_regularizer=None,
                                bias_regularizer=None, activity_regularizer=None, kernel_constraint=None,
                                bias_constraint=None)(inputs)
    _y = tf.keras.layers.BatchNormalization()(_y)
    _y = tf.keras.layers.Dropout(dropoutRate)(_y)
    _y = tf.keras.layers.LSTM(n_rnn,
                              kernel_regularizer=tf.keras.regularizers.l2(l2_reg),
                              recurrent_regularizer=tf.keras.regularizers.l2(l2_reg),
                              return_sequences=n_rnn2 > 0,
                              stateful=False,
                              name='rnn1')(_y)
    _y = tf.keras.layers.Activation(activation)(_y)
    _y = tf.keras.layers.BatchNormalization()(_y)
    _y = tf.keras.layers.Dropout(dropoutRate)(_y)
    
    
    if n_rnn2 > 0:
        
        _y = tf.keras.layers.LSTM(n_rnn2,
                              kernel_regularizer=tf.keras.regularizers.l2(l2_reg),
                              recurrent_regularizer=tf.keras.regularizers.l2(l2_reg),
                              return_sequences=False,
                              stateful=False,
                              name='rnn2')(_y)
        _y = tf.keras.layers.Activation(activation)(_y)
        _y = tf.keras.layers.BatchNormalization()(_y)
        _y = tf.keras.layers.Dropout(dropoutRate)(_y)
    
    # Dense
    _y = tf.keras.layers.Dense(latent_dim, activation=activation)(_y)
#     _y = parts.Bottleneck(_y, latent_dim=latent_dim, activation=activation)
    
    # Classify
    outputs = tf.keras.layers.Dense(num_classes, activation='softmax')(_y)
#     outputs = parts.Classify(_y, n_classes=num_classes, norm_rate=norm_rate)
    

    if return_model is False:
        return outputs
    else:
        return tf.keras.models.Model(inputs=inputs, outputs=outputs)


def kfold_pred(sess_id,X_rates,Y_class,name, verbose=1):
    splitter = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=0)
    split_ix = 0
    histories = []
    per_fold_eval = []
    per_fold_true = []

    for trn, vld in splitter.split(X_rates, Y_class):
        print(f"\tSplit {split_ix + 1} of {N_SPLITS}")
        _y = tf.keras.utils.to_categorical(Y_class, num_classes=np.max(Y_class)+1)
        
        ds_train = tf.data.Dataset.from_tensor_slices((X_rates[trn], _y[trn]))
        ds_valid = tf.data.Dataset.from_tensor_slices((X_rates[vld], _y[vld]))

        # cast data types to GPU-friendly types.
        ds_train = ds_train.map(lambda x, y: (tf.cast(x, tf.float32), tf.cast(y, tf.uint8)))
        ds_valid = ds_valid.map(lambda x, y: (tf.cast(x, tf.float32), tf.cast(y, tf.uint8)))

        # TODO: augmentations (random slicing?)

        ds_train = ds_train.shuffle(len(trn) + 1)
        ds_train = ds_train.batch(BATCH_SIZE, drop_remainder=True)
        ds_valid = ds_valid.batch(BATCH_SIZE, drop_remainder=False)

        tf.keras.backend.clear_session()
        
        randseed = 12345
        random.seed(randseed)
        np.random.seed(randseed)
        tf.random.set_seed(randseed)
        
        model = make_model(X_rates, _y.shape[-1])
        optim = tf.keras.optimizers.Nadam(learning_rate=0.001)
        loss_obj = tf.keras.losses.CategoricalCrossentropy(label_smoothing=LABEL_SMOOTHING)
        model.compile(optimizer=optim, loss=loss_obj, metrics=['accuracy'])
        
        best_model_path = f'{name}_{sess_id}_split{split_ix}.h5'
        callbacks = [
            tf.keras.callbacks.ModelCheckpoint(
                filepath=best_model_path,
                # Path where to save the model
                # The two parameters below mean that we will overwrite
                # the current checkpoint if and only if
                # the `val_loss` score has improved.
                save_best_only=True,
                monitor='val_accuracy',
                verbose=verbose)
        ]

        hist = model.fit(x=ds_train, epochs=EPOCHS,
                         verbose=verbose,
                         validation_data=ds_valid,
                         callbacks=callbacks)
        # tf.keras.models.save_model(model, 'model.h5')
        histories.append(hist.history)
        
        model = tf.keras.models.load_model(best_model_path)
        per_fold_eval.append(model(X_rates[vld]).numpy())
        per_fold_true.append(Y_class[vld])
        
        split_ix += 1
        
    # Combine histories into one dictionary.
    history = {}
    for h in histories:
        for k,v in h.items():
            if k not in history:
                history[k] = v
            else:
                history[k].append(np.nan)
                history[k].extend(v)
                
    pred_y = np.concatenate([np.argmax(_, axis=1) for _ in per_fold_eval])
    true_y = np.concatenate(per_fold_true).flatten()
    accuracy = 100 * np.sum(pred_y == true_y) / len(pred_y)
    print(f"\n\nSession {sess_id} overall accuracy with CNN/LSTM Model: {accuracy}%")
    
    return history, accuracy, pred_y, true_y

# Behavioural Analysis (Figure 1D)

In [None]:
test_sess_ix = 1
sess_info = sess_infos[test_sess_ix]
sess_id = sess_info['exp_code']
segmented_path = Path.cwd() / 'StudyLocationRule' / 'Data' / 'Preprocessed' / 'sra3_1_j_050_00_segmented_v2.h5'

segmented_data = from_neuropype_h5(segmented_path)
outcome = segmented_data[2][1]['axes'][0]['data']['OutcomeCode']
Y = np.array(segmented_data[2][1]['axes'][0]['data']['TargetClass'])
Y_class = tf.keras.utils.to_categorical(Y, num_classes=8)
X = segmented_data[2][1]['data']
X = np.nan_to_num(X)
from scipy import signal
kernel = signal.gaussian(100, 20)
X_conv = np.zeros_like(X)
for i in range(X.shape[0]):
    for j in range(X.shape[2]):
        X_conv[i,:,j] = signal.fftconvolve(np.squeeze(X[i,:,j]), kernel, mode='same')
X_conv = np.abs(X_conv)
X_conv = np.transpose(X_conv, (0, 2, 1))
X_conv = X_conv[:,:,::20]
block = np.array(segmented_data[2][1]['axes'][0]['data']['Block']).flatten()
b=np.diff(block, axis=0)
blck=np.array(np.where(b>0)).flatten()
color = np.array(segmented_data[2][1]['axes'][0]['data']['CueColour']).flatten()
target = np.array(segmented_data[2][1]['axes'][0]['data']['TargetRule']).flatten()
classes = np.array(segmented_data[2][1]['axes'][0]['data']['TargetClass']).flatten()

In [None]:
b = np.diff(block, axis=0)
border = np.array(np.where(b > 0)).flatten()
to_keep = [0]
for i in range(len(border) - 1):
    if (border[i + 1] - border[i]) > 30 and (len(outcome)-border[i+1]) > 30:
        to_keep.append(i + 1)
border = border[to_keep]
m_performance = np.zeros(len(outcome))
cor = 0
b = 0
tot = 25
for i in range(tot):
    if outcome[i] == 0:
        cor += 1

m_performance[:tot] = 100 * (cor / tot)
i = tot
while i < len(outcome):
    if i == border[b]:
        cor = 0
        for j in range(tot):
            if outcome[i + j] == 0:
                cor += 1
        m_performance[i:i + tot] = 100 * (cor / tot)
        i += tot
        b = (b + 1) % len(border)
    elif outcome[i] == outcome[i - tot]:
        m_performance[i] = m_performance[i - 1]
        i += 1
    elif outcome[i] == 0:
        cor += 1
        m_performance[i] = 100 * (cor / tot)
        i += 1
    else:
        cor -= 1
        m_performance[i] = 100 * (cor / tot)
        i += 1

In [None]:
plt.plot(m_performance)

# Temporal Analysis of Rule Decoding (Figure 3A)

In [None]:
sess_info = sess_infos[8]
sess_id = sess_info['exp_code']
sess_id = sess_id.replace("+", "")+"_v1"
X_rates, _, ax_info = load_macaque_pfc(data_path, sess_id, x_chunk='spikerates', **load_args)
times = np.array(ax_info['timestamps'])
target = np.array(ax_info['instance_data']['TargetRule'])
color = np.array(ax_info['instance_data']['CueColour'])
rule = np.zeros(np.size(X_rates,0))
for i in range(len(rule)):
    if (target[i]=='UU'):
        if (color[i] == 'r'):
            rule[i]=0
        elif(color[i] == 'g'):
            rule[i]=1
        else:
            rule[i]=2
    elif (target[i]=='UR'):
        if (color[i] == 'r'):
            rule[i]=3
        elif(color[i] == 'g'):
            rule[i]=4
        else:
            rule[i]=5
    elif (target[i]=='RR'):
        if (color[i] == 'r'):
            rule[i]=6
        elif(color[i] == 'g'):
            rule[i]=7
        else:
            rule[i]=8
    elif (target[i]=='DR'):
        if (color[i] == 'r'):
            rule[i]=9
        elif(color[i] == 'g'):
            rule[i]=10
        else:
            rule[i]=11
    elif (target[i]=='DD'):
        if (color[i] == 'r'):
            rule[i]=12
        elif(color[i] == 'g'):
            rule[i]=13
        else:
            rule[i]=14
    elif (target[i]=='DL'):
        if (color[i] == 'r'):
            rule[i]=15
        elif(color[i] == 'g'):
            rule[i]=16
        else:
            rule[i]=17
    elif (target[i]=='LL'):
        if (color[i] == 'r'):
            rule[i]=18
        elif(color[i] == 'g'):
            rule[i]=19
        else:
            rule[i]=20
    elif (target[i]=='UL'):
        if (color[i] == 'r'):
            rule[i]=21
        elif(color[i] == 'g'):
            rule[i]=22
        else:
            rule[i]=23
rule = rule.astype(int)
tmp = rule
_yu = np.unique(rule)
for i in range(len(tmp)):
    rule[i] = np.where(_yu == tmp[i])[0][0]
rule_shuf = np.random.permutation(rule)
for t in range(33):
    X = X_rates[:,:,:t*5 + 5]
    _, rul_acc_dnn_all[1,t], _, _ = kfold_pred(sess_id,X,rule,name='R2R_TEMP', verbose=0)
    _, rul_acc_dnn_shuf[1,t], _, _ = kfold_pred(sess_id,X,rule_shuf,name='R2R_TEMP_SHUF', verbose=0)

In [None]:
plt.figure()
t = times[::5][1:-1]
plt.plot(t, rul_acc_dnn_all[1,1:-1], linewidth=3, label="Rule Decoding")
plt.plot(t, rul_acc_dnn_shuf[1,1:-1], linewidth=3, label="Chance Level")
plt.legend()

# Temporal Analysis of Saccade Decoding (Figure 3B)

In [None]:
N_SPLITS = 10
sac_acc = np.zeros(41)
sess = 2
sess_info = sess_infos[sess]
sess_id = sess_info['exp_code']
sess_id = sess_id.replace("+", "")+"_v1"
X_rates, Y, ax_info = load_macaque_pfc(data_path, sess_id, x_chunk='spikerates', **load_kwargs)
times = np.array(ax_info['timestamps'])
for t in range(40):
    tmp = X_rates[:,:,:t*5 + 5]
    kf = KFold(n_splits=N_SPLITS)
    X = np.reshape(tmp, (tmp.shape[0], -1))
    scores = []
    for train_index, test_index in kf.split(X):
        x_tr, x_ts = X[train_index], X[test_index]
        y_tr, y_ts = Y[train_index], Y[test_index]
        clf = SVC(verbose=False).fit(x_tr, y_tr)
        acc = clf.score(x_ts, y_ts)
        scores.append(acc)
    sac_acc[t] = np.mean(np.array(scores))


## Chance Level

N_SPLITS = 10
sac_acc_shuf = np.zeros(41)
sess=2
sess_info = sess_infos[sess]
sess_id = sess_info['exp_code']
sess_id = sess_id.replace("+", "")+"_v1"
X_rates, Y, ax_info = load_macaque_pfc(data_path, sess_id, x_chunk='spikerates', **load_kwargs)
Y = Y.ravel()
Y_shuf = np.random.permutation(Y)
times = np.array(ax_info['timestamps'])
for t in range(40):
    tmp = X_rates[:,:,:t*5 + 5]
    kf = KFold(n_splits=N_SPLITS)
    X = np.reshape(tmp, (tmp.shape[0], -1))
    scores = []
    for train_index, test_index in kf.split(X):
        x_tr, x_ts = X[train_index], X[test_index]
        y_tr, y_ts = Y_shuf[train_index], Y_shuf[test_index]
        clf = SVC(verbose=False).fit(x_tr, y_tr)
        acc = clf.score(x_ts, y_ts)
        scores.append(acc)
    sac_acc_shuf[t] = np.mean(np.array(scores))

In [None]:
plt.figure()
t = times[::5][:-2]
plt.plot(t, sac_acc[2,:-1]*100, linewidth=3, label="Saccade Decoding")
plt.plot(t, sac_acc_shuf[:-1]*100, linewidth=3, label="Chance Level")
plt.legend()

# Latent Space Analysis (Figure 4, 5)

In [None]:
segmented_path = data_path / 'sra3_1_m_074_0001_segmented.h5'
segmented_data = from_neuropype_h5(segmented_path)
outcome = np.array(segmented_data[2][1]['axes'][0]['data']['OutcomeCode'])
flag = np.argwhere(outcome>-1).flatten()
outcome = outcome[flag]
Y = np.array(segmented_data[2][1]['axes'][0]['data']['TargetClass']).flatten()[flag]
Y_class = tf.keras.utils.to_categorical(Y, num_classes=8)
X_rates = segmented_data[2][1]['data'][flag]
X_rates = np.nan_to_num(X_rates)
X_rates = np.transpose(X_rates, (0, 2, 1))
block = np.array(segmented_data[2][1]['axes'][0]['data']['Block']).flatten()[flag]
b=np.diff(block, axis=0)
border=np.array(np.where(b>0)).flatten()
to_keep = [0]
for i in range(len(border) - 1):
    if (border[i + 1] - border[i]) > 30 and (len(outcome)-border[i+1]) > 30:
        to_keep.append(i + 1)
border = border[to_keep]
color = np.array(segmented_data[2][1]['axes'][0]['data']['CueColour']).flatten()[flag]
target = np.array(segmented_data[2][1]['axes'][0]['data']['TargetRule']).flatten()[flag]
classes = np.array(segmented_data[2][1]['axes'][0]['data']['TargetClass']).flatten()[flag]
times = np.array(segmented_data[2][1]['axes'][1]['times']).flatten()

In [None]:
m_performance = np.zeros(len(outcome))
cor = 0
b=0
tot = 25
for i in range(tot):
    if outcome[i]==0:
        cor += 1

m_performance[:tot] = 100 * (cor / tot)
i = tot
while i<len(outcome):
    if i == border[b]:
        cor = 0
        for j in range(tot):
            if outcome[i+j]==0:
                cor += 1
        m_performance[i:i+tot] = 100 * (cor / tot)
        i += tot
        b = (b+1)%len(border)
    elif outcome[i] == outcome[i-tot]:
        m_performance[i] = m_performance[i-1]
        i += 1
    elif outcome[i]==0:
        cor += 1
        m_performance[i] = 100 * (cor / tot)
        i += 1
    else:
        cor -= 1
        m_performance[i] = 100 * (cor / tot)
        i +=1

for i in range(len(m_performance)):
    if m_performance[i]<0:
        m_performance[i] = 0
learned = np.argwhere(m_performance>80).flatten()
unlearned = np.argwhere(m_performance<60).flatten()

In [None]:
rule = np.zeros(np.size(X_rates,0))
for i in range(len(rule)):
    if (target[i]=='UU'):
        if (color[i] == 'r'):
            rule[i]=0
        elif(color[i] == 'g'):
            rule[i]=1
        else:
            rule[i]=2
    elif (target[i]=='UR'):
        if (color[i] == 'r'):
            rule[i]=3
        elif(color[i] == 'g'):
            rule[i]=4
        else:
            rule[i]=5
    elif (target[i]=='RR'):
        if (color[i] == 'r'):
            rule[i]=6
        elif(color[i] == 'g'):
            rule[i]=7
        else:
            rule[i]=8
    elif (target[i]=='DR'):
        if (color[i] == 'r'):
            rule[i]=9
        elif(color[i] == 'g'):
            rule[i]=10
        else:
            rule[i]=11
    elif (target[i]=='DD'):
        if (color[i] == 'r'):
            rule[i]=12
        elif(color[i] == 'g'):
            rule[i]=13
        else:
            rule[i]=14
    elif (target[i]=='DL'):
        if (color[i] == 'r'):
            rule[i]=15
        elif(color[i] == 'g'):
            rule[i]=16
        else:
            rule[i]=17
    elif (target[i]=='LL'):
        if (color[i] == 'r'):
            rule[i]=18
        elif(color[i] == 'g'):
            rule[i]=19
        else:
            rule[i]=20
    elif (target[i]=='UL'):
        if (color[i] == 'r'):
            rule[i]=21
        elif(color[i] == 'g'):
            rule[i]=22
        else:
            rule[i]=23
rule = rule.astype(int)

In [None]:
rule1 = 9
rule2 = 10

In [None]:
rl = np.array(np.where((rule==rule1)|(rule==rule2))).flatten()
l_rl = []
ul_rl = []
for r in rl:
    if r in learned:
        l_rl.append(r)
    elif r in unlearned:
        ul_rl.append(r)

In [None]:
tmp = rule[l_rl]
_y_l = np.zeros_like(tmp)
for i in range(len(tmp)):
    if (tmp[i]==rule2):
        _y_l[i]=1
tmp = rule[ul_rl]
_y_ul = np.zeros_like(tmp)
for i in range(len(tmp)):
    if (tmp[i]==rule2):
        _y_ul[i]=1

In [None]:
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

step = 20
length = len(times[::step])

_X = np.transpose(X_rates, (0, 2, 1))
# _X = X_rates

class_colors = np.array(['tab:red', 'tab:green', 'tab:blue', 'tab:pink','tab:cyan','tab:orange',
                         'tab:olive','tab:purple','maroon', 'lime', 'navy', 'sienna', 'tan', 'black', 'grey'])

TEST_PERPLEXITY = [10]
X = _X[l_rl]

traj_l = np.zeros((X.shape[0], length, 2))

pca = PCA(n_components=16)
tsne_model = TSNE(n_components=2, perplexity=TEST_PERPLEXITY[-1])
for t in np.arange(1,length):
    tmp = X[:, :step*t, :]
    pca_values = pca.fit_transform(tmp.reshape([-1, np.prod(tmp.shape[1:])]))
    traj_l[:,t,:] = tsne_model.fit_transform(pca_values)
    
X = _X[ul_rl]
traj_ul = np.zeros((X.shape[0], length, 2))

pca = PCA(n_components=16)
tsne_model = TSNE(n_components=2, perplexity=TEST_PERPLEXITY[-1])
for t in np.arange(1,length):
    tmp = X[:, :step*t, :]
    pca_values = pca.fit_transform(tmp.reshape([-1, np.prod(tmp.shape[1:])]))
    traj_ul[:,t,:] = tsne_model.fit_transform(pca_values)

In [None]:
from sklearn.metrics.cluster import calinski_harabasz_score as chs

score_l = np.zeros(traj_l.shape[1])
score_ul = np.zeros(traj_ul.shape[1])
for i in range(len(score_l)):
    score_l[i] = chs(np.squeeze(traj_l[:,i,:]), _y_l)
    score_ul[i] = chs(np.squeeze(traj_ul[:,i,:]), _y_ul)
t = times[::step]
plt.plot(t, score_l, lw=3,label='Learned')
plt.plot(t, score_ul, lw=3,label='Unlearned')

In [None]:
color=['tab:red', 'tab:green']
plt.figure()
for i in range(len(_y_l)):
    plt.plot(np.squeeze(traj_l[i,1,0]), np.squeeze(traj_l[i,1,1]), 'o', color=color[_y_l[i]], lw=1)
plt.title('Learned Trials: Pre-Color-Cue')
plt.xlabel('Latent 1')
plt.ylabel('Latent 2')
plt.figure()
for i in range(len(_y_l)):
    plt.plot(np.squeeze(traj_l[i,6,0]), np.squeeze(traj_l[i,6,1]), 'o', color=color[_y_l[i]], lw=1)
plt.title('Learned Trials: Color-Cue Period')
plt.xlabel('Latent 1')
plt.ylabel('Latent 2')

In [None]:
cor = np.array(np.where(outcome==0)).flatten()
icor = np.array(np.where(outcome==9)).flatten()
c_l = []
ic_ul = []
for c in cor:
    if c in learned:
        c_l.append(c)
for ic in icor:
    if ic in unlearned:
        ic_ul.append(ic)
print(np.unique(rule[c_l], return_counts=True))
print(np.unique(rule[ic_ul], return_counts=True))

In [None]:
rates = X_rates[c_l]
rules = rule[c_l]
hist, acc, y_pred, y_true = kfold_pred(sess_id,rates,rules,name='cor_l_rd' ,verbose=1)

In [None]:
from tensorflow.keras.models import load_model
model = load_model("cor_l_rd_sra3_1_j_050_00_split9.h5")

In [None]:
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

class_colors = np.array(['tab:blue', 'tab:red', 'tab:green', 'tab:pink','tab:cyan','tab:orange',
                         'tab:olive','tab:purple','maroon', 'lime', 'navy', 'sienna', 'tan', 'black', 'grey'])

TEST_PERPLEXITY = [10]
X = X_rates[c_l]
_y = rule[c_l].astype(int)
tmp = _y
_yu = np.unique(_y)
for i in range(len(tmp)):
    _y[i] = np.where(_yu == tmp[i])[0][0]

fig = plt.figure(figsize=(18, 6))
def plot_tsne(x_vals, y_vals, perplexity, title='Model Output'):
    plt.scatter(x=x_vals[:, 0], y=x_vals[:, 1], color=class_colors[y_vals])
    plt.xlabel('Latent D-1')
    plt.ylabel('Latent D-2')
    plt.title(title)
    ax = plt.gca()

pca = PCA(n_components=50)
pca_values = pca.fit_transform(X.reshape([-1, np.prod(X.shape[1:])]))
tsne_model = TSNE(n_components=2, perplexity=TEST_PERPLEXITY[-1])
tsne_values = tsne_model.fit_transform(pca_values)

output_layer = -5
tbs = 30  # tsne batch size
truncated_model = tf.keras.Model(model.input, model.layers[output_layer].output)
flattened_output = []
for start_ix in range(0, X.shape[0], tbs):
    flattened_output.append(truncated_model(X[start_ix:start_ix+tbs, :, :].astype(np.float32)))
flattened_output = tf.concat(flattened_output, 0)
tf.keras.backend.clear_session()

for p_ix, perplexity in enumerate(TEST_PERPLEXITY):
    pca = PCA(n_components=30)
    pca_values = pca.fit_transform(flattened_output)
    tsne_model = TSNE(n_components=2, perplexity=perplexity)
    tsne_values = tsne_model.fit_transform(pca_values)
    
    plt.subplot(1, 3, p_ix + 2)
    plot_tsne(tsne_values, _y.ravel(), perplexity, title='LSTM Output Latent States')

plt.tight_layout()