In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%cd ../..

In [None]:
import jax
import jax.numpy as jnp
from flax import nnx
from functools import partial
from jax_tqdm import scan_tqdm

from algorithms.utils import save_state, restore_state

import numpy as np

In [None]:
from ernestogym.envs.single_agent.env_new_clip import MicroGridEnv as NewClipEnv
from ernestogym.envs.single_agent.env import MicroGridEnv
from ernestogym.envs.single_agent.env_trading_soc import MicroGridEnvSocAction

In [None]:
def my_env_creator(params, battery_type, env_type='normal'):
    if env_type == 'normal':
        env = MicroGridEnv(params, battery_type)
    elif env_type == 'soc_action':
        env = MicroGridEnvSocAction(params, battery_type)
    elif env_type == 'new_clip':
        env = NewClipEnv(params, battery_type)
    else:
        raise ValueError('Unknown env_type')
    return env, env.params

In [None]:
from ernestogym.envs.single_agent.utils import parameter_generator

In [None]:
# pack_options = "ernestogym/ernesto/data/battery/pack.yaml"
# pack_options = "ernestogym/ernesto/data/battery/pack_init_full.yaml"
pack_options = "ernestogym/ernesto/data/battery/pack_init_full_cheap.yaml"

ecm = "ernestogym/ernesto/data/battery/models/electrical/thevenin_pack.yaml"
# ecm = "ernestogym/ernesto/data/battery/models/electrical/thevenin_fading_pack.yaml"
r2c = "ernestogym/ernesto/data/battery/models/thermal/r2c_thermal_pack.yaml"
# bolun = "ernestogym/ernesto/data/battery/models/aging/bolun_pack.yaml"
bolun = "ernestogym/ernesto/data/battery/models/aging/bolun_pack.yaml"
# world = "ernestogym/envs/single_agent/ijcnn_deg_train.yaml"
# world = "ernestogym/envs/single_agent/ijcnn_deg_train_no_clip.yaml"
# world = "ernestogym/envs/single_agent/ijcnn_deg_train_only_trad.yaml"
# world = "ernestogym/envs/single_agent/world_fading.yaml"
# world = "ernestogym/envs/single_agent/ijcnn_deg_train_new_gen_data.yaml"
world = "ernestogym/envs/single_agent/ijcnn_deg_train_new_gen_data_new_clip.yaml"

# battery_type = 'fading'
# battery_type = 'degrading'
battery_type = 'degrading_dropflow'

params = parameter_generator(
    input_var='current',
    battery_options=pack_options,
    electrical_model=ecm,
    thermal_model=r2c,
    aging_model=bolun,
    world_options=world,
    # use_reward_normalization=True

)

params

In [None]:
# pack_options = "ernestogym/ernesto/data/battery/pack.yaml"
# pack_options = "ernestogym/ernesto/data/battery/pack_init_full.yaml"
pack_options = "ernestogym/ernesto/data/battery/pack_init_full_cheap.yaml"


ecm = "ernestogym/ernesto/data/battery/models/electrical/thevenin_pack.yaml"
# ecm = "ernestogym/ernesto/data/battery/models/electrical/thevenin_fading_pack.yaml"
r2c = "ernestogym/ernesto/data/battery/models/thermal/r2c_thermal_pack.yaml"
# bolun = "ernestogym/ernesto/data/battery/models/aging/bolun_pack.yaml"
bolun = "ernestogym/ernesto/data/battery/models/aging/bolun_pack.yaml"
# world = "ernestogym/envs/single_agent/ijcnn_deg_test.yaml"
# world = "ernestogym/envs/single_agent/ijcnn_deg_test_no_clip.yaml"
# world = "ernestogym/envs/single_agent/ijcnn_deg_test_only_trad.yaml"
# world = "ernestogym/envs/single_agent/world_fading.yaml"
# world = "ernestogym/envs/single_agent/ijcnn_deg_test_new_gen_data.yaml"
world = "ernestogym/envs/single_agent/ijcnn_deg_test_new_gen_data_new_clip.yaml"

