In [1]:
import numpy as np
import time
import cvxpy as cp
from copy import deepcopy
from typing import NamedTuple, List
from simulation_parameters import SimulationConfiguration, SimulationGeneralParameters, \
    EnvType, SimulationParameters, EnvParameters
from multireward_ope.tabular.envs.env import make_env, EnvParameters, EnvType, RiverSwimParameters
from multireward_ope.tabular.agents.make_agent import make_agent, \
    NoisyPolicyParameters, MRNaSPEParameters, AgentParameters, PolicyNoiseType
from multireward_ope.tabular.agents.base_agent import Experience
from multireward_ope.tabular.reward_set import RewardSet, RewardSetCircle
from multireward_ope.tabular.policy import Policy, PolicyFactory
from multireward_ope.tabular.utils import policy_evaluation
seed = 0 
discount_factor = 0.9
HORIZON = 20000
NUM_REWARDS = 30
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
np.random.seed(seed)

env_params = EnvParameters(
    env_type=EnvType.RIVERSWIM,
    parameters=RiverSwimParameters(num_states=5))

env = make_env(env = env_params)

reward_set = RewardSetCircle(env.dim_state, env.dim_action, 
        RewardSetCircle.RewardSetCircleConfig(np.zeros(env.dim_state), radius=1, p=2))


policy_to_eval = PolicyFactory.random(env.dim_state, env.dim_action)

eval_rewards = reward_set.sample(NUM_REWARDS)
rewards = np.zeros((NUM_REWARDS, env.dim_state, env.dim_action))
values = np.zeros((NUM_REWARDS, env.dim_state))
for i in range(NUM_REWARDS):
    rewards[i, np.arange(env.dim_state), policy_to_eval] = eval_rewards[i]

    values[i] = env.policy_evaluation(rewards[i], discount_factor, policy_to_eval)


In [2]:

agent_params = AgentParameters(
        dim_state_space=env.dim_state,
        dim_action_space=env.dim_action,
        discount_factor=discount_factor,
        horizon=HORIZON,
        frequency_evaluation=50,
        delta=1e-2,
        epsilon=0.1,
        solver_type=cp.GUROBI)

agent = make_agent(NoisyPolicyParameters(
        agent_parameters=agent_params,
        noise_type=PolicyNoiseType.VISITATION,
        noise_parameter=0.3
    ),
    policy=policy_to_eval,
    reward_set=reward_set)


start_time = time.time()
s = env.reset()
results = []

for t in range(HORIZON):
    a = agent.forward(s, t)
    next_state, _ = env.step(a)
    exp = Experience(s, a, next_state)
    reset = agent.backward(exp, t)

    s = env.reset() if reset else next_state

    # Evaluate the agent
    if (t +1) % agent_params.frequency_evaluation == 0:
        hat_values = np.array([
            policy_evaluation(discount_factor, agent.empirical_transition(), R=rewards[r], policy=policy_to_eval)
            for r in range(NUM_REWARDS)])

        err = np.linalg.norm(values - hat_values, ord=np.inf, axis=-1)
        # {agent.U_t} {agent.Z_t} - {agent.beta} {agent.state_action_visits}
        print(f'[{t}]  {err.mean()} - {err.std()}')
        print('--------')
        

#         # Append results to be saved
#         results.append(
#                 Results(step=t, omega=deepcopy(agent.omega), total_state_visits=deepcopy(agent.total_state_visits),
#                         last_visit= deepcopy(agent.last_visit), exp_visits=deepcopy(agent.exp_visits), V_res=V_res,
#                         Q_res=Q_res, pi_res=pi_res, elapsed_time=time.time() - start_time))
# return results

[49]  1.0691895974493597 - 0.5074924132994711
--------
[99]  0.889301195277774 - 0.49389717009321277
--------
[149]  0.9419105249729718 - 0.5511047306980468
--------
[199]  1.0285127584211948 - 0.6244874009453112
--------
[249]  1.0299006439917853 - 0.6207596828363933
--------
[299]  0.9407988097891159 - 0.5730406837001083
--------
[349]  0.9291963235985862 - 0.5753053266805005
--------
[399]  0.9496504086665999 - 0.5876367861196212
--------
[449]  0.9530486671242229 - 0.5876694458469474
--------
[499]  0.8311098823903535 - 0.5270097046271968
--------
[549]  0.8044788800163617 - 0.5117035504956894
--------
[599]  0.6921111031847181 - 0.4400952595938651
--------
[649]  0.7109285596106883 - 0.45273670047970693
--------
[699]  0.7166238816437714 - 0.4564464782694844
--------
[749]  0.7161580732575245 - 0.45609526057761063
--------
[799]  0.6236674034817862 - 0.3966193858042196
--------
[849]  0.5816747002178141 - 0.3701907843958448
--------
[899]  0.5859750229990185 - 0.3744927706632066
-

KeyboardInterrupt: 