In [1]:
# ToDo:
# print statements for the class require verbose=True in the class
# Sklearn estimator
# SINDY class is not quite right! 

# libraries

In [1]:
#@title Import libraries
import sys
import os
import warnings
from typing import Callable, Tuple, Iterable, Union

import matplotlib.pyplot as plt
from sympy.parsing.sympy_parser import parse_expr
import numpy as np
import pandas as pd
import scipy.stats as st
import pickle

# deepmind related libraries
import haiku as hk
import jax
import jax.numpy as jnp
import optax

import pysindy as ps

warnings.filterwarnings("ignore")

# RL libraries
sys.path.append('resources')  # add source directoy to path
from resources import bandits, disrnn, hybrnn, hybrnn_forget, plotting, rat_data, rnn_utils

# Agent

In [2]:
#@title make update rule of Q-/SINDyNetwork-Agents adjustable and make values of RNN-Agent visible

class AgentQuadQ(bandits.AgentQ):
  
  def __init__(
      self,
      alpha: float=0.2,
      beta: float=3.,
      n_actions: int=2,
      forgetting_rate: float=0.,
      perseveration_bias: float=0.,
      ):
    super().__init__(alpha, beta, n_actions, forgetting_rate, perseveration_bias)
  
  def update(self,
            choice: int,
            reward: float):
    """Update the agent after one step of the task.

    Args:
      choice: The choice made by the agent. 0 or 1
      reward: The reward received by the agent. 0 or 1
    """
    
    # Decay q-values toward the initial value.
    self._q = (1-self._forgetting_rate) * self._q + self._forgetting_rate * self._q_init

    # Update chosen q for chosen action with observed reward.
    self._q[choice] = self._q[choice] - self._alpha * self._q[choice]**2 + self._alpha * reward


class AgentSindy(bandits.AgentQ):

  def __init__(
      self,
      alpha: float=0.2,
      beta: float=3.,
      n_actions: int=2,
      forgetting_rate: float=0.,
      perservation_bias: float=0.,):
    super().__init__(alpha, beta, n_actions, forgetting_rate, perservation_bias)

    self._update_rule = lambda q, choice, reward: (1 - self._alpha) * q[choice] + self._alpha * reward
    self._update_rule_formula = None

  def set_update_rule(self, update_rule: callable, update_rule_formula: str=None):
    self._update_rule=update_rule
    self._update_rule_formula=update_rule_formula

  @property
  def update_rule(self):
    if self._update_rule_formula is not None:
      return self._update_rule_formula
    else:
      return f'{self._update_rule}'

  def update(self, choice: int, reward: int):

    for c in range(self._n_actions):
      self._q[c] = self._update_rule(self._q[c], int(c==choice), reward)


class AgentNetwork_VisibleState(bandits.AgentNetwork):

  def __init__(self,
               make_network: Callable[[], hk.RNNCore],
               params: hk.Params,
               n_actions: int = 2,
               state_to_numpy: bool = False,
               habit=False):
    super().__init__(make_network=make_network, params=params, n_actions=n_actions, state_to_numpy=state_to_numpy)
    self.habit = habit

  @property
  def q(self):
    if self.habit:
      return self._state[2], self._state[3]
    else:
      return self._state[3].reshape(-1)


In [3]:
dict_agents = {
    'basic': lambda alpha, beta, n_actions, forgetting_rate, perseveration_bias: bandits.AgentQ(alpha, beta, n_actions, forgetting_rate, perseveration_bias),
    'quad_q': lambda alpha, beta, n_actions, forgetting_rate, perseveration_bias: AgentQuadQ(alpha, beta, n_actions, forgetting_rate, perseveration_bias)
}

# Dataset

