In [None]:
import matplotlib.font_manager
from IPython.core.display import HTML

def make_html(fontname):
    return "<p>{font}: <span style='font-family:{font}; font-size: 24px;'>{font}</p>".format(font=fontname)

code = "\n".join([make_html(font) for font in sorted(set([f.name for f in matplotlib.font_manager.fontManager.ttflist]))])

HTML("<div style='column-count: 2;'>{}</div>".format(code))

In [None]:
import shutil
import matplotlib

shutil.rmtree(matplotlib.get_cachedir())

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

rc = {"font.sans-serif": "Microsoft Sans Serif", "axes.unicode_minus": False}
plt.rcParams.update(rc)

color_list = [
    "#1f77b4",
    "#ff7f0e",
    "#2ca02c",
    "#d62728",
    "#9467bd",
    "#8c564b",
    "#e377c2",
    "#7f7f7f",
    "#bcbd22",
    "#17becf",
    "#1f77b4",
    "#ff7f0e",
    "#2ca02c",
    "#d62728",
    "#9467bd",
    "#8c564b",
    "#e377c2",
    "#7f7f7f",
    "#bcbd22",
    "#17becf",
]


def validate_data(label_data, labels):
    if not isinstance(label_data, dict):
        raise TypeError("label_data must be a dictionary.")
    for key, values in label_data.items():
        if not isinstance(values, (list, tuple)):
            raise TypeError(f"Values for {key} must be a list or tuple.")
        if len(values) != len(labels):
            raise ValueError(f"Length of values for {key} does not match labels length.")


def prepare_data(labels, label_data):
    df = pd.DataFrame(label_data).melt(var_name="category", value_name="value")
    df["x"] = labels * len(label_data)
    return df


def setup_plot(fig_size, dpi):
    plt.figure(figsize=fig_size, dpi=dpi)
    sns.set_theme(style="whitegrid", font_scale=1.2, rc=rc)
    return plt.gca()


def configure_axes(ax, title, x_label, y_label):
    ax.set_title(title, pad=20, fontsize=14, weight="bold")
    ax.set_xlabel(x_label, labelpad=10, fontsize=12)
    ax.set_ylabel(y_label, labelpad=10, fontsize=12)


def configure_legend(
    ax, legend_outside, title="Categories", title_fontsize=11, fontsize=10, frameon=True, framealpha=0.9, edgecolor="0.6"
):
    legend = ax.legend(
        title=title,
        title_fontsize=title_fontsize,
        fontsize=fontsize,
        frameon=frameon,
        framealpha=framealpha,
        edgecolor=edgecolor,
    )
    if legend_outside:
        legend.set_bbox_to_anchor((1.05, 1))
        ax.add_artist(legend)
    return legend


def set_y_range(ax, y_range):
    ax.set_ylim(y_range)
    ax.yaxis.set_ticks(np.linspace(y_range[0], y_range[1], 6))


def configure_grid(ax, grid_linestyle, grid_alpha):
    ax.grid(axis="y", linestyle=grid_linestyle, alpha=grid_alpha)


def save_or_show_plot(output_path, dpi):
    plt.tight_layout()
    if output_path:
        plt.savefig(output_path, dpi=dpi, bbox_inches="tight")
    else:
        plt.show()
    plt.close()


def plot_bar(
    labels,
    label_data,
    title,
    x_label,
    y_label,
    output_path=None,
    y_range=(0, 1.0),
    legend_outside=False,
    palette=None,
    fig_size=(10, 6),
    dpi=300,
    grid_linestyle="--",
    grid_alpha=0.6,
):
    validate_data(label_data, labels)
    df = prepare_data(labels, label_data)

    ax = setup_plot(fig_size, dpi)

    if palette is None:
        palette = color_list[: len(label_data)]

    ax = sns.barplot(
        x="x",
        y="value",
        hue="category",
        data=df,
        palette=palette,
        linewidth=0,
    )

    configure_axes(ax, title, x_label, y_label)

    for container in ax.containers:
        ax.bar_label(
            container,
            fmt="%.2f",
            padding=4,
            color="black",
            fontsize=9,
            rotation=0,
        )

    configure_legend(ax, legend_outside)

    set_y_range(ax, y_range)
    configure_grid(ax, grid_linestyle, grid_alpha)

    save_or_show_plot(output_path, dpi)


