In [2]:
#@title Import libraries
import sys
import warnings

import pickle
import time
import torch
import numpy as np
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")


In [31]:
# RL libraries
sys.path.append('resources')  # add source directoy to path
from resources import rnn, rnn_training, bandits, rnn_utils


In [39]:

# train model
train = True
checkpoint = False
data = False

use_lstm = False

path_data = 'data/dataset_train.pkl'
params_path = 'params/params_lstm_b3.pkl'  # overwritten if data is False (adapted to the ground truth model)

# rnn parameters
hidden_size = 4
last_output = False
last_state = False
use_habit = False

# ensemble parameters
sampling_replacement = True
n_submodels = 1   # baseline model / no ensembling
ensemble = False
voting_type = rnn.EnsembleRNN.MEAN  # necessary if ensemble==True, can be mean or median

# tracked variables in the RNN
x_train_list = ['xQf','xQr']
control_list = ['ca','ca[k-1]', 'cr']
if use_habit:
  x_train_list += ['xHf']
sindy_feature_list = x_train_list + control_list

# training parameters
epochs = 8   # change to 100 (madd)
n_steps_per_call = 10  # None for full sequence
batch_size = None  # None for one batch per epoch
learning_rate = 1e-2
convergence_threshold = 1e-6

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if not data:
  # agent parameters
  agent_kw = 'basic'  #@param ['basic', 'quad_q'] 
  gen_alpha = .25 #@param
  gen_beta = 3 #@param
  forget_rate = 0. #@param
  perseverance_bias = 0. #@param
  # environment parameters
  non_binary_reward = False #@param
  n_actions = 2 #@param
  sigma = .1  #@param

  # dataset parameters
  n_trials_per_session = 200  #@param
  n_sessions = 64  #@param


  # setup
  environment = bandits.EnvironmentBanditsDrift(sigma=sigma, n_actions=n_actions, non_binary_rewards=non_binary_reward)
  agent = bandits.AgentQ(gen_alpha, gen_beta, n_actions, forget_rate, perseverance_bias)  

  dataset_train, experiment_list_train = bandits.create_dataset(
      agent=agent,
      environment=environment,
      n_trials_per_session=n_trials_per_session,
      n_sessions=n_sessions,
      device=device)

  dataset_test, experiment_list_test = bandits.create_dataset(
      agent=agent,
      environment=environment,
      n_trials_per_session=200,  #n_trials_per_session,
      n_sessions=1024,
      device=device)
  
  params_path = rnn_utils.parameter_file_naming(
      'params/params',
      use_lstm,
      last_output,
      last_state,
      use_habit,
      gen_beta,
      forget_rate,
      perseverance_bias,
      non_binary_reward,
      verbose=True,
  )
  
else:
  # load data
  with open(path_data, 'rb') as f:
      dataset_train = pickle.load(f)

# define model
if use_lstm:
  model = rnn.LSTM(
      n_actions=n_actions, 
      hidden_size=hidden_size, 
      init_value=0.5,
      device=device,
      ).to(device)
else:
  model = rnn.RLRNN(
      n_actions=n_actions, 
      hidden_size=hidden_size, 
      init_value=0.5,
      last_output=last_output,
      last_state=last_state,
      use_habit=use_habit,
      device=device,
      list_sindy_signals=sindy_feature_list,
      ).to(device)

optimizer_rnn = torch.optim.Adam(model.parameters(), lr=learning_rate)


Automatically generated name for model parameter file: params/params_rnn_b3.pkl.


In [40]:
from resources import rnn, rnn_training, bandits, rnn_utils

if train:
  if checkpoint:
    # load trained parameters
    state_dict = torch.load(params_path, map_location=torch.device('cpu'))
    model.load_state_dict(state_dict['model'])
    optimizer_rnn.load_state_dict(state_dict['optimizer'])
    print('Loaded parameters.')
  
  start_time = time.time()
  
  #Fit the hybrid RNN
  print('Training the hybrid RNN...')
  model, optimizer_rnn, training_loss = rnn_training.fit_model(
      model=model,
      dataset=dataset_train,
      optimizer=optimizer_rnn,
      convergence_threshold=convergence_threshold,
      epochs=epochs,
      n_steps_per_call = n_steps_per_call,
      batch_size=batch_size,
      n_submodels=n_submodels,
      return_ensemble=ensemble,
      voting_type=voting_type,
      sampling_replacement=sampling_replacement,
  )
  