In [4]:
class DatasetCreator:
    def __init__(self, dataset_type, agent_dict):
        self.dataset_type = dataset_type
        self.agent_dict = agent_dict

    def create_dataset(self):
        if self.dataset_type == 'synt':
            self.setup_synthetic_data()
            self.dataset_train, self.experiment_list_train = self.generate_data()
            self.dataset_test, self.experiment_list_test = self.generate_data()
        
        elif self.dataset_type == 'real':
            raise NotImplementedError('Real data setup not implemented yet.')
        
        else:
            raise NotImplementedError(f'dataset_type {self.dataset_type} not implemented. Please select from drop-down list.')

    def setup_synthetic_data(self):
        # Define agent parameters
        agent_kw = 'basic'  # ['basic', 'quad_q']
        gen_alpha = 0.25
        gen_beta = 5
        forgetting_rate = 0.1
        perseveration_bias = 0.0
        
        # Define environment parameters
        non_binary_reward = False
        self.n_actions = 2
        sigma = 0.1
        
        # Define experiment parameters
        self.n_trials_per_session = 200
        self.n_sessions = 220
        
        # Setup environment and agent
        self.environment = bandits.EnvironmentBanditsDrift(sigma=sigma, n_actions=self.n_actions, non_binary_rewards=non_binary_reward)
        self.agent = self.agent_dict[agent_kw](gen_alpha, gen_beta, self.n_actions, forgetting_rate, perseveration_bias)
    
    def setup_real_data(self):
        pass


    def generate_data(self):
        return bandits.create_dataset(
            agent=self.agent,
            environment=self.environment,
            n_trials_per_session=self.n_trials_per_session,
            n_sessions=self.n_sessions
        )


In [5]:
data = DatasetCreator(dataset_type='synt', agent_dict=dict_agents)
data.create_dataset()
n_actions = data.n_actions
agent = data.agent
environment = data.environment
n_trials_per_session = data.n_trials_per_session
n_sessions = data.n_sessions

dataset_train, experiment_list_train = data.dataset_train, data.experiment_list_train
dataset_test, experiment_list_test = data.dataset_test, data.experiment_list_test

# RNN

In [6]:
class HybridRNN:
    def __init__(self, use_hidden_state=False, use_previous_values=False, fit_forget=False, habit_weight=0.0, value_weight=1.0, n_actions=2, hidden_size=16):
        self.use_hidden_state = use_hidden_state
        self.use_previous_values = use_previous_values
        self.fit_forget = fit_forget
        self.habit_weight = float(habit_weight)
        self.value_weight = value_weight
        self.n_actions = n_actions
        self.hidden_size = hidden_size
        
        # Set up the RNN parameters
        self.rnn_rl_params = {
            's': self.use_hidden_state,
            'o': self.use_previous_values,
            'fit_forget': self.fit_forget,
            'forget': 0.,
            'w_h': self.habit_weight,
            'w_v': self.value_weight
        }
        self.network_params = {
            'n_actions': self.n_actions,
            'hidden_size': self.hidden_size
        }
        
    def make_hybrnn(self):
        return hybrnn_forget.BiRNN(rl_params=self.rnn_rl_params, network_params=self.network_params)
    
    def get_make_hybrnn(self):
        return self.make_hybrnn

In [7]:
habit_weight=0.0 # used in Sindy RNN
rnn = HybridRNN(habit_weight)
optimizer_rnn = optax.adam(learning_rate=1e-3)

In [8]:
class RNNTrainer:
    def __init__(self, params_path, train=True, load=False, loss_function='categorical'):
        self.params_path = params_path
        self.train = train
        self.load = load
        self.loss_function = loss_function
        self.optimizer = optax.adam(learning_rate=1e-3)
        self.rnn_params = None
        self.opt_state = None

    def load_parameters(self):
        try:
            with open(self.params_path, 'rb') as f:
                saved_params = pickle.load(f)
            self.rnn_params, self.opt_state = saved_params[0], saved_params[1]
            print('Loaded parameters.')
        except FileNotFoundError:
            print('No parameters found to load.')

    def save_parameters(self):
        with open(self.params_path, 'wb') as f:
            pickle.dump((self.rnn_params, self.opt_state), f)
        print('Parameters saved.')

    def train_model(self, dataset_train, n_steps_max=10000, convergence_thresh=1e-5):
        if self.train:
            if self.load:
                self.load_parameters()
            else:
                self.rnn_params, self.opt_state = None, None

            print('Training the hybrid RNN...')
            self.rnn_params, self.opt_state, _ = rnn_utils.fit_model(
                model_fun=rnn.make_hybrnn,
                dataset=dataset_train,
                optimizer=self.optimizer,
                optimizer_state=self.opt_state,
                model_params=self.rnn_params,
                loss_fun=self.loss_function,
                convergence_thresh=convergence_thresh,
                n_steps_max=n_steps_max
            )

            self.save_parameters()

    def execute(self, dataset_train):
        if self.train:
            self.train_model(dataset_train)
        else:
            self.load_parameters()
    
    def get_rnn_params(self):
        return self.rnn_params