params_testing = parameter_generator(
    input_var='current',
    battery_options=pack_options,
    electrical_model=ecm,
    thermal_model=r2c,
    aging_model=bolun,
    world_options=world,
    # use_reward_normalization=True

)

params_testing['step']

In [None]:
from algorithms.ppo import make_train, train_wrapper  # , train_for, train_for_flax
from flax.core.frozen_dict import freeze, unfreeze

In [None]:
num_envs = 4

total_timesteps = 8760 * num_envs * 200 #876000 * num_envs * 3

env_type = 'new_clip'

config = {
    "LR": 1e-3,
    'LR_MIN': 1e-5,
    "NUM_ENVS": num_envs,
    "NUM_STEPS": 8192, #2048 #10000,
    # "TOTAL_TIMESTEPS": 100000 * num_envs,
    "TOTAL_TIMESTEPS": total_timesteps,
    "UPDATE_EPOCHS": 10,
    "NUM_MINIBATCHES": 32,
    "NORMALIZE_ENV": False,
    "GAMMA": 0.99,
    "GAE_LAMBDA": 0.98,
    "CLIP_EPS": 0.25,
    "ENT_COEF": 0.01,
    "VF_COEF": 0.5,
    "MAX_GRAD_NORM": 0.5,
    "NET_ARCH": (64, 32, 16), #(16, 16, 16),
    "ACTIVATION": 'tanh',
    "LOGISTIC_FUNCTION_TO_ACTOR": False,
    "ENV_NAME": None,
    # "ANNEAL_LR": False,
    'LR_SCHEDULE': 'constant',
    'DEBUG': False,
    "NORMALIZE_REWARD_FOR_GAE_AND_TARGETS": False,
    "NORMALIZE_TARGETS": False,
    "NORMALIZE_ADVANTAGES": True,
    "ENV_TYPE": env_type,
    'NETWORK': 'actor_critic',
    'USE_WEIGHT_DECAY': True
}



rng = jax.random.PRNGKey(42)

env, env_params = my_env_creator(params, battery_type, env_type=env_type)
env, env_params, train_state = make_train(config, env, env_params)

env_testing, env_testing_params = my_env_creator(params_testing, battery_type, env_type=env_type)

val_rng = jax.random.PRNGKey(51)
val_num_iters = 8670*8

# out = train_jit(rng)

config = freeze(config)

In [None]:
print(config)

In [None]:
network, optimizer = nnx.merge(train_state.graph_def, train_state.state)

In [None]:
network

In [None]:
import time
import matplotlib.pyplot as plt
rng = jax.random.PRNGKey(42)
t0 = time.time()
out = train_wrapper(env, env_params, config, train_state, rng, validate=True, freq_val=10, val_env=env_testing, val_params=env_testing_params, val_rng=val_rng, val_num_iters=val_num_iters)

# out = train_for_flax(env, env_params, config, network, optimizer, rng)

train_state = out['runner_state'][0]

print(f"time: {time.time() - t0:.2f} s")
plt.plot(out["metrics"]["returned_episode_returns"].mean(-1).reshape(-1))
plt.xlabel("Update Step")
plt.ylabel("Return")
plt.show()

In [None]:
out["metrics"]['r_tot'].device

In [None]:
out['val_info']['r_tot'].device

In [None]:
out2 = jax.device_put(out, device=jax.devices('cpu')[0])
out2['val_info']['r_tot'].device

In [None]:
network, optimizer = nnx.merge(train_state.graph_def, train_state.state)
network

## State saving

In [None]:
save_state(network, config, params, train_info=out['metrics'], val_info=out.get('val_info'), env_type=env_type, additional_info='total_timesteps'+str(total_timesteps)+'_init_full_'+str('full' in pack_options))

In [None]:
out

