In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import json
import numpy as np
import pandas as pd
import glob

rc = {"font.sans-serif": "SimHei", "axes.unicode_minus": False}
plt.rcParams.update(rc)

In [None]:
agent_files = []
# agent_files += glob.glob(r"logs\spb\Flat32\Medium\RandomBuild\Qwen2.5-32B-Instruct\*\trace.json")
agent_files += glob.glob(r"logs\spb\Flat64\MediumHard\RandomBuild\Qwen2.5-32B-Instruct-SFT\*\trace.json")
agent_files += glob.glob(r"logs\spb\Flat96\MediumHard\RandomBuild\Qwen2.5-32B-Instruct-SFT\*\trace.json")
agent_files.sort()
# agent_files = agent_files[:1] + agent_files[-5:]

agent_data = []
labels = []
for agent_file in agent_files:
    with open(agent_file, "r", encoding="utf-8") as f:
        data = json.load(f)
        agent_data.append(data)
        label_ = agent_file.split("\\")
        labels.append(label_[-3] + "/" + label_[-2])

In [None]:
def plot_key_vs_time(json_list, key, labels=None):
    """
    绘制指定键值随时间变化的曲线

    参数:
    json_list -- 包含多个JSON数据的列表
    key -- 要绘制的键名（如"n_structures"）
    labels -- 可选的标签列表（长度需与json_list一致）
    """
    # 设置seaborn样式
    sns.set_theme(style="whitegrid")
    plt.figure(figsize=(12, 12), dpi=300)

    # 处理默认标签
    if labels is None:
        labels = [f"Data {i+1}" for i in range(len(json_list))]

    # 检查标签数量
    if len(labels) != len(json_list):
        raise ValueError("标签数量必须与JSON数据数量一致")

    # 收集所有时间点用于统一X轴
    all_times = []

    # 处理每个JSON数据
    for idx, json_data in enumerate(json_list):
        # 转换为字典（如果输入是字符串）
        if isinstance(json_data, str):
            data_dict = json.loads(json_data)
        else:
            data_dict = json_data

        # 提取时间和键值
        times = []
        values = []
        for entry in data_dict.values():
            if "time_seconds" in entry and key in entry:
                times.append(entry["time_seconds"])
                values.append(entry[key])

        # 排序数据点
        sorted_indices = np.argsort(times)
        sorted_times = np.array(times)[sorted_indices]
        sorted_values = np.array(values)[sorted_indices]

        # 收集所有时间点
        all_times.extend(sorted_times)

        # 绘制曲线
        plt.plot(sorted_times, sorted_values, marker="o", markersize=4, linewidth=2, label=labels[idx])

    # 设置图表属性
    plt.xlabel("Time (seconds)", fontsize=12)
    plt.ylabel(key, fontsize=12)
    plt.title(f"{key} over Time", fontsize=14)

    # 设置智能刻度
    max_time = max(all_times) if all_times else 100
    max_time = int(np.ceil(max_time / 60) * 60)  # 向上取整到最接近的10的倍数
    plt.xlim(0, max_time)
    plt.xticks(np.arange(0, max_time + 1, 60), fontsize=10)
    plt.yticks(fontsize=10)
    plt.ylim(0, None)  # 自动调整Y轴下限

    # 添加图例和网格
    plt.legend(fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

In [None]:
for key in [
    "minerals",
    "vespene",
    "supply_army",
    "supply_workers",
    # "supply_left",
    "n_structures",
    "n_enemy_units",
    "n_enemy_structures",
    # "n_unit_types",
    # "n_structure_types",
]:
    plot_key_vs_time(agent_data, key, labels=labels)