Training the hybrid RNN...
Epoch 1/8 --- Loss: 0.6646731; Time: 0.5s; Convergence value: 3.35e-01
Epoch 2/8 --- Loss: 0.6363984; Time: 0.4s; Convergence value: 1.75e-01
Epoch 3/8 --- Loss: 0.6075963; Time: 0.4s; Convergence value: 1.22e-01
Epoch 4/8 --- Loss: 0.5902762; Time: 0.4s; Convergence value: 9.11e-02
Epoch 5/8 --- Loss: 0.5812210; Time: 0.4s; Convergence value: 7.05e-02
Epoch 6/8 --- Loss: 0.5499825; Time: 0.5s; Convergence value: 6.11e-02
Epoch 7/8 --- Loss: 0.5548703; Time: 0.4s; Convergence value: 4.90e-02
Epoch 8/8 --- Loss: 0.5572730; Time: 0.7s; Convergence value: 1.57e-02
Maximum number of training epochs reached.
Model did not converge yet.


In [35]:
print(training_loss)

tensor(0.5315, grad_fn=<DivBackward0>)


In [None]:

  # validate model
  print('\nValidating the trained hybrid RNN on a test dataset...')
  with torch.no_grad():
    model, optimizer_rnn, val_loss, losses_list = rnn_training.fit_model(     #val_loss gets the final validation loss
        model=model,
        dataset=dataset_test,
        n_steps_per_call=1,
    )

  print(f'Training took {time.time() - start_time:.2f} seconds.')
  
  
  # save trained parameters  
  state_dict = {
    'model': model.state_dict() if isinstance(model, torch.nn.Module) else [model_i.state_dict() for model_i in model],
    'optimizer': optimizer_rnn.state_dict() if isinstance(optimizer_rnn, torch.optim.Adam) else [optim_i.state_dict() for optim_i in optimizer_rnn],
  }
  torch.save(state_dict, params_path)

else:
  # load trained parameters
  model.load_state_dict(torch.load(params_path)['model'])
  print(f'Loaded parameters from file {params_path}.')

# if hasattr(model, 'beta'):
#   print(f'beta: {model.beta}')

# Synthesize a dataset using the fitted network
environment = bandits.EnvironmentBanditsDrift(0.1)
model.set_device(torch.device('cpu'))
model.to(torch.device('cpu'))
rnn_agent = bandits.AgentNetwork(model, n_actions=2, habit=use_habit)
# dataset_rnn, experiment_list_rnn = bandits.create_dataset(rnn_agent, environment, 220, 10)



In [26]:

#print(float(v_loss))
print(rnn_training.fit_model.loss)            

AttributeError: 'function' object has no attribute 'loss'

In [None]:

# Analysis
session_id = 0

choices = experiment_list_test[session_id].choices
rewards = experiment_list_test[session_id].rewards

list_probs = []
list_qs = []

# get q-values from groundtruth
qs_test, probs_test = bandits.get_update_dynamics(experiment_list_test[session_id], agent)
list_probs.append(np.expand_dims(probs_test, 0))
list_qs.append(np.expand_dims(qs_test, 0))

# get q-values from trained rnn
qs_rnn, probs_rnn = bandits.get_update_dynamics(experiment_list_test[session_id], rnn_agent)
list_probs.append(np.expand_dims(probs_rnn, 0))
list_qs.append(np.expand_dims(qs_rnn, 0))

colors = ['tab:blue', 'tab:orange', 'tab:pink', 'tab:grey']

# concatenate all choice probs and q-values
probs = np.concatenate(list_probs, axis=0)
qs = np.concatenate(list_qs, axis=0)

# normalize q-values
# qs = (qs - np.min(qs, axis=1, keepdims=True)) / (np.max(qs, axis=1, keepdims=True) - np.min(qs, axis=1, keepdims=True))


### Model 2: 2 submodels

In [17]:
# train model
train = True
checkpoint = False
data = False

use_lstm = False

path_data = 'data/dataset_train.pkl'
params_path = 'params/params_lstm_b3.pkl'  # overwritten if data is False (adapted to the ground truth model)

# rnn parameters
hidden_size = 4
last_output = False
last_state = False
use_habit = False

# ensemble parameters
sampling_replacement = True
n_submodels = 2   
ensemble = False
voting_type = rnn.EnsembleRNN.MEAN  # necessary if ensemble==True, can be mean or median

