# Metrics Plotter（実験結果の再プロット）

このノートブックは `experiments/outputs/<run>/metrics/*.csv` を読み込み、
評価曲線（episode reward vs steps）を再プロットします。

## 使い方（最短）
1. 上から順にセルを実行します
2. GUIでrunを選ぶ or 手動で `OUTPUT_DIR` を指定します
3. `STEP_BIN` で点の間隔、`SMOOTH_WINDOW` で滑らかさを調整します

Tip: もっと細かいログが欲しい場合は、実験の `eval_interval` を小さくします。


In [1]:
from pathlib import Path
import csv
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

try:
    import ipywidgets as widgets
    HAS_WIDGETS = True
except Exception:
    widgets = None
    HAS_WIDGETS = False


In [2]:
OUTPUTS_ROOT = None  # 例: 'experiments/outputs' or '/home/.../DRL-Pytorch/experiments/outputs'

COLOR_MAP = {
    'q_learning': '#d62728',
    'dqn': '#1f77b4',
    'per_dqn': '#2ca02c',
    'double_dqn': '#ff7f0e',
}

def find_repo_root(start=None):
    start = Path(start or Path.cwd()).resolve()
    for p in [start] + list(start.parents):
        if (p / 'pyproject.toml').exists() or (p / '.git').exists():
            return p
    return None

def detect_outputs_root():
    if OUTPUTS_ROOT:
        return Path(OUTPUTS_ROOT).expanduser().resolve()
    repo = find_repo_root()
    if repo is not None:
        cand = repo / 'experiments' / 'outputs'
        if cand.exists():
            return cand
    cwd = Path.cwd().resolve()
    for p in [cwd] + list(cwd.parents):
        cand = p / 'experiments' / 'outputs'
        if cand.exists():
            return cand
    return cwd / 'experiments' / 'outputs'

def list_output_dirs(root=None):
    root = Path(root) if root else detect_outputs_root()
    if not root.exists():
        return []
    return sorted([d for d in root.iterdir() if d.is_dir()], key=lambda d: d.stat().st_mtime, reverse=True)

def latest_output_dir(root=None):
    dirs = list_output_dirs(root)
    return dirs[0] if dirs else None

def pick_output_dir(output_dir, select_index):
    dirs = list_output_dirs()
    if output_dir:
        return Path(output_dir)
    if select_index is not None:
        if not dirs:
            return None
        if select_index < 0 or select_index >= len(dirs):
            raise ValueError('SELECT_INDEX out of range')
        return dirs[select_index]
    return latest_output_dir()

def parse_score(row):
    if 'episode_reward' in row:
        return float(row['episode_reward'])
    if 'score' in row:
        return float(row['score'])
    return float(row.get('reward', 0.0))

