# Import Libraries

In [1]:
#@title Import libraries
import sys
import os
import warnings
from typing import Callable, Tuple, Iterable, Union

import matplotlib.pyplot as plt
from sympy.parsing.sympy_parser import parse_expr
import numpy as np
import pandas as pd
import scipy.stats as st
import pickle

# deepmind related libraries
import haiku as hk
import jax
import jax.numpy as jnp
import optax

import pysindy as ps

warnings.filterwarnings("ignore")

# RL libraries
sys.path.append('resources')  # add source directoy to path
from resources import bandits, disrnn, hybrnn, hybrnn_forget, plotting, rat_data, rnn_utils

# Agent

In [2]:
#@title make update rule of Q-/SINDyNetwork-Agents adjustable and make values of RNN-Agent visible

class AgentQuadQ(bandits.AgentQ):
  
  def __init__(
      self,
      alpha: float=0.2,
      beta: float=3.,
      n_actions: int=2,
      forgetting_rate: float=0.,
      perseveration_bias: float=0.,
      ):
    super().__init__(alpha, beta, n_actions, forgetting_rate, perseveration_bias)
  
  def update(self,
            choice: int,
            reward: float):
    """Update the agent after one step of the task.

    Args:
      choice: The choice made by the agent. 0 or 1
      reward: The reward received by the agent. 0 or 1
    """
    
    # Decay q-values toward the initial value.
    self._q = (1-self._forgetting_rate) * self._q + self._forgetting_rate * self._q_init

    # Update chosen q for chosen action with observed reward.
    self._q[choice] = self._q[choice] - self._alpha * self._q[choice]**2 + self._alpha * reward


class AgentSindy(bandits.AgentQ):

  def __init__(
      self,
      alpha: float=0.2,
      beta: float=3.,
      n_actions: int=2,
      forgetting_rate: float=0.,
      perservation_bias: float=0.,):
    super().__init__(alpha, beta, n_actions, forgetting_rate, perservation_bias)

    self._update_rule = lambda q, choice, reward: (1 - self._alpha) * q[choice] + self._alpha * reward
    self._update_rule_formula = None

  def set_update_rule(self, update_rule: callable, update_rule_formula: str=None):
    self._update_rule=update_rule
    self._update_rule_formula=update_rule_formula

  @property
  def update_rule(self):
    if self._update_rule_formula is not None:
      return self._update_rule_formula
    else:
      return f'{self._update_rule}'

  def update(self, choice: int, reward: int):

    for c in range(self._n_actions):
      self._q[c] = self._update_rule(self._q[c], int(c==choice), reward)


class AgentNetwork_VisibleState(bandits.AgentNetwork):

  def __init__(self,
               make_network: Callable[[], hk.RNNCore],
               params: hk.Params,
               n_actions: int = 2,
               state_to_numpy: bool = False,
               habit=False):
    super().__init__(make_network=make_network, params=params, n_actions=n_actions, state_to_numpy=state_to_numpy)
    self.habit = habit

  @property
  def q(self):
    if self.habit:
      return self._state[2], self._state[3]
    else:
      return self._state[3].reshape(-1)

dict_agents = {
    'basic': lambda alpha, beta, n_actions, forgetting_rate, perseveration_bias: bandits.AgentQ(alpha, beta, n_actions, forgetting_rate, perseveration_bias),
    'quad_q': lambda alpha, beta, n_actions, forgetting_rate, perseveration_bias: AgentQuadQ(alpha, beta, n_actions, forgetting_rate, perseveration_bias)
}

# RNN


## dataset 

In [3]:
# create dataset


dataset_type = 'synt'  #@param ['synt', 'real']

#@markdown Set up parameters for synthetic data generation:
if dataset_type == 'synt':
    # agent parameters
    agent_kw = 'basic'  #@param ['basic', 'quad_q'] 
    gen_alpha = .25 #@param
    gen_beta = 5 #@param
    forgetting_rate = 0.1 #@param
    perseveration_bias = 0.  #@param
    # environment parameters
    non_binary_reward = False #@param
    n_actions = 2 #@param
    sigma = .1  #@param
    
    # experiement parameters
    n_trials_per_session = 200  #@param
    n_sessions = 220  #@param
    
    # setup
    environment = bandits.EnvironmentBanditsDrift(sigma=sigma, n_actions=n_actions, non_binary_rewards=non_binary_reward)
    agent = dict_agents[agent_kw](gen_alpha, gen_beta, n_actions, forgetting_rate, perseveration_bias)  
  
    dataset_train, experiment_list_train = bandits.create_dataset(
        agent=agent,
        environment=environment,
        n_trials_per_session=n_trials_per_session,
        n_sessions=n_sessions)

    dataset_test, experiment_list_test = bandits.create_dataset(
        agent=agent,
        environment=environment,
        n_trials_per_session=n_trials_per_session,
        n_sessions=n_sessions)

