In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from matrix import make_confusion_matrix, subplot_confusion_matrix

from prl_utils import (
    Mode,
    read_hdf5,
    get_features,
)

In [None]:
# @title Constants
N_TRAIN_AGENT = 30000
N_VAL_AGENT = 3000
NUM_TRIAL = 2000
mode = Mode.PRL2_intractable

def get_latent_labels(data, num_agents, num_trial, mode):
  if mode == Mode.PRL2_intractable:
    return data['which_state'].to_numpy().astype(np.float32).reshape((num_agents, num_trial))
  elif mode == Mode.PRL2:
    return data['rpe_history'].to_numpy().astype(np.float32).reshape((num_agents, num_trial))


In [None]:
# @title Read data

list_of_train_files = [
  #'5000agent_200t_2ParamRL.csv'
  'data/to_a0.7_30000agent_2000t_2ParamRL_intractable.csv',
]

all_train_features = []
train_labels = []
num_trial = NUM_TRIAL
for f in list_of_train_files:
  data = read_hdf5(f) if '.h5' in f else pd.read_csv(f)

  num_agents = 20000 if '20000agent' in f else N_TRAIN_AGENT
  features = get_features(data, num_agents, num_trial, mode=mode)
  all_train_features.append(features)

  if mode == Mode.PRL2_intractable:
    train_labels.append(data['which_state'].to_numpy().astype(np.float32).reshape((num_agents, num_trial)))
  elif mode == Mode.PRL2:
    train_labels.append(data['rpe_history'].to_numpy().astype(np.float32).reshape((num_agents, num_trial)))

all_train_features = tf.concat(all_train_features, 0)
train_labels = tf.concat(train_labels, 0)

# all_train_features = all_train_features[:500, :NUM_TRIAL, :]
# train_labels = train_labels[:500, :NUM_TRIAL]

In [None]:
from sklearn.utils import shuffle

train_features, normalized_train_labels = shuffle(all_train_features.numpy(), train_labels.numpy(), random_state=0)
if mode == Mode.PRL2_intractable:
    normalized_train_labels = tf.keras.utils.to_categorical(normalized_train_labels, num_classes=2)

print(train_features.shape, len(normalized_train_labels))

In [None]:
from tensorflow.keras import layers
from keras.callbacks import EarlyStopping, ReduceLROnPlateau

from tensorflow.keras.layers import (
    Dense,
    Dropout,
    LSTM,
    Bidirectional,
    GRU,
    Concatenate,
)
from tensorflow import keras
from tensorflow.keras.optimizers import Adam, SGD

RNN = LSTM
is_bidirection = False
output_dim = 2 if mode == Mode.PRL2_intractable else 1
batch_size = 250
units = 128
learning_rate = 3e-3
op_loss = 'binary_crossentropy' #mse
decay = 0 # Learning rate decay

identifier = f'trnn_{NUM_TRIAL}t_{RNN.__name__}_B{batch_size}_U{units}_{learning_rate}'
print(identifier)

# Encoder
encoder_inputs = layers.Input(shape=(None, train_features.shape[2]))

if is_bidirection:
   encoder = Bidirectional(RNN(units, return_state=True, return_sequences=True))
else:
    encoder = RNN(units, return_state=True, return_sequences=True)

if RNN.__name__ == 'GRU':
  if is_bidirection:
    encoder_outputs, forward_h, backward_h = encoder(encoder_inputs)
    state_h = Concatenate()([forward_h, backward_h])
  else:    
    encoder_outputs, state_h = encoder(encoder_inputs)

  encoder_states = state_h
elif is_bidirection:
    lstm, forward_h, forward_c, backward_h, backward_c = encoder(encoder_inputs)
    state_h = Concatenate()([forward_h, backward_h])
    state_c = Concatenate()([forward_c, backward_c])
    encoder_states = [state_h, state_c]
