# Two-Phase Training Curve Plotter (Notebook)

This notebook is an `.ipynb` version of your `plot_curves.py` + batch script.

- **Phase 1**: steps `1..110` (single curve)
- **Phase 2**: steps `111..220` (3 branches, connected from Phase 1 last point)
- **Outputs**: saved images under `plots/` and also displayed inline.


In [None]:

# ===== Imports =====
import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Make plots show inside the notebook
%matplotlib inline


## Config (same as your script)

In [None]:

# ==================== 颜色配置常量 ====================
# 根据任务名称分配颜色（低饱和度，专业配色）
TASK_COLORS = {
    'math': (0.6, 0.2, 0.2),      # RGB: 深红色（低饱和度）
    'science': (0.3, 0.4, 0.8),   # RGB: 深蓝色（低饱和度）
    'logic': (0.3, 0.6, 0.4),     # RGB: 深绿色（低饱和度）
    'puzzle': (0.8, 0.5, 0.2),    # RGB: 橙色（低饱和度）
}

# ==================== 样式配置常量 ====================
LINE_WIDTH = 4.0
LEGEND_FONTSIZE = 24
AXIS_LABEL_FONTSIZE = 24
AXIS_TICK_FONTSIZE = 20
TITLE_FONTSIZE = 30
MARKER_SIZE = 6
SHOW_TITLE = False

# ==================== 配置参数 ====================
SELECTED_STEPS = [1, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110]
PHASE1_STEPS = 110
PHASE2_STEPS = 110
TOTAL_STEPS = PHASE1_STEPS + PHASE2_STEPS


## Helper functions

In [None]:

def identify_columns(df):
    '''
    识别第一阶段和第二阶段曲线的列名
    返回:
        phase1_col: str
        phase2_cols: list[str]
    '''
    all_columns = df.columns.tolist()

    # 只选择mean@1列，忽略MIN和MAX列
    mean_cols = [col for col in all_columns if 'mean@1' in col and '__MIN' not in col and '__MAX' not in col]

    phase1_col = None
    phase2_cols = []

    for col in mean_cols:
        # Base-后面的任务串，如 math 或 math-puzzle
        match = re.search(r'Base-([a-z]+(?:-[a-z]+)?)', col)
        if match:
            tasks_str = match.group(1)
            task_parts = tasks_str.split('-')

            if len(task_parts) == 1:
                phase1_col = col
            elif len(task_parts) == 2:
                phase2_cols.append(col)

    return phase1_col, phase2_cols


def extract_data_points(df, phase1_col, phase2_cols):
    '''
    提取所需的数据点
    返回:
        phase1_data: dict
        phase2_data: list[dict]
    '''
    df = df.copy()
    df['Step'] = df['Step'].astype(int)

    # Phase 1
    phase1_steps = [s for s in SELECTED_STEPS if s <= PHASE1_STEPS]
    phase1_data = {'steps': phase1_steps, 'values': [], 'label': phase1_col}

    for step in phase1_steps:
        row = df[df['Step'] == step]
        if not row.empty and phase1_col in row.columns:
            val = row[phase1_col].iloc[0]
            if pd.isna(val) or val == '' or val is None:
                phase1_data['values'].append(np.nan)
            else:
                try:
                    phase1_data['values'].append(float(val))
                except (ValueError, TypeError):
                    phase1_data['values'].append(np.nan)
        else:
            phase1_data['values'].append(np.nan)

    # Phase 2 (map step 1..110 to x=111..220)
    phase2_data = []
    for col in phase2_cols:
        curve_data = {'steps': [], 'values': [], 'label': col}

        for step in SELECTED_STEPS:
            if step <= PHASE2_STEPS:
                x_pos = PHASE1_STEPS + step
                row = df[df['Step'] == step]
                if not row.empty and col in row.columns:
                    val = row[col].iloc[0]
                    if pd.isna(val) or val == '' or val is None:
                        curve_data['values'].append(np.nan)
                    else:
                        try:
                            curve_data['values'].append(float(val))
                        except (ValueError, TypeError):
                            curve_data['values'].append(np.nan)
                    curve_data['steps'].append(x_pos)
                else:
                    curve_data['values'].append(np.nan)
                    curve_data['steps'].append(x_pos)

        phase2_data.append(curve_data)

    return phase1_data, phase2_data


def generate_label(col_name):
    '''
    从列名生成简短的标签（Phase2的第二任务名）
    '''
    match = re.search(r'Base-[^-]+-([^-]+)', col_name)
    if match:
        return match.group(1)
    return col_name


def get_phase1_task_name(col_name):
    '''
    提取Phase1任务名
    '''
    match = re.search(r'Base-([a-z]+)', col_name)
    if match:
        return match.group(1)
    return ''


def get_task_color(col_name):
    '''
    Phase1: 用第一个任务颜色
    Phase2: 用第二个任务颜色
    '''
    match = re.search(r'Base-([a-z]+(?:-[a-z]+)?)', col_name)
    if match:
        tasks_str = match.group(1)
        task_parts = tasks_str.split('-')

        if len(task_parts) == 2:
            task_name = task_parts[1]
        elif len(task_parts) == 1:
            task_name = task_parts[0]
        else:
            return (0.0, 0.0, 0.0)

        return TASK_COLORS.get(task_name, (0.0, 0.0, 0.0))

    return (0.0, 0.0, 0.0)