#@markdown Set up parameters for loading rat data from Miller et al 2019.
elif dataset_type == 'real':
    # TODO: ys are not the rewards but the following choices!!!!
    raise NotImplementedError('This is not implemented yet.')

else:
  raise NotImplementedError(
      (f'dataset_type {dataset_type} not implemented. '
       'Please select from drop-down list.'))

## RNN initialization 

In [4]:
#@title Set up Hybrid RNN.

#@markdown Is the model recurrent (ie can it see the hidden state from the previous step)
use_hidden_state = False  #@param ['True', 'False']

#@markdown Is the model recurrent (ie can it see the hidden state from the previous step)
use_previous_values = False  #@param ['True', 'False']

#@markdown If True, learn a value for the forgetting term
fit_forget = False  #@param ['True', 'False']

#@markdown Learn a reward-independent term that depends on past choices.
habit_weight = "0"  #@param [0, 1]
habit_weight = float(habit_weight)

value_weight = 1.  # This is needed for it to be doing RL

rnn_rl_params = {
    's': use_hidden_state,
    'o': use_previous_values,
    'fit_forget': fit_forget,
    'forget': 0.,
    'w_h': habit_weight,
    'w_v': value_weight}
network_params = {'n_actions': n_actions, 'hidden_size': 16}

def make_hybrnn():
  # model = hybrnn.BiRNN(rl_params=rnn_rl_params, network_params=network_params)
  model = hybrnn_forget.BiRNN(rl_params=rnn_rl_params, network_params=network_params)
  return model

optimizer_rnn = optax.adam(learning_rate=1e-3)

## RNN Training

In [5]:
train = True
load = False  # only relevant if train is True --> Determines whether to load trained parameters and continue training or start new training

# params_path = 'params/params_rnn_forget_f01.pkl'
params_path = 'params/params_rnn_forget_f01_b5.pkl'

if train:
  if load:
    with open(params_path, 'rb') as f:
      rnn_params = pickle.load(f)
    opt_state = rnn_params[1]
    rnn_params = rnn_params[0]
    print('Loaded parameters.')
  else:
    opt_state = None
    rnn_params = None

  # with jax.disable_jit():
  #@title Fit the hybrid RNN
  print('Training the hybrid RNN...')
  rnn_params, opt_state, _ = rnn_utils.fit_model(
      model_fun=make_hybrnn,
      dataset=dataset_train,
      optimizer=optimizer_rnn,
      optimizer_state=opt_state,
      model_params=rnn_params,
      loss_fun='categorical',  # penalized_categorical, categorical
      convergence_thresh=1e-5,
      n_steps_max=10000,
  )

  # save trained parameters
  params = (rnn_params, opt_state)
  with open(params_path, 'wb') as f:
    pickle.dump(params, f)
    
else:
  # load trained parameters
  with open(params_path, 'rb') as f:
    rnn_params = pickle.load(f)[0]
  print('Loaded parameters.')

Training the hybrid RNN...
Step 500 of 500; Loss: 0.5324958; Time: 10.9s)
Model not yet converged - Running more steps of gradient descent. Time elapsed = 2e-05s.
Step 500 of 500; Loss: 0.5321959; Time: 10.5s)
Model not yet converged (convergence_value = 0.0005632543) - Running more steps of gradient descent. Time elapsed = 3e-05s.
Step 500 of 500; Loss: 0.5321779; Time: 10.9s)
Model not yet converged (convergence_value = 3.382327e-05) - Running more steps of gradient descent. Time elapsed = 3e-05s.
Step 500 of 500; Loss: 0.5321552; Time: 10.6s)
Model not yet converged (convergence_value = 4.267252e-05) - Running more steps of gradient descent. Time elapsed = 3e-05s.
Step 500 of 500; Loss: 0.5321294; Time: 10.7s)
Model not yet converged (convergence_value = 4.838665e-05) - Running more steps of gradient descent. Time elapsed = 0.0002s.
Step 500 of 500; Loss: 0.5321042; Time: 10.6s)
Model not yet converged (convergence_value = 4.738089e-05) - Running more steps of gradient descent. Time

# SINDY

## dataset