else:
    encoder_outputs, state_h, state_c = encoder(encoder_inputs)
    encoder_states = [state_h, state_c]

dense_layer_1 = Dense(int(units/2), activation='relu') #'softmax
dense_layer_2 = Dense(output_dim, activation='softmax') #'softmax

outputs = dense_layer_1(encoder_outputs)
outputs = dense_layer_2(outputs)

best_model = keras.Model(encoder_inputs, outputs)

In [None]:
optimiser = keras.optimizers.Adam(learning_rate=learning_rate)
best_model.compile(optimizer=optimiser, loss=op_loss)

callbacks = [EarlyStopping(monitor='val_loss', patience=30, restore_best_weights=True)]
history = best_model.fit(train_features, normalized_train_labels,
          batch_size=batch_size,
          epochs=100,
          callbacks=callbacks,
          validation_split=0.2)

In [None]:
result = pd.DataFrame(history.history)
ax = sns.lineplot(result)
ax.set_xlabel('epochs')
ax.set_ylabel(f'{op_loss} loss')

plt.savefig(f'{identifier}.png')

In [None]:
best_model = tf.keras.models.load_model('2000t_B512_U128_0.003_model')
best_model.summary()

In [None]:
from numpy import load

prefix = f'to_a0.7_{N_VAL_AGENT}agent_{NUM_TRIAL}t_2ParamRL_intractable'
# load dict of arrays
features = load(f'{prefix}_features_test.npz')
test_features = features['arr_0']

labels = load(f'{prefix}_labels_test.npz')
test_labels = labels['arr_0']

avg_attentive_states = np.mean(test_labels, axis=1)
print(test_features.shape)

In [None]:
# @title Inference model
# Encode the input as state vectors.
output_tokens = best_model.predict(test_features)
prediction = np.argmax(output_tokens, axis=-1)
prediction[3]

In [None]:
from sklearn.metrics import accuracy_score

t_score = {'avg_attentive_states': [], 'accuracy': []}
for i in range(len(test_labels)):
  y_true = test_labels[i]
  y_pred = prediction[i]
  score = accuracy_score(y_true, y_pred)

  t_score['avg_attentive_states'].append(avg_attentive_states[i])
  t_score['accuracy'].append(score)

t_score = pd.DataFrame(t_score)
t_score['mean'] = np.mean(t_score['accuracy'])
#t_score.to_csv(f'{RESULT_DIR}/{identifier}_accuracy_to_t.csv')

t_score['mean'].mean()

In [None]:
from sklearn.metrics import confusion_matrix

cf_matrixes = np.zeros((8, 2, 2), dtype=int)
for ag in range(len(avg_attentive_states)):
    idx = int((avg_attentive_states[ag]*10)%10-2)
    cf_matrix = confusion_matrix(test_labels[ag], prediction[ag])
    cf_matrixes[idx] += cf_matrix
# cf_matrix = confusion_matrix(test_labels.flatten(), prediction.flatten())
# cf_matrix

In [None]:
nrows, ncols = 2, 4
fig, axes = plt.subplots(
    nrows=nrows, ncols=ncols, figsize=(ncols * 5, nrows * 5), sharey=True, sharex=True
)

avg_attentive_states_range = ['0.2-0.3', '0.3-0.4', '0.4-0.5', '0.5-0.6', '0.6-0.7', '0.7-0.8', '0.8-0.9', '>0.9']
for idx, ax in enumerate(axes.flat):
    ax.set_title(f'avg attentive state percent range {avg_attentive_states_range[idx]}', fontsize=10)
    subplot_confusion_matrix(
        cf_matrixes[idx], categories=["inattentive", "attentive"], percent="by_row", ax=ax,
        vmin=0, vmax=1
    )

In [None]:
make_confusion_matrix(cf_matrix, categories=['inattentive', 'attentive'], percent='by_row')
plt.savefig(f"binary_{identifier}_confusion_matrix.png")