In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%cd ../..

In [None]:
import matplotlib.pyplot as plt
from functools import partial
from algorithms.tqdm_custom import scan_tqdm

import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly
import seaborn as sns

from plotly_resampler import register_plotly_resampler, FigureWidgetResampler
# register_plotly_resampler(mode="auto", default_n_shown_samples=4500)

import pandas as pd

from algorithms.utils import restore_state, restore_state_multi_agent_with_double_battery_net

import numpy as np

import jax
import jax.numpy as jnp

import flax.nnx as nnx
from flax.core.frozen_dict import freeze
# jax.config.update("jax_enable_x64", True)


In [None]:
from ernestogym.envs.multi_agent.env import RECEnv
# from ernestogym.envs.single_agent.env_trading_soc import MicroGridEnvSocAction

In [None]:
def my_env_creator(params, battery_type, env_type='normal'):
    if env_type == 'normal':
        env = RECEnv(params, battery_type)
    return env

In [None]:
def print_heatmap(X, Y, Z, num_bins=10, x_label='', y_label='', title='', n_decimals_axes=0):
    z_grid, y_edges, x_edges = np.histogram2d(Y, X, bins=(num_bins, num_bins), weights=Z)
    counts, _, _ = np.histogram2d(Y, X, bins=(num_bins, num_bins))

    z_grid = np.divide(z_grid, counts, out=np.full_like(z_grid, fill_value=np.nan), where=counts > 0)

    x_centers = (x_edges[1:]+x_edges[:-1])/2
    y_centers = (y_edges[1:]+y_edges[:-1])/2

    x_ticks = [f'{v:.{n_decimals_axes}f}' for v in x_centers]
    y_ticks = [f'{v:.{n_decimals_axes}f}' for v in y_centers]

    plt.figure(figsize=(8, 6))
    sns.heatmap(z_grid, xticklabels=x_ticks, yticklabels=y_ticks, cmap='coolwarm', center=0).invert_yaxis()

    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)
    plt.show()


def print_heatmap(X, Y, Z, num_bins=10, x_label='', y_label='', title='', n_decimals_axes=0, ax=None, print_axes=False, annot=True, file_path=None):
    z_grid, y_edges, x_edges = np.histogram2d(Y, X, bins=(num_bins, num_bins), weights=Z)
    counts, _, _ = np.histogram2d(Y, X, bins=(num_bins, num_bins))

    z_grid = np.divide(z_grid, counts, out=np.full_like(z_grid, fill_value=np.nan), where=counts > 0)

    x_centers = (x_edges[1:]+x_edges[:-1])/2
    y_centers = (y_edges[1:]+y_edges[:-1])/2

    x_ticks = [f'{v:.{n_decimals_axes}f}' for v in x_centers]
    y_ticks = [f'{v:.{n_decimals_axes}f}' for v in y_centers]

    if ax is None:
        _, ax = plt.subplots(figsize=(8, 6))

    if annot:
        sns.heatmap(z_grid, annot=counts, annot_kws={"fontsize":8}, xticklabels=x_ticks, yticklabels=y_ticks, cmap='coolwarm', center=0, robust=True, ax=ax)
    else:
        sns.heatmap(z_grid, xticklabels=x_ticks, yticklabels=y_ticks, cmap='coolwarm', center=0, robust=True, ax=ax)
    ax.invert_yaxis()

    if print_axes:
        x_zero_index = np.searchsorted(x_centers, 0)
        y_zero_index = np.searchsorted(y_centers, 0)

        if 0 <= x_zero_index < len(x_edges):
            ax.axvline(x=x_zero_index, color='green', linewidth=2)
        if 0 <= y_zero_index < len(y_edges):
            ax.axhline(y=y_zero_index, color='green', linewidth=2)


    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)
    ax.set_title(title)

    if file_path is not None:
        plt.savefig(file_path, bbox_inches='tight')

def plot_hist_soc(soc, actions):
    hist, bins = np.histogram(soc, weights=actions, bins=16)
    counts, _ = np.histogram(soc, bins=16)

    hist = np.divide(hist, counts, where=counts > 0)

    bins_ctrs = (bins[:-1] + bins[1:])/2

    plt.bar(bins_ctrs, hist, width=0.045)

    plt.xlabel('SoC')
    plt.ylabel('Mean action')

In [None]:
from ernestogym.envs.multi_agent.utils import get_world_data

In [None]:
num_iter = 8760 * 5

# Testing

In [None]:
logs = {}
experiments = []
agents = {}

