# libraries

In [32]:
#@title Import libraries
import sys
import os
import warnings
from typing import Callable, Tuple, Iterable, Union

import matplotlib.pyplot as plt
from sympy.parsing.sympy_parser import parse_expr
import numpy as np
import pandas as pd
import scipy.stats as st
import pickle

# deepmind related libraries
import haiku as hk
import jax
import jax.numpy as jnp
import optax

import pysindy as ps

warnings.filterwarnings("ignore")

# RL libraries
sys.path.append('resources')  # add source directoy to path
from resources import bandits, disrnn, hybrnn, hybrnn_forget, plotting, rat_data, rnn_utils

# Agent

In [33]:
#@title make update rule of Q-/SINDyNetwork-Agents adjustable and make values of RNN-Agent visible

class AgentQuadQ(bandits.AgentQ):
  
  def __init__(
      self,
      alpha: float=0.2,
      beta: float=3.,
      n_actions: int=2,
      forgetting_rate: float=0.,
      perseveration_bias: float=0.,
      ):
    super().__init__(alpha, beta, n_actions, forgetting_rate, perseveration_bias)
  
  def update(self,
            choice: int,
            reward: float):
    """Update the agent after one step of the task.

    Args:
      choice: The choice made by the agent. 0 or 1
      reward: The reward received by the agent. 0 or 1
    """
    
    # Decay q-values toward the initial value.
    self._q = (1-self._forgetting_rate) * self._q + self._forgetting_rate * self._q_init

    # Update chosen q for chosen action with observed reward.
    self._q[choice] = self._q[choice] - self._alpha * self._q[choice]**2 + self._alpha * reward


class AgentSindy(bandits.AgentQ):

  def __init__(
      self,
      alpha: float=0.2,
      beta: float=3.,
      n_actions: int=2,
      forgetting_rate: float=0.,
      perservation_bias: float=0.,):
    super().__init__(alpha, beta, n_actions, forgetting_rate, perservation_bias)

    self._update_rule = lambda q, choice, reward: (1 - self._alpha) * q[choice] + self._alpha * reward
    self._update_rule_formula = None

  def set_update_rule(self, update_rule: callable, update_rule_formula: str=None):
    self._update_rule=update_rule
    self._update_rule_formula=update_rule_formula

  @property
  def update_rule(self):
    if self._update_rule_formula is not None:
      return self._update_rule_formula
    else:
      return f'{self._update_rule}'

  def update(self, choice: int, reward: int):

    for c in range(self._n_actions):
      self._q[c] = self._update_rule(self._q[c], int(c==choice), reward)


class AgentNetwork_VisibleState(bandits.AgentNetwork):

  def __init__(self,
               make_network: Callable[[], hk.RNNCore],
               params: hk.Params,
               n_actions: int = 2,
               state_to_numpy: bool = False,
               habit=False):
    super().__init__(make_network=make_network, params=params, n_actions=n_actions, state_to_numpy=state_to_numpy)
    self.habit = habit

  @property
  def q(self):
    if self.habit:
      return self._state[2], self._state[3]
    else:
      return self._state[3].reshape(-1)


In [34]:
dict_agents = {
    'basic': lambda alpha, beta, n_actions, forgetting_rate, perseveration_bias: bandits.AgentQ(alpha, beta, n_actions, forgetting_rate, perseveration_bias),
    'quad_q': lambda alpha, beta, n_actions, forgetting_rate, perseveration_bias: AgentQuadQ(alpha, beta, n_actions, forgetting_rate, perseveration_bias)
}

# Dataset

In [67]:
class DatasetCreator:
    def __init__(self, dataset_type, agent_dict):
        self.dataset_type = dataset_type
        #self.environment = None
        self.agent_dict = agent_dict

    def create_dataset(self):
        if self.dataset_type == 'synt':
            self.setup_synthetic_data()
            self.dataset_train, self.experiment_list_train = self.generate_data()
            self.dataset_test, self.experiment_list_test = self.generate_data()
        
        elif self.dataset_type == 'real':
            raise NotImplementedError('Real data setup not implemented yet.')
        
        else:
            raise NotImplementedError(f'dataset_type {self.dataset_type} not implemented. Please select from drop-down list.')

    def setup_synthetic_data(self):
        # Define agent parameters
        agent_kw = 'basic'  # ['basic', 'quad_q']
        gen_alpha = 0.25
        gen_beta = 5
        forgetting_rate = 0.1
        perseveration_bias = 0.0
        
        # Define environment parameters
        non_binary_reward = False
        self.n_actions = 2
        sigma = 0.1
        
        # Define experiment parameters
        self.n_trials_per_session = 200
        self.n_sessions = 220
        
        # Setup environment and agent
        self.environment = bandits.EnvironmentBanditsDrift(sigma=sigma, n_actions=self.n_actions, non_binary_rewards=non_binary_reward)
        self.agent = self.agent_dict[agent_kw](gen_alpha, gen_beta, self.n_actions, forgetting_rate, perseveration_bias)
    
    def setup_real_data(self):
        pass


    def generate_data(self):
        return bandits.create_dataset(
            agent=self.agent,
            environment=self.environment,
            n_trials_per_session=self.n_trials_per_session,
            n_sessions=self.n_sessions
        )