In [9]:
params_path = 'params/params_rnn_forget_f01_b5.pkl'
rnn_train = RNNTrainer(params_path=params_path, train=True, load=False)
rnn_train.execute(dataset_train)
rnn_params = rnn_train.get_rnn_params()

Training the hybrid RNN...
Step 500 of 500; Loss: 0.5304508; Time: 11.1s)
Model not yet converged - Running more steps of gradient descent. Time elapsed = 2e-05s.
Step 500 of 500; Loss: 0.5302345; Time: 14.0s)
Model not yet converged (convergence_value = 0.0004077763) - Running more steps of gradient descent. Time elapsed = 3e-05s.
Step 500 of 500; Loss: 0.5301611; Time: 10.5s)
Model not yet converged (convergence_value = 0.000138379) - Running more steps of gradient descent. Time elapsed = 3e-05s.
Step 500 of 500; Loss: 0.5300407; Time: 10.5s)
Model not yet converged (convergence_value = 0.0002271034) - Running more steps of gradient descent. Time elapsed = 3e-05s.
Step 500 of 500; Loss: 0.5297045; Time: 11.7s)
Model not yet converged (convergence_value = 0.0006343471) - Running more steps of gradient descent. Time elapsed = 3e-05s.
Step 500 of 500; Loss: 0.5247840; Time: 11.6s)
Model not yet converged (convergence_value = 0.009289108) - Running more steps of gradient descent. Time el

# Sindy

In [10]:
def make_sindy_data(
    dataset,
    agent: bandits.AgentQ,
    sessions=-1,
    get_choices=True,
    # keep_sessions=False,
    ):

  # Get training data for SINDy
  # put all relevant signals in x_train

  if not isinstance(sessions, Iterable) and sessions == -1:
    # use all sessions
    sessions = np.arange(len(dataset))
  else:
    # use only the specified sessions
    sessions = np.array(sessions)
    
  if get_choices:
    n_control = 2
  else:
    n_control = 1
  
  # if keep_sessions:
  #   # concatenate all sessions along the trial dimensinon -> shape: (n_trials, n_sessions, n_features)
  #   choices = np.expand_dims(np.stack([dataset[i].choices for i in sessions], axis=1), -1)
  #   rewards = np.expand_dims(np.stack([dataset[i].rewards for i in sessions], axis=1), -1)
  #   qs = np.stack([dataset[i].q for i in sessions], axis=1)
  # else:
  # concatenate all sessions along the trial dimensinon -> shape: (n_trials*n_sessions, n_features)
  # choices = np.expand_dims(np.concatenate([dataset[i].choices for i in sessions], axis=0), -1)
  # rewards = np.expand_dims(np.concatenate([dataset[i].rewards for i in sessions], axis=0), -1)
  # qs = np.concatenate([dataset[i].q for i in sessions], axis=0)
  
  choices = np.stack([dataset[i].choices for i in sessions], axis=0)
  rewards = np.stack([dataset[i].rewards for i in sessions], axis=0)
  qs = np.stack([dataset[i].q for i in sessions], axis=0)
  
  if not get_choices:
    raise NotImplementedError('Only get_choices=True is implemented right now.')
    n_sessions = qs.shape[0]
    n_trials = qs.shape[1]*qs.shape[2]
    qs_all = np.zeros((n_sessions, n_trials))
    r_all = np.zeros((n_sessions, n_trials))
    c_all = None
    # concatenate the data of all arms into one array for more training data
    index_end_last_arm = 0
    for index_arm in range(agent._n_actions):
      index = np.where(choices==index_arm)[0]
      r_all[index_end_last_arm:index_end_last_arm+len(index)] = rewards[index]
      qs_all[index_end_last_arm:index_end_last_arm+len(index)] = qs[index, index_arm].reshape(-1, 1)
      index_end_last_arm += len(index)
  else:
    choices_oh = np.zeros((len(sessions), choices.shape[1], agent._n_actions))
    for sess in sessions:
      # one-hot encode choices
      choices_oh[sess] = np.eye(agent._n_actions)[choices[sess]]
      # add choices as control parameter; no sorting required then
      # qs_all = np.concatenate([qs[sess, :, i] for i in range(agent._n_actions)], axis=1)
      # c_all = np.concatenate([choices[:, sess, i] for i in range(agent._n_actions)], axis=1)
      # r_all = np.concatenate([rewards for _ in range(agent._n_actions)], axis=1)
      # concatenate all qs values of one sessions along the trial dimension
      qs_all = np.concatenate([np.stack([np.expand_dims(qs_sess[:, i], axis=-1) for i in range(agent._n_actions)], axis=0) for qs_sess in qs], axis=0)
      c_all = np.concatenate([np.stack([c_sess[:, i] for i in range(agent._n_actions)], axis=0) for c_sess in choices_oh], axis=0)
      r_all = np.concatenate([np.stack([r_sess for _ in range(agent._n_actions)], axis=0) for r_sess in rewards], axis=0)
  
  # get observed dynamics
  x_train = qs_all
  feature_names = ['q']

  # get control
  control_names = []
  control = np.zeros((*x_train.shape[:-1], n_control))
  if get_choices:
    control[:, :, 0] = c_all
    control_names += ['c']
  control[:, :, n_control-1] = r_all
  control_names += ['r']
  
  feature_names += control_names
  
  print(f'Shape of Q-Values is: {x_train.shape}')
  print(f'Shape of control parameters is: {control.shape}')
  print(f'Feature names are: {feature_names}')
  
  # make x_train and control sequences instead of arrays
  x_train = [x_train_sess for x_train_sess in x_train]
  control = [control_sess for control_sess in control]
 
  return x_train, control, feature_names


