In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%cd ../..

In [None]:
import matplotlib.pyplot as plt
from functools import partial
from jax_tqdm import scan_tqdm

import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly

from plotly_resampler import register_plotly_resampler, FigureWidgetResampler

import pandas as pd

from algorithms.utils import restore_state

import numpy as np

import jax
import jax.numpy as jnp
# jax.config.update("jax_enable_x64", True)

In [None]:
from ernestogym.envs.single_agent.env_new_clip import MicroGridEnv as NewClipEnv
from ernestogym.envs.single_agent.env import MicroGridEnv
from ernestogym.envs.single_agent.env_trading_soc import MicroGridEnvSocAction

In [None]:
def my_env_creator(params, battery_type, env_type='normal'):
    if env_type == 'normal':
        env = MicroGridEnv(params, battery_type)
    elif env_type == 'soc_action':
        env = MicroGridEnvSocAction(params, battery_type)
    elif env_type == 'new_clip':
        env = NewClipEnv(params, battery_type)
    else:
        raise ValueError('Unknown env_type')
    return env, env.params

In [None]:
from ernestogym.envs.single_agent.utils import parameter_generator

In [None]:
# pack_options = "ernestogym/ernesto/data/battery/pack.yaml"
# pack_options = "ernestogym/ernesto/data/battery/pack_init_full.yaml"
pack_options = "ernestogym/ernesto/data/battery/pack_init_full_cheap.yaml"


ecm = "ernestogym/ernesto/data/battery/models/electrical/thevenin_pack.yaml"
# ecm = "ernestogym/ernesto/data/battery/models/electrical/thevenin_fading_pack.yaml"
r2c = "ernestogym/ernesto/data/battery/models/thermal/r2c_thermal_pack.yaml"
# bolun = "ernestogym/ernesto/data/battery/models/aging/bolun_pack.yaml"
bolun = "ernestogym/ernesto/data/battery/models/aging/bolun_pack.yaml"
# world = "ernestogym/envs/single_agent/ijcnn_deg_test.yaml"
# world = "ernestogym/envs/single_agent/ijcnn_deg_test_only_trad.yaml"
# world = "ernestogym/envs/single_agent/ijcnn_deg_test_new_gen_data.yaml"
world = "ernestogym/envs/single_agent/ijcnn_deg_test_new_gen_data_new_clip.yaml"

# world = "ernestogym/envs/single_agent/world_fading.yaml"

# battery_type = 'fading'
# battery_type = 'degrading'
battery_type = 'degrading_dropflow'

params = parameter_generator(
    input_var='current',
    battery_options=pack_options,
    electrical_model=ecm,
    thermal_model=r2c,
    aging_model=bolun,
    world_options=world,
    # use_reward_normalization=True

)

# env_type = 'normal'
env_type = 'new_clip'
num_iter = 8760 * 8

params['step']

# Testing

In [None]:
logs = {}

## PPO