In [None]:
@partial(jax.jit, static_argnums=(0, 1, 2, 3, 4, 5, 7))
def test(env:RECEnv, networks_batteries, networks_batteries_only_local, network_rec, num_iter, config, rng, rec_rule_based_policy):

    networks_batteries.eval()
    networks_batteries_only_local.eval()
    if not config['USE_REC_RULE_BASED_POLICY']:
        network_rec.eval()

    rng, _rng = jax.random.split(rng)

    obsv, env_state = env.reset(_rng, profile_index=0)

    if config['NETWORK_TYPE_BATTERIES'] == 'recurrent_actor_critic' and config['NUM_RL_AGENTS'] > 0:
        init_act_state_batteries, init_cri_state_batteries = networks_batteries.get_initial_lstm_state()
        act_state_batteries, cri_state_batteries = init_act_state_batteries, init_cri_state_batteries
    else:
        act_state_batteries, cri_state_batteries = None, None

    if not config['USE_REC_RULE_BASED_POLICY'] and config['NETWORK_TYPE_REC'] == 'recurrent_actor_critic':
        init_act_state_rec, init_cri_state_rec = network_rec.get_initial_lstm_state()
        act_state_rec, cri_state_rec = init_act_state_rec, init_cri_state_rec
    else:
        act_state_rec, cri_state_rec = None, None

    @scan_tqdm(0, 0, 1, num_iter, print_rate=num_iter // 100)
    def _env_step(runner_state, unused):
        obsv_batteries, env_state, act_state_batteries, act_state_rec, rng, next_profile_index = runner_state

        # print('aaaaa', obsv_batteries[:config['NUM_RL_AGENTS']].shape)

        actions_batteries = []
        actions_diff_only_local = []

        if config['NUM_RL_AGENTS'] > 0:

            if config['NETWORK_TYPE_BATTERIES'] == 'recurrent_actor_critic':
                pi, _, act_state_batteries, _ = networks_batteries(obsv_batteries[:config['NUM_RL_AGENTS']], act_state_batteries, cri_state_batteries)
            else:
                pi, _ = networks_batteries(obsv_batteries[:config['NUM_RL_AGENTS']])

            #deterministic action
            actions_batteries_rl = pi.mode()

            # print('act 1', actions_batteries_rl.shape)
            actions_batteries_rl = actions_batteries_rl.squeeze(axis=-1)
            actions_batteries.append(actions_batteries_rl)

            obsv_batteries_only_local = jnp.stack([obsv_batteries[:config['NUM_RL_AGENTS'], env._obs_battery_agents_idx[obs]] for obs in config['OBS_NET_ONLY_LOCAL']], axis=-1)
            pi_only_local, _ = networks_batteries_only_local(obsv_batteries_only_local)
            actions_diff_only_local_rl = actions_batteries_rl - pi_only_local.mode().squeeze(axis=-1)
            actions_diff_only_local.append(actions_diff_only_local_rl)


        if config['NUM_BATTERY_FIRST_AGENTS'] > 0:
            idx_start_bf = config['NUM_RL_AGENTS']
            idx_end_bf = config['NUM_RL_AGENTS'] + config['NUM_BATTERY_FIRST_AGENTS']

            demand = obsv_batteries[idx_start_bf:idx_end_bf, env._obs_battery_agents_idx['demand']]
            generation = obsv_batteries[idx_start_bf:idx_end_bf, env._obs_battery_agents_idx['generation']]

            actions_batteries_battery_first = (generation - demand) / env_state.battery_states.electrical_state.v[idx_start_bf:idx_end_bf]

            actions_batteries.append(actions_batteries_battery_first)
            actions_diff_only_local.append(jnp.zeros((config['NUM_BATTERY_FIRST_AGENTS'],)))

        if config['NUM_ONLY_MARKET_AGENTS'] > 0:
            actions_batteries_only_market = jnp.zeros((config['NUM_ONLY_MARKET_AGENTS'],))
            actions_batteries.append(actions_batteries_only_market)
            actions_diff_only_local.append(jnp.zeros((config['NUM_ONLY_MARKET_AGENTS'],)))

        if config['NUM_RANDOM_AGENTS'] > 0:
            rng, _rng = jax.random.split(rng)

            actions_batteries_random = jax.random.uniform(_rng,
                                                          shape=(config['NUM_RANDOM_AGENTS'],),
                                                          minval=-1.,
                                                          maxval=1.)
            actions_batteries_random *= config['MAX_ACTION_RANDOM_AGENTS']

            rng, _rng = jax.random.split(rng)

            actions_batteries_random2 = jax.random.uniform(_rng,
                                                          shape=(config['NUM_RANDOM_AGENTS'],),
                                                          minval=-1.,
                                                          maxval=1.)

            actions_batteries_random2 *= config['MAX_ACTION_RANDOM_AGENTS']

            actions_diff_only_local_random = actions_batteries_random - actions_batteries_random2

            actions_batteries.append(actions_batteries_random)
            actions_diff_only_local.append(actions_diff_only_local_random)

        actions_batteries = jnp.concat(actions_batteries, axis=0)
        actions_diff_only_local = jnp.concat(actions_diff_only_local, axis=0)

        jax.debug.print('{x}', x=actions_diff_only_local)


        actions_first = {env.battery_agents[i]: actions_batteries[i] for i in range(env.num_battery_agents)}
        actions_first[env.rec_agent] = jnp.zeros(env.num_battery_agents)

        rng, _rng = jax.random.split(rng)
        obsv, env_state, reward_first, done_first, info_first = env.step(
            _rng, env_state, actions_first
        )

        rec_obsv = obsv[env.rec_agent]
        rec_obsv['action_diff_only_local'] = actions_diff_only_local

        if not config['USE_REC_RULE_BASED_POLICY']:
            if config['NETWORK_TYPE_REC'] == 'recurrent_actor_critic':
                pi, _, act_state_rec, _ = network_rec(rec_obsv, act_state_rec, cri_state_rec)
            else:
                pi, _ = network_rec(rec_obsv)
            actions_rec = pi.mean()
        else:
            actions_rec = rec_rule_based_policy(rec_obsv)

        actions_second = {agent: jnp.array(0.) for agent in env.battery_agents}
        actions_second[env.rec_agent] = actions_rec

        rng, _rng = jax.random.split(rng)
        obsv, env_state, reward_second, done_second, info_second = env.step(
            _rng, env_state, actions_second
        )

        done = jnp.logical_or(done_first['__all__'], done_second['__all__'])

        info = jax.tree.map(lambda  x, y: x + y, info_first, info_second)

        info['actions_batteries'] = actions_batteries
        info['actions_rec'] = actions_rec
        info['dones'] = jax.tree.map(lambda x, y : jnp.logical_or(x, y), done_first, done_second)
        info['actions_diff_only_local'] = actions_diff_only_local

        rng, _rng = jax.random.split(rng)
        obsv, env_state,next_profile_index = jax.lax.cond(done,
                                                          lambda : env.reset(_rng, profile_index=next_profile_index) + (next_profile_index+1,),
                                                          lambda : (obsv, env_state, next_profile_index))

        # jax.lax.cond(done, lambda: jax.debug.print('i {x}, {dem}, {gen}, {spr}, {bpr}, {rew}, {pr}, {nr}, {wr}\n{soh}\n',
        #                                            x=unused, dem=info['demands'], gen=info['generations'],
        #                                            spr=info['sell_prices'], bpr=info['buy_prices'], rew=info['r_tot'],
        #                                            pr=info['pure_reward'], nr=info['norm_reward'], wr=info['weig_reward'],
        #                                            soh=info['soh'], ordered=True), lambda : None)

        obs_batteries = jnp.vstack([obsv[a] for a in env.battery_agents])

        runner_state = (obs_batteries, env_state, act_state_batteries, act_state_rec, rng, next_profile_index)
        return runner_state, info

    obsv_batteries = jnp.vstack([obsv[a] for a in env.battery_agents])

    runner_state = (obsv_batteries, env_state, act_state_batteries, act_state_rec, rng, 1)

    runner_state, info = jax.lax.scan(_env_step, runner_state, jnp.arange(num_iter))

    return info

In [None]:
def rec_scarce_resource_policy(rec_obs):
    net_exchange = rec_obs['generations_battery_houses'] - rec_obs['demands_base_battery_houses'] - rec_obs['demands_battery_battery_houses']

    net_exchange_plus = jnp.maximum(net_exchange, 0)
    net_exchange_minus = -jnp.minimum(net_exchange, 0)

    action_plus = net_exchange_plus / (net_exchange_plus.sum(axis=-1, keepdims=True) + 1e-8)
    action_minus = net_exchange_minus / (net_exchange_minus.sum(axis=-1, keepdims=True) + 1e-8)

    actions = (rec_obs['network_REC_plus'] > rec_obs['network_REC_minus'])[..., None] * action_minus + (rec_obs['network_REC_plus'] <= rec_obs['network_REC_minus'])[..., None] * action_plus

    actions = jnp.where(actions.sum(axis=-1, keepdims=True) == 0., jnp.ones_like(actions)/actions.shape[-1], actions)

    # jax.debug.print('{x}', x=actions)

    return actions

def rec_both_resources_policy(rec_obs):
    net_exchange = rec_obs['generations_battery_houses'] - rec_obs['demands_base_battery_houses'] - rec_obs['demands_battery_battery_houses']

    net_exchange_plus = jnp.maximum(net_exchange, 0)
    net_exchange_minus = -jnp.minimum(net_exchange, 0)

    frac_plus = net_exchange_plus / (net_exchange_plus.sum(axis=-1, keepdims=True) + 1e-8)
    frac_minus = net_exchange_minus / (net_exchange_minus.sum(axis=-1, keepdims=True) + 1e-8)

    actions = (frac_plus + frac_minus) / 2

    actions = jnp.where(actions.sum(axis=-1, keepdims=True) == 0., jnp.ones_like(actions) / actions.shape[-1], actions)

    return actions

In [None]:
directory = '/media/samuele/Disco/PycharmProjectsUbuntu/MARL-CER/trained_agents/server_06-07/20250407_061258_bat_net_type_actor_critic_rec_net_type_actor_critic_concat_lr_bat_3e-05_lr_REC_cosine_tot_timesteps_8970240_lr_sched_cosine_multiagent/checkpoints/20250407_082412_22'

curr_exp = 'exp1'
experiments.append(curr_exp)

networks_batteries, networks_batteries_only_local, network_rec, config, world_metadata, train_info, val_info = restore_state_multi_agent_with_double_battery_net(directory)

print(world_metadata)

test_params = get_world_data(world_metadata, get_test=True)
battery_type = 'degrading_dropflow'

curr_agents = []
curr_agents += ['recurrent_ppo_agent_' if config['NETWORK_TYPE_BATTERIES'] == 'recurrent_actor_critic' else 'ppo_agent_' + str(i) for i in range(config['NUM_RL_AGENTS'])]
curr_agents += ['battery_first_agent_' + str(i) for i in range(config['NUM_BATTERY_FIRST_AGENTS'])]
curr_agents += ['only_market_agent_' + str(i) for i in range(config['NUM_ONLY_MARKET_AGENTS'])]
curr_agents += ['random_agent_' + str(i) for i in range(config['NUM_RANDOM_AGENTS'])]

agents[curr_exp] = curr_agents

In [None]:
network_rec

In [None]:
networks_batteries.log_std

In [None]:
config

In [None]:
world_metadata

In [None]:
env = my_env_creator(test_params, battery_type, config['ENV_TYPE'])

In [None]:
config['OBS_NET_ONLY_LOCAL'] = tuple([obs for i, obs in enumerate(env._obs_battery_agents_keys) if env.obs_is_local_battery[i]])

In [None]:
from time import time
t0 = time()
info = test(env, networks_batteries, networks_batteries_only_local, network_rec, num_iter, freeze(config), jax.random.PRNGKey(51), rec_rule_based_policy=rec_scarce_resource_policy)
print(time() - t0)
logs[curr_exp] = info

In [None]:
jnp.abs(info['actions_diff_only_local']).mean(axis=0)

In [None]:
info['demands'].shape

In [None]:
reward_type = 'pure_reward'

In [None]:
print_heatmap(info['demands'][:, :config['NUM_RL_AGENTS']].flatten(), info['generations'][:, :config['NUM_RL_AGENTS']].flatten(), info['actions_batteries'][:, :config['NUM_RL_AGENTS']].flatten(), x_label='Demand', y_label='Generation', title='Actions batteries', file_path=directory+'/actions_rl.png')

In [None]:
info[reward_type]['r_glob'].mean(), info[reward_type]['r_glob'].max(), info[reward_type]['r_glob'].min()

In [None]:
net_exchange = info['generations'] - info['demands'] - info['energy_to_batteries']
balance = info['balance_plus'] - info['balance_minus']
centered_r_glob = info[reward_type]['r_glob'] - info[reward_type]['r_glob'].mean(axis=1, keepdims=True)

print_heatmap(net_exchange[:, :config['NUM_RL_AGENTS']].T.flatten(), np.tile(balance, reps=config['NUM_RL_AGENTS']), centered_r_glob[:, :config['NUM_RL_AGENTS']].T.flatten(), x_label='Network exchange', y_label='Balance', title='Centered global reward', print_axes=True, file_path=directory+'/centered_global_reward_rl.png') #, num_bins=25, annot=False)

In [None]:
_, axs = plt.subplots(1, config['NUM_RL_AGENTS'], figsize=(15, 5))

if config['NUM_RL_AGENTS'] > 1:
    axs = axs.flatten()
    for i in range(config['NUM_RL_AGENTS']):
        print_heatmap(net_exchange[:, i], balance, centered_r_glob[:, i], ax=axs[i], x_label='Network exchange', y_label='Balance', title=f'Centered global reward agent {i}', print_axes=True)

In [None]:
config['NUM_RL_AGENTS'], config['NUM_BATTERY_AGENTS']

In [None]:
print_heatmap(net_exchange[:, config['NUM_RL_AGENTS']:].T.flatten(), np.tile(balance, reps=config['NUM_BATTERY_AGENTS']-config['NUM_RL_AGENTS']), centered_r_glob[:, config['NUM_RL_AGENTS']:].T.flatten(), x_label='Network exchange', y_label='Balance', title='Centered global reward', print_axes=True, file_path=directory+'/centered_global_reward_rule_based.png') #, num_bins=25)

In [None]:
net_exchange.shape

In [None]:
centered_net_exchange = net_exchange - net_exchange.mean(axis=1, keepdims=True)


print_heatmap(centered_net_exchange[:, :config['NUM_RL_AGENTS']].T.flatten(), np.tile(balance, reps=config['NUM_RL_AGENTS']), centered_r_glob[:, :config['NUM_RL_AGENTS']].T.flatten(), x_label='Centered network exchange', y_label='Balance', title='Centered global reward', print_axes=True, file_path=directory+'/centered_global_reward_with_centered_network_rl.png') #, num_bins=25, annot=False)

In [None]:
_, axs = plt.subplots(1, config['NUM_RL_AGENTS'], figsize=(15, 5))

if config['NUM_RL_AGENTS'] > 1:
    axs = axs.flatten()
    for i in range(config['NUM_RL_AGENTS']):
        print_heatmap(centered_net_exchange[:, i], balance, centered_r_glob[:, i], ax=axs[i], x_label='Network exchange', y_label='Balance', title=f'Centered global reward agent {i}', print_axes=True)

In [None]:
print_heatmap(centered_net_exchange[:, config['NUM_RL_AGENTS']:].T.flatten(), np.tile(balance, reps=config['NUM_BATTERY_AGENTS']-config['NUM_RL_AGENTS']), centered_r_glob[:, config['NUM_RL_AGENTS']:].T.flatten(), x_label='Network exchange', y_label='Balance', title='Centered global reward', print_axes=True, file_path=directory+'/centered_global_reward_with_centered_network_rule_based.png') #, num_bins=25)

In [None]:
# print_heatmap(net_exchange[:, :config['NUM_RL_AGENTS']].T.flatten(), np.tile(info['balance_plus'] - info['balance_minus'], reps=config['NUM_RL_AGENTS']), info[reward_type]['r_glob'][:, :config['NUM_RL_AGENTS']].T.flatten(), x_label='Network exchange', y_label='Balance', title='Actions REC')

In [None]:
# print_heatmap(net_exchange[:, config['NUM_RL_AGENTS']:].T.flatten(), np.tile(info['balance_plus'] - info['balance_minus'], reps=config['NUM_BATTERY_AGENTS']-config['NUM_RL_AGENTS']), info['actions_rec'][:, config['NUM_RL_AGENTS']:].T.flatten(), x_label='Network exchange', y_label='Balance', title='Actions REC')

In [None]:
plot_hist_soc(info['soc'].flatten(), info['actions_batteries'].flatten())

# Plotting

In [None]:
logs = jax.tree.map(lambda x : np.asarray(x), logs)


In [None]:
exp = ['exp1']

print('Tot self consumption')
for e in exp:
    print(f'\t{e}: {np.sum(logs[e]['self_consumption'])}')

print('Tot reward')
for e in exp:
    print(f'\t{e}: {np.sum(logs[e]['r_tot'])}')

print('Tot glob reward')
for e in exp:
    print(f'\t{e}: {np.sum(logs[e][reward_type]['r_glob'])}')

print('Tot trad reward')
for e in exp:
    print(f'\t{e}: {np.sum(logs[e][reward_type]['r_trad'])}')

print('Tot deg reward')
for e in exp:
    print(f'\t{e}: {np.sum(logs[e][reward_type]['r_deg'])}')

print('Tot clipping reward')
for e in exp:
    print(f'\t{e}: {np.sum(logs[e][reward_type]['r_clipping'])}')

print('Tot rec reward')
for e in exp:
    print(f'\t{e}: {np.sum(logs[e]['rec_reward'])}')

print('Ratios global reward')
for e in exp:
    glob_rew = logs[e][reward_type]['r_glob'].sum(axis=0)
    print(f'\t{e}: {glob_rew/glob_rew.sum()}')

print('\n\n')

print('Mean rec actions')
for e in exp:
    print(f'\t{e}: {np.mean(logs[e]['actions_rec'], axis=0)}')
print('Mean of variance rec actions')
for e in exp:
    print(f'\t{e}: {np.mean(np.var(logs[e]['actions_rec'], axis=1))}')
print('Mean of difference max-min rec actions')
for e in exp:
    print(f'\t{e}: {np.mean(np.max(logs[e]['actions_rec'], axis=1) - np.min(logs[e]['actions_rec'], axis=1))}')

print('Pearson correlation coeff')
for e in exp:
    v1 = logs[e]['actions_rec']
    v2 = (logs[e]['generations']-logs[e]['demands']-logs[e]['energy_to_batteries']) * (logs[e]['balance_plus'] - logs[e]['balance_minus'])[:, None]
    coeff = np.mean((v1 - v1.mean(axis=0, keepdims=True)) * (v2 - v2.mean(axis=0, keepdims=True)), axis=0)
    norm_coeff = coeff/np.sqrt(np.var(v1, axis=0)*np.var(v2, axis=0))
    print(f'\t{e}: {coeff}\t\t{norm_coeff}')


In [None]:
import statsmodels.api as sm

e = 'exp1'

a = logs[e]['actions_rec'][:, :config['NUM_RL_AGENTS']].T.flatten()

b = (logs[e]['generations'][:, :config['NUM_RL_AGENTS']]-logs[e]['demands'][:, :config['NUM_RL_AGENTS']]-logs[e]['energy_to_batteries'][:, :config['NUM_RL_AGENTS']]).T.flatten()
c = np.tile(logs[e]['balance_plus'] - logs[e]['balance_minus'], reps=config['NUM_RL_AGENTS'])

print(np.corrcoef(b, c)[0, 1])

X = np.column_stack((b, c, b * c))  # Independent variables
X = sm.add_constant(X)  # Add intercept
model = sm.OLS(a, X).fit()

print(model.summary())

In [None]:
from statsmodels.stats.outliers_influence import variance_inflation_factor

X = np.column_stack((b, c, b * c))
vif = [variance_inflation_factor(X, i) for i in range(X.shape[1])]
print(vif)

In [None]:
v1 = logs['exp1']['actions_rec'][:, 0]
v2 = (logs['exp1']['generations'][:, 0]-logs['exp1']['demands'][:, 0]-logs['exp1']['energy_to_batteries'][:, 0]) * (logs['exp1']['balance_plus'] - logs['exp1']['balance_minus'])

print(np.corrcoef(v1, v2))

print(v1.shape)
print(v2.shape)


np.mean((v1-v1.mean()) * (v2-v2.mean())) / np.sqrt(np.var(v1)*np.var(v2))

In [None]:
# exp = ['random', 'only_market', 'battery_first', 'ppo', 'recurrent_ppo']        #, 'recurrent_ppo']
# colors = {e: col for e, col in zip(es, plotly.colors.sample_colorscale('rainbow', len(es)))}
colors = plotly.colors.sample_colorscale('rainbow', env.num_battery_agents)
line_styles = {'exp1': 'solid'}

def plot_data_plotly(demands, generations, sell_prices, buy_prices, log, exp, full_exp, agents, time_step, reward_type='weig_reward', cumulative=True):

    n_points = max(demands.shape[0], generations.shape[0], sell_prices.shape[0], buy_prices.shape[0])
    # print(n_points)
    # print(demands.shape)

    time = pd.date_range('2015-01-01', periods=n_points, freq=str(int(time_step))+'s')
    # fig = FigureWidgetResampler(make_subplots(rows=11, cols=1, shared_xaxes=True, vertical_spacing = 0.02, subplot_titles=['Power demand and generation', 'SoC', 'Trad plus deg plus glob reward', 'Trading reward', 'Degradation reward', 'Global reward', 'Actions', 'REC actions', 'REC reward','REC balances', 'Network exchange']), default_n_shown_samples=5000)

    fig = make_subplots(rows=11, cols=1, shared_xaxes=True, vertical_spacing = 0.02, subplot_titles=['Power demand and generation', 'SoC', 'Trad plus deg plus glob reward', 'Trading reward', 'Degradation reward', 'Global reward', 'Actions', 'REC actions', 'REC reward','REC balances', 'Network exchange'])
    # fig = make_subplots(rows=11, cols=1, shared_xaxes=True, vertical_spacing = 0.05, subplot_titles=['Power demand and generation', 'Market prices', 'SoC', 'Total reward', 'Trading reward', 'Degradation reward', 'Clipping reward', 'Global reward', 'Actions', 'REC reward', 'REC balances'])

    # print(time.shape)
    # print(demands[: 0].shape)

    for i in range(demands.shape[1]):
        fig.add_trace(go.Scatter(x=time, y=demands[:, i], mode='lines', legend='legend1', name=f'demand {agents[exp[0]][i]}', line=dict(color=colors[i], dash='solid')), row=1, col=1)
        fig.add_trace(go.Scatter(x=time, y=generations[:, i], mode='lines', legend='legend1', name=f'generation  {agents[exp[0]][i]}', line=dict(color=colors[i], dash='dot')), row=1, col=1)

    rewards = {}
    for e in exp:
        rewards[e] = {}

    if cumulative:
        for e in exp:
            rewards[e]['r_tot'] = np.cumsum(log[e]['r_tot'], axis=0)
            rewards[e]['r_trad'] = np.cumsum(log[e][reward_type]['r_trad'], axis=0)
            rewards[e]['r_deg'] = np.cumsum(log[e][reward_type]['r_deg'], axis=0)
            rewards[e]['r_clipping'] = np.cumsum(log[e][reward_type]['r_clipping'], axis=0)
            rewards[e]['r_glob'] = np.cumsum(log[e][reward_type]['r_glob'], axis=0)
            rewards[e]['r_rec'] = np.cumsum(log[e]['rec_reward'])
    else:
        for e in exp:
            rewards[e]['r_tot'] = log[e]['r_tot']
            rewards[e]['r_trad'] = log[e][reward_type]['r_trad']
            rewards[e]['r_deg'] = log[e][reward_type]['r_deg']
            rewards[e]['r_clipping'] = log[e][reward_type]['r_clipping']
            rewards[e]['r_glob'] = log[e][reward_type]['r_glob']
            rewards[e]['r_rec'] = log[e]['rec_reward']

    for e in full_exp:
        for i in range(demands.shape[1]):
            fig.add_trace(go.Scatter(x=time, y=log[e]['soc'][:, i], line=dict(color=colors[i], dash=line_styles[e]), mode='lines', legend='legend2', name=f'{e} {agents[e][i]}'), row=2, col=1)
            fig.add_trace(go.Scatter(x=time, y=rewards[e]['r_trad'][:, i]+rewards[e]['r_deg'][:, i]+rewards[e]['r_glob'][:, i], line=dict(color=colors[i], dash=line_styles[e]), mode='lines', legend='legend3', name=f'{e} {agents[e][i]}'), row=3, col=1)
            fig.add_trace(go.Scatter(x=time, y=rewards[e]['r_trad'][:, i], line=dict(color=colors[i], dash=line_styles[e]), mode='lines', legend='legend4', name=f'{e} {agents[e][i]}'), row=4, col=1)
            fig.add_trace(go.Scatter(x=time, y=rewards[e]['r_deg'][:, i], line=dict(color=colors[i], dash=line_styles[e]), mode='lines', legend='legend5', name=f'{e} {agents[e][i]}'), row=5, col=1)
            fig.add_trace(go.Scatter(x=time, y=rewards[e]['r_glob'][:, i], line=dict(color=colors[i], dash=line_styles[e]), mode='lines', legend='legend6', name=f'{e} {agents[e][i]}'), row=6, col=1)
            fig.add_trace(go.Scatter(x=time, y=log[e]['actions_batteries'][:, i], line=dict(color=colors[i], dash=line_styles[e]), mode='lines', legend='legend7', name=f'{e} {agents[e][i]}'), row=7, col=1)
            fig.add_trace(go.Scatter(x=time, y=log[e]['generations'][:, i]-log[e]['demands'][:, i]-log[e]['energy_to_batteries'][:, i], line=dict(color=colors[i], dash=line_styles[e]), mode='lines', legend='legend11', name=f'{e} {agents[e][i]}'), row=11, col=1)

    for e in [e for e in exp if e not in full_exp]:
        fig.add_trace(go.Scatter(x=time, y=log[e]['soc'].mean(axis=1), line=dict(color='grey', dash=line_styles[e]), mode='lines', legend='legend2', name=e), row=2, col=1)
        fig.add_trace(go.Scatter(x=time, y=rewards[e]['r_trad'].mean(axis=1) + rewards[e]['r_deg'].mean(axis=1) + rewards[e]['r_glob'].mean(axis=1), line=dict(color='grey', dash=line_styles[e]), mode='lines', legend='legend3', name=e), row=3, col=1)
        fig.add_trace(go.Scatter(x=time, y=rewards[e]['r_trad'].mean(axis=1), line=dict(color='grey', dash=line_styles[e]), mode='lines', legend='legend4', name=e), row=4, col=1)
        fig.add_trace(go.Scatter(x=time, y=rewards[e]['r_deg'].mean(axis=1), line=dict(color='grey', dash=line_styles[e]), mode='lines', legend='legend5', name=e), row=5, col=1)
        fig.add_trace(go.Scatter(x=time, y=rewards[e]['r_glob'].mean(axis=1), line=dict(color='grey', dash=line_styles[e]), mode='lines', legend='legend6', name=e), row=6, col=1)
        fig.add_trace(go.Scatter(x=time, y=log[e]['actions_batteries'].mean(axis=1), line=dict(color='grey', dash=line_styles[e]), mode='lines', legend='legend7', name=e), row=7, col=1)

    print(log[e]['actions_rec'][0].shape, time.shape)

    for e in exp:
        for i in range(demands.shape[1]):
            fig.add_trace(go.Scatter(x=time, y=log[e]['actions_rec'][:, i], line=dict(color=colors[i], dash=line_styles[e]), mode='lines', legend='legend8', name=e), row=8, col=1)
        fig.add_trace(go.Scatter(x=time, y=rewards[e]['r_rec'], line=dict(color='grey', dash=line_styles[e]), mode='lines', legend='legend9', name=e), row=9, col=1)
        fig.add_trace(go.Scatter(x=time, y=log[e]['balance_plus'], line=dict(color='green', dash=line_styles[e]), mode='lines', legend='legend10', name=e), row=10, col=1)
        fig.add_trace(go.Scatter(x=time, y=log[e]['balance_minus'], line=dict(color='red', dash=line_styles[e]), mode='lines', legend='legend10', name=e), row=10, col=1)



    fig.update_layout(
        title='',
        xaxis=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis2=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis3=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis4=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis5=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis6=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis7=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis8=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis9=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis10=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis11=dict(tickformat='%b %d %H:00', showticklabels=True),

        yaxis=dict(fixedrange=True, title='Wh', minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis2=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis3=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis4=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis5=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis6=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis7=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis8=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis9=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis10=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis11=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),

        height=3000,
        width=1200,
        legend_tracegroupgap=20
    )

    # fig.update_layout(
    # legends=[
    #     dict(x=1.05, y=0.95, tracegroup="group1"),  # Legend for first subplot
    #     dict(x=1.05, y=0.60, tracegroup="group2"),  # Legend for second subplot
    # ]
    # )

    fig.update_layout(
        legend1=dict(
            xref="container",
            yref="container",
            y=0.6),
        legend2=dict(
            xref="container",
            yref="container",
            y=0.1)
    )


    fig.write_html(directory + '/plots.html')
    fig = FigureWidgetResampler(fig, default_n_shown_samples=5000)
    print(directory + '/plots.html')


    display(fig)

In [None]:
# plot_external_data_matplotlib(logs['ppo']['demand'], logs['ppo']['generation'], logs['ppo']['sell_price'], logs['ppo']['buy_price'], start=1000, length_max=200)
# algs_to_plot = ['only_market', 'battery_first', 'ppo', 'recurrent_ppo']
exp_to_plot = ['exp1'] #['battery_first', 'only_market', 'ppo', 'recurrent_ppo']#, 'battery_first']
# algs_to_plot = ['ppo']      # ['only_market', 'ppo', 'recurrent_ppo']
full_exp = ['exp1'] #['ppo']#, 'recurrent_ppo'] #['battery_first', 'only_market', 'recurrent_ppo']#, 'battery_first'] #['ppo']      #['only_market', 'ppo', 'recurrent_ppo']

dem_agent = exp_to_plot[0] # 'recurrent_ppo'

plot_data_plotly(logs[dem_agent]['demands'], logs[dem_agent]['generations'], logs[dem_agent]['sell_prices'], logs[dem_agent]['buy_prices'], logs, exp_to_plot, full_exp, agents, env.env_step, reward_type='weig_reward', cumulative=True)

In [None]:
# exp = ['random', 'only_market', 'battery_first', 'ppo', 'recurrent_ppo']        #, 'recurrent_ppo']
# colors = {e: col for e, col in zip(es, plotly.colors.sample_colorscale('rainbow', len(es)))}
colors = plotly.colors.sample_colorscale('rainbow', env.num_battery_agents)
line_styles = {'exp1': 'dot',
               'only_market': 'dash',
               'battery_first': 'dashdot',
               'ppo': 'solid',
               'recurrent_ppo': 'longdash'}

def plot_data_plotly(demands, generations, sell_prices, buy_prices, log, exp, full_exp, agents, time_step, reward_type='weig_reward', cumulative=True):

    n_points = max(demands.shape[0], generations.shape[0], sell_prices.shape[0], buy_prices.shape[0])
    # print(n_points)
    # print(demands.shape)

    time = pd.date_range('2015-01-01', periods=n_points, freq=str(int(time_step))+'s')
    fig = FigureWidgetResampler(make_subplots(rows=13, cols=1, shared_xaxes=True, vertical_spacing = 0.02, subplot_titles=['Power demand and generation', 'Market prices', 'SoC', 'Trad plus deg plus glob reward', 'Trading reward', 'Degradation reward', 'Clipping reward', 'Global reward', 'Actions', 'REC actions', 'REC reward','REC balances', 'Network exchange']), default_n_shown_samples=5000)
    # fig = make_subplots(rows=11, cols=1, shared_xaxes=True, vertical_spacing = 0.05, subplot_titles=['Power demand and generation', 'Market prices', 'SoC', 'Total reward', 'Trading reward', 'Degradation reward', 'Clipping reward', 'Global reward', 'Actions', 'REC reward', 'REC balances'])

    # print(time.shape)
    # print(demands[: 0].shape)

    for i in range(demands.shape[1]):
        fig.add_trace(go.Scatter(x=time, y=demands[:, i], mode='lines', legend='legend1', name=f'demand {agents[exp[0]][i]}', line=dict(color=colors[i], dash='solid')), row=1, col=1)
        fig.add_trace(go.Scatter(x=time, y=generations[:, i], mode='lines', legend='legend1', name=f'generation  {agents[exp[0]][i]}', line=dict(color=colors[i], dash='dot')), row=1, col=1)


    fig.add_trace(go.Scatter(x=time, y=sell_prices[:, 0]*1000000, mode='lines', legend='legend2', name='Selling prices'), row=2, col=1)
    fig.add_trace(go.Scatter(x=time, y=buy_prices[:, 0]*1000000, mode='lines', legend='legend2', name='Buying prices'), row=2, col=1)

    rewards = {}
    for e in exp:
        rewards[e] = {}

    if cumulative:
        for e in exp:
            rewards[e]['r_tot'] = np.cumsum(log[e]['r_tot'], axis=0)
            rewards[e]['r_trad'] = np.cumsum(log[e][reward_type]['r_trad'], axis=0)
            rewards[e]['r_deg'] = np.cumsum(log[e][reward_type]['r_deg'], axis=0)
            rewards[e]['r_clipping'] = np.cumsum(log[e][reward_type]['r_clipping'], axis=0)
            rewards[e]['r_glob'] = np.cumsum(log[e]['weig_reward']['r_glob'], axis=0)
            rewards[e]['r_rec'] = np.cumsum(log[e]['rec_reward'])
    else:
        for e in exp:
            rewards[e]['r_tot'] = log[e]['r_tot']
            rewards[e]['r_trad'] = log[e][reward_type]['r_trad']
            rewards[e]['r_deg'] = log[e][reward_type]['r_deg']
            rewards[e]['r_clipping'] = log[e][reward_type]['r_clipping']
            rewards[e]['r_glob'] = log[e]['weig_reward']['r_glob']
            rewards[e]['r_rec'] = log[e]['rec_reward']

    for e in full_exp:
        for i in range(demands.shape[1]):
            fig.add_trace(go.Scatter(x=time, y=log[e]['soc'][:, i], line=dict(color=colors[i], dash=line_styles[e]), mode='lines', legend='legend3', name=f'{e} {agents[e][i]}'), row=3, col=1)
            fig.add_trace(go.Scatter(x=time, y=rewards[e]['r_trad'][:, i]+rewards[e]['r_deg'][:, i]+rewards[e]['r_glob'][:, i], line=dict(color=colors[i], dash=line_styles[e]), mode='lines', legend='legend4', name=f'{e} {agents[e][i]}'), row=4, col=1)
            fig.add_trace(go.Scatter(x=time, y=rewards[e]['r_trad'][:, i], line=dict(color=colors[i], dash=line_styles[e]), mode='lines', legend='legend5', name=f'{e} {agents[e][i]}'), row=5, col=1)
            fig.add_trace(go.Scatter(x=time, y=rewards[e]['r_deg'][:, i], line=dict(color=colors[i], dash=line_styles[e]), mode='lines', legend='legend6', name=f'{e} {agents[e][i]}'), row=6, col=1)
            fig.add_trace(go.Scatter(x=time, y=rewards[e]['r_clipping'][:, i], line=dict(color=colors[i], dash=line_styles[e]), mode='lines', legend='legend7', name=f'{e} {agents[e][i]}'), row=7, col=1)
            fig.add_trace(go.Scatter(x=time, y=rewards[e]['r_glob'][:, i], line=dict(color=colors[i], dash=line_styles[e]), mode='lines', legend='legend8', name=f'{e} {agents[e][i]}'), row=8, col=1)
            fig.add_trace(go.Scatter(x=time, y=log[e]['actions_batteries'][:, i], line=dict(color=colors[i], dash=line_styles[e]), mode='lines', legend='legend9', name=f'{e} {agents[e][i]}'), row=9, col=1)
            fig.add_trace(go.Scatter(x=time, y=log[e]['generations'][:, i]-log[e]['demands'][:, i]-log[e]['energy_to_batteries'][:, i], line=dict(color=colors[i], dash=line_styles[e]), mode='lines', legend='legend13', name=f'{e} {agents[e][i]}'), row=13, col=1)

    for e in [e for e in exp if e not in full_exp]:
        fig.add_trace(go.Scatter(x=time, y=log[e]['soc'].mean(axis=1), line=dict(color='grey', dash=line_styles[e]), mode='lines', legend='legend3', name=e), row=3, col=1)
        fig.add_trace(go.Scatter(x=time, y=rewards[e]['r_trad'].mean(axis=1) + rewards[e]['r_deg'].mean(axis=1) + rewards[e]['r_glob'].mean(axis=1), line=dict(color='grey', dash=line_styles[e]), mode='lines', legend='legend4', name=e), row=4, col=1)
        fig.add_trace(go.Scatter(x=time, y=rewards[e]['r_trad'].mean(axis=1), line=dict(color='grey', dash=line_styles[e]), mode='lines', legend='legend5', name=e), row=5, col=1)
        fig.add_trace(go.Scatter(x=time, y=rewards[e]['r_deg'].mean(axis=1), line=dict(color='grey', dash=line_styles[e]), mode='lines', legend='legend6', name=e), row=6, col=1)
        fig.add_trace(go.Scatter(x=time, y=rewards[e]['r_clipping'].mean(axis=1), line=dict(color='grey', dash=line_styles[e]), mode='lines', legend='legend7', name=e), row=7, col=1)
        fig.add_trace(go.Scatter(x=time, y=rewards[e]['r_glob'].mean(axis=1), line=dict(color='grey', dash=line_styles[e]), mode='lines', legend='legend8', name=e), row=8, col=1)
        fig.add_trace(go.Scatter(x=time, y=log[e]['actions_batteries'].mean(axis=1), line=dict(color='grey', dash=line_styles[e]), mode='lines', legend='legend9', name=e), row=9, col=1)

    print(log[e]['actions_rec'][0].shape, time.shape)

    for e in exp:
        for i in range(demands.shape[1]):
            fig.add_trace(go.Scatter(x=time, y=log[e]['actions_rec'][:, i], line=dict(color=colors[i], dash=line_styles[e]), mode='lines', legend='legend10', name=e), row=10, col=1)
        fig.add_trace(go.Scatter(x=time, y=rewards[e]['r_rec'], line=dict(color='grey', dash=line_styles[e]), mode='lines', legend='legend11', name=e), row=11, col=1)
        fig.add_trace(go.Scatter(x=time, y=log[e]['balance_plus'], line=dict(color='green', dash=line_styles[e]), mode='lines', legend='legend12', name=e), row=12, col=1)
        fig.add_trace(go.Scatter(x=time, y=log[e]['balance_minus'], line=dict(color='red', dash=line_styles[e]), mode='lines', legend='legend12', name=e), row=12, col=1)



    fig.update_layout(
        title='',
        xaxis=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis2=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis3=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis4=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis5=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis6=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis7=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis8=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis9=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis10=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis11=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis12=dict(tickformat='%b %d %H:00', showticklabels=True),
        xaxis13=dict(tickformat='%b %d %H:00', showticklabels=True),

        yaxis=dict(fixedrange=True, title='Wh', minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis2=dict(fixedrange=True, title='€/MWh', minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis3=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis4=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis5=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis6=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis7=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis8=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis9=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis10=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis11=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis12=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),
        yaxis13=dict(fixedrange=True, minor=dict(ticklen=6, tickcolor="black", showgrid=True)),

        height=3000,
        width=1200,
        legend_tracegroupgap=20
    )

    # fig.update_layout(
    # legends=[
    #     dict(x=1.05, y=0.95, tracegroup="group1"),  # Legend for first subplot
    #     dict(x=1.05, y=0.60, tracegroup="group2"),  # Legend for second subplot
    # ]
    # )

    fig.update_layout(
        legend1=dict(
            xref="container",
            yref="container",
            y=0.6),
        legend2=dict(
            xref="container",
            yref="container",
            y=0.1)
    )

    fig.write_html(directory + '/plots.html')
    print(directory + '/plots.html')


    display(fig)