In [11]:
# This is not needed anymore! Check SINDY class!!!

class SINDyTrainerGroundTruth:
    def __init__(self, library, dataset_type='synt', threshold=0.01, dt=1, ensemble=False, library_ensemble=False, get_choices=True):
        self.library = library
        self.dataset_type = dataset_type
        self.threshold = threshold
        self.dt = dt
        self.ensemble = ensemble
        self.library_ensemble = library_ensemble
        self.get_choices = get_choices

    def fit(self, experiment_list_train, agent, custom_lib_functions=None, custom_lib_names=None, poly_order=3):
        if self.library == 'custom_lib':
            library_datasindy = ps.CustomLibrary(
                library_functions=custom_lib_functions,
                function_names=custom_lib_names,
                include_bias=True
            )
        elif self.library == 'poly_lib':
            library_datasindy = ps.PolynomialLibrary(poly_order)
        else:
            raise ValueError("Unsupported library type")

        if self.dataset_type == 'synt':
            x_train, control, feature_names = make_sindy_data(experiment_list_train, agent, get_choices=self.get_choices)

            optimizer = ps.STLSQ(threshold=self.threshold, verbose=True, alpha=0.1)
            datasindy = ps.SINDy(
                optimizer=optimizer,
                feature_library=library_datasindy,
                discrete_time=True,
                feature_names=feature_names
            )
            datasindy.fit(x_train, t=self.dt, u=control, ensemble=self.ensemble, library_ensemble=self.library_ensemble, multiple_trajectories=True)
            datasindy.print()

            return datasindy
        else:
            raise ValueError("Unsupported dataset type")
        
    def update_rule(self, datasindy, get_choices):
        if not get_choices:
            return lambda q, choice, reward: datasindy.simulate(q[choice], t=2, u=np.array(reward).reshape(1, 1))[-1]
        else:
            return lambda q, choice, reward: datasindy.simulate(q, t=2, u=np.array([choice, reward]).reshape(1, 2))[-1]



In [12]:
trainer = SINDyTrainerGroundTruth(library='poly_lib')
rnnsindyagent = AgentSindy(alpha=0, beta=1, n_actions=2)
rnnsindyagent.set_update_rule(trainer.update_rule)
trainer.fit(experiment_list_train, agent=agent)