In [None]:
r_trad = out['metrics']['weig_reward']['r_trad'].mean(-1).reshape(-1)
r_trad_cum = np.cumsum(r_trad)
r_trad.shape

In [None]:
acts = out['metrics']['action'].flatten()
acts.min(), acts.max(), acts.mean()

In [None]:
plt.plot(r_trad_cum)

In [None]:
r_tot_cum = np.cumsum(out['metrics']['r_tot'].mean(-1).reshape(-1))
plt.plot(r_tot_cum)

# Testing

In [None]:
pack_options = "ernestogym/ernesto/data/battery/pack.yaml"
# pack_options = "ernestogym/ernesto/data/battery/pack_init_full.yaml"

ecm = "ernestogym/ernesto/data/battery/models/electrical/thevenin_pack.yaml"
# ecm = "ernestogym/ernesto/data/battery/models/electrical/thevenin_fading_pack.yaml"
r2c = "ernestogym/ernesto/data/battery/models/thermal/r2c_thermal_pack.yaml"
# bolun = "ernestogym/ernesto/data/battery/models/aging/bolun_pack.yaml"
bolun = "ernestogym/ernesto/data/battery/models/aging/bolun_pack.yaml"
world = "ernestogym/envs/single_agent/ijcnn_deg_test_only_trad.yaml"
# world = "ernestogym/envs/single_agent/world_fading.yaml"

# battery_type = 'fading'
# battery_type = 'degrading'
battery_type = 'degrading_dropflow'

params = parameter_generator(
    input_var='current',
    battery_options=pack_options,
    electrical_model=ecm,
    thermal_model=r2c,
    aging_model=bolun,
    world_options=world,
    use_reward_normalization=True

)

params['step']

In [None]:
@partial(jax.jit, static_argnums=(0, 2, 3))
def test(env: MicroGridEnv, env_params, network, num_iter, rng):

    rng, _rng = jax.random.split(rng)

    obsv, env_state = env.reset(_rng, env_params)

    @scan_tqdm(num_iter, print_rate=num_iter // 100)
    def _env_step(runner_state, unused):
        obsv, env_state, rng = runner_state

        pi, _ = network(obsv)

        #deterministic action
        action = pi.mode()

        rng, _rng = jax.random.split(rng)
        obsv, env_state, reward, done, info = env.step(_rng, env_state, action, env_params)

        runner_state = (obsv, env_state, rng)
        return runner_state, (info, action)

    runner_state = (obsv, env_state, rng)

    runner_state, info = jax.lax.scan(_env_step, runner_state, jnp.arange(num_iter))

    return info

In [None]:
env, env_params = my_env_creator(params, battery_type, env_type=env_type)
env_params = env.eval(env_params)

In [None]:
num_iter = 8760 * 8
info, actions = test(env, env_params, network, num_iter, jax.random.PRNGKey(51))


In [None]:
jax.tree.map(lambda val: val.shape, info)

In [None]:
actions.max(), actions.min(), actions.mean()

In [None]:
plt.plot(info['soc'])

In [None]:
reward_type = 'weig_reward'

In [None]:
plt.plot(info[reward_type]['r_trad'])

In [None]:
plt.plot(info[reward_type]['r_clipping'])

In [None]:
info['pure_reward']['r_clipping'].mean(), info['weig_reward']['r_clipping'].mean()

In [None]:
plt.plot(info[reward_type]['r_deg'])

In [None]:
plt.plot(info['r_tot'])

In [None]:
plt.plot(np.cumsum(info[reward_type]['r_trad']))

In [None]:
plt.plot(np.cumsum(info[reward_type]['r_clipping']))

In [None]:
plt.plot(np.cumsum(info[reward_type]['r_deg']))

In [None]:
plt.plot(np.cumsum(info['r_tot']))

In [None]:
print(info['soc'])

In [None]:
plt.plot(np.cumsum(info[reward_type]['r_deg']))

In [None]:
print(info['soc'])