def plot_line(
    labels,
    label_data,
    title,
    x_label,
    y_label,
    output_path=None,
    y_range=(0, 1.0),
    legend_outside=False,
    palette=None,
    fig_size=(10, 6),
    dpi=300,
    grid_linestyle="--",
    grid_alpha=0.6,
    line_width=2.0,
    markers=True,
    show_values=True,
    value_format="%.2f",
    value_padding=4,
    value_fontsize=9,
    value_rotation=0,
    legend_title="Categories",
    legend_title_fontsize=11,
    legend_fontsize=10,
    legend_frameon=True,
    legend_framealpha=0.9,
    legend_edgecolor="0.6",
):
    validate_data(label_data, labels)
    df = prepare_data(labels, label_data)

    ax = setup_plot(fig_size, dpi)

    if palette is None:
        palette = color_list[: len(label_data)]

    ax = sns.lineplot(
        x="x",
        y="value",
        hue="category",
        data=df,
        palette=palette,
        markers=markers,
        linewidth=line_width,
    )

    configure_axes(ax, title, x_label, y_label)

    ax.set_xticks(labels)
    ax.set_xticklabels(labels)

    if show_values:
        for i, row in df.iterrows():
            x_val = row["x"]
            y_val = row["value"]
            formatted_value = value_format % y_val
            ax.text(
                x_val,
                y_val + value_padding / 100 * (y_range[1] - y_range[0]),
                formatted_value,
                ha="center",
                va="bottom",
                fontsize=value_fontsize,
                rotation=value_rotation,
                color="black",
            )

    configure_legend(
        ax,
        legend_outside,
        legend_title,
        legend_title_fontsize,
        legend_fontsize,
        legend_frameon,
        legend_framealpha,
        legend_edgecolor,
    )

    set_y_range(ax, y_range)
    configure_grid(ax, grid_linestyle, grid_alpha)

    save_or_show_plot(output_path, dpi)

In [None]:
class Metric:
    def __init__(
        self,
        diffculty,
        map_name,
        ai_build,
        model,
        enable_rag,
        enable_plan,
        enable_plan_verifier,
        enable_action_verifier,
        win_rate,
        win_avg_time,
        sbr,
        rur,
        avg_inference,
        avg_token,
        action_valid_rate,
    ):
        self.diffculty = diffculty
        self.map_name = map_name
        self.ai_build = ai_build
        self.model = model
        self.enable_rag = enable_rag
        self.enable_plan = enable_plan
        self.enable_plan_verifier = enable_plan_verifier
        self.enable_action_verifier = enable_action_verifier
        self.win_rate = win_rate
        self.win_avg_time = win_avg_time
        self.sbr = sbr
        self.rur = rur
        self.avg_inference = avg_inference
        self.avg_token = avg_token
        self.action_valid_rate = action_valid_rate


metrics = [
    Metric(
        diffculty="Easy",
        map_name="Flat32",
        ai_build="RandomBuild",
        model="Qwen2.5-32B-Instruct",
        enable_rag=False,
        enable_plan=False,
        enable_plan_verifier=False,
        enable_action_verifier=False,
        win_rate=0.3,
        win_avg_time=897.33,
        sbr=0.1313,
        rur=8.7600,
        avg_inference=389.10,
        avg_token=261.02,
        action_valid_rate=0.3795,
    ),
    Metric(
        diffculty="Easy",
        map_name="Flat32",
        ai_build="RandomBuild",
        model="Qwen2.5-32B-Instruct",
        enable_rag=False,
        enable_plan=False,
        enable_plan_verifier=False,
        enable_action_verifier=True,
        win_rate=0.4,
        win_avg_time=536.00,
        sbr=0.2111,
        rur=8.3149,
        avg_inference=204.30,
        avg_token=643.81,
        action_valid_rate=0.6368,
    ),
    Metric(
        diffculty="Easy",
        map_name="Flat32",
        ai_build="RandomBuild",
        model="Qwen2.5-32B-Instruct",
        enable_rag=False,
        enable_plan=False,
        enable_plan_verifier=True,
        enable_action_verifier=True,
        win_rate=0.9,
        win_avg_time=465.44,
        sbr=0.0227,
        rur=10.4580,
        avg_inference=112.90,
        avg_token=838.73,
        action_valid_rate=0.9676,
    ),
    Metric(
        diffculty="Easy",
        map_name="Flat64",
        ai_build="RandomBuild",
        model="Qwen2.5-32B-Instruct",
        enable_rag=False,
        enable_plan=False,
        enable_plan_verifier=True,
        enable_action_verifier=True,
        win_rate=0.8,
        win_avg_time=463.88,
        sbr=0.0544,
        rur=13.3699,
        avg_inference=134.30,
        avg_token=1117.87,
        action_valid_rate=0.8931,
    ),
    Metric(
        diffculty="Easy",
        map_name="Flat64",
        ai_build="RandomBuild",
        model="Qwen2.5-32B-Instruct",
        enable_rag=False,
        enable_plan=False,
        enable_plan_verifier=False,
        enable_action_verifier=False,
        win_rate=0.7,
        win_avg_time=690.43,
        sbr=0.3339,
        rur=15.1130,
        avg_inference=782.90,
        avg_token=284.87,
        action_valid_rate=0.3083,
    ),
]


