In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%cd ../..

In [None]:
import os

# os.environ["JAX_PLATFORM_NAME"] = "cpu"

import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp

# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_enable_x64", True)

jax.devices()

In [None]:
type(np.ceil(5./2).astype(int)), np.ceil(5./2).astype(int)

In [None]:
from ernestogym.envs.single_agent.utils import parameter_generator

In [None]:
pack_options = "ernestogym/ernesto/data/battery/pack.yaml"
ecm = "ernestogym/ernesto/data/battery/models/electrical/thevenin_pack.yaml"
# ecm = "ernestogym/ernesto/data/battery/models/electrical/thevenin_fading_pack.yaml"
r2c = "ernestogym/ernesto/data/battery/models/thermal/r2c_thermal_pack.yaml"
# bolun = "ernestogym/ernesto/data/battery/models/aging/bolun_pack.yaml"
bolun = "ernestogym/ernesto/data/battery/models/aging/bolun_pack_dropflow.yaml"
world = "ernestogym/envs/single_agent/world_deg.yaml"
# world = "ernestogym/envs/single_agent/world_fading.yaml"

# battery_type = 'fading'
# battery_type = 'degrading'
battery_type = 'degrading_dropflow'

params = parameter_generator(
    input_var='current',
    battery_options=pack_options,
    electrical_model=ecm,
    thermal_model=r2c,
    aging_model=bolun,
    world_options=world,
    use_reward_normalization=True

)

params

In [None]:
key = jax.random.key(30)

In [None]:
from ernestogym.envs.single_agent.env import MicroGridEnv as JaxEnv

In [None]:
def prep_for_training(params, battery_type, demand_profile):
    env = JaxEnv(params, battery_type, demand_profile)

    initial_state = env.initial_state
    env_params = env.params

    return env, initial_state, env_params

In [None]:
env_jax, initial_state, env_params = prep_for_training(params, battery_type=battery_type, demand_profile='64')

In [None]:
initial_state

In [None]:
def train(env:JaxEnv, env_params):

    def policy(obs, key):
        # return jax.random.uniform(key, minval=env_params.i_min_action, maxval=env_params.i_max_action)
        return jax.random.uniform(key, minval=-50, maxval=50)

    def iter_body(i, val):
        state, obs, key, log = val
        key, subkey1, subkey2 = jax.random.split(key, 3)

        a = policy(obs, key)

        obs, state, reward, done, info = env.step_env(subkey2, state, a, env_params)    #FIXME è step non step_env

        def update_dict(d, r_trad, r_op, r_deg, r_clipping):
            d['r_trad'] = d['r_trad'].at[i].set(r_trad)
            d['r_op'] = d['r_op'].at[i].set(r_op)
            d['r_deg'] = d['r_deg'].at[i].set(r_deg)
            d['r_clip'] = d['r_clip'].at[i].set(r_clipping)

        log['soc'] = log['soc'].at[i].set(info['soc'])
        log['soh'] = log['soh'].at[i].set(info['soh'])

        update_dict(log['pure_reward'], **info['pure_reward'])
        update_dict(log['norm_reward'], **info['norm_reward'])
        update_dict(log['weig_reward'], **info['weig_reward'])
        log['r_tot'] = log['r_tot'].at[i].set(info['r_tot'])

        log['action'] = log['action'].at[i].set(a)

        return state, obs, key, log

    def training_loop(num_iter, init_key):

        log = {'soc': jnp.zeros(num_iter),
               'soh': jnp.zeros(num_iter),
               'pure_reward': {'r_trad': jnp.zeros(num_iter),
                               'r_op': jnp.zeros(num_iter),
                               'r_deg': jnp.zeros(num_iter),
                               'r_clip': jnp.zeros(num_iter)},
               'norm_reward': {'r_trad': jnp.zeros(num_iter),
                               'r_op': jnp.zeros(num_iter),
                               'r_deg': jnp.zeros(num_iter),
                               'r_clip': jnp.zeros(num_iter)},
               'weig_reward': {'r_trad': jnp.zeros(num_iter),
                               'r_op': jnp.zeros(num_iter),
                               'r_deg': jnp.zeros(num_iter),
                               'r_clip': jnp.zeros(num_iter)},
               'r_tot': jnp.zeros(num_iter),
               'action': jnp.zeros(num_iter)}

        key, subkey = jax.random.split(init_key)

        obs, initial_state = env.reset(subkey, env_params)

        state, obs, key, log = jax.lax.fori_loop(0,num_iter, iter_body, (initial_state, obs, key, log))
        return state, log

    return jax.jit(training_loop, static_argnums=(0,))

In [None]:
jitted_training_loop = train(env_jax, env_params)

In [None]:
from time import time

num_iter = 10001

t1 = time()

state, log = jitted_training_loop(num_iter, key)

print(time() - t1)

log = jax.tree.map(lambda x: np.array(x), log)
log

In [None]:
reward_type = 'weig_reward'

f = plt.figure(figsize=(20, 20))

a1 = f.add_subplot(2, 2, 1)
a1.plot(log[reward_type]['r_trad'])
a1.set_title('r trad')

a2 = f.add_subplot(2, 2, 2)
a2.plot(log[reward_type]['r_clip'])
a2.set_title('r clip')

a3 = f.add_subplot(2, 2, 3)
a3.plot(log[reward_type]['r_op'])
a3.set_title('r op')

a4 = f.add_subplot(2, 2, 4)
a4.plot(log[reward_type]['r_deg'])
a4.set_title('r deg')

In [None]:
f = plt.figure(figsize=(20, 6))

a1 = f.add_subplot(1, 2, 1)
a1.plot(log['soc'])
a1.set_title('soc')

a2 = f.add_subplot(1, 2, 2)
a2.plot(log['soh'])
a2.set_title('soh')