In [None]:
import matplotlib.pyplot as plt

# Setup and helper code
import bauwerk
import bauwerk.eval
import bauwerk.benchmarks
import gym
import numpy as np

# The length of each task, i.e. how long we aim to 
TASK_LEN = 24*30 # evaluate on 1 month of actions

# Create SolarBatteryHouse environment
build_dist_b = bauwerk.benchmarks.BuildDistB(seed=0, episode_len=TASK_LEN)
test_env = build_dist_b.make_env()

battery_sizes = [1,5,15,25]
env_data = {}

for size in battery_sizes:
    env_data[size] = {}
    task = bauwerk.benchmarks.Task(
        cfg=bauwerk.envs.solar_battery_house.EnvConfig(
            battery_size=size, 
            episode_len=TASK_LEN
        )
    )
    test_env.set_task(task)
    env_data[size]["optimal"] = bauwerk.eval.get_optimal_perf(test_env, eval_len=TASK_LEN)
    env_data[size]["no charge"] = bauwerk.eval.evaluate_actions(np.zeros((TASK_LEN,1)), test_env)
    env_data[size]["random"], _ = bauwerk.eval.get_avg_rndm_perf(
        test_env, 
        eval_len=TASK_LEN,
        num_samples=10,
    )

    def get_feasible_val(perc):
        return env_data[size]["optimal"] * perc + env_data[size]["no charge"] * (1 - perc)


    env_data[size]["PEARL"] = get_feasible_val(0.9)
    env_data[size]["RL$^2$"] = get_feasible_val(0.5)
    env_data[size]["MAML-TPRO"] = get_feasible_val(0.2)
    





In [None]:
# relevant tutorial https://www.geeksforgeeks.org/bar-plot-in-matplotlib/

import copy
import seaborn as sns
sns.set_theme(style="white", context="paper", font="serif")
palette = sns.color_palette("deep")

num_values_per_house = 4
space_between_houses = 2.5
height = 1/(num_values_per_house + space_between_houses)


def get_rel_perf(maximum, minimum, perf):
    return (perf-minimum)/(maximum-minimum)

def get_loc(house, idx):
    """Get location of bar in plot for perf measure 'idx' in building 'house'."""
    return house - height*(num_values_per_house/2 - 0.5) + height * idx



def create_bar_chart(max_key="optimal", min_key="random", remove_keys=None):
    # Figure Size
    fig, ax = plt.subplots(figsize =(6.5, 4.5))

    ys = []
    y_labels = []
    nocharge_lines = []

    for i, size in enumerate(env_data.keys()):
        name = f"{size}kWh"
        rel_nocharge_perf = get_rel_perf(
            maximum = env_data[size][max_key],
            minimum = env_data[size][min_key], 
            perf = env_data[size]["no charge"]
        )

        #nocharge_lines.append(ax.barh(get_loc(i,0) , width=rel_nocharge_perf, height=height, label ='no charging', color=palette[0]))
        perf_dict: dict = copy.deepcopy(env_data[size])
        perf_dict.pop(max_key)
        perf_dict.pop(min_key)

        if remove_keys is not None:
            for key in remove_keys:
                perf_dict.pop(key)
        
        for j, (key, value) in enumerate(perf_dict.items()):
            rel_value = get_rel_perf(
                maximum = env_data[size][max_key],
                minimum = env_data[size][min_key], 
                perf = env_data[size][key]
            )
            ax.barh(get_loc(i,j) , width=rel_value, height=height, color=palette[j], label=key)

        ys.append(i)
        y_labels.append(name)
        



    # Add annotation to bars
    x_len = ax.get_xbound()[1] - ax.get_xbound()[0]
    for i in ax.patches:
        width = i.get_width()
        plt.text(width + x_len * 0.007 * np.sign(width), i.get_y()+ i.get_height()*0.55,
                str(round((width), 3)),
                fontsize = 8, #fontweight ='bold',
                color = 'black',
                horizontalalignment=("left" if width>0 else "right"),
                verticalalignment="center",
                #transform=ax.transAxes
                )

    # add the battery size labels
    ax.set_yticks(ys, y_labels)
    ax.set_xticks([0,1], [f"0\n ({min_key})", f"1\n ({max_key})"])
    ax.tick_params(axis="x", pad=-10) # reduce the padding of tick labels

    # add axis labels
    ax.set_xlabel(f"Performance relative to {min_key} and {max_key} ")
    ax.set_ylabel("Battery size")

    # avoid duplicate labels
    handles, labels = ax.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax.legend(by_label.values(), by_label.keys())

    # change order to from smallest to largest battery size
    ax.invert_yaxis()
    ax.set_title("Building distribution B performance (PLACEHOLDER, not real data)")

    # remove top/bottom lines
    ax.vlines([0,1],*ax.get_ylim(),colors=["grey","grey"], linestyles=["solid","dotted"])

    # remove default frame around figure
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)

    # extend figure slightly to left to show full vline at 0
    ax.set_xlim(left=ax.get_xlim()[0] - 0.005)

    plt.tight_layout()
    plt.savefig("test.png", dpi=300)
    plt.show()

In [None]:
create_bar_chart("optimal","random")

In [None]:
create_bar_chart("optimal","no charge", remove_keys=["random"])