# tracked variables in the RNN
x_train_list = ['xQf','xQr']
control_list = ['ca','ca[k-1]', 'cr']
if use_habit:
  x_train_list += ['xHf']
sindy_feature_list = x_train_list + control_list

# training parameters
epochs = 100
n_steps_per_call = 10  # None for full sequence
batch_size = None  # None for one batch per epoch
learning_rate = 1e-2
convergence_threshold = 1e-6

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if not data:
  # agent parameters
  agent_kw = 'basic'  #@param ['basic', 'quad_q'] 
  gen_alpha = .25 #@param
  gen_beta = 3 #@param
  forget_rate = 0. #@param
  perseverance_bias = 0. #@param
  # environment parameters
  non_binary_reward = False #@param
  n_actions = 2 #@param
  sigma = .1  #@param

  # dataset parameters
  n_trials_per_session = 200  #@param
  n_sessions = 64  #@param


  # setup
  environment = bandits.EnvironmentBanditsDrift(sigma=sigma, n_actions=n_actions, non_binary_rewards=non_binary_reward)
  agent = bandits.AgentQ(gen_alpha, gen_beta, n_actions, forget_rate, perseverance_bias)  

  dataset_train, experiment_list_train = bandits.create_dataset(
      agent=agent,
      environment=environment,
      n_trials_per_session=n_trials_per_session,
      n_sessions=n_sessions,
      device=device)

  dataset_test, experiment_list_test = bandits.create_dataset(
      agent=agent,
      environment=environment,
      n_trials_per_session=200,#n_trials_per_session,
      n_sessions=1024,
      device=device)
  
  params_path = rnn_utils.parameter_file_naming(
      'params/params',
      use_lstm,
      last_output,
      last_state,
      use_habit,
      gen_beta,
      forget_rate,
      perseverance_bias,
      non_binary_reward,
      verbose=True,
  )
  
else:
  # load data
  with open(path_data, 'rb') as f:
      dataset_train = pickle.load(f)

# define model
if use_lstm:
  model = rnn.LSTM(
      n_actions=n_actions, 
      hidden_size=hidden_size, 
      init_value=0.5,
      device=device,
      ).to(device)
else:
  model = rnn.RLRNN(
      n_actions=n_actions, 
      hidden_size=hidden_size, 
      init_value=0.5,
      last_output=last_output,
      last_state=last_state,
      use_habit=use_habit,
      device=device,
      list_sindy_signals=sindy_feature_list,
      ).to(device)

optimizer_rnn = torch.optim.Adam(model.parameters(), lr=learning_rate)

if train:
  if checkpoint:
    # load trained parameters
    state_dict = torch.load(params_path, map_location=torch.device('cpu'))
    model.load_state_dict(state_dict['model'])
    optimizer_rnn.load_state_dict(state_dict['optimizer'])
    print('Loaded parameters.')
  
  start_time = time.time()
  
  #Fit the hybrid RNN
  print('Training the hybrid RNN...')
  model, optimizer_rnn, _ = rnn_training.fit_model(
      model=model,
      dataset=dataset_train,
      optimizer=optimizer_rnn,
      convergence_threshold=convergence_threshold,
      epochs=epochs,
      n_steps_per_call = n_steps_per_call,
      batch_size=batch_size,
      n_submodels=n_submodels,
      return_ensemble=ensemble,
      voting_type=voting_type,
      sampling_replacement=sampling_replacement,
  )
  

  # validate model
  print('\nValidating the trained hybrid RNN on a test dataset...')
  with torch.no_grad():
    model, optimizer_rnn, val_loss = rnn_training.fit_model(     #val_loss gets the final validation loss
        model=model,
        dataset=dataset_test,
        n_steps_per_call=1,
    )

  print(f'Training took {time.time() - start_time:.2f} seconds.')
  
  
  # save trained parameters  
  state_dict = {
    'model': model.state_dict() if isinstance(model, torch.nn.Module) else [model_i.state_dict() for model_i in model],
    'optimizer': optimizer_rnn.state_dict() if isinstance(optimizer_rnn, torch.optim.Adam) else [optim_i.state_dict() for optim_i in optimizer_rnn],
  }
  torch.save(state_dict, params_path)

else:
  # load trained parameters
  model.load_state_dict(torch.load(params_path)['model'])
  print(f'Loaded parameters from file {params_path}.')

# if hasattr(model, 'beta'):
#   print(f'beta: {model.beta}')

