In [1]:
#@title Import libraries
import sys
import warnings

import pickle
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd


warnings.filterwarnings("ignore")

# RL libraries
sys.path.append('resources')  # add source directoy to path
from resources import rnn, rnn_training, bandits, rnn_utils

In [2]:
# train model (BASELINE)

train = True
checkpoint = False
data = False

path_data = 'data/dataset_train.pkl'
params_path = 'params/params_lstm_b3.pkl'  # overwritten if data is False (adapted to the ground truth model)

# rnn parameters
hidden_size = 4
last_output = False
last_state = False
use_lstm = False

# ensemble parameters
evolution_interval = None
sampling_replacement = False
n_submodels = 1
ensemble = rnn_training.ensemble_types.NONE
voting_type = rnn.EnsembleRNN.MEDIAN  # necessary if ensemble==True


# training parameters
epochs = 5
n_steps_per_call = 16  # None for full sequence
batch_size = None  # None for one batch per epoch
learning_rate = 1e-2
convergence_threshold = 1e-6


# ground truth parameters
gen_alpha = .25
gen_beta = 1
forget_rate = 0.1  # possible values: 0., 0.1
perseverance_bias = 0.
correlated_update = False  # possible values: True, False


# environment parameters
n_actions = 2
sigma = 0.1
n_trials_per_session = 200
n_sessions = 256
correlated_reward = False
non_binary_reward = False


# tracked variables in the RNN
x_train_list = ['xQf','xQr', 'xQc']
control_list = ['ca','ca[k-1]', 'cr']
sindy_feature_list = x_train_list + control_list

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


if not data:
  # setup
  environment = bandits.EnvironmentBanditsDrift(sigma=sigma, n_actions=n_actions, non_binary_reward=non_binary_reward, correlated_reward=correlated_reward)
  agent = bandits.AgentQ(gen_alpha, gen_beta, n_actions, forget_rate, perseverance_bias, correlated_update)  

  dataset_train, experiment_list_train = bandits.create_dataset(
      agent=agent,
      environment=environment,
      n_trials_per_session=n_trials_per_session,
      n_sessions=n_sessions,
      device=device)

  dataset_test, experiment_list_test = bandits.create_dataset(
      agent=agent,
      environment=environment,
      n_trials_per_session=200,
      n_sessions=1024,
      device=device)
  
  params_path = rnn_utils.parameter_file_naming(
      'params/params',
      use_lstm,
      last_output,
      last_state,
      gen_beta,
      forget_rate,
      perseverance_bias,
      correlated_update,
      non_binary_reward,
      verbose=True,
  )
  
else:
  # load data
  with open(path_data, 'rb') as f:
      dataset_train = pickle.load(f)

if ensemble > -1 and n_submodels == 1:
  Warning('Ensemble is actived but n_submodels is set to 1. Deactivating ensemble...')
  ensemble = rnn_training.ensemble_types.NONE

# define model
if use_lstm:
  model = rnn.LSTM(
      n_actions=n_actions, 
      hidden_size=hidden_size, 
      init_value=0.5,
      device=device,
      ).to(device)
else:
  model = [rnn.RLRNN(
      n_actions=n_actions, 
      hidden_size=hidden_size, 
      init_value=0.5,
      last_output=last_output,
      last_state=last_state,
      device=device,
      list_sindy_signals=sindy_feature_list,
      ).to(device)
           for _ in range(n_submodels)]

optimizer_rnn = [torch.optim.Adam(m.parameters(), lr=learning_rate) for m in model]

