### Import

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

from prl_utils import (
    Mode,
    read_hdf5,
    get_features,
)

In [None]:
# @title Constants
N_TRAIN_AGENT = 3000
N_VAL_AGENT = 200
NUM_TRIAL = 2000
mode = Mode.PRL2_intractable

In [None]:
def get_latent_labels(data, num_agents, num_trial, mode):
  if mode == Mode.PRL2_intractable:
    return data['which_state'].to_numpy().astype(np.float32).reshape((num_agents, num_trial))
  elif mode == Mode.PRL2:
    return data['rpe_history'].to_numpy().astype(np.float32).reshape((num_agents, num_trial))


### Read data

In [None]:
# @title Read data

list_of_train_files = [
  #'5000agent_200t_2ParamRL.csv'
  '3000agent_2000t_2ParamRL_intractable_fixed_inattention_tau_validation.h5',
]

all_train_features = []
train_labels = []
num_trial = NUM_TRIAL
for f in list_of_train_files:
  data = read_hdf5(f) if '.h5' in f else pd.read_csv(f)

  num_agents = 20000 if '20000agent' in f else N_TRAIN_AGENT
  features = get_features(data, num_agents, num_trial, mode=mode)
  all_train_features.append(features)

  if mode == Mode.PRL2_intractable:
    train_labels.append(data['which_state'].to_numpy().astype(np.float32).reshape((num_agents, num_trial)))
  elif mode == Mode.PRL2:
    train_labels.append(data['rpe_history'].to_numpy().astype(np.float32).reshape((num_agents, num_trial)))

  del data # save memory

all_train_features = tf.concat(all_train_features, 0)
train_labels = tf.concat(train_labels, 0)

n_agent = 500
all_train_features = all_train_features[:n_agent, :NUM_TRIAL, :]
train_labels = train_labels[:n_agent, :NUM_TRIAL]

In [None]:
from sklearn.utils import shuffle

train_features, normalized_train_labels = shuffle(all_train_features.numpy(), train_labels.numpy(), random_state=0)
if mode == Mode.PRL2_intractable:
    normalized_train_labels = tf.keras.utils.to_categorical(normalized_train_labels, num_classes=2)

print(train_features.shape, len(normalized_train_labels))

In [None]:
if mode == Mode.PRL2_intractable:
    # Pad a dummy to the front of the input
    input_paddings = tf.constant([[1, 0], [0, 0]]) 
    decoder_input_data = tf.unstack(normalized_train_labels)
    # Pad a dummy to the end of the output
    output_paddings = tf.constant([[0, 1], [0, 0]])
    decoder_target_data = tf.unstack(normalized_train_labels)

    for j in range(len(normalized_train_labels)):
        decoder_input_data[j] = tf.pad(normalized_train_labels[j], input_paddings, "CONSTANT", constant_values=0)
        decoder_target_data[j] = tf.pad(normalized_train_labels[j], output_paddings, "CONSTANT", constant_values=0)

    decoder_input_data = tf.stack(decoder_input_data)
    decoder_target_data = tf.stack(decoder_target_data)
else:
    decoder_input_data = np.zeros((normalized_train_labels.shape[0], normalized_train_labels.shape[1]+1))
    decoder_target_data = np.zeros((normalized_train_labels.shape[0], normalized_train_labels.shape[1]+1))
    # Pad a dummy to the front of the input
    decoder_input_data[:, 1:] = normalized_train_labels
    # Pad a dummy to the end of the output
    decoder_target_data[:, :normalized_train_labels.shape[1]] = normalized_train_labels
print(decoder_input_data.shape, decoder_target_data.shape)

### Model Training

In [None]:
# @title Model definition

from tensorflow.keras import layers
from keras.callbacks import EarlyStopping, ReduceLROnPlateau

from tensorflow.keras.layers import (
    Dense,
    Dropout,
    LSTM,
    Bidirectional,
    GRU,
    Concatenate,
)
from tensorflow import keras
from tensorflow.keras.optimizers import Adam, SGD

RNN = GRU
is_bidirection = True
output_dim = 2 if mode == Mode.PRL2_intractable else 1
batch_size = 128
units = 96
learning_rate = 3e-3
op_loss = 'binary_crossentropy'
decay = 0 # Learning rate decay

n_agent = train_features.shape[0]
identifier = f'{n_agent}a_ed_{NUM_TRIAL}t_{RNN.__name__}_B{batch_size}_U{units}_{learning_rate}'
if is_bidirection:
  identifier = 'bi_' + identifier

print(identifier)

# Encoder
encoder_inputs = layers.Input(shape=(None, train_features.shape[2]))

if is_bidirection:
    encoder = Bidirectional(RNN(units, return_state=True, return_sequences=True))
else:
    encoder = RNN(units, return_state=True, return_sequences=True)

# We discard `encoder_outputs` and only keep the states.
if RNN.__name__ == 'GRU':
  if is_bidirection:
    encoder_outputs, forward_h, backward_h = encoder(encoder_inputs)
    state_h = Concatenate()([forward_h, backward_h])
  else:    
    encoder_outputs, state_h = encoder(encoder_inputs)
  
  encoder_states = state_h
else:
  if is_bidirection:
    encoder_outputs, forward_h, forward_c, backward_h, backward_c = encoder(encoder_inputs)
    state_h = Concatenate()([forward_h, backward_h])
    state_c = Concatenate()([forward_c, backward_c])
  else:    
    encoder_outputs, state_h, state_c = encoder(encoder_inputs)
  encoder_states = [state_h, state_c]