# Synthesize a dataset using the fitted network
environment = bandits.EnvironmentBanditsDrift(0.1)
model.set_device(torch.device('cpu'))
model.to(torch.device('cpu'))
rnn_agent = bandits.AgentNetwork(model, n_actions=2, habit=use_habit)
# dataset_rnn, experiment_list_rnn = bandits.create_dataset(rnn_agent, environment, 220, 10)



Automatically generated name for model parameter file: params/params_rnn_b3.pkl.
Training the hybrid RNN...
Epoch 1/100 --- Loss: 0.6391087; Time: 1.0s; Convergence value: 3.61e-01
Epoch 2/100 --- Loss: 0.5916474; Time: 1.0s; Convergence value: 2.02e-01
Epoch 3/100 --- Loss: 0.5740830; Time: 1.2s; Convergence value: 1.39e-01
Epoch 4/100 --- Loss: 0.5851295; Time: 1.1s; Convergence value: 1.05e-01
Epoch 5/100 --- Loss: 0.5633962; Time: 1.2s; Convergence value: 8.75e-02
Epoch 6/100 --- Loss: 0.5581319; Time: 1.1s; Convergence value: 7.27e-02
Epoch 7/100 --- Loss: 0.5494964; Time: 1.2s; Convergence value: 6.26e-02
Epoch 8/100 --- Loss: 0.5561534; Time: 1.0s; Convergence value: 5.47e-02
Epoch 9/100 --- Loss: 0.5690601; Time: 1.1s; Convergence value: 4.93e-02
Epoch 10/100 --- Loss: 0.5492237; Time: 0.9s; Convergence value: 4.58e-02
Epoch 11/100 --- Loss: 0.5367042; Time: 0.9s; Convergence value: 4.21e-02
Epoch 12/100 --- Loss: 0.5391178; Time: 1.0s; Convergence value: 3.81e-02
Epoch 13/100 

In [18]:
# 3 submodels

# train model
train = True
checkpoint = False
data = False

use_lstm = False

path_data = 'data/dataset_train.pkl'
params_path = 'params/params_lstm_b3.pkl'  # overwritten if data is False (adapted to the ground truth model)

# rnn parameters
hidden_size = 4
last_output = False
last_state = False
use_habit = False

# ensemble parameters
sampling_replacement = True
n_submodels = 3   
ensemble = False
voting_type = rnn.EnsembleRNN.MEAN  # necessary if ensemble==True, can be mean or median

# tracked variables in the RNN
x_train_list = ['xQf','xQr']
control_list = ['ca','ca[k-1]', 'cr']
if use_habit:
  x_train_list += ['xHf']
sindy_feature_list = x_train_list + control_list

# training parameters
epochs = 100
n_steps_per_call = 10  # None for full sequence
batch_size = None  # None for one batch per epoch
learning_rate = 1e-2
convergence_threshold = 1e-6

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if not data:
  # agent parameters
  agent_kw = 'basic'  #@param ['basic', 'quad_q'] 
  gen_alpha = .25 #@param
  gen_beta = 3 #@param
  forget_rate = 0. #@param
  perseverance_bias = 0. #@param
  # environment parameters
  non_binary_reward = False #@param
  n_actions = 2 #@param
  sigma = .1  #@param

  # dataset parameters
  n_trials_per_session = 200  #@param
  n_sessions = 64  #@param


  # setup
  environment = bandits.EnvironmentBanditsDrift(sigma=sigma, n_actions=n_actions, non_binary_rewards=non_binary_reward)
  agent = bandits.AgentQ(gen_alpha, gen_beta, n_actions, forget_rate, perseverance_bias)  

  dataset_train, experiment_list_train = bandits.create_dataset(
      agent=agent,
      environment=environment,
      n_trials_per_session=n_trials_per_session,
      n_sessions=n_sessions,
      device=device)

  dataset_test, experiment_list_test = bandits.create_dataset(
      agent=agent,
      environment=environment,
      n_trials_per_session=200,#n_trials_per_session,
      n_sessions=1024,
      device=device)
  
  params_path = rnn_utils.parameter_file_naming(
      'params/params',
      use_lstm,
      last_output,
      last_state,
      use_habit,
      gen_beta,
      forget_rate,
      perseverance_bias,
      non_binary_reward,
      verbose=True,
  )
  
else:
  # load data
  with open(path_data, 'rb') as f:
      dataset_train = pickle.load(f)