Shape of Q-Values is: (440, 200, 1)
Shape of control parameters is: (440, 200, 2)
Feature names are: ['q', 'c', 'r']
 Iteration ... |y - Xw|^2 ...  a * |w|_2 ...      |w|_0 ... Total error: |y - Xw|^2 + a * |w|_2
         0 ... 7.1269e+00 ... 8.3683e-02 ...          9 ... 7.2106e+00
         1 ... 1.3575e+00 ... 8.7696e-02 ...          8 ... 1.4452e+00
         2 ... 3.6690e-01 ... 8.6887e-02 ...          8 ... 4.5379e-01
(q)[k+1] = 0.047 1 + 0.903 q[k] + -107786162.664 q[k] c[k] + 753016385.140 c[k] r[k] + 0.010 q[k]^3 + 107786162.416 q[k] c[k]^2 + -1535317856.596 c[k]^2 r[k] + 782301471.706 c[k] r[k]^2


In [13]:
# RNN agent for Sindy

hybrnn_agent = AgentNetwork_VisibleState(rnn.get_make_hybrnn(), rnn_params, habit=habit_weight==1, n_actions=n_actions)

dataset_hybrnn, experiment_list_hybrnn = bandits.create_dataset(hybrnn_agent, environment, n_trials_per_session, int(n_sessions*1e0))

In [19]:
# This is not needed anymore! Check SINDY class!!!
class SINDyRNN:
    def __init__(self, experiment_list, hybrnn_agent, get_choices, poly_order, dt, n_actions, custom_lib_functions=None, custom_lib_names=None, threshold=0.015):
        self.experiment_list = experiment_list
        self.hybrnn_agent = hybrnn_agent
        self.get_choices = get_choices
        self.poly_order = poly_order
        self.dt = dt
        self.n_actions = n_actions
        self.threshold = threshold
        self.custom_lib_functions = custom_lib_functions
        self.custom_lib_names = custom_lib_names

    def make_sindy_data(self):
        # Assuming make_sindy_data is a function defined elsewhere
        return make_sindy_data(self.experiment_list, self.hybrnn_agent, get_choices=self.get_choices)

    def fit(self):
        x_train, control, feature_names = self.make_sindy_data()
        self.feature_names = feature_names

        # Scale q-values between 0 and 1 for more realistic dynamics
        self.x_max = np.max(np.stack(x_train, axis=0))
        self.x_min = np.min(np.stack(x_train, axis=0))
        print(f'Dataset characteristics: max={self.x_max}, min={self.x_min}')
        x_train = [(x - self.x_min) / (self.x_max - self.x_min) for x in x_train]

        if self.custom_lib_functions and self.custom_lib_names:
            library_rnnsindy = ps.CustomLibrary(
                library_functions=self.custom_lib_functions,
                function_names=self.custom_lib_names,
                include_bias=True,
            )
        else:
            library_rnnsindy = ps.PolynomialLibrary(self.poly_order)

        self.rnnsindy = ps.SINDy(
            optimizer=ps.STLSQ(threshold=self.threshold, verbose=False, alpha=0.1),
            feature_library=library_rnnsindy,
            discrete_time=True,
            feature_names=self.feature_names,
        )

        self.rnnsindy.fit(x_train, t=self.dt, u=control, ensemble=True, library_ensemble=False, multiple_trajectories=True)
        self.rnnsindy.print()
        self.sparsity_index = np.sum(self.rnnsindy.coefficients() < self.threshold) / self.rnnsindy.coefficients().size
        

        if not self.get_choices:
            update_rule_rnnsindy = lambda q, choice, reward: self.rnnsindy.simulate(q[choice], t=2, u=np.array(reward).reshape(1, 1))[-1]
        else:
            update_rule_rnnsindy = lambda q, choice, reward: self.rnnsindy.simulate(q, t=2, u=np.array([choice, reward]).reshape(1, 2))[-1]

        self.rnnsindyagent = AgentSindy(alpha=0, beta=1, n_actions=self.n_actions)
        self.rnnsindyagent.set_update_rule(update_rule_rnnsindy)

    def get_rnnsindyagent(self):
        return self.rnnsindyagent

    def get_sparsity_index(self):
        return self.sparsity_index


