In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%cd ../..

/media/samuele/Disco/PycharmProjectsUbuntu/MARL-CER


In [2]:
import os

# os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax
import jax.numpy as jnp
# jax.config.update("jax_debug_nans", True)
jax.config.update("jax_enable_x64", True)

jax.devices()

[CudaDevice(id=0)]

In [3]:
from ernestogym.envs_jax.single_agent.utils import parameter_generator

In [4]:
pack_options = "ernestogym/ernesto_jax/data/battery/pack.yaml"
ecm = "ernestogym/ernesto_jax/data/battery/models/electrical/thevenin_pack.yaml"
# ecm = "ernestogym/ernesto_jax/data/battery/models/electrical/thevenin_fading_pack.yaml"
r2c = "ernestogym/ernesto_jax/data/battery/models/thermal/r2c_thermal_pack.yaml"
# bolun = "ernestogym/ernesto_jax/data/battery/models/aging/bolun_pack.yaml"
bolun = "ernestogym/ernesto_jax/data/battery/models/aging/bolun_pack_dropflow.yaml"
world = "ernestogym/envs_jax/single_agent/world_deg.yaml"
# world = "ernestogym/envs_jax/single_agent/world_fading.yaml"

# battery_type = 'fading'
# battery_type = 'degrading'
battery_type = 'degrading_dropflow'

params = parameter_generator(
    input_var='current',
    battery_options=pack_options,
    electrical_model=ecm,
    thermal_model=r2c,
    aging_model=bolun,
    world_options=world,
    use_reward_normalization=True

)

params