# define model
if use_lstm:
  model = rnn.LSTM(
      n_actions=n_actions, 
      hidden_size=hidden_size, 
      init_value=0.5,
      device=device,
      ).to(device)
else:
  model = rnn.RLRNN(
      n_actions=n_actions, 
      hidden_size=hidden_size, 
      init_value=0.5,
      last_output=last_output,
      last_state=last_state,
      use_habit=use_habit,
      device=device,
      list_sindy_signals=sindy_feature_list,
      ).to(device)

optimizer_rnn = torch.optim.Adam(model.parameters(), lr=learning_rate)

if train:
  if checkpoint:
    # load trained parameters
    state_dict = torch.load(params_path, map_location=torch.device('cpu'))
    model.load_state_dict(state_dict['model'])
    optimizer_rnn.load_state_dict(state_dict['optimizer'])
    print('Loaded parameters.')
  
  start_time = time.time()
  
  #Fit the hybrid RNN
  print('Training the hybrid RNN...')
  model, optimizer_rnn, _ = rnn_training.fit_model(
      model=model,
      dataset=dataset_train,
      optimizer=optimizer_rnn,
      convergence_threshold=convergence_threshold,
      epochs=epochs,
      n_steps_per_call = n_steps_per_call,
      batch_size=batch_size,
      n_submodels=n_submodels,
      return_ensemble=ensemble,
      voting_type=voting_type,
      sampling_replacement=sampling_replacement,
  )
  

  # validate model
  print('\nValidating the trained hybrid RNN on a test dataset...')
  with torch.no_grad():
    model, optimizer_rnn, val_loss, losses_list = rnn_training.fit_model(     #val_loss gets the final validation loss
        model=model,
        dataset=dataset_test,
        n_steps_per_call=1,
    )

  print(f'Training took {time.time() - start_time:.2f} seconds.')
  
  
  # save trained parameters  
  state_dict = {
    'model': model.state_dict() if isinstance(model, torch.nn.Module) else [model_i.state_dict() for model_i in model],
    'optimizer': optimizer_rnn.state_dict() if isinstance(optimizer_rnn, torch.optim.Adam) else [optim_i.state_dict() for optim_i in optimizer_rnn],
  }
  torch.save(state_dict, params_path)

else:
  # load trained parameters
  model.load_state_dict(torch.load(params_path)['model'])
  print(f'Loaded parameters from file {params_path}.')

# if hasattr(model, 'beta'):
#   print(f'beta: {model.beta}')

# Synthesize a dataset using the fitted network
environment = bandits.EnvironmentBanditsDrift(0.1)
model.set_device(torch.device('cpu'))
model.to(torch.device('cpu'))
rnn_agent = bandits.AgentNetwork(model, n_actions=2, habit=use_habit)
# dataset_rnn, experiment_list_rnn = bandits.create_dataset(rnn_agent, environment, 220, 10)



Automatically generated name for model parameter file: params/params_rnn_b3.pkl.
Training the hybrid RNN...
Epoch 1/100 --- Loss: 0.6557027; Time: 1.5s; Convergence value: 3.44e-01
Epoch 2/100 --- Loss: 0.6089835; Time: 1.7s; Convergence value: 1.93e-01
Epoch 3/100 --- Loss: 0.5718898; Time: 1.6s; Convergence value: 1.40e-01
Epoch 4/100 --- Loss: 0.5655484; Time: 1.7s; Convergence value: 1.05e-01
Epoch 5/100 --- Loss: 0.5846874; Time: 1.5s; Convergence value: 8.67e-02
Epoch 6/100 --- Loss: 0.5786796; Time: 1.5s; Convergence value: 7.21e-02
Epoch 7/100 --- Loss: 0.5775130; Time: 1.4s; Convergence value: 6.09e-02
Epoch 8/100 --- Loss: 0.5572833; Time: 1.1s; Convergence value: 5.52e-02
Epoch 9/100 --- Loss: 0.5614052; Time: 1.2s; Convergence value: 4.87e-02
Epoch 10/100 --- Loss: 0.5979993; Time: 1.2s; Convergence value: 4.71e-02
Epoch 11/100 --- Loss: 0.5673509; Time: 1.3s; Convergence value: 4.52e-02
Epoch 12/100 --- Loss: 0.5617303; Time: 1.3s; Convergence value: 4.12e-02
Epoch 13/100 

ValueError: not enough values to unpack (expected 4, got 3)