In [6]:
def make_sindy_data(
    dataset,
    agent: bandits.AgentQ,
    sessions=-1,
    get_choices=True,
    # keep_sessions=False,
    ):

  # Get training data for SINDy
  # put all relevant signals in x_train

  if not isinstance(sessions, Iterable) and sessions == -1:
    # use all sessions
    sessions = np.arange(len(dataset))
  else:
    # use only the specified sessions
    sessions = np.array(sessions)
    
  if get_choices:
    n_control = 2
  else:
    n_control = 1
  
  # if keep_sessions:
  #   # concatenate all sessions along the trial dimensinon -> shape: (n_trials, n_sessions, n_features)
  #   choices = np.expand_dims(np.stack([dataset[i].choices for i in sessions], axis=1), -1)
  #   rewards = np.expand_dims(np.stack([dataset[i].rewards for i in sessions], axis=1), -1)
  #   qs = np.stack([dataset[i].q for i in sessions], axis=1)
  # else:
  # concatenate all sessions along the trial dimensinon -> shape: (n_trials*n_sessions, n_features)
  # choices = np.expand_dims(np.concatenate([dataset[i].choices for i in sessions], axis=0), -1)
  # rewards = np.expand_dims(np.concatenate([dataset[i].rewards for i in sessions], axis=0), -1)
  # qs = np.concatenate([dataset[i].q for i in sessions], axis=0)
  
  choices = np.stack([dataset[i].choices for i in sessions], axis=0)
  rewards = np.stack([dataset[i].rewards for i in sessions], axis=0)
  qs = np.stack([dataset[i].q for i in sessions], axis=0)
  
  if not get_choices:
    raise NotImplementedError('Only get_choices=True is implemented right now.')
    n_sessions = qs.shape[0]
    n_trials = qs.shape[1]*qs.shape[2]
    qs_all = np.zeros((n_sessions, n_trials))
    r_all = np.zeros((n_sessions, n_trials))
    c_all = None
    # concatenate the data of all arms into one array for more training data
    index_end_last_arm = 0
    for index_arm in range(agent._n_actions):
      index = np.where(choices==index_arm)[0]
      r_all[index_end_last_arm:index_end_last_arm+len(index)] = rewards[index]
      qs_all[index_end_last_arm:index_end_last_arm+len(index)] = qs[index, index_arm].reshape(-1, 1)
      index_end_last_arm += len(index)
  else:
    choices_oh = np.zeros((len(sessions), choices.shape[1], agent._n_actions))
    for sess in sessions:
      # one-hot encode choices
      choices_oh[sess] = np.eye(agent._n_actions)[choices[sess]]
      # add choices as control parameter; no sorting required then
      # qs_all = np.concatenate([qs[sess, :, i] for i in range(agent._n_actions)], axis=1)
      # c_all = np.concatenate([choices[:, sess, i] for i in range(agent._n_actions)], axis=1)
      # r_all = np.concatenate([rewards for _ in range(agent._n_actions)], axis=1)
      # concatenate all qs values of one sessions along the trial dimension
      qs_all = np.concatenate([np.stack([np.expand_dims(qs_sess[:, i], axis=-1) for i in range(agent._n_actions)], axis=0) for qs_sess in qs], axis=0)
      c_all = np.concatenate([np.stack([c_sess[:, i] for i in range(agent._n_actions)], axis=0) for c_sess in choices_oh], axis=0)
      r_all = np.concatenate([np.stack([r_sess for _ in range(agent._n_actions)], axis=0) for r_sess in rewards], axis=0)
  
  # get observed dynamics
  x_train = qs_all
  feature_names = ['q']

  # get control
  control_names = []
  control = np.zeros((*x_train.shape[:-1], n_control))
  if get_choices:
    control[:, :, 0] = c_all
    control_names += ['c']
  control[:, :, n_control-1] = r_all
  control_names += ['r']
  
  feature_names += control_names
  
  print(f'Shape of Q-Values is: {x_train.shape}')
  print(f'Shape of control parameters is: {control.shape}')
  print(f'Feature names are: {feature_names}')
  
  # make x_train and control sequences instead of arrays
  x_train = [x_train_sess for x_train_sess in x_train]
  control = [control_sess for control_sess in control]
 
  return x_train, control, feature_names


## SiNDY Ground truth dataset fit


In [7]:
#@title Fit SINDy to actual dataset
# library = custom_lib  # custom_lib, poly_lib, solution_lib
ensemble = False
library_ensemble = False

get_choices = True
poly_order = 3
threshold = 0.01
dt = 1

# library_datasindy = ps.CustomLibrary(
#     library_functions=custom_lib_functions,
#     function_names=custom_lib_names,
#     include_bias=True,
# )

library_datasindy = ps.PolynomialLibrary(poly_order)

experiment_list_datasindy = None

