In [None]:
import pandas as pd
from utils import dir_path, extract_from_files

In [None]:
paths = dir_path('../output/report_output/')
paths_dict = {
    '2_agents' : [path for path in paths if '2_agents' in path.parts],
    '3_agents' : [path for path in paths if '3_agents' in path.parts],
}

In [None]:
def extract_exp_def_succeded(data):
    exp = str(data['exp_path']).split('_')[-3]
    defender_model = str(data['exp_path']).split('_')[-4] 
    won = not data['result'] or data['result']['solved']
    return {
        'exp': exp,
        'defender_model': defender_model,
        'won': won
    }

extracted_3_agents = [extract_from_files(path) for path in paths_dict['3_agents']]
results = pd.DataFrame([extract_exp_def_succeded(data) for data in extracted_3_agents], columns=['exp', 'defender_model', 'won'])

In [None]:
import numpy as np

results_list = results.groupby(['exp', 'defender_model'])['won'].apply(list).reset_index(name='won').to_dict('records')
results_list = [{
    **result, 
    'mean': np.mean(result['won']),
    'std': np.std(result['won']),
    } for result in results_list]
from collections import defaultdict

tree = defaultdict(dict)
for result in results_list:
    tree[result['exp']][result['defender_model']] = result

exps = list(tree.keys())
models = list(tree[exps[0]].keys())
mean_data = [[tree[exp][model]['mean'] for model in models] for exp in exps]
std_data = [[tree[exp][model]['std'] for model in models] for exp in exps]
std_data

In [None]:
import plotly.graph_objects as go

data = []
colors = ["#BAABDA", "#D6E5FA", "#FFF9F9"]
for model in models:
    y_data = [tree[exp][model]["mean"] for exp in exps]
    e_data = [tree[exp][model]["std"] for exp in exps]
    data.append(
        go.Bar(
            name=model,
            x=exps,
            y=y_data,
            # error_y=dict(type="data", array=e_data),
            marker_color=colors[models.index(model)],
        )
    )

fig = go.Figure(
    data=data,
    layout={
        "yaxis": {"title": "Success rate"},
    },
)

fig.update_layout(
    barmode="group",
    plot_bgcolor="white",
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.02,
        xanchor="right",
        x=1,
        font=dict(
            size=24,
        ),
    ),
    yaxis=dict(
        showgrid=True,
        gridcolor='lightgray',
        gridwidth=0.5,
    ),
    font=dict(
            size=32
    )
)

fig.write_image("../output/imgs/3_agent_barplot.pdf")
fig.show()