In [None]:
@partial(jax.jit, static_argnums=(0, 2, 3))
def test_ppo(env: MicroGridEnv, env_params, network, num_iter, rng):

    rng, _rng = jax.random.split(rng)

    env_params = env.eval(env_params)

    obsv, env_state = env.reset(_rng, env_params)

    env_params = env_params.replace(test_profile=env_params.test_profile+1)

    @scan_tqdm(num_iter, print_rate=num_iter // 100)
    def _env_step(runner_state, unused):
        obsv, env_state, env_params, rng = runner_state

        pi, _ = network(obsv)

        #deterministic action
        action = pi.mode()

        rng, _rng = jax.random.split(rng)
        obsv, env_state, reward, done, info = env.step(_rng, env_state, action, env_params)

        env_params = jax.lax.cond(done,
                                  lambda : env_params.replace(test_profile=env_params.test_profile+1),
                                  lambda : env_params)

        runner_state = (obsv, env_state, env_params, rng)
        return runner_state, (info, action)

    runner_state = (obsv, env_state, env_params, rng)

    runner_state, info = jax.lax.scan(_env_step, runner_state, jnp.arange(num_iter))

    return info

In [None]:
directory = '/media/samuele/Disco/PycharmProjectsUbuntu/MARL-CER/trained_agents/20250320_101434_lr_0.001_tot_timesteps_7008000_rl_sched_cosine_new_clip_actor_critic_total_timesteps7008000_init_full_True'

network, config, params_training, train_info, val_info = restore_state(directory)

assert config['ENV_TYPE'] == env_type

In [None]:
params_training

In [None]:
val_info

In [None]:
val_info['r_tot'].sum()

In [None]:
config

In [None]:
env, env_params = my_env_creator(params, battery_type, env_type=env_type)
env_params = env.eval(env_params)

print(env.dem_matrix.shape)

In [None]:
info, actions = test_ppo(env, env_params, network, num_iter, jax.random.PRNGKey(51))
info['actions'] = actions.flatten()
logs['ppo'] = info

## Recurrent PPO

In [None]:
# @partial(jax.jit, static_argnums=(0, 2, 3))
def test_recurrent_ppo(env: MicroGridEnv, env_params, network, num_iter, rng):

    rng, _rng = jax.random.split(rng)

    env_params = env.eval(env_params)

    obsv, env_state = env.reset(_rng, env_params)

    env_params = env_params.replace(test_profile=env_params.test_profile+1)

    act_state, cri_state = network.get_initial_lstm_state()

    act_state, cri_state = jax.tree.map(lambda x : jnp.astype(x, float), (act_state, cri_state))    # cast to float64 if 'jax_enable_x64' is enabled

    @scan_tqdm(num_iter, print_rate=num_iter // 100)
    def _env_step(runner_state, unused):

        obsv, env_state, env_params, act_state, rng = runner_state

        pi, _, act_state, _ = network(obsv, act_state, cri_state)

        #deterministic action
        action = pi.mode()

        rng, _rng = jax.random.split(rng)
        obsv, env_state, reward, done, info = env.step(_rng, env_state, action, env_params)

        env_params = jax.lax.cond(done,
                                  lambda : env_params.replace(test_profile=env_params.test_profile+1),
                                  lambda : env_params)

        runner_state = (obsv, env_state, env_params, act_state, rng)
        return runner_state, (info, action)

    runner_state = (obsv, env_state, env_params, act_state, rng)


    runner_state, info = jax.lax.scan(_env_step, runner_state, jnp.arange(num_iter))

    return info

In [None]:
# directory = '/media/samuele/Disco/PycharmProjectsUbuntu/MARL-CER/trained_agents/20250227_144145_lr_0.001_tot_timesteps_3504000_anneal_rl_True_normal_recurrent_actor_critic'
#
# network, config, params_training, train_info, val_info = restore_state(directory)
#
# assert config['ENV_TYPE'] == env_type

In [None]:
config

In [None]:
# info, actions = test_recurrent_ppo(env, env_params, network, num_iter, jax.random.PRNGKey(51))
# info['actions'] = actions.flatten()
# logs['recurrent_ppo'] = info

## Random

In [None]:
@partial(jax.jit, static_argnums=(0, 2))
def test_random(env: MicroGridEnv, env_params, num_iter, rng):

    rng, _rng = jax.random.split(rng)

    env_params = env.eval(env_params)

    obsv, env_state = env.reset(_rng, env_params)

    env_params = env_params.replace(test_profile=env_params.test_profile+1)

    @scan_tqdm(num_iter, print_rate=num_iter // 100)
    def _env_step(runner_state, unused):
        obsv, env_state, env_params, rng = runner_state

        rng, _rng = jax.random.split(rng)
        action = jax.random.uniform(_rng, minval=env_params.i_min_action, maxval=env_params.i_max_action)

        rng, _rng = jax.random.split(rng)
        obsv, env_state, reward, done, info = env.step(_rng, env_state, action, env_params)

        env_params = jax.lax.cond(done,
                                  lambda : env_params.replace(test_profile=env_params.test_profile+1),
                                  lambda : env_params)

        runner_state = (obsv, env_state, env_params, rng)
        return runner_state, (info, action)

    runner_state = (obsv, env_state, env_params, rng)

    runner_state, info = jax.lax.scan(_env_step, runner_state, jnp.arange(num_iter))

    return info

In [None]:
# info, actions = test_random(env, env_params, num_iter, jax.random.PRNGKey(51))
# info['actions'] = actions.flatten()
# logs['random'] = info

## Only market

In [None]:
@partial(jax.jit, static_argnums=(0, 2))
def test_only_market(env: MicroGridEnv, env_params, num_iter, rng):

    rng, _rng = jax.random.split(rng)

    env_params = env.eval(env_params)

    obsv, env_state = env.reset(_rng, env_params)

    env_params = env_params.replace(test_profile=env_params.test_profile+1)

    @scan_tqdm(num_iter, print_rate=num_iter // 100)
    def _env_step(runner_state, unused):
        obsv, env_state, env_params, rng = runner_state

        action = 0.

        rng, _rng = jax.random.split(rng)
        obsv, env_state, reward, done, info = env.step(_rng, env_state, action, env_params)

        env_params = jax.lax.cond(done,
                                  lambda : env_params.replace(test_profile=env_params.test_profile+1),
                                  lambda : env_params)

        runner_state = (obsv, env_state, env_params, rng)
        return runner_state, (info, action)

    runner_state = (obsv, env_state, env_params, rng)

    runner_state, info = jax.lax.scan(_env_step, runner_state, jnp.arange(num_iter))

    return info

In [None]:
info, actions = test_only_market(env, env_params, num_iter, jax.random.PRNGKey(51))
info['actions'] = actions.flatten()
logs['only_market'] = info

## Battery first

In [None]:
@partial(jax.jit, static_argnums=(0, 2))
def test_battery_first(env: MicroGridEnv, env_params, num_iter, rng):

    rng, _rng = jax.random.split(rng)

    env_params = env.eval(env_params)

    obsv, env_state = env.reset(_rng, env_params)

    env_params = env_params.replace(test_profile=env_params.test_profile+1)

    @scan_tqdm(num_iter, print_rate=num_iter // 100)
    def _env_step(runner_state, unused):
        obsv, env_state, env_params, rng = runner_state

        demand = obsv[env._obs_idx['demand']]
        generation = obsv[env._obs_idx['generation']]

        action = (generation - demand) / env_state.battery_state.electrical_state.v

        rng, _rng = jax.random.split(rng)
        obsv, env_state, reward, done, info = env.step(_rng, env_state, action, env_params)

        env_params = jax.lax.cond(done,
                                  lambda : env_params.replace(test_profile=env_params.test_profile+1),
                                  lambda : env_params)

        runner_state = (obsv, env_state, env_params, rng)
        return runner_state, (info, action)

    runner_state = (obsv, env_state, env_params, rng)

    runner_state, info = jax.lax.scan(_env_step, runner_state, jnp.arange(num_iter))

    return info

In [None]:
info, actions = test_battery_first(env, env_params, num_iter, jax.random.PRNGKey(51))
info['actions'] = actions.flatten()
logs['battery_first'] = info

# Plotting

In [None]:
logs = jax.tree.map(lambda x : np.asarray(x), logs)

algs = ['random', 'only_market', 'battery_first', 'ppo', 'recurrent_ppo']
colors = {alg: col for alg, col in zip(algs, plotly.colors.sample_colorscale('rainbow', len(algs)))}

In [None]:
def plot_external_data_matplotlib(demand, generation, sell_prices, buy_prices, start=0, length_max=None):

    if length_max is None:
        length_max = max(len(demand), len(generation), len(sell_prices), len(buy_prices))

    fig = plt.figure(figsize=(15, 20))

    ax1 = fig.add_subplot(2, 1, 1)

    ax1.plot(demand[start:start + length_max], label='demand')
    ax1.plot(generation[start:start + length_max], label='generation')
    ax1.legend()
    ax1.set_ylabel('W')

    sell_prices *= 1000000
    buy_prices *= 1000000

    ax2 = fig.add_subplot(2, 1, 2)
    ax2.plot(sell_prices[start:start + length_max], label='sell_price')
    ax2.plot(buy_prices[start:start + length_max], label='buy_price')
    ax2.legend()
    ax2.set_ylabel('€/MWh')

    plt.show()

# def plot_external_data_plotly(demand, generation, sell_prices, buy_prices, time_step):
#
#     n_points = max(len(demand), len(generation), len(sell_prices), len(buy_prices))
#
#     time = pd.date_range('2015-01-01', periods=n_points, freq=str(int(time_step))+'s')
#     fig = make_subplots(rows=2, cols=1, shared_xaxes=True, subplot_titles=['Power demand and generation', 'Market prices'])
#
#     # Add first time series
#     fig.add_trace(go.Scatter(x=time, y=demand, mode='lines', legendgroup='group1', name='demand'), row=1, col=1)
#     fig.add_trace(go.Scatter(x=time, y=generation, mode='lines', legendgroup='group1', name='generation'), row=1, col=1)
#
#     # Add second time series
#     fig.add_trace(go.Scatter(x=time, y=sell_prices*1000000, mode='lines', legendgroup='group2', name='Selling prices'), row=2, col=1)
#     fig.add_trace(go.Scatter(x=time, y=buy_prices*1000000, mode='lines', legendgroup='group2', name='Buying prices'), row=2, col=1)
#
#     # Format x-axis to show only month & day
#     fig.update_layout(
#         title='Synchronized Zoom with Month/Day Formatting',
#         xaxis=dict(tickformat='%b %d', showticklabels=True),   # Apply to main x-axis
#         xaxis2=dict(tickformat='%b %d', showticklabels=True),  # Apply to second subplot
#         yaxis=dict(fixedrange=True, title='Wh', minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
#         yaxis2=dict(fixedrange=True, title='€/MWh', minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
#         height=800,
#         width=1000,
#         legend_tracegroupgap=20
#     )
#
#     # fig.update_layout(
#     # legends=[
#     #     dict(x=1.05, y=0.95, tracegroup="group1"),  # Legend for first subplot
#     #     dict(x=1.05, y=0.60, tracegroup="group2"),  # Legend for second subplot
#     # ]
#     # )
#
#     fig.update_layout(
#         legend_tracegroupgap=400
#     )
#
#
#     fig.show()


def plot_external_data_plotly(demand, generation, sell_prices, buy_prices, time_step):

    n_points = max(len(demand), len(generation), len(sell_prices), len(buy_prices))

    time = pd.date_range('2015-01-01', periods=n_points, freq=str(int(time_step))+'s')
    fig = FigureWidgetResampler(make_subplots(rows=2, cols=1, shared_xaxes=True, subplot_titles=['Power demand and generation', 'Market prices']))

    # Add first time series
    fig.add_trace(go.Scatter(x=time, y=demand, mode='lines', legend='legend1', name='demand'), row=1, col=1)
    fig.add_trace(go.Scatter(x=time, y=generation, mode='lines', legend='legend1', name='generation'), row=1, col=1)

    # Add second time series
    fig.add_trace(go.Scatter(x=time, y=sell_prices*1000000, mode='lines', legend='legend2', name='Selling prices'), row=2, col=1)
    fig.add_trace(go.Scatter(x=time, y=buy_prices*1000000, mode='lines', legend='legend2', name='Buying prices'), row=2, col=1)

    # Format x-axis to show only month & day
    fig.update_layout(
        title='Synchronized Zoom with Month/Day Formatting',
        xaxis=dict(tickformat='%b %d %H:00', showticklabels=True),   # Apply to main x-axis
        xaxis2=dict(tickformat='%b %d %H:00', showticklabels=True),  # Apply to second subplot
        yaxis=dict(fixedrange=True, title='Wh', minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis2=dict(fixedrange=True, title='€/MWh', minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        height=800,
        width=1000,
        legend_tracegroupgap=20
    )

    # fig.update_layout(
    # legends=[
    #     dict(x=1.05, y=0.95, tracegroup="group1"),  # Legend for first subplot
    #     dict(x=1.05, y=0.60, tracegroup="group2"),  # Legend for second subplot
    # ]
    # )

    fig.update_layout(
        legend1=dict(
            xref="container",
            yref="container",
            y=0.6),
        legend2=dict(
            xref="container",
            yref="container",
            y=0.1)
    )


    display(fig)


def plot_data_plotly(demand, generation, sell_prices, buy_prices, log, algs, time_step, reward_type='weig_reward', cumulative=True):

    n_points = max(len(demand), len(generation), len(sell_prices), len(buy_prices))

    time = pd.date_range('2015-01-01', periods=n_points, freq=str(int(time_step))+'s')
    fig = FigureWidgetResampler(make_subplots(rows=9, cols=1, shared_xaxes=True, vertical_spacing = 0.05, subplot_titles=['Power demand and generation', 'Market prices', 'SoC', 'Total reward', 'Trading reward', 'Degradation reward', 'Clipping reward', 'Trad + deg rewards', 'Actions']), default_n_shown_samples=5000)

    # Add first time series
    fig.add_trace(go.Scatter(x=time, y=demand, mode='lines', legend='legend1', name='demand'), row=1, col=1)
    fig.add_trace(go.Scatter(x=time, y=generation, mode='lines', legend='legend1', name='generation'), row=1, col=1)

    # Add second time series
    fig.add_trace(go.Scatter(x=time, y=sell_prices*1000000, mode='lines', legend='legend2', name='Selling prices'), row=2, col=1)
    fig.add_trace(go.Scatter(x=time, y=buy_prices*1000000, mode='lines', legend='legend2', name='Buying prices'), row=2, col=1)

    rewards = {}
    for alg in algs:
        rewards[alg] = {}

    if cumulative:
        for alg in algs:
            rewards[alg]['r_tot'] = np.cumsum(log[alg]['r_tot'])
            rewards[alg]['r_trad'] = np.cumsum(log[alg][reward_type]['r_trad'])
            rewards[alg]['r_deg'] = np.cumsum(log[alg][reward_type]['r_deg'])
            rewards[alg]['r_clipping'] = np.cumsum(log[alg][reward_type]['r_clipping'])
    else:
        for alg in algs:
            rewards[alg]['r_tot'] = log[alg]['r_tot']
            rewards[alg]['r_trad'] = log[alg][reward_type]['r_trad']
            rewards[alg]['r_deg'] = log[alg][reward_type]['r_deg']
            rewards[alg]['r_clipping'] = log[alg][reward_type]['r_clipping']

    for alg in algs:
        fig.add_trace(go.Scatter(x=time, y=log[alg]['soc'], line=dict(color=colors[alg]), mode='lines', legend='legend3', name=alg), row=3, col=1)
        fig.add_trace(go.Scatter(x=time, y=rewards[alg]['r_tot'], line=dict(color=colors[alg]), mode='lines', legend='legend4', name=alg), row=4, col=1)
        fig.add_trace(go.Scatter(x=time, y=rewards[alg]['r_trad'], line=dict(color=colors[alg]), mode='lines', legend='legend5', name=alg), row=5, col=1)
        fig.add_trace(go.Scatter(x=time, y=rewards[alg]['r_deg'], line=dict(color=colors[alg]), mode='lines', legend='legend6', name=alg), row=6, col=1)
        fig.add_trace(go.Scatter(x=time, y=rewards[alg]['r_clipping'], line=dict(color=colors[alg]), mode='lines', legend='legend7', name=alg), row=7, col=1)
        fig.add_trace(go.Scatter(x=time, y=rewards[alg]['r_trad']+rewards[alg]['r_deg'], line=dict(color=colors[alg]), mode='lines', legend='legend7', name=alg), row=8, col=1)
        fig.add_trace(go.Scatter(x=time, y=log[alg]['actions'], line=dict(color=colors[alg]), mode='lines', legend='legend8', name=alg), row=9, col=1)




    # Format x-axis to show only month & day
    fig.update_layout(
        title='',
        xaxis=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis2=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis3=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis4=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis5=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis6=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis7=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis8=dict(tickformat='%b %d %H:00', showticklabels=True),

        yaxis=dict(fixedrange=True, title='Wh', minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis2=dict(fixedrange=True, title='€/MWh', minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis3=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis4=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis5=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis6=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis7=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis8=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis9=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),

        height=3000,
        width=1000,
        legend_tracegroupgap=20
    )

    # fig.update_layout(
    # legends=[
    #     dict(x=1.05, y=0.95, tracegroup="group1"),  # Legend for first subplot
    #     dict(x=1.05, y=0.60, tracegroup="group2"),  # Legend for second subplot
    # ]
    # )

    fig.update_layout(
        legend1=dict(
            xref="container",
            yref="container",
            y=0.6),
        legend2=dict(
            xref="container",
            yref="container",
            y=0.1)
    )


    fig.write_html(directory + '/plots.html')
    print(directory + '/plots.html')

    display(fig)

In [None]:
def plot_ext_data_and_reward_plotly(demand, generation, sell_prices, buy_prices, log, algs, time_step, cumulative=True):

    n_points = max(len(demand), len(generation), len(sell_prices), len(buy_prices))

    time = pd.date_range('2015-01-01', periods=n_points, freq=str(int(time_step))+'s')
    fig = FigureWidgetResampler(make_subplots(rows=4, cols=1, shared_xaxes=True, vertical_spacing = 0.05, subplot_titles=['Power demand and generation', 'Market prices', 'Total reward']))#, 'Actions'])

    # Add first time series
    fig.add_trace(go.Scattergl(x=time, y=demand, mode='lines', legend='legend1', name='demand'), row=1, col=1)
    fig.add_trace(go.Scattergl(x=time, y=generation, mode='lines', legend='legend1', name='generation'), row=1, col=1)

    # Add second time series
    fig.add_trace(go.Scattergl(x=time, y=sell_prices*1000000, mode='lines', legend='legend2', name='Selling prices'), row=2, col=1)
    fig.add_trace(go.Scattergl(x=time, y=buy_prices*1000000, mode='lines', legend='legend2', name='Buying prices'), row=2, col=1)

    rewards = {}
    for alg in algs:
        rewards[alg] = {}

    if cumulative:
        for alg in algs:
            rewards[alg]['r_tot'] = np.cumsum(log[alg]['r_tot'])
    else:
        for alg in algs:
            rewards[alg]['r_tot'] = log[alg]['r_tot']

    for alg in algs:
        fig.add_trace(go.Scattergl(x=time, y=rewards[alg]['r_tot'], line=dict(color=colors[alg]), mode='lines', legend='legend3', name=alg), row=3, col=1)
        fig.add_trace(go.Scattergl(x=time, y=log[alg]['actions'], line=dict(color=colors[alg]), mode='lines', legend='legend7', name=alg), row=4, col=1)




    # Format x-axis to show only month & day
    fig.update_layout(
        title='',
        xaxis=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis2=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis3=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis7=dict(tickformat='%b %d %H:00', showticklabels=True),

        yaxis=dict(fixedrange=True, title='Wh', minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis2=dict(fixedrange=True, title='€/MWh', minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis3=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis7=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),

        height=2000,
        width=1000,
        legend_tracegroupgap=20
    )

    # fig.update_layout(
    # legends=[
    #     dict(x=1.05, y=0.95, tracegroup="group1"),  # Legend for first subplot
    #     dict(x=1.05, y=0.60, tracegroup="group2"),  # Legend for second subplot
    # ]
    # )

    fig.update_layout(
        legend1=dict(
            xref="container",
            yref="container",
            y=0.6),
        legend2=dict(
            xref="container",
            yref="container",
            y=0.1)
    )


    display(fig)

def plot_details_reward_plotly(log, algs, time_step, reward_type='weig_reward', cumulative=True):

    n_points = len(log[algs[0]]['r_tot'])

    time = pd.date_range('2015-01-01', periods=n_points, freq=str(int(time_step))+'s')
    fig = FigureWidgetResampler(make_subplots(rows=5, cols=1, shared_xaxes=True, vertical_spacing = 0.05, subplot_titles=['Total reward', 'Trading reward', 'Degradation reward', 'Clipping reward', 'Actions']))

    rewards = {}
    for alg in algs:
        rewards[alg] = {}

    if cumulative:
        for alg in algs:
            rewards[alg]['r_tot'] = np.cumsum(log[alg]['r_tot'])
            rewards[alg]['r_trad'] = np.cumsum(log[alg][reward_type]['r_trad'])
            rewards[alg]['r_deg'] = np.cumsum(log[alg][reward_type]['r_deg'])
            rewards[alg]['r_clipping'] = np.cumsum(log[alg][reward_type]['r_clipping'])
    else:
        for alg in algs:
            rewards[alg]['r_tot'] = log[alg]['r_tot']
            rewards[alg]['r_trad'] = log[alg][reward_type]['r_trad']
            rewards[alg]['r_deg'] = log[alg][reward_type]['r_deg']
            rewards[alg]['r_clipping'] = log[alg][reward_type]['r_clipping']

    for alg in algs:
        fig.add_trace(go.Scattergl(x=time, y=rewards[alg]['r_tot'], line=dict(color=colors[alg]), mode='lines', legend='legend1', name=alg), row=1, col=1)
        fig.add_trace(go.Scattergl(x=time, y=rewards[alg]['r_trad'], line=dict(color=colors[alg]), mode='lines', legend='legend2', name=alg), row=2, col=1)
        fig.add_trace(go.Scattergl(x=time, y=rewards[alg]['r_deg'], line=dict(color=colors[alg]), mode='lines', legend='legend3', name=alg), row=3, col=1)
        fig.add_trace(go.Scattergl(x=time, y=rewards[alg]['r_clipping'], line=dict(color=colors[alg]), mode='lines', legend='legend4', name=alg), row=4, col=1)
        fig.add_trace(go.Scattergl(x=time, y=log[alg]['actions'], line=dict(color=colors[alg]), mode='lines', legend='legend5', name=alg), row=5, col=1)




    # Format x-axis to show only month & day
    fig.update_layout(
        title='',
        xaxis=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis2=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis3=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis4=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis5=dict(tickformat='%b %d %H:00', showticklabels=True),

        yaxis=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis2=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis3=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis4=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis5=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),

        height=2000,
        width=1000,
        legend_tracegroupgap=20
    )

    # fig.update_layout(
    # legends=[
    #     dict(x=1.05, y=0.95, tracegroup="group1"),  # Legend for first subplot
    #     dict(x=1.05, y=0.60, tracegroup="group2"),  # Legend for second subplot
    # ]
    # )

    fig.update_layout(
        legend1=dict(
            xref="container",
            yref="container",
            y=0.6),
        legend2=dict(
            xref="container",
            yref="container",
            y=0.1)
    )

    display(fig)

In [None]:
logs['ppo']['soc'].shape, logs['ppo']['r_tot'].shape

In [None]:
logs['ppo']['r_tot'].sum(), logs['battery_first']['r_tot'].sum()

In [None]:
# plot_external_data_matplotlib(logs['ppo']['demand'], logs['ppo']['generation'], logs['ppo']['sell_price'], logs['ppo']['buy_price'], start=1000, length_max=200)
algs_to_plot = ['battery_first', 'only_market', 'ppo']

plot_data_plotly(logs['ppo']['demand'], logs['ppo']['generation'], logs['ppo']['sell_price'], logs['ppo']['buy_price'], logs, algs_to_plot, env.params.env_step, reward_type='pure_reward')

In [None]:
# plot_ext_data_and_reward_plotly(logs['ppo']['demand'], logs['ppo']['generation'], logs['ppo']['sell_price'], logs['ppo']['buy_price'], logs, algs_to_plot, env.params.env_step)

In [None]:
# plot_details_reward_plotly(logs, algs_to_plot, env.params.env_step)

In [None]:
summary = {}

for alg in algs:
    summary[alg] = {'mean_soc': np.mean(logs[alg]['soc']),
                    'r_tot': np.sum(logs[alg]['r_tot']),
                    # 'norm_reward': {'r_trad': np.sum(logs[alg]['norm_reward']['r_trad']),
                    #                 # 'r_op': np.sum(logs[alg]['norm_reward']['r_op']),
                    #                 'r_deg': np.sum(logs[alg]['norm_reward']['r_deg']),
                    #                 'r_clipping': np.sum(logs[alg]['norm_reward']['r_clipping'])},
                    'weig_reward': {'r_trad': np.sum(logs[alg]['weig_reward']['r_trad']),
                                    # 'r_op': np.sum(logs[alg]['weig_reward']['r_op']),
                                    'r_deg': np.sum(logs[alg]['weig_reward']['r_deg']),
                                    'r_clipping': np.sum(logs[alg]['weig_reward']['r_clipping'])},
                    'pure_reward': {'r_trad': np.sum(logs[alg]['pure_reward']['r_trad']),
                                    # 'r_op': np.sum(logs[alg]['pure_reward']['r_op']),
                                    'r_deg': np.sum(logs[alg]['pure_reward']['r_deg']),
                                    'r_clipping': np.sum(logs[alg]['pure_reward']['r_clipping'])},
                    'mean_action': np.mean(logs[alg]['actions']),
                    'max_action': np.max(logs[alg]['actions']),
                    'min_action': np.min(logs[alg]['actions']),
                    'variance_action': np.var(logs[alg]['actions']),
                    'final_soh': logs[alg]['soh'][-1],}

flatten_summary = {}

for alg in algs:
    flatten_summary[alg] = {}
    for key in summary[alg].keys():
        if isinstance(summary[alg][key], dict):
            flatten_summary[alg].update([(key+'.'+subkey, summary[alg][key][subkey]) for subkey in summary[alg][key].keys()])
        else:
            flatten_summary[alg][key] = summary[alg][key]


In [None]:
df = pd.DataFrame.from_dict(flatten_summary, orient='index')
pd.set_option('display.max_columns', None)
df

In [None]:
plt.plot(logs['only_market']['soh'][:int(len(logs['ppo']['soh'])/14)])