def load_metrics(output_dir):
    metrics_dir = Path(output_dir) / 'metrics'
    records = []
    for csv_path in sorted(metrics_dir.glob('*.csv')):
        run_id = csv_path.stem
        with open(csv_path, newline='', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            for row in reader:
                if not row:
                    continue
                try:
                    step = int(float(row.get('step', 0)))
                except ValueError:
                    continue
                env = row.get('env', '')
                seed = row.get('seed', '')
                score = parse_score(row)
                records.append({
                    'run_id': run_id,
                    'step': step,
                    'score': score,
                    'env': env,
                    'seed': seed,
                })
    return records

def pick_color(run_id):
    for key, color in COLOR_MAP.items():
        if run_id.startswith(key):
            return color
    return None

def bin_records(records, step_bin):
    buckets = defaultdict(list)
    for r in records:
        b = (r['step'] // step_bin) * step_bin
        buckets[b].append(r['score'])
    xs = sorted(buckets.keys())
    ys = [float(np.mean(buckets[x])) for x in xs]
    return np.array(xs), np.array(ys)

def smooth_series(x, y, window):
    if window <= 1 or len(y) < window:
        return x, y
    kernel = np.ones(window) / window
    y_s = np.convolve(y, kernel, mode='valid')
    x_s = x[window - 1:]
    return x_s, y_s

def plot_from_records(records, step_bin, smooth_window, only_envs=None, merge_envs=False):
    envs = sorted(set(r['env'] for r in records))
    if only_envs:
        envs = [e for e in envs if e in only_envs]
    if merge_envs:
        plt.figure(figsize=(12, 5))
        for env in envs:
            run_ids = sorted(set(r['run_id'] for r in records if r['env'] == env))
            for run_id in run_ids:
                recs = [r for r in records if r['env'] == env and r['run_id'] == run_id]
                if not recs:
                    continue
                x, y = bin_records(recs, step_bin)
                x, y = smooth_series(x, y, smooth_window)
                label = f'{env}:{run_id}'
                plt.plot(x, y, label=label, color=pick_color(run_id))
        plt.title('ep_r comparison - all envs')
        plt.xlabel('steps')
        plt.ylabel('episode reward')
        plt.grid(True, alpha=0.3)
        plt.legend()
        plt.show()
        return
    for env in envs:
        plt.figure(figsize=(10, 4))
        run_ids = sorted(set(r['run_id'] for r in records if r['env'] == env))
        for run_id in run_ids:
            recs = [r for r in records if r['env'] == env and r['run_id'] == run_id]
            if not recs:
                continue
            x, y = bin_records(recs, step_bin)
            x, y = smooth_series(x, y, smooth_window)
            plt.plot(x, y, label=run_id, color=pick_color(run_id))
        plt.title(f'ep_r comparison - {env}')
        plt.xlabel('steps')
        plt.ylabel('episode reward')
        plt.grid(True, alpha=0.3)
        plt.legend()
        plt.show()


## 1) GUIモード（おすすめ）
ドロップダウンから run を選んで、ボタンで描画できます。

※ `ipywidgets` が無い場合は手動モードを使ってください。


In [None]:
if not HAS_WIDGETS:
    print('ipywidgets が見つかりません。GUIモードは使えません。')
    print('uv add ipywidgets で追加してから再起動してください。')
else:
    def build_gui():
        root = detect_outputs_root()
        dirs = list_output_dirs(root)
        options = [(str(d.name), str(d)) for d in dirs]

        output_dd = widgets.Dropdown(
            options=options,
            description='Output',
            layout=widgets.Layout(width='90%')
        )
        step_slider = widgets.IntSlider(value=1000, min=100, max=10000, step=100, description='STEP_BIN')
        smooth_slider = widgets.IntSlider(value=5, min=1, max=50, step=1, description='SMOOTH')
        merge_checkbox = widgets.Checkbox(value=False, description='Merge ENVS')
        env_select = widgets.SelectMultiple(options=[], description='ENVS')
        refresh_btn = widgets.Button(description='Refresh List')
        plot_btn = widgets.Button(description='Plot', button_style='success')
        status = widgets.HTML('')
        out = widgets.Output()

        def refresh_envs():
            if not output_dd.options:
                env_select.options = []
                return
            output_dir = output_dd.value
            try:
                recs = load_metrics(output_dir)
            except Exception as e:
                status.value = f'<b>load error:</b> {e}'
                return
            envs = sorted(set(r['env'] for r in recs))
            env_select.options = envs

        def on_refresh(_):
            root = detect_outputs_root()
            dirs = list_output_dirs(root)
            output_dd.options = [(str(d.name), str(d)) for d in dirs]
            if output_dd.options:
                output_dd.value = output_dd.options[0][1]
            refresh_envs()

        def on_plot(_):
            with out:
                clear_output(wait=True)
                if not output_dd.value:
                    print('No output selected.')
                    return
                records = load_metrics(output_dd.value)
                only_envs = list(env_select.value)
                plot_from_records(records, step_slider.value, smooth_slider.value, only_envs, merge_checkbox.value)

        refresh_btn.on_click(on_refresh)
        plot_btn.on_click(on_plot)
        output_dd.observe(lambda _: refresh_envs(), names='value')

        if output_dd.options:
            output_dd.value = output_dd.options[0][1]
            refresh_envs()

        status.value = f'<b>Outputs root:</b> {root}'
        ui = widgets.VBox([
            status,
            output_dd,
            widgets.HBox([step_slider, smooth_slider, merge_checkbox]),
            env_select,
            widgets.HBox([refresh_btn, plot_btn]),
            out,
        ])
        display(ui)

    build_gui()


VBox(children=(HTML(value='<b>Outputs root:</b> /home/taiyo-sato/Desktop/B3kadai_2026/DRL-Pytorch/experiments/…

## 2) 手動モード（GUIが使えない場合）
直接フォルダを指定して再プロットします。


In [None]:
OUTPUT_DIR = None  # 例: 'experiments/outputs/cartpole_q_dqn_per_compare_20260204_025537'
SELECT_INDEX = None  # 例: 0, 1, 2 ... (一覧表示後に選ぶ)

STEP_BIN = 1000     # プロット点の間隔（例: 200, 500, 1000, 5000）
SMOOTH_WINDOW = 5   # 移動平均の窓（1で無効）
ONLY_ENVS = []      # 例: ['CartPole-v1'] （空なら全環境）
MERGE_ENVS = False  # Trueにすると全環境を1枚にまとめて描画


In [None]:
def show_output_list():
    root = detect_outputs_root()
    dirs = list_output_dirs(root)
    print('Outputs root:', root)
    if not dirs:
        print('No output directories found.')
        return
    print('Outputs (newest first):')
    for i, d in enumerate(dirs):
        print(f'[{i}] {d}')

show_output_list()


In [None]:
output_dir = pick_output_dir(OUTPUT_DIR, SELECT_INDEX)
if output_dir is None:
    raise RuntimeError('No output directory found.')
print('Using output:', output_dir)

records = load_metrics(output_dir)
print('Records:', len(records), 'Runs:', len(set(r['run_id'] for r in records)))

plot_from_records(records, STEP_BIN, SMOOTH_WINDOW, ONLY_ENVS, MERGE_ENVS)