{'battery': {'sign_convention': 'passive',
  'params': {'nominal_voltage': 350.4,
   'nominal_capacity': 60.0,
   'nominal_dod': 0.8,
   'nominal_lifetime': 3000,
   'v_max': 398.4,
   'v_min': 288.0,
   'temp_ambient': 298.15,
   'nominal_cost': 3000.0},
  'bounds': {'voltage': {'low': 288.0, 'high': 398.4},
   'current': {'low': -180.0, 'high': 60.0},
   'power': {'low': -71712.0, 'high': 23904.0},
   'temperature': {'low': 273.15, 'high': 323.15},
   'temp_ambient': {'low': 273.15, 'high': 313.15},
   'soc': {'low': 0.2, 'high': 0.8},
   'soh': {'low': 0.8, 'high': 1.0}},
  'init': {'voltage': 288.0,
   'current': 0.0,
   'power': 0.0,
   'temperature': 293.15,
   'temp_ambient': 293.15,
   'soc': 0.5,
   'soh': 1.0}},
 'input_var': 'current',
 'models_config': [{'type': 'electrical',
   'class_name': 'TheveninModel',
   'use_fading': False,
   'components': {'r0': {'selected_type': 'scalar',
     'scalar': 10.0,
     'lookup': {'table': 'r0_pack.csv',
      'inputs': [{'var': 'temp

In [5]:
from ernestogym.envs_jax.single_agent.env import MicroGridEnv

In [6]:
def prep_for_training(params, battery_type, demand_profile):
    env = MicroGridEnv(params, battery_type, demand_profile)

    initial_state = env.initial_state
    env_params = env.params

    return env, initial_state, env_params

In [7]:
env, initial_state, env_params = prep_for_training(params, battery_type=battery_type, demand_profile='64')

In [8]:
key = jax.random.key(30)
state = initial_state
state

EnvState(time=Array(0, dtype=int64, weak_type=True), battery_state=BessBolunDropflowState(nominal_capacity=Array(60., dtype=float64, weak_type=True), nominal_cost=Array(3000., dtype=float64, weak_type=True), nominal_voltage=Array(350.4, dtype=float64, weak_type=True), nominal_dod=Array(0.8, dtype=float64, weak_type=True), nominal_lifetime=Array(3000, dtype=int64, weak_type=True), c_max=Array(60., dtype=float64, weak_type=True), temp_ambient=Array(293.15, dtype=float64, weak_type=True), v_max=Array(398.4, dtype=float64, weak_type=True), v_min=Array(288., dtype=float64, weak_type=True), elapsed_time=Array(0., dtype=float64, weak_type=True), electrical_state=ElectricalModelState(r0_nominal=Array(10., dtype=float64, weak_type=True), r0=Array(10., dtype=float64, weak_type=True), rc=RCState(resistance_nominal=Array(10., dtype=float64, weak_type=True), resistance=Array(10., dtype=float64, weak_type=True), capacity=Array(10., dtype=float64, weak_type=True), i_resistance=Array(0., dtype=float64

In [9]:
obs, state, reward, done, info = env.step(key, state, -20., env_params)

In [10]:
print(state)
print('\n\n\n############################\n\n\n')
print(info)

EnvState(time=Array(60, dtype=int64, weak_type=True), battery_state=BessBolunDropflowState(nominal_capacity=Array(60., dtype=float64, weak_type=True), nominal_cost=Array(3000., dtype=float64, weak_type=True), nominal_voltage=Array(350.4, dtype=float64, weak_type=True), nominal_dod=Array(0.8, dtype=float64, weak_type=True), nominal_lifetime=Array(3000, dtype=int64, weak_type=True), c_max=Array(60., dtype=float64, weak_type=True), temp_ambient=Array(293.15, dtype=float64, weak_type=True), v_max=Array(398.4, dtype=float64, weak_type=True), v_min=Array(288., dtype=float64, weak_type=True), elapsed_time=Array(1., dtype=float64, weak_type=True), electrical_state=ElectricalModelState(r0_nominal=Array(10., dtype=float64, weak_type=True), r0=Array(10., dtype=float64, weak_type=True), rc=RCState(resistance_nominal=Array(10., dtype=float64, weak_type=True), resistance=Array(10., dtype=float64, weak_type=True), capacity=Array(10., dtype=float64, weak_type=True), i_resistance=Array(7.5, dtype=float

In [11]:
def train(env:MicroGridEnv, initial_state, env_params, num_iter, init_key):

    def iter_body(i, val):
        state, key, log = val
        key, subkey1, subkey2 = jax.random.split(key, 3)
        a = jax.random.uniform(subkey1, minval=env_params.i_min_action, maxval=env_params.i_max_action)
        obs, state, reward, done, info = env.step(subkey2, state, a, env_params)

        def update_dict(d, r_trad, r_op, r_deg, r_clipping):
            d['r_trad'] = d['r_trad'].at[i].set(r_trad)
            d['r_op'] = d['r_op'].at[i].set(r_op)
            d['r_deg'] = d['r_deg'].at[i].set(r_deg)
            d['r_clipping'] = d['r_clipping'].at[i].set(r_clipping)

        log['soc'] = log['soc'].at[i].set(info['soc'])
        log['soh'] = log['soh'].at[i].set(info['soh'])

        update_dict(log['pure_reward'], **info['pure_reward'])
        update_dict(log['norm_reward'], **info['norm_reward'])
        update_dict(log['weig_reward'], **info['weig_reward'])
        log['r_tot'] = log['r_tot'].at[i].set(info['r_tot'])

        log['action'] = log['action'].at[i].set(a)

        return state, key, log

    def training_loop():

        log = {'soc': jnp.zeros(num_iter),
               'soh': jnp.zeros(num_iter),
               'pure_reward': {'r_trad': jnp.zeros(num_iter),
                               'r_op': jnp.zeros(num_iter),
                               'r_deg': jnp.zeros(num_iter),
                               'r_clipping': jnp.zeros(num_iter)},
               'norm_reward': {'r_trad': jnp.zeros(num_iter),
                               'r_op': jnp.zeros(num_iter),
                               'r_deg': jnp.zeros(num_iter),
                               'r_clipping': jnp.zeros(num_iter)},
               'weig_reward': {'r_trad': jnp.zeros(num_iter),
                               'r_op': jnp.zeros(num_iter),
                               'r_deg': jnp.zeros(num_iter),
                               'r_clipping': jnp.zeros(num_iter)},
               'r_tot': jnp.zeros(num_iter),
               'action': jnp.zeros(num_iter)}

        state = jax.lax.fori_loop(0,num_iter, iter_body, (initial_state, init_key, log))
        return state

    return jax.jit(training_loop)

In [12]:
num_iter = 1000

fun = train(env, initial_state, env_params, num_iter, key)

In [13]:
from time import time

t1 = time()

state, _, log = fun()

print(time() - t1)

1.6407296657562256


In [14]:
log

{'action': Array([-1698.54334675,  1932.12488843,  -862.07735192,  1821.11030368,
          749.36976305,   745.17617247, -1269.50135217,  1394.75427937,
         -697.40944059, -2147.99919599,   146.25048375,   -19.08365028,
          465.3238229 ,   953.10599758, -1020.8902266 ,  -115.32097698,
        -1672.04078533,   510.61082472, -1798.6202373 ,  1774.12825608,
         1269.95696008,  1237.69449196, -1334.40616383,  -661.18989989,
          -33.38860568,  -519.60608041, -1669.13660861,   273.0278752 ,
        -1508.84742942,  1650.37776817,  1554.35705522,  1129.15921945,
         1859.51403301,  -748.99626881,  -721.97344449,   504.34940957,
        -1571.91836292,   242.03401586,   762.35408705,  -312.54722286,
         -730.42236014,  -799.83356378, -1154.68146685, -1720.78976258,
         1778.41423762,  2081.72748391,  -957.92952336,   642.02804094,
        -1833.47065511,    28.21789051, -1652.30135502,  1919.30422015,
          729.94959151, -1928.38877841, -2124.61714547

In [15]:
for e in log['soh']:
    print(e, end='\t')

1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	0.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	0.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	

In [16]:
state

EnvState(time=Array(5460, dtype=int64, weak_type=True), battery_state=BessBolunDropflowState(nominal_capacity=Array(60., dtype=float64, weak_type=True), nominal_cost=Array(3000., dtype=float64, weak_type=True), nominal_voltage=Array(350.4, dtype=float64, weak_type=True), nominal_dod=Array(0.8, dtype=float64, weak_type=True), nominal_lifetime=Array(3000, dtype=int64, weak_type=True), c_max=Array(60., dtype=float64, weak_type=True), temp_ambient=Array(293.15, dtype=float64, weak_type=True), v_max=Array(398.4, dtype=float64, weak_type=True), v_min=Array(288., dtype=float64, weak_type=True), elapsed_time=Array(91., dtype=float64, weak_type=True), electrical_state=ElectricalModelState(r0_nominal=Array(10., dtype=float64, weak_type=True), r0=Array(10., dtype=float64, weak_type=True), rc=RCState(resistance_nominal=Array(10., dtype=float64, weak_type=True), resistance=Array(10., dtype=float64, weak_type=True), capacity=Array(10., dtype=float64, weak_type=True), i_resistance=Array(607.43970575,