if checkpoint:
    # load trained parameters
    state_dict = torch.load(params_path, map_location=torch.device('cpu'))
    state_dict_model = state_dict['model']
    state_dict_optimizer = state_dict['optimizer']
    if isinstance(state_dict_model, dict):
      for m, o in zip(model, optimizer_rnn):
        m.load_state_dict(state_dict_model)
        o.load_state_dict(state_dict_optimizer)
    elif isinstance(state_dict_model, list):
        print('Loading ensemble model...')
        for i, state_dict_model_i, state_dict_optim_i in zip(range(n_submodels), state_dict_model, state_dict_optimizer):
            model[i].load_state_dict(state_dict_model_i)
            optimizer_rnn[i].load_state_dict(state_dict_optim_i)
        rnn = rnn.EnsembleRNN(model, voting_type=voting_type)
    print('Loaded parameters.')

if train:
  
  start_time = time.time()
  
  #Fit the hybrid RNN
  print('Training the hybrid RNN...')
  model, optimizer_rnn, _ = rnn_training.fit_model(
      model=model,
      dataset=dataset_train,
      optimizer=optimizer_rnn,
      convergence_threshold=convergence_threshold,
      epochs=epochs,
      batch_size=batch_size,
      n_submodels=n_submodels,
      ensemble_type=ensemble,
      voting_type=voting_type,
      sampling_replacement=sampling_replacement,
      evolution_interval=evolution_interval,
      n_steps_per_call=n_steps_per_call,
  )
  

  baseline_losses = []

  # validate model
  print('\nValidating the trained hybrid RNN on a test dataset...')

  for _ in range(10):
    with torch.no_grad():
      model, _, loss = rnn_training.fit_model(
          model=model,
          dataset=dataset_test,
          n_steps_per_call=1,
      )
      baseline_losses.append(float(loss))


  print(f'Training took {time.time() - start_time:.2f} seconds.')
  

  # save trained parameters  
  state_dict = {
    'model': model.state_dict() if isinstance(model, torch.nn.Module) else [model_i.state_dict() for model_i in model],
    'optimizer': optimizer_rnn.state_dict() if isinstance(optimizer_rnn, torch.optim.Adam) else [optim_i.state_dict() for optim_i in optimizer_rnn],
  }
  torch.save(state_dict, params_path)
  
  print(f'Saved RNN parameters to file {params_path}.')

else:
  model, _, _ = rnn_training.fit_model(
      model=model,
      dataset=dataset_train,
      epochs=0,
      n_submodels=n_submodels,
      ensemble_type=ensemble,
      voting_type=voting_type,
      verbose=True
  )



# Synthesize a dataset using the fitted network
environment = bandits.EnvironmentBanditsDrift(0.1)
model.set_device(torch.device('cpu'))
model.to(torch.device('cpu'))
rnn_agent = bandits.AgentNetwork(model, n_actions=2)

# Analysis
session_id = 0

choices = experiment_list_test[session_id].choices
rewards = experiment_list_test[session_id].rewards

list_probs = []
list_qs = []

# get q-values from groundtruth
qs_test, probs_test = bandits.get_update_dynamics(experiment_list_test[session_id], agent)
list_probs.append(np.expand_dims(probs_test, 0))
list_qs.append(np.expand_dims(qs_test, 0))

# get q-values from trained rnn
qs_rnn, probs_rnn = bandits.get_update_dynamics(experiment_list_test[session_id], rnn_agent)
list_probs.append(np.expand_dims(probs_rnn, 0))
list_qs.append(np.expand_dims(qs_rnn, 0))

colors = ['tab:blue', 'tab:orange', 'tab:pink', 'tab:grey']

# concatenate all choice probs and q-values
probs = np.concatenate(list_probs, axis=0)
qs = np.concatenate(list_qs, axis=0)

# normalize q-values
def normalize(qs):
  return (qs - np.min(qs, axis=1, keepdims=True)) / (np.max(qs, axis=1, keepdims=True) - np.min(qs, axis=1, keepdims=True))

qs = normalize(qs)
fig, axs = plt.subplots(4, 1, figsize=(20, 10))