In [68]:
data = DatasetCreator(dataset_type='synt', agent_dict=dict_agents)
data.create_dataset()
n_actions = data.n_actions
agent = data.agent

dataset_train, experiment_list_train = data.dataset_train, data.experiment_list_train
dataset_test, experiment_list_test = data.dataset_test, data.experiment_list_test

# RNN

In [37]:
class HybridRNN:
    def __init__(self, use_hidden_state=False, use_previous_values=False, fit_forget=False, habit_weight=0.0, value_weight=1.0, n_actions=2, hidden_size=16):
        # Store parameters
        self.use_hidden_state = use_hidden_state
        self.use_previous_values = use_previous_values
        self.fit_forget = fit_forget
        self.habit_weight = float(habit_weight)
        self.value_weight = value_weight
        self.n_actions = n_actions
        self.hidden_size = hidden_size
        
        # Set up the RNN parameters
        self.rnn_rl_params = {
            's': self.use_hidden_state,
            'o': self.use_previous_values,
            'fit_forget': self.fit_forget,
            'forget': 0.,
            'w_h': self.habit_weight,
            'w_v': self.value_weight
        }
        self.network_params = {
            'n_actions': self.n_actions,
            'hidden_size': self.hidden_size
        }
        
        # Initialize the model
        # self.model = self.make_hybrnn()
        
    
    def make_hybrnn(self):
        return hybrnn_forget.BiRNN(rl_params=self.rnn_rl_params, network_params=self.network_params)

In [40]:
habit_weight=0.0 # used in Sindy RNN
rnn = HybridRNN(habit_weight=0.0)
optimizer_rnn = optax.adam(learning_rate=1e-3)

In [42]:
class RNNTrainer:
    def __init__(self, params_path, train=True, load=False, loss_function='categorical'):
        self.params_path = params_path
        self.train = train
        self.load = load
        self.loss_function = loss_function
        self.optimizer = optax.adam(learning_rate=1e-3)
        self.rnn_params = None
        self.opt_state = None

    def load_parameters(self):
        try:
            with open(self.params_path, 'rb') as f:
                saved_params = pickle.load(f)
            self.rnn_params, self.opt_state = saved_params[0], saved_params[1]
            print('Loaded parameters.')
        except FileNotFoundError:
            print('No parameters found to load.')

    def save_parameters(self):
        with open(self.params_path, 'wb') as f:
            pickle.dump((self.rnn_params, self.opt_state), f)
        print('Parameters saved.')

    def train_model(self, dataset_train, n_steps_max=10000, convergence_thresh=1e-5):
        if self.train:
            if self.load:
                self.load_parameters()
            else:
                self.rnn_params, self.opt_state = None, None

            print('Training the hybrid RNN...')
            self.rnn_params, self.opt_state, _ = rnn_utils.fit_model(
                model_fun=rnn.make_hybrnn,
                dataset=dataset_train,
                optimizer=self.optimizer,
                optimizer_state=self.opt_state,
                model_params=self.rnn_params,
                loss_fun=self.loss_function,
                convergence_thresh=convergence_thresh,
                n_steps_max=n_steps_max
            )

            self.save_parameters()

    def execute(self, dataset_train):
        if self.train:
            self.train_model(dataset_train)
        else:
            self.load_parameters()

In [43]:
params_path = 'params/params_rnn_forget_f01_b5.pkl'
rnn_train = RNNTrainer(params_path=params_path, train=True, load=False)
rnn_train.execute(dataset_train)

Training the hybrid RNN...
Step 500 of 500; Loss: 0.5348482; Time: 10.6s)
Model not yet converged - Running more steps of gradient descent. Time elapsed = 2e-05s.
Step 500 of 500; Loss: 0.5346749; Time: 11.1s)
Model not yet converged (convergence_value = 0.000323851) - Running more steps of gradient descent. Time elapsed = 3e-05s.
Step 500 of 500; Loss: 0.5346544; Time: 11.4s)
Model not yet converged (convergence_value = 3.846001e-05) - Running more steps of gradient descent. Time elapsed = 3e-05s.
Step 500 of 500; Loss: 0.5346281; Time: 12.8s)
Model not yet converged (convergence_value = 4.916381e-05) - Running more steps of gradient descent. Time elapsed = 4e-05s.
Step 500 of 500; Loss: 0.5345983; Time: 13.7s)
Model not yet converged (convergence_value = 5.574403e-05) - Running more steps of gradient descent. Time elapsed = 3e-05s.
Step 500 of 500; Loss: 0.5345682; Time: 11.5s)
Model not yet converged (convergence_value = 5.630461e-05) - Running more steps of gradient descent. Time e