In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%cd ../..

/media/samuele/Disco/PycharmProjectsUbuntu/MARL-CER


In [2]:
import os

# os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax
import jax.numpy as jnp

jax.devices()

[CudaDevice(id=0)]

In [3]:
from ernestogym.envs_jax.single_agent.utils import parameter_generator

In [4]:
pack_options = "ernestogym/ernesto_jax/data/battery/pack.yaml"
# ecm = "ernestogym/ernesto_jax/data/battery/models/electrical/thevenin_pack.yaml"
ecm = "ernestogym/ernesto_jax/data/battery/models/electrical/thevenin_fading_pack.yaml"
r2c = "ernestogym/ernesto_jax/data/battery/models/thermal/r2c_thermal_pack.yaml"
bolun = "ernestogym/ernesto_jax/data/battery/models/aging/bolun_pack.yaml"
# bolun = "ernestogym/ernesto_jax/data/battery/models/aging/bolun_pack_dropflow.yaml"
# world = "ernestogym/envs_jax/single_agent/world_deg.yaml"
world = "ernestogym/envs_jax/single_agent/world_fading.yaml"

battery_type = 'fading'

params = parameter_generator(
    input_var='current',
    battery_options=pack_options,
    electrical_model=ecm,
    thermal_model=r2c,
    aging_model=bolun,
    world_options=world,
    use_reward_normalization=True

)

params

{'battery': {'sign_convention': 'passive',
  'params': {'nominal_voltage': 350.4,
   'nominal_capacity': 60.0,
   'nominal_dod': 0.8,
   'nominal_lifetime': 3000,
   'v_max': 398.4,
   'v_min': 288.0,
   'temp_ambient': 298.15,
   'nominal_cost': 3000.0},
  'bounds': {'voltage': {'low': 288.0, 'high': 398.4},
   'current': {'low': -180.0, 'high': 60.0},
   'power': {'low': -71712.0, 'high': 23904.0},
   'temperature': {'low': 273.15, 'high': 323.15},
   'temp_ambient': {'low': 273.15, 'high': 313.15},
   'soc': {'low': 0.2, 'high': 0.8},
   'soh': {'low': 0.8, 'high': 1.0}},
  'init': {'voltage': 288.0,
   'current': 0.0,
   'power': 0.0,
   'temperature': 293.15,
   'temp_ambient': 293.15,
   'soc': 0.5,
   'soh': 1.0}},
 'input_var': 'current',
 'models_config': [{'type': 'electrical',
   'class_name': 'TheveninFadingModel',
   'use_fading': True,
   'alpha_fading': 0.00066667,
   'beta_fading': 1.66667e-07,
   'components': {'r0': {'selected_type': 'scalar',
     'scalar': 10.0,
   

In [5]:
from ernestogym.envs_jax.single_agent.env import MicroGridEnv

In [6]:
def prep_for_training(params, battery_type):
    env = MicroGridEnv(params, battery_type)

    initial_state = env.initial_state
    env_params = env.params

    return env, initial_state, env_params

In [7]:
env, initial_state, env_params = prep_for_training(params, battery_type=battery_type)

In [8]:
key = jax.random.key(30)
state = initial_state

In [10]:
obs, state, reward, done, info = env.step(key, state, 5., env_params)

In [11]:
def train(env:MicroGridEnv, initial_state, env_params, num_iter, init_key):

    def iter_body(i, val):
        state, key, log = val
        key, subkey1, subkey2 = jax.random.split(key, 3)
        a = jax.random.uniform(subkey1, minval=env_params.i_min_action, maxval=env_params.i_max_action)
        obs, state, reward, done, info = env.step(subkey2, state, a, env_params)

        def update_dict(d, r_trad, r_op, r_deg, r_clipping):
            d['r_trad'] = d['r_trad'].at[i].set(r_trad)
            d['r_op'] = d['r_op'].at[i].set(r_op)
            d['r_deg'] = d['r_deg'].at[i].set(r_deg)
            d['r_clipping'] = d['r_clipping'].at[i].set(r_clipping)

        log['soc'] = log['soc'].at[i].set(info['soc'])
        log['soh'] = log['soh'].at[i].set(info['soh'])

        update_dict(log['pure_reward'], **info['pure_reward'])
        update_dict(log['norm_reward'], **info['norm_reward'])
        update_dict(log['weig_reward'], **info['weig_reward'])
        log['r_tot'] = log['r_tot'].at[i].set(info['r_tot'])

        log['action'] = log['action'].at[i].set(a)

        return state, key, log

    def training_loop():

        log = {'soc': jnp.zeros(num_iter),
               'soh': jnp.zeros(num_iter),
               'pure_reward': {'r_trad': jnp.zeros(num_iter),
                               'r_op': jnp.zeros(num_iter),
                               'r_deg': jnp.zeros(num_iter),
                               'r_clipping': jnp.zeros(num_iter)},
               'norm_reward': {'r_trad': jnp.zeros(num_iter),
                               'r_op': jnp.zeros(num_iter),
                               'r_deg': jnp.zeros(num_iter),
                               'r_clipping': jnp.zeros(num_iter)},
               'weig_reward': {'r_trad': jnp.zeros(num_iter),
                               'r_op': jnp.zeros(num_iter),
                               'r_deg': jnp.zeros(num_iter),
                               'r_clipping': jnp.zeros(num_iter)},
               'r_tot': jnp.zeros(num_iter),
               'action': jnp.zeros(num_iter)}

        state = jax.lax.fori_loop(0,num_iter, iter_body, (initial_state, init_key, log))
        return state

    return jax.jit(training_loop)

In [12]:
num_iter = 10000

fun = train(env, initial_state, env_params, num_iter, key)

In [14]:
from time import time

t1 = time()

state, _, log = fun()

print(time() - t1)

0.5482473373413086


In [15]:
log

{'action': Array([-2031.5222 ,  1340.9451 ,   781.2202 , ...,   869.72314,
          301.07422,   560.17285], dtype=float32),
 'norm_reward': {'r_clipping': Array([ 45.67307 ,   0.      ,   0.      , ..., 231.57426 ,  50.292633,
          58.834675], dtype=float32),
  'r_deg': Array([-21.212696 , -10.547787 ,  -4.7662854, ...,  -0.       ,
          -0.       ,  -0.       ], dtype=float32),
  'r_op': Array([-15.253226, -45.256145, -23.087864, ...,  -0.      ,  -0.      ,
          -0.      ], dtype=float32),
  'r_trad': Array([ 1.06014595e+01,  6.77827393e+02, -1.11533386e+03, ...,
          1.22076720e-02,  1.18588815e-02,  1.56955793e-02], dtype=float32)},
 'pure_reward': {'r_clipping': Array([ 274038.4 ,       0.  ,       0.  , ..., 1389445.6 ,  301755.8 ,
          353008.06], dtype=float32),
  'r_deg': Array([-21.212696 , -10.547787 ,  -4.7662854, ...,  -0.       ,
          -0.       ,  -0.       ], dtype=float32),
  'r_op': Array([ -45759.68, -135768.44,  -69263.59, ...,      -0