reward_probs = np.stack([experiment_list_test[session_id].timeseries[:, i] for i in range(n_actions)], axis=0)
bandits.plot_session(
    compare=True,
    choices=choices,
    rewards=rewards,
    timeseries=reward_probs,
    timeseries_name='Reward Probs',
    labels=[f'Arm {a}' for a in range(n_actions)],
    color=['tab:purple', 'tab:cyan'],
    binary=not non_binary_reward,
    fig_ax=(fig, axs[0]),
    )

bandits.plot_session(
    compare=True,
    choices=choices,
    rewards=rewards,
    timeseries=probs[:, :, 0],
    timeseries_name='Choice Probs',
    color=colors,
    labels=['Ground Truth', 'RNN'],
    binary=not non_binary_reward,
    fig_ax=(fig, axs[1]),
    )

bandits.plot_session(
    compare=True,
    choices=choices,
    rewards=rewards,
    timeseries=qs[:, :, 0],
    timeseries_name='Q-Values',
    color=colors,
    binary=not non_binary_reward,
    fig_ax=(fig, axs[2]),
    )

dqs_arms = normalize(-1*np.diff(qs, axis=2))

bandits.plot_session(
    compare=True,
    choices=choices,
    rewards=rewards,
    timeseries=dqs_arms[:, :, 0],
    timeseries_name='dQ/dActions',
    color=colors,
    binary=not non_binary_reward,
    fig_ax=(fig, axs[3]),
    )

plt.show()



Automatically generated name for model parameter file: params/params_rnn_b1_f01.pkl.
Training the hybrid RNN...
Epoch 1/5 --- Loss: 0.6818449; Time: 11.6449s; Convergence value: 3.18e-01
Epoch 2/5 --- Loss: 0.6814595; Time: 11.8765s; Convergence value: 1.45e-01
Epoch 3/5 --- Loss: 0.6810587; Time: 12.0010s; Convergence value: 8.51e-02
Epoch 4/5 --- Loss: 0.6806923; Time: 11.7729s; Convergence value: 5.33e-02
Epoch 5/5 --- Loss: 0.6802283; Time: 11.7369s; Convergence value: 4.10e-04
Maximum number of training epochs reached.
Model did not converge yet.

Validating the trained hybrid RNN on a test dataset...
Epoch 1/1 --- Loss: 0.6868224; Time: 1.0394s; Convergence value: nan
Maximum number of training epochs reached.
Model did not converge yet.
Epoch 1/1 --- Loss: 0.6868225; Time: 0.7296s; Convergence value: nan
Maximum number of training epochs reached.
Model did not converge yet.
Epoch 1/1 --- Loss: 0.6868224; Time: 0.7618s; Convergence value: nan
Maximum number of training epochs rea

IndexError: tensors used as indices must be long, byte or bool tensors

In [3]:
baseline_losses

[0.6868224143981934,
 0.6868224740028381,
 0.6868224143981934,
 0.6868224740028381,
 0.6868224740028381,
 0.6868224740028381,
 0.6868223547935486,
 0.6868224740028381,
 0.6868224740028381,
 0.6868224143981934]

In [5]:
nsubmod = [4, 32, 64] # numer of submodels
sessionlist = [32, 64] # number of sessions
tri = [50, 200]  # number of trials
rep = [True] # replacement

