In [1]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import ast

In [2]:
# Create dataframe of summary
df = pd.read_csv("output/summary.csv")

In [3]:
# Parse best_trajectory to get the best path length
df["best_length"] = df["best_trajectory"].apply(lambda x: len(ast.literal_eval(x)))

# Normalize node counts by subtracting 1 for the root node and dividing by the best trajectory length
root_node_size = 1
df["nodes_norm"] = (df["nodes"] - root_node_size) / df["best_length"]
df["rs_nodes_norm"] = (df["rs_nodes"] - root_node_size) / df["best_length"]
df["r_nodes_norm"] = (df["r_nodes"] - root_node_size) / df["best_length"]

# Melt the normalized data
df_melted = df.melt(id_vars="id", 
                    value_vars=["nodes_norm", "rs_nodes_norm", "r_nodes_norm"],
                    var_name="Method",
                    value_name="Normalized Nodes")
method_map = {
    "nodes_norm": "LLM agent",
    "rs_nodes_norm": "Random sampling",
    "r_nodes_norm": "Random possible action"
}
df_melted["Method"] = df_melted["Method"].map(method_map)

# Aggregate: mean normalized nodes per method
agg_df = df_melted.groupby("Method", as_index=False)["Normalized Nodes"].mean().sort_values("Normalized Nodes").reset_index(drop=True)

# Compute differences
min_value = agg_df["Normalized Nodes"].min()
agg_df["Difference"] = agg_df["Normalized Nodes"] - min_value

# Create color palette dict: best gets #82B366, others get #666666
colors = {method: '#666666' for method in agg_df["Method"]}
colors[agg_df["Method"][0]] = '#82B366'

# Plot
plt.figure()
ax = sns.barplot(data=agg_df, x="Method", y="Normalized Nodes", hue="Method", palette=colors, legend=False)

# Annotate bars
for i, row in agg_df.iterrows():
    diff_text = f"+{row['Difference']:.2f}" if row["Difference"] > 0 else ""
    ax.text(i, row["Normalized Nodes"] + 0.02, diff_text, ha="center", va="bottom", fontsize=10, color="black")

# Make space for annotation
ylim = ax.get_ylim()
ax.set_ylim(ylim[0],ylim[1]*1.01)

plt.title("Average branching factor of the search tree")
plt.ylabel("Branching factor")
plt.xlabel("")
plt.tight_layout()
plt.savefig('../images/expansion_efficiency.svg', transparent=True)
plt.savefig('../images/expansion_efficiency.png', transparent=True)
plt.close()