# Decoder 
# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = layers.Input(shape=(None, output_dim))
# We set up our decoder to return full output sequences,
# and to return internal states as well. We don't use the
# return states in the training model, but we will use them in inference.
decoder_units = units*2 if is_bidirection else units
decoder_lstm = RNN(decoder_units, return_sequences=True, return_state=True)
if RNN.__name__ == 'GRU':
    decoder_outputs, _ = decoder_lstm(decoder_inputs,
                                    initial_state=encoder_states)
else:
    decoder_outputs, _, _ = decoder_lstm(decoder_inputs,
                                     initial_state=encoder_states)

decoder_dense = Dense(output_dim, activation='softmax') #'softmax
decoder_outputs = decoder_dense(decoder_outputs)

# Define the model that will turn
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
best_model = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)

In [None]:
optimiser = keras.optimizers.Adam(learning_rate=learning_rate)
best_model.compile(optimizer=optimiser, loss=op_loss)

callbacks = [EarlyStopping(monitor='val_loss', patience=30, restore_best_weights=True)]
history = best_model.fit([train_features, decoder_input_data], decoder_target_data,
          batch_size=batch_size,
          epochs=100,
          callbacks=callbacks,
          validation_split=0.2)

In [None]:
result = pd.DataFrame(history.history)
ax = sns.lineplot(result)
ax.set_xlabel('epochs')
ax.set_ylabel(f'{op_loss} loss')

plt.savefig(f'{identifier}_loss.png')

## Model Evaluation

In [None]:
test_original_num_trial = NUM_TRIAL

# 3000agent_2000t_2ParamRL_intractable_fixed_inattention_tau_test
test_data = pd.read_csv('3000agent_2000t_2ParamRL_intractable_fixed_inattention_tau_test.csv')

test_features = get_features(test_data, N_VAL_AGENT, test_original_num_trial, mode=mode)
test_labels = get_latent_labels(test_data, N_VAL_AGENT, test_original_num_trial, mode)

# Process the test features and labels
test_features = test_features[:, :NUM_TRIAL, :]
test_labels = test_labels[:, :NUM_TRIAL]

print(test_features.shape, len(test_labels))

### Inference Model

In [None]:
# @title Inference model
latent_dim = units*2 if is_bidirection else units

encoder_model = keras.Model(encoder_inputs, encoder_states)

if RNN.__name__ == 'GRU':
    decoder_state_input_h = layers.Input(shape=(latent_dim,))
    decoder_states_inputs = [decoder_state_input_h]        
elif RNN.__name__ == 'LSTM':
    decoder_state_input_h = layers.Input(shape=(latent_dim,))
    decoder_state_input_c = layers.Input(shape=(latent_dim,))
    decoder_states_inputs =  [decoder_state_input_h, decoder_state_input_c]

if RNN.__name__ == 'GRU':
    decoder_outputs, state_h = decoder_lstm(
        decoder_inputs, initial_state=decoder_states_inputs)
    decoder_states = [state_h]
else:
    decoder_outputs, state_h, state_c = decoder_lstm(
        decoder_inputs, initial_state=decoder_states_inputs)
    decoder_states = [state_h, state_c]

decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = keras.Model(
    [decoder_inputs] + decoder_states_inputs,
    [decoder_outputs] + decoder_states)

In [None]:
def decode_sequence(input_seq, num_decoder_tokens, max_decoder_seq_length: int=NUM_TRIAL):
    # Encode the input as state vectors.
    states_value = encoder_model.predict(input_seq)
    # Generate first value of the decoder input sequence
    decoder_input = np.zeros((input_seq.shape[0], 1, num_decoder_tokens))

    # Sampling loop for a batch of sequences
    # (to simplify, here we assume a batch of size 1).
    stop_condition = False
    decoded_sentence = []
    while not stop_condition:
        if RNN.__name__ == 'GRU':
            output_tokens, h = decoder_model.predict(
                [decoder_input] + states_value)
        else:
            output_tokens, h, c = decoder_model.predict(
                [decoder_input] + states_value)
        # Sample a token
        predicted_state = np.argmax(output_tokens, axis=-1)
        decoded_sentence.append(predicted_state)

        # Exit condition: either hit max length
        # or find stop character.
        if len(decoded_sentence) == max_decoder_seq_length:
            stop_condition = True

        # Update the target sequence (of length 1).
        #target_seq = np.zeros((1, 1, num_decoder_tokens))
        decoder_input = output_tokens

        # Update states
        if RNN.__name__ == 'GRU':
            states_value = [h]
        else:
            states_value = [h, c]

    #return decoded_sentence
    return np.concatenate(decoded_sentence, axis=1)

### Prediction

In [None]:
tf.debugging.enable_traceback_filtering()
prediction = decode_sequence(test_features, num_decoder_tokens=1)
prediction.shape

In [None]:
import seaborn as sns

plt.figure(figsize=(12, 5))

agent = 12
result = pd.DataFrame({'true_label':  test_labels[agent], 'dl_label': prediction[agent]})
markers = {"true_label": "v", "dl_label": "."}
#sns.scatterplot(result, markers=markers) #{"true_label": 10, "dl_label": 4
sns.lineplot(result)