if dataset_type == 'synt':
    x_train, control, feature_names = make_sindy_data(experiment_list_train, agent, get_choices=get_choices)

    datasindy = ps.SINDy(
        optimizer=ps.STLSQ(threshold=threshold, verbose=True, alpha=0.1),
        feature_library=library_datasindy,
        discrete_time=True,
        feature_names=feature_names,
    )
    datasindy.fit(x_train, t=dt, u=control, ensemble=ensemble, library_ensemble=library_ensemble, multiple_trajectories=True)
    datasindy.print()

    # set new sindy update rule and synthesize new dataset
    if not get_choices:
        update_rule_datasindy = lambda q, choice, reward: datasindy.simulate(q[choice], t=2, u=np.array(reward).reshape(1, 1))[-1]
    else:
        update_rule_datasindy = lambda q, choice, reward: datasindy.simulate(q, t=2, u=np.array([choice, reward]).reshape(1, 2))[-1]
    
    datasindyagent = AgentSindy(alpha=0, beta=gen_beta, n_actions=n_actions)
    datasindyagent.set_update_rule(update_rule_datasindy)

    # _, experiment_list_datasindy = bandits.create_dataset(datasindyagent, environment, n_trials_per_session, n_sessions)

Shape of Q-Values is: (440, 200, 1)
Shape of control parameters is: (440, 200, 2)
Feature names are: ['q', 'c', 'r']
 Iteration ... |y - Xw|^2 ...  a * |w|_2 ...      |w|_0 ... Total error: |y - Xw|^2 + a * |w|_2
         0 ... 7.0994e+00 ... 8.3619e-02 ...          9 ... 7.1830e+00
         1 ... 1.0142e+00 ... 8.7461e-02 ...          8 ... 1.1016e+00
         2 ... 3.6680e-01 ... 8.6815e-02 ...          8 ... 4.5361e-01
(q)[k+1] = 0.047 1 + 0.903 q[k] + 152632824.330 q[k] c[k] + -2321015003.621 c[k] r[k] + 0.010 q[k]^3 + -152632824.577 q[k] c[k]^2 + 4492037219.974 c[k]^2 r[k] + -2171022216.103 c[k] r[k]^2


## SiNDY fitt to RNN 

In [8]:
# This is RNN to SiNDY
#@title Synthesize a dataset using the fitted network
hybrnn_agent = AgentNetwork_VisibleState(make_hybrnn, rnn_params, habit=habit_weight==1, n_actions=n_actions)
dataset_hybrnn, experiment_list_hybrnn = bandits.create_dataset(hybrnn_agent, environment, n_trials_per_session, int(n_sessions*1e0))

In [9]:
#@title Fit SINDy to RNN data and synthesize new dataset

threshold = 0.015

x_train, control, feature_names = make_sindy_data(experiment_list_hybrnn, hybrnn_agent, get_choices=get_choices)
# x_train, control, feature_names = make_sindy_data(experiment_list_train, agent, get_choices=get_choices)
# scale q-values between 0 and 1 for more realistic dynamics
x_max = np.max(np.stack(x_train, axis=0))
x_min = np.min(np.stack(x_train, axis=0))
print(f'Dataset characteristics: max={x_max}, min={x_min}')
x_train = [(x - x_min) / (x_max - x_min) for x in x_train]

# library_rnnsindy = ps.CustomLibrary(
#     library_functions=custom_lib_functions,
#     function_names=custom_lib_names,
#     include_bias=True,
# )

library_rnnsindy = ps.PolynomialLibrary(poly_order)

rnnsindy = ps.SINDy(
    optimizer=ps.STLSQ(threshold=threshold, verbose=False, alpha=0.1),
    feature_library=library_rnnsindy,
    discrete_time=True,
    feature_names=feature_names,
)

rnnsindy.fit(x_train, t=dt, u=control, ensemble=True, library_ensemble=False, multiple_trajectories=True)
rnnsindy.print()
sparsity_index = np.sum(rnnsindy.coefficients() < threshold) / rnnsindy.coefficients().size
print(f'Sparsity index: {sparsity_index}')

if not get_choices:
    update_rule_rnnsindy = lambda q, choice, reward: rnnsindy.simulate(q[choice], t=2, u=np.array(reward).reshape(1, 1))[-1]
else:
    update_rule_rnnsindy = lambda q, choice, reward: rnnsindy.simulate(q, t=2, u=np.array([choice, reward]).reshape(1, 2))[-1]

rnnsindyagent = AgentSindy(alpha=0, beta=1, n_actions=n_actions)
rnnsindyagent.set_update_rule(update_rule_rnnsindy)

Shape of Q-Values is: (440, 200, 1)
Shape of control parameters is: (440, 200, 2)
Feature names are: ['q', 'c', 'r']
Dataset characteristics: max=3.0769786834716797, min=-2.972050905227661
(q)[k+1] = 0.362 1 + 0.214 q[k] + -19048347826.358 c[k] + 0.060 q[k]^2 + -187415.909 q[k] c[k] + 38055833457.122 c[k]^2 + 25102425487.198 c[k] r[k] + -0.186 q[k]^2 c[k] + 187416.532 q[k] c[k]^2 + -19007485631.102 c[k]^3 + -12625863212.696 c[k]^2 r[k] + -12476562274.304 c[k] r[k]^2
Sparsity index: 0.7


# Plotting?