def plot_curves(phase1_data, phase2_data, output_path, title=None, show_inline=True):
    '''
    绘制并保存曲线图，同时可在notebook中直接显示。
    '''
    plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial', 'Liberation Sans']
    plt.rcParams['axes.unicode_minus'] = False

    fig, ax = plt.subplots(figsize=(10, 6))

    # Phase 1
    phase1_steps = phase1_data['steps']
    phase1_values = phase1_data['values']
    phase1_col_name = phase1_data.get('label', '')

    phase1_color = get_task_color(phase1_col_name)
    phase1_task_name = get_phase1_task_name(phase1_col_name)
    phase1_label = f'Phase 1: {phase1_task_name}' if phase1_task_name else 'Phase 1'

    valid_indices = [i for i, v in enumerate(phase1_values) if not np.isnan(v)]
    if valid_indices:
        valid_steps = [phase1_steps[i] for i in valid_indices]
        valid_values = [phase1_values[i] for i in valid_indices]
        ax.plot(valid_steps, valid_values,
                color=phase1_color,
                linewidth=LINE_WIDTH,
                marker='o',
                markersize=MARKER_SIZE,
                label=phase1_label)

    phase1_last_value = None
    if phase1_values and not np.isnan(phase1_values[-1]):
        phase1_last_value = phase1_values[-1]

    # Phase 2
    for curve_data in phase2_data:
        steps = curve_data['steps']
        values = curve_data['values']
        col_name = curve_data['label']

        phase2_color = get_task_color(col_name)

        plot_steps = steps.copy()
        plot_values = values.copy()

        if phase1_last_value is not None and len(steps) > 0:
            plot_steps = [PHASE1_STEPS] + steps
            plot_values = [phase1_last_value] + values

        valid_indices = [j for j, v in enumerate(plot_values) if not np.isnan(v)]
        if valid_indices:
            valid_steps = [plot_steps[j] for j in valid_indices]
            valid_values = [plot_values[j] for j in valid_indices]
            label = generate_label(col_name)

            ax.plot(valid_steps, valid_values,
                    color=phase2_color,
                    linewidth=LINE_WIDTH,
                    marker='o',
                    markersize=MARKER_SIZE,
                    label=f'Phase 2: {label}')

    # Split line
    ax.axvline(x=PHASE1_STEPS, color='gray', linestyle='--', linewidth=4, alpha=0.5)

    # Axes
    ax.set_xlabel('Step', fontsize=AXIS_LABEL_FONTSIZE)
    ax.set_ylabel('Accuracy', fontsize=AXIS_LABEL_FONTSIZE)
    ax.set_xlim(0, TOTAL_STEPS)

    ax.set_xticks([0, 55, 110, 165, 220])
    ax.set_xticklabels(['0', '55', '110', '165', '220'])
    ax.tick_params(axis='both', which='major', labelsize=AXIS_TICK_FONTSIZE)

    ax.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)

    legend = ax.legend(loc='best', fontsize=LEGEND_FONTSIZE)
    for text in legend.get_texts():
        if text.get_text().startswith('Phase 1'):
            text.set_fontweight('bold')

    if SHOW_TITLE and title:
        ax.set_title(title, fontsize=TITLE_FONTSIZE, fontweight='bold')

    plt.tight_layout()

    os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"Saved: {output_path}")

    if show_inline:
        plt.show()
    else:
        plt.close(fig)


## Single file run (preview inline)

In [None]:

# ===== Set these paths =====
INPUT_CSV = "data/math.csv"
OUTPUT_IMG = "plots/math_curve.png"
TITLE = "Math Training Curve"

df = pd.read_csv(INPUT_CSV)
phase1_col, phase2_cols = identify_columns(df)
print("Phase1:", phase1_col)
print("Phase2:", len(phase2_cols), "curves")

phase1_data, phase2_data = extract_data_points(df, phase1_col, phase2_cols)
plot_curves(phase1_data, phase2_data, OUTPUT_IMG, TITLE, show_inline=True)


## Batch run (generate all 4 curves)

In [None]:

OUTPUT_DIR = "plots"
os.makedirs(OUTPUT_DIR, exist_ok=True)

jobs = [
    ("data/math.csv",    f"{OUTPUT_DIR}/math_curve.png",    "Math Training Curve"),
    ("data/logic.csv",   f"{OUTPUT_DIR}/logic_curve.png",   "Logic Training Curve"),
    ("data/puzzle.csv",  f"{OUTPUT_DIR}/puzzle_curve.png",  "Puzzle Training Curve"),
    ("data/science.csv", f"{OUTPUT_DIR}/science_curve.png", "Science Training Curve"),
]

for inp, outp, title in jobs:
    if not os.path.exists(inp):
        print(f"Skip (missing): {inp}")
        continue

    df = pd.read_csv(inp)
    phase1_col, phase2_cols = identify_columns(df)
    if phase1_col is None or len(phase2_cols) == 0:
        print(f"Warning: columns not found in {inp}")
        print("  phase1_col:", phase1_col)
        print("  phase2_cols:", len(phase2_cols))
        continue

    phase1_data, phase2_data = extract_data_points(df, phase1_col, phase2_cols)
    print(f"\n==> {inp}")
    plot_curves(phase1_data, phase2_data, outp, title, show_inline=True)

print("\nAll done.")
