In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%cd ../..

In [None]:
import os

# os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax
import jax.numpy as jnp
# jax.config.update("jax_debug_nans", True)
jax.config.update("jax_enable_x64", True)

jax.devices()

In [None]:
from ernestogym.envs.single_agent.utils import parameter_generator

In [None]:
pack_options = "ernestogym/ernesto/data/battery/pack.yaml"
ecm = "ernestogym/ernesto/data/battery/models/electrical/thevenin_pack.yaml"
# ecm = "ernestogym/ernesto/data/battery/models/electrical/thevenin_fading_pack.yaml"
r2c = "ernestogym/ernesto/data/battery/models/thermal/r2c_thermal_pack.yaml"
# bolun = "ernestogym/ernesto/data/battery/models/aging/bolun_pack.yaml"
bolun = "ernestogym/ernesto/data/battery/models/aging/bolun_pack_dropflow.yaml"
world = "ernestogym/envs/single_agent/world_deg.yaml"
# world = "ernestogym/envs/single_agent/world_fading.yaml"

# battery_type = 'fading'
# battery_type = 'degrading'
battery_type = 'degrading_dropflow'

params = parameter_generator(
    input_var='current',
    battery_options=pack_options,
    electrical_model=ecm,
    thermal_model=r2c,
    aging_model=bolun,
    world_options=world,
    use_reward_normalization=True

)

params

In [None]:
from ernestogym.envs.single_agent.env import MicroGridEnv

In [None]:
def prep_for_training(params, battery_type, demand_profile):
    env = MicroGridEnv(params, battery_type, demand_profile)

    initial_state = env.initial_state
    env_params = env.params

    return env, initial_state, env_params

In [None]:
env, initial_state, env_params = prep_for_training(params, battery_type=battery_type, demand_profile='64')

In [None]:
key = jax.random.key(30)
state = initial_state
state

In [None]:
obs, state, reward, done, info = env.step(key, state, -20., env_params)

In [None]:
print(state)
print('\n\n\n############################\n\n\n')
print(info)

In [None]:
def train(env:MicroGridEnv, initial_state, env_params, num_iter, init_key):

    def iter_body(i, val):
        state, key, log = val
        key, subkey1, subkey2 = jax.random.split(key, 3)
        a = jax.random.uniform(subkey1, minval=env_params.i_min_action, maxval=env_params.i_max_action)
        obs, state, reward, done, info = env.step(subkey2, state, a, env_params)

        def update_dict(d, r_trad, r_op, r_deg, r_clipping):
            d['r_trad'] = d['r_trad'].at[i].set(r_trad)
            d['r_op'] = d['r_op'].at[i].set(r_op)
            d['r_deg'] = d['r_deg'].at[i].set(r_deg)
            d['r_clipping'] = d['r_clipping'].at[i].set(r_clipping)

        log['soc'] = log['soc'].at[i].set(info['soc'])
        log['soh'] = log['soh'].at[i].set(info['soh'])

        update_dict(log['pure_reward'], **info['pure_reward'])
        update_dict(log['norm_reward'], **info['norm_reward'])
        update_dict(log['weig_reward'], **info['weig_reward'])
        log['r_tot'] = log['r_tot'].at[i].set(info['r_tot'])

        log['action'] = log['action'].at[i].set(a)

        return state, key, log

    def training_loop():

        log = {'soc': jnp.zeros(num_iter),
               'soh': jnp.zeros(num_iter),
               'pure_reward': {'r_trad': jnp.zeros(num_iter),
                               'r_op': jnp.zeros(num_iter),
                               'r_deg': jnp.zeros(num_iter),
                               'r_clipping': jnp.zeros(num_iter)},
               'norm_reward': {'r_trad': jnp.zeros(num_iter),
                               'r_op': jnp.zeros(num_iter),
                               'r_deg': jnp.zeros(num_iter),
                               'r_clipping': jnp.zeros(num_iter)},
               'weig_reward': {'r_trad': jnp.zeros(num_iter),
                               'r_op': jnp.zeros(num_iter),
                               'r_deg': jnp.zeros(num_iter),
                               'r_clipping': jnp.zeros(num_iter)},
               'r_tot': jnp.zeros(num_iter),
               'action': jnp.zeros(num_iter)}

        state = jax.lax.fori_loop(0,num_iter, iter_body, (initial_state, init_key, log))
        return state

    return jax.jit(training_loop)

In [None]:
num_iter = 1000

fun = train(env, initial_state, env_params, num_iter, key)

In [None]:
from time import time

t1 = time()

state, _, log = fun()

print(time() - t1)

In [None]:
log

In [None]:
for e in log['soh']:
    print(e, end='\t')

In [None]:
state