for m in nsubmod:
    for ses in sessionlist:
        for t in tri:
            for r in rep:
    
                # train model 
                train = True
                checkpoint = False
                data = False

                path_data = 'data/dataset_train.pkl'
                params_path = 'params/params_lstm_b3.pkl'  # overwritten if data is False (adapted to the ground truth model)

                # rnn parameters
                hidden_size = 4
                last_output = False
                last_state = False
                use_lstm = False

                # ensemble parameters
                evolution_interval = None
                sampling_replacement = bool(r)
                n_submodels = int(m)
                ensemble = rnn_training.ensemble_types.AVERAGE
                voting_type = rnn.EnsembleRNN.MEDIAN  # necessary if ensemble==True


                # training parameters
                epochs = 5
                n_steps_per_call = 16  # None for full sequence
                batch_size = None  # None for one batch per epoch
                learning_rate = 1e-2
                convergence_threshold = 1e-6


                # ground truth parameters
                gen_alpha = .25
                gen_beta = 3
                forget_rate = 0.1  # possible values: 0., 0.1
                perseverance_bias = 0.
                correlated_update = False  # possible values: True, False


                # environment parameters
                n_actions = 2
                sigma = 0.1
                n_trials_per_session = int(t) #200
                n_sessions = int(ses)
                correlated_reward = False
                non_binary_reward = False


                # tracked variables in the RNN
                x_train_list = ['xQf','xQr', 'xQc']
                control_list = ['ca','ca[k-1]', 'cr']
                sindy_feature_list = x_train_list + control_list

                device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


                if not data:
                    # setup
                    environment = bandits.EnvironmentBanditsDrift(sigma=sigma, n_actions=n_actions, non_binary_reward=non_binary_reward, correlated_reward=correlated_reward)
                    agent = bandits.AgentQ(gen_alpha, gen_beta, n_actions, forget_rate, perseverance_bias, correlated_update)  

                    dataset_train, experiment_list_train = bandits.create_dataset(
                            agent=agent,
                            environment=environment,
                            n_trials_per_session=n_trials_per_session,
                            n_sessions=n_sessions,
                            device=device)

                    dataset_test, experiment_list_test = bandits.create_dataset(
                            agent=agent,
                            environment=environment,
                            n_trials_per_session=200,
                            n_sessions=1024,
                            device=device)
                    
                    params_path = rnn_utils.parameter_file_naming(
                            'params/params',
                            use_lstm,
                            last_output,
                            last_state,
                            gen_beta,
                            forget_rate,
                            perseverance_bias,
                            correlated_update,
                            non_binary_reward,
                            verbose=True,
                    )
            
                else:
                    # load data
                    with open(path_data, 'rb') as f:
                            dataset_train = pickle.load(f)

                if ensemble > -1 and n_submodels == 1:
                    Warning('Ensemble is actived but n_submodels is set to 1. Deactivating ensemble...')
                    ensemble = rnn_training.ensemble_types.NONE

                # define model
                if use_lstm:
                    model = rnn.LSTM(
                            n_actions=n_actions, 
                            hidden_size=hidden_size, 
                            init_value=0.5,
                            device=device,
                            ).to(device)
                else:
                    model = [rnn.RLRNN(
                            n_actions=n_actions, 
                            hidden_size=hidden_size, 
                            init_value=0.5,
                            last_output=last_output,
                            last_state=last_state,
                            device=device,
                            list_sindy_signals=sindy_feature_list,
                            ).to(device)
                                for _ in range(n_submodels)]

                optimizer_rnn = [torch.optim.Adam(m.parameters(), lr=learning_rate) for m in model]

                if checkpoint:
                    # load trained parameters
                    state_dict = torch.load(params_path, map_location=torch.device('cpu'))
                    state_dict_model = state_dict['model']
                    state_dict_optimizer = state_dict['optimizer']
                    if isinstance(state_dict_model, dict):
                        for m, o in zip(model, optimizer_rnn):
                            m.load_state_dict(state_dict_model)
                            o.load_state_dict(state_dict_optimizer)
                    elif isinstance(state_dict_model, list):
                        print('Loading ensemble model...')
                        for i, state_dict_model_i, state_dict_optim_i in zip(range(n_submodels), state_dict_model, state_dict_optimizer):
                            model[i].load_state_dict(state_dict_model_i)
                            optimizer_rnn[i].load_state_dict(state_dict_optim_i)
                        rnn = rnn.EnsembleRNN(model, voting_type=voting_type)
                    print('Loaded parameters.')


                if train:
                    start_time = time.time()
                    
                    #Fit the hybrid RNN
                    print('Training the hybrid RNN...')
                    model, optimizer_rnn, _ = rnn_training.fit_model(
                            model=model,
                            dataset=dataset_train,
                            optimizer=optimizer_rnn,
                            convergence_threshold=convergence_threshold,
                            epochs=epochs,
                            batch_size=batch_size,
                            n_submodels=n_submodels,
                            ensemble_type=ensemble,
                            voting_type=voting_type,
                            sampling_replacement=sampling_replacement,
                            evolution_interval=evolution_interval,
                            n_steps_per_call=n_steps_per_call,
                    )
                    

                    model_name = [] # validation loss list

                    # validate model
                    print('\nValidating the trained hybrid RNN on a test dataset...')

                    for _ in range(10):
                        with torch.no_grad():
                            model, _, loss = rnn_training.fit_model(
                                    model=model,
                                    dataset=dataset_test,
                                    n_steps_per_call=1,
                            )
                        model_name.append(float(loss))


                    print(f'Training took {time.time() - start_time:.2f} seconds.')
                    

                    # save trained parameters  
                    state_dict = {
                        'model': model.state_dict() if isinstance(model, torch.nn.Module) else [model_i.state_dict() for model_i in model],
                        'optimizer': optimizer_rnn.state_dict() if isinstance(optimizer_rnn, torch.optim.Adam) else [optim_i.state_dict() for optim_i in optimizer_rnn],
                    }
                    torch.save(state_dict, params_path)
                    
                    print(f'Saved RNN parameters to file {params_path}.')

                else:
                    model, _, _ = rnn_training.fit_model(
                            model=model,
                            dataset=dataset_train,
                            epochs=0,
                            n_submodels=n_submodels,
                            ensemble_type=ensemble,
                            voting_type=voting_type,
                            verbose=True
                    )

                df = pd.DataFrame(columns=['model', 'loss','Replacement', 'Submodels', 'Voting', 'Ensemble','Sessions', 'Trials'])

                df_lenght = len(model_name)

                df["model"] = ["model_name"] *df_lenght
                df["loss"] = model_name

                df["Replacement"] = [r] *df_lenght
                df["Submodels"] = [n_submodels] *df_lenght
                df["Voting"] = ["median"] *df_lenght
                df["Ensemble"] = ["average"] *df_lenght
                df["Sessions"] = [n_sessions] *df_lenght
                df["Trials"] = [n_trials_per_session] *df_lenght


                df1 = pd.read_csv("losses.csv")
                losses = df1.append(df)

                losses.to_csv("losses.csv", index=False)




Automatically generated name for model parameter file: params/params_rnn_b3_f01.pkl.
Training the hybrid RNN...
Epoch 1/5 --- Loss: 0.6664248; Time: 5.7079s; Convergence value: 3.34e-01
Epoch 2/5 --- Loss: 0.6431982; Time: 6.3183s; Convergence value: 1.64e-01
Epoch 3/5 --- Loss: 0.6423386; Time: 6.6044s; Convergence value: 9.70e-02
Epoch 4/5 --- Loss: 0.6184285; Time: 6.0435s; Convergence value: 6.90e-02
Epoch 5/5 --- Loss: 0.6121451; Time: 6.3203s; Convergence value: 1.28e-02
Maximum number of training epochs reached.
Model did not converge yet.

Validating the trained hybrid RNN on a test dataset...
Epoch 1/1 --- Loss: 0.5938452; Time: 0.8050s; Convergence value: nan
Maximum number of training epochs reached.
Model did not converge yet.
Epoch 1/1 --- Loss: 0.5938452; Time: 0.5766s; Convergence value: nan
Maximum number of training epochs reached.
Model did not converge yet.
Epoch 1/1 --- Loss: 0.5938452; Time: 0.5586s; Convergence value: nan
Maximum number of training epochs reached.