In [20]:
get_choices = True
poly_order = 3
threshold = 0.01
dt = 1


sindy_rnn = SINDyRNN(
    experiment_list=experiment_list_hybrnn,
    hybrnn_agent=hybrnn_agent,
    get_choices=get_choices,
    poly_order=poly_order,
    dt=dt,
    n_actions=n_actions
)

sindy_rnn.fit()
rnnsindyagent = sindy_rnn.get_rnnsindyagent()
sparsity_index = sindy_rnn.get_sparsity_index()

print(f'Sparsity index: {sparsity_index}')

Shape of Q-Values is: (440, 200, 1)
Shape of control parameters is: (440, 200, 2)
Feature names are: ['q', 'c', 'r']
Dataset characteristics: max=2.6803841590881348, min=-1.2137874364852905
(q)[k+1] = 0.043 1 + 0.922 q[k] + -1795847285.974 c[k] + -0.022 q[k]^2 + -201448.414 q[k] c[k] + -6583929.778 q[k] r[k] + -1797956774.496 c[k]^2 + -2641056717.121 c[k] r[k] + 0.013 q[k]^3 + 0.004 q[k]^2 c[k] + -0.001 q[k]^2 r[k] + 201448.171 q[k] c[k]^2 + 6583929.779 q[k] r[k]^2 + 3593804060.423 c[k]^3 + 5205082427.721 c[k]^2 r[k] + -2564025710.275 c[k] r[k]^2
Sparsity index: 0.7


# Sindy Mixture

In [49]:
class SINDY:
    def __init__(self, experiment_list, agent, library_config=None, scaling=True,
                 get_choices=True, dt=1, threshold=0.01, n_actions=None,
                 ensemble=False, library_ensemble=False):
        self.experiment_list = experiment_list
        self.agent = agent
        
        # Ensure library_config is a dictionary
        if library_config is None:
            self.library_config = {'type': 'poly_lib', 'poly_order': 3}
        elif isinstance(library_config, dict):
            self.library_config = library_config
        else:
            raise ValueError("library_config must be a dictionary")
        
        self.scaling = scaling
        self.get_choices = get_choices
        self.poly_order = poly_order
        self.dt = dt
        self.threshold = threshold
        self.n_actions = n_actions
        self.ensemble = ensemble
        self.library_ensemble = library_ensemble

    def make_sindy_data(self):
        # Assuming make_sindy_data is a function defined elsewhere
        return make_sindy_data(self.experiment_list, self.agent, get_choices=self.get_choices)

    def fit(self):
        x_train, control, feature_names = self.make_sindy_data()
        self.feature_names = feature_names

        # Library selection
        library_type = self.library_config.get('type')
        if library_type == 'custom_lib':
            custom_lib_functions = self.library_config.get('custom_lib_functions')
            custom_lib_names = self.library_config.get('custom_lib_names')
            library_sindy = ps.CustomLibrary(
                library_functions=custom_lib_functions,
                function_names=custom_lib_names,
                include_bias=True
            )
        
        elif library_type == 'poly_lib':
            poly_order = self.library_config.get('poly_order')
            library_sindy = ps.PolynomialLibrary(poly_order)
        
        else:
            raise ValueError("Unsupported library type")

        # Data scaling for RNN SINDy
        if self.scaling == True:
            self.x_max = np.max(np.stack(x_train, axis=0))
            self.x_min = np.min(np.stack(x_train, axis=0))
            x_train = [(x - self.x_min) / (self.x_max - self.x_min) for x in x_train]

        optimizer = ps.STLSQ(threshold=self.threshold, verbose=True, alpha=0.1)
        self.model = ps.SINDy(
            optimizer=optimizer,
            feature_library=library_sindy,
            discrete_time=True,
            feature_names=feature_names
        )
        self.model.fit(x_train, t=self.dt, u=control, ensemble=self.ensemble, library_ensemble=self.library_ensemble, multiple_trajectories=True)
        self.model.print()

        if not self.get_choices:
            update_rule_rnnsindy = lambda q, choice, reward: self.rnnsindy.simulate(q[choice], t=2, u=np.array(reward).reshape(1, 1))[-1]
        else:
            update_rule_rnnsindy = lambda q, choice, reward: self.rnnsindy.simulate(q, t=2, u=np.array([choice, reward]).reshape(1, 2))[-1]
            
        self.rnnsindyagent = AgentSindy(alpha=0, beta=1, n_actions=self.n_actions)
        self.rnnsindyagent.set_update_rule(update_rule_rnnsindy) 

    def update_rule(self, q, choice, reward):
        if not self.get_choices:
            return self.model.simulate(q[choice], t=2, u=np.array(reward).reshape(1, 1))[-1]
        else:
            return self.model.simulate(q, t=2, u=np.array([choice, reward]).reshape(1, 2))[-1]

    def get_rnnsindyagent(self):
        return self.rnnsindyagent

    def get_sparsity_index(self):
        if self.model:
            return np.sum(self.model.coefficients() < self.threshold) / self.model.coefficients().size
        return None


## GT

In [50]:
# library_config = {
#     'type': 'custom_lib',
#     'custom_lib_functions':
#     'custom_lib_names': 
# }

lib_config = {
    'type': 'poly_lib',
    'poly_order': 3
}

trainer = SINDY(get_choices=True, dt=1, library_config=lib_config, n_actions=2, experiment_list=experiment_list_train, agent=agent, scaling=False)
rnnsindyagent = AgentSindy(alpha=0, beta=1, n_actions=2)
rnnsindyagent.set_update_rule(trainer.update_rule)
trainer.fit()

Shape of Q-Values is: (440, 200, 1)
Shape of control parameters is: (440, 200, 2)
Feature names are: ['q', 'c', 'r']
 Iteration ... |y - Xw|^2 ...  a * |w|_2 ...      |w|_0 ... Total error: |y - Xw|^2 + a * |w|_2
         0 ... 7.1269e+00 ... 8.3683e-02 ...          9 ... 7.2106e+00
         1 ... 1.3575e+00 ... 8.7696e-02 ...          8 ... 1.4452e+00
         2 ... 3.6690e-01 ... 8.6887e-02 ...          8 ... 4.5379e-01
(q)[k+1] = 0.047 1 + 0.903 q[k] + -107786162.664 q[k] c[k] + 753016385.140 c[k] r[k] + 0.010 q[k]^3 + 107786162.416 q[k] c[k]^2 + -1535317856.596 c[k]^2 r[k] + 782301471.706 c[k] r[k]^2


## SINDYRNN

In [51]:
get_choices = True
poly_order = 3
threshold = 0.01
dt = 1
n_actions = 2

sindy_rnn = SINDY(
    experiment_list=experiment_list_hybrnn,
    agent=hybrnn_agent,
    get_choices=get_choices,
    dt=dt,
    n_actions=n_actions,
    scaling=True

)

sindy_rnn.fit()
rnnsindyagent = sindy_rnn.get_rnnsindyagent()
sparsity_index = sindy_rnn.get_sparsity_index()

print(f'Sparsity index: {sparsity_index}')

Shape of Q-Values is: (440, 200, 1)
Shape of control parameters is: (440, 200, 2)
Feature names are: ['q', 'c', 'r']
 Iteration ... |y - Xw|^2 ...  a * |w|_2 ...      |w|_0 ... Total error: |y - Xw|^2 + a * |w|_2
         0 ... 2.8895e+00 ... 9.1328e-02 ...         11 ... 2.9809e+00
         1 ... 7.6341e-01 ... 8.9726e-02 ...         10 ... 8.5314e-01
         2 ... 7.6109e-01 ... 8.9789e-02 ...         10 ... 8.5088e-01
(q)[k+1] = 0.044 1 + 0.912 q[k] + -0.016 c[k] + -0.119 q[k] c[k] + -0.016 c[k]^2 + 0.108 c[k] r[k] + -0.119 q[k] c[k]^2 + -0.016 c[k]^3 + 0.108 c[k]^2 r[k] + 0.108 c[k] r[k]^2
Sparsity index: 0.75