def filter_metrics(
    diffculty=None,
    map_name=None,
    ai_build=None,
    model=None,
    enable_rag=None,
    enable_plan=None,
    enable_plan_verifier=None,
    enable_action_verifier=None,
    sort_with=None,
    sort_values=None,
):
    def is_contain(src, values):
        return src == values or (isinstance(values, (list, tuple)) and src in values)

    filtered_metrics = metrics
    if diffculty:
        filtered_metrics = [m for m in filtered_metrics if is_contain(m.diffculty, diffculty)]
    if map_name:
        filtered_metrics = [m for m in filtered_metrics if is_contain(m.map_name, map_name)]
    if ai_build:
        filtered_metrics = [m for m in filtered_metrics if is_contain(m.ai_build, ai_build)]
    if model:
        filtered_metrics = [m for m in filtered_metrics if is_contain(m.model, model)]
    if enable_rag is not None:
        filtered_metrics = [m for m in filtered_metrics if m.enable_rag == enable_rag]
    if enable_plan is not None:
        filtered_metrics = [m for m in filtered_metrics if m.enable_plan == enable_plan]
    if enable_plan_verifier is not None:
        filtered_metrics = [m for m in filtered_metrics if m.enable_plan_verifier == enable_plan_verifier]
    if enable_action_verifier is not None:
        filtered_metrics = [m for m in filtered_metrics if m.enable_action_verifier == enable_action_verifier]
    if sort_with:
        # assert sort_with in Metric.__dict__(), f"Metric does not have attribute {sort_with}. Avaliable attributes are {Metric.__dict__.keys()}"
        if sort_values:
            assert isinstance(sort_values, (list, tuple)), "sort_values must be a list or tuple."
            filtered_metrics = sorted(filtered_metrics, key=lambda x: sort_values.index(getattr(x, sort_with)))
        else:
            filtered_metrics = sorted(filtered_metrics, key=lambda x: getattr(x, sort_with))
    return filtered_metrics


# Easy 难度、RandomBuild 策略，大语言模型采用 Qwen2.5-32B-Instruct 模型，开启/关闭 计划验证器和行动验证器的对比
# 关闭时的label叫"LLM"，开启时的label叫"Agent"
maps = ["Flat32", "Flat64"]
base_setting = {
    "diffculty": "Easy",
    "ai_build": "RandomBuild",
    "model": "Qwen2.5-32B-Instruct",
    "map_name": maps,
    "sort_with": "map_name",
}
metrics_1 = filter_metrics(
    enable_rag=False, enable_plan=False, enable_plan_verifier=False, enable_action_verifier=False, **base_setting
)
metrics_2 = filter_metrics(
    enable_rag=False, enable_plan=False, enable_plan_verifier=True, enable_action_verifier=True, **base_setting
)
data_win_rate = {
    "LLM": [m.win_rate for m in metrics_1],
    "Agent": [m.win_rate for m in metrics_2],
}
print(data_win_rate)
plot_bar(
    maps,
    data_win_rate,
    title="不同地图下开启/关闭计划验证器和行动验证器的胜率对比",
    x_label="地图",
    y_label="胜率",
    # output_path="win_rate.png",
    y_range=(0, 1),
    legend_outside=False,
)