In [None]:
import numpy as np
from typing import NamedTuple, List
from tabular.simulation_parameters import SimulationConfiguration, SimulationGeneralParameters, EnvType, SimulationParameters, EnvParameters, DoubleChainParameters, RiverSwimParameters, make_env, NArmsParameters, ForkedRiverSwimParameters
from tabular.agents.mr_nas import MRNaSParameters
from tabular.agents.random import RandomAgentParameters
from tabular.agents.ide3al import IDE3ALParameters
from tabular.agents.rf_ucrl import RFUCRLParameters
from tabular.agents.mr_psrl import MRPSRLparameters
from tabular.utils.period import ConstantPeriod
seed = 0 





CONFIG = SimulationConfiguration(
    sim_parameters=SimulationGeneralParameters(
        num_sims=100,
        num_rewards=30,
        freq_eval=500,
        discount_factor=0.9,
        delta=1e-2
    ),
    envs = [
    (
      EnvParameters(EnvType.FORKED_RIVERSWIM, ForkedRiverSwimParameters(river_length=4), 50000),
      [ 
        IDE3ALParameters(agent_parameters=None, xi=None),
        MRPSRLparameters(agent_parameters=None),
        RFUCRLParameters(
            agent_parameters=None
        ),
        MRNaSParameters(agent_parameters = None,
                        enable_averaging=True,
                        alpha=0.99,
                        beta=0.01,
                        period_computation_omega=ConstantPeriod(250)
        )
      ]
     ),
     (
      EnvParameters(EnvType.RIVERSWIM, RiverSwimParameters(num_states=10), 50000),
      [ 
        IDE3ALParameters(agent_parameters=None, xi=None),
        MRPSRLparameters(agent_parameters=None),
        RFUCRLParameters(
            agent_parameters=None
        ),
        MRNaSParameters(agent_parameters = None,
                        enable_averaging=True,
                        alpha=0.99,
                        beta=0.01,
                        period_computation_omega=ConstantPeriod(100)
        )
      ]
     ),
     (
      EnvParameters(EnvType.N_ARMS, NArmsParameters(num_arms=4,p0=1), 50000),
      [ 
        IDE3ALParameters(agent_parameters=None, xi=None),
        MRPSRLparameters(agent_parameters=None),
        RFUCRLParameters(
            agent_parameters=None
        ),
        MRNaSParameters(agent_parameters = None,
                        enable_averaging=True,
                        alpha=0.99,
                        beta=0.01,
                        period_computation_omega=ConstantPeriod(10000)
        )
      ]
     ),
     (
      EnvParameters(EnvType.DOUBLE_CHAIN, DoubleChainParameters(length=6), 50000),
      [ 
        IDE3ALParameters(agent_parameters=None, xi=None),
        MRPSRLparameters(agent_parameters=None),
        RFUCRLParameters(
            agent_parameters=None
        ),
        MRNaSParameters(agent_parameters = None,
                        enable_averaging=True,
                        alpha=0.99,
                        beta=0.01,
                        period_computation_omega=ConstantPeriod(100)
        )
      ]
     )
    ],
)

np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
np.random.seed(seed)
env = make_env(env = p.env_parameters)
print(np.random.uniform())

start_time = time.time()
s = env.reset()
discount_factor = p.sim_parameters.discount_factor
agent = make_agent(agent_parameters)

results = []

R_basis = env.generate_boundary_rewards()
R_random = env.generate_random_rewards(N=p.sim_parameters.num_rewards)
R = np.vstack([R_basis, R_random])

for t in range(p.env_parameters.horizon):
    a = agent.forward(s, t)
    next_state, _ = env.step(a)
    exp = Experience(s, a, next_state)
    reset = agent.backward(exp, t)

    s = env.reset() if reset else next_state

    # Evaluate the agent
    if (t +1) % p.sim_parameters.freq_eval == 0:
        
        V_res, pi_res, Q_res = env.eval_transition(
            Phat=agent.empirical_transition(), R=R, discount_factor=discount_factor)
        print(f'[{t}] {agent.U_t} {agent.Z_t} - {agent.beta} -  {V_res.mean()} - {pi_res.mean()} - {agent.state_action_visits}')
        print('--------')
        
        # Append results to be saved
        results.append(
                Results(step=t, omega=deepcopy(agent.omega), total_state_visits=deepcopy(agent.total_state_visits),
                        last_visit= deepcopy(agent.last_visit), exp_visits=deepcopy(agent.exp_visits), V_res=V_res,
                        Q_res=Q_res, pi_res=pi_res, elapsed_time=time.time() - start_time))
return results