In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%cd ../..

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from functools import partial
from jax_tqdm import scan_tqdm

from algorithms.utils import restore_state

import numpy as np

In [None]:
from ernestogym.envs.single_agent.env import MicroGridEnv
from ernestogym.envs.single_agent.env_trading_soc import MicroGridEnvSocAction

In [None]:
def my_env_creator(params, battery_type, env_type='normal'):
    if env_type == 'normal':
        env = MicroGridEnv(params, battery_type)
    elif env_type == 'soc_action':
        env = MicroGridEnvSocAction(params, battery_type)
    else:
        raise ValueError('Unknown env_type')
    return env, env.params

In [None]:
from ernestogym.envs.single_agent.utils import parameter_generator

# Testing

In [None]:
@partial(jax.jit, static_argnums=(0, 2, 3))
def test(env: MicroGridEnv, env_params, network, num_iter, rng):

    rng, _rng = jax.random.split(rng)

    obsv, env_state = env.reset(_rng, env_params)

    @scan_tqdm(num_iter, print_rate=num_iter // 100)
    def _env_step(runner_state, unused):
        obsv, env_state, rng = runner_state

        pi, _ = network(obsv)

        #deterministic action
        action = pi.mode()

        rng, _rng = jax.random.split(rng)
        obsv, env_state, reward, done, info = env.step(_rng, env_state, action, env_params)

        runner_state = (obsv, env_state, rng)
        return runner_state, (info, action)

    runner_state = (obsv, env_state, rng)

    runner_state, info = jax.lax.scan(_env_step, runner_state, jnp.arange(num_iter))

    return info

In [None]:
pack_options = "ernestogym/ernesto/data/battery/pack.yaml"
ecm = "ernestogym/ernesto/data/battery/models/electrical/thevenin_pack.yaml"
# ecm = "ernestogym/ernesto/data/battery/models/electrical/thevenin_fading_pack.yaml"
r2c = "ernestogym/ernesto/data/battery/models/thermal/r2c_thermal_pack.yaml"
# bolun = "ernestogym/ernesto/data/battery/models/aging/bolun_pack.yaml"
bolun = "ernestogym/ernesto/data/battery/models/aging/bolun_pack.yaml"
world = "ernestogym/envs/single_agent/ijcnn_deg_test.yaml"
# world = "ernestogym/envs/single_agent/world_fading.yaml"

# battery_type = 'fading'
# battery_type = 'degrading'
battery_type = 'degrading_dropflow'

params = parameter_generator(
    input_var='current',
    battery_options=pack_options,
    electrical_model=ecm,
    thermal_model=r2c,
    aging_model=bolun,
    world_options=world,
    use_reward_normalization=True

)

params['step']

In [None]:
directory = '/media/samuele/Disco/PycharmProjectsUbuntu/MARL-CER/trained_agents/20250225_103553_lr_0.0005_tot_timesteps_3504000_anneal_rl_True_normal'

network, config, params_training = restore_state(directory)
env_type = config['ENV_TYPE']

In [None]:
env, env_params = my_env_creator(params, battery_type, env_type=env_type)
env_params = env.eval(env_params)

In [None]:
num_iter = 8760 * 8
info, actions = test(env, env_params, network, num_iter, jax.random.PRNGKey(51))


In [None]:
jax.tree.map(lambda val: val.shape, info)

In [None]:
actions.max(), actions.min(), actions.mean()

In [None]:
plt.plot(info['soc'])

In [None]:
reward_type = 'weig_reward'

In [None]:
plt.plot(info[reward_type]['r_trad'])

In [None]:
plt.plot(info[reward_type]['r_clipping'])

In [None]:
info['pure_reward']['r_clipping'].mean(), info['weig_reward']['r_clipping'].mean()

In [None]:
plt.plot(info[reward_type]['r_deg'])

In [None]:
plt.plot(info['r_tot'])

In [None]:
plt.plot(np.cumsum(info[reward_type]['r_trad']))

In [None]:
plt.plot(np.cumsum(info[reward_type]['r_clipping']))

In [None]:
plt.plot(np.cumsum(info[reward_type]['r_deg']))

In [None]:
plt.plot(np.cumsum(info['r_tot']))

In [None]:
print(info['soc'])

In [None]:
plt.plot(np.cumsum(info[reward_type]['r_deg']))

In [None]:
print(info['soc'])