# Training curves with confidence intervals

This notebook loads TensorBoard event files from `./runs`, groups runs by (env_id, algorithm) across different seeds, and plots mean ± std bands using Plotly. It works with run directories named like `PointMassDiscrete-v0__dqn__0__1760811783`.


In [1]:
# Imports and config
from pathlib import Path
import re
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import plotly.graph_objs as go
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

RUNS_DIR = Path('runs')
SMOOTH_WINDOW = 200  # set >1 for moving average smoothing

# Which scalar keys to read from TB; adapt to your logging (CleanRL often uses 'charts/episode_reward' and 'charts/ep_len')
SCALAR_KEYS = [
    'charts/episodic_return',
    'charts/episode_reward',
    'charts/ep_len',
    'eval/episodic_return',
    'train/episodic_return',
    'charts/actual_performance',
]

# Regex for directory names like 'PointMassDiscrete-v0__dqn__0__1760811783'
RUN_NAME_RE = re.compile(r'^(?P<env>[^_]+)__?(?P<algo>[^_]+)__?(?P<seed>\d+).*$')


def parse_run_name(run_dir: Path) -> Tuple[str, str, int]:
    name = run_dir.name
    m = RUN_NAME_RE.match(name)
    if not m:
        # Fallback: try CleanRL default 'algo__env__seed__timestamp'
        parts = name.split('__')
        if len(parts) >= 3:
            # Try to infer env and algo positions heuristically
            # Prefer pattern: env__algo__seed
            try:
                seed = int(parts[2])
                return parts[0], parts[1], seed
            except Exception:
                # Or algo__env__seed
                try:
                    seed = int(parts[2])
                    return parts[1], parts[0], seed
                except Exception:
                    pass
        raise ValueError(f"Unrecognized run dir name: {name}")
    env = m.group('env')
    algo = m.group('algo')
    seed = int(m.group('seed'))
    return env, algo, seed


def load_scalars_from_event(run_dir: Path, keys: List[str]) -> Dict[str, pd.DataFrame]:
    # Find event files
    event_files = sorted(run_dir.glob('events.out.tfevents.*'))
    if not event_files:
        # Also support nested structure where TB logs are in subdirs
        event_files = sorted(run_dir.rglob('events.out.tfevents.*'))
    if not event_files:
        return {}

    # Accumulate across files for robustness
    series: Dict[str, List[Tuple[int, float]]] = {}
    for ef in event_files:
        ea = EventAccumulator(str(ef), size_guidance={'scalars': 100000})
        try:
            ea.Reload()
        except Exception:
            continue
        for key in keys:
            if key in ea.Tags().get('scalars', []):
                events = ea.Scalars(key)
                pts = [(e.step, e.value) for e in events]
                series.setdefault(key, []).extend(pts)

    # Convert to DataFrames, deduplicate by step, sort
    dfs: Dict[str, pd.DataFrame] = {}
    for key, pts in series.items():
        if not pts:
            continue
        df = pd.DataFrame(pts, columns=['step', 'value']).drop_duplicates('step')
        df = df.sort_values('step').reset_index(drop=True)
        if SMOOTH_WINDOW and SMOOTH_WINDOW > 1:
            df['value'] = df['value'].rolling(SMOOTH_WINDOW, min_periods=1).mean()
        dfs[key] = df
    return dfs


def aggregate_by_group(runs_dir: Path, keys: List[str]):
    # Map: (env, algo) -> list of per-seed dicts of key->df
    groups: Dict[Tuple[str, str], List[Dict[str, pd.DataFrame]]] = {}

    for run in sorted(runs_dir.iterdir()):
        if not run.is_dir():
            continue
        try:
            env, algo, seed = parse_run_name(run)
        except Exception:
            continue
        dfs = load_scalars_from_event(run, keys)
        if not dfs:
            continue
        groups.setdefault((env, algo), []).append(dfs)

    # For each group and scalar key, align by step, compute mean/std/count across seeds
    agg: Dict[Tuple[str, str, str], pd.DataFrame] = {}
    for (env, algo), seed_runs in groups.items():
        for key in keys:
            # Collect step-value frames for runs that have this key
            frames = []
            for dfs in seed_runs:
                if key in dfs:
                    frames.append(dfs[key].set_index('step').rename(columns={'value': 'value'}))
            if not frames:
                continue
            # Outer-join on steps, then compute across rows
            aligned = pd.concat(frames, axis=1)
            aligned.columns = [f'seed_{i}' for i in range(len(frames))]
            mean = aligned.mean(axis=1, skipna=True)
            std = aligned.std(axis=1, ddof=0, skipna=True)
            count = aligned.count(axis=1)
            out = pd.DataFrame({
                'step': aligned.index.values,
                'mean': mean.values,
                'std': std.values,
                'count': count.values,
            }).sort_values('step').reset_index(drop=True)
            agg[(env, algo, key)] = out
    return agg


def plot_groups(agg: Dict[Tuple[str, str, str], pd.DataFrame], key_preference: List[str]):
    # Group available keys per (env, algo)
    grouped: Dict[Tuple[str, str], List[Tuple[str, pd.DataFrame]]] = {}
    for (env, algo, key), df in agg.items():
        if key in key_preference:
            grouped.setdefault((env, algo), []).append((key, df))

    figs = {}
    by_env: Dict[str, List[Tuple[str, pd.DataFrame]]] = {}
    for (env, algo), items in grouped.items():
        # Choose key with maximum coverage (median count), tie-break by preference order
        def coverage_score(item):
            key, df = item
            cov = df['count'].median() if 'count' in df.columns else len(df)
            pref_rank = key_preference.index(key)
            return (cov, -pref_rank)
        best_key, best_df = max(items, key=coverage_score)
        by_env.setdefault(env, []).append((algo, best_df))

    for env, items in by_env.items():
        fig = go.Figure()
        for algo, df in sorted(items, key=lambda x: x[0]):
            if df.empty:
                continue
            # Mean line
            fig.add_trace(go.Scatter(
                x=df['step'], y=df['mean'], name=f"{algo}", mode='lines'
            ))
            # ±1 std (≈68%) band
            fig.add_trace(go.Scatter(
                x=pd.concat([df['step'], df['step'][::-1]]),
                y=pd.concat([df['mean'] + df['std'], (df['mean'] - df['std'])[::-1]]),
                fill='toself', fillcolor='rgba(31,119,180,0.15)', line=dict(width=0),
                name=f"{algo} ±1σ (≈68%)", showlegend=False
            ))
        fig.update_layout(
            title=f"{env}: mean ± std across seeds",
            xaxis_title='step', yaxis_title='scalar value', template='plotly_white'
        )
        figs[env] = fig
    return figs


agg = aggregate_by_group(RUNS_DIR, SCALAR_KEYS)
# Original plot (mean ± std ~68%) for returns (choose best-covered key)
figs = plot_groups(agg, key_preference=['charts/episodic_return','charts/episode_reward','eval/episodic_return','train/episodic_return'])
# Additional plot: actual performance time-series
figs_actual = plot_groups(agg, key_preference=['charts/actual_performance'])

# Display figures
for env, fig in figs.items():
    fig.show()
for env, fig in figs_actual.items():
    fig.show()


In [12]:
# Final-value confidence plots (bell curves per algo)

from collections import defaultdict

# Number of most recent points from a run's series to estimate mean/std when only one run exists
RUN_SERIES_WINDOW = 200

def collect_final_values(runs_dir: Path, keys: List[str], key_preference: List[str]):
    # Build groups of runs and store per-run dfs
    grouped: Dict[Tuple[str, str], List[Dict[str, pd.DataFrame]]] = {}
    for run in sorted(runs_dir.iterdir()):
        if not run.is_dir():
            continue
        try:
            env, algo, seed = parse_run_name(run)
        except Exception:
            continue
        dfs = load_scalars_from_event(run, keys)
        if not dfs:
            continue
        grouped.setdefault((env, algo), []).append(dfs)

    # For each group, select key with max coverage (most seeds logging it)
    results: Dict[str, List[Dict[str, object]]] = {}  # env -> list of {algo, mean, std, n, key}
    for (env, algo), runs in grouped.items():
        coverage = defaultdict(int)
        for pref in key_preference:
            for d in runs:
                if pref in d and not d[pref].empty:
                    coverage[pref] += 1
        if not coverage:
            continue
        # Best key: highest coverage; break ties by preference order
        best_key = max(coverage.items(), key=lambda kv: (kv[1], -key_preference.index(kv[0])))[0]

        # Collect per-run final values and per-run series stats (for fallback)
        final_vals = []
        per_run_series_stats = []  # (mean, std, count)
        for d in runs:
            if best_key in d and not d[best_key].empty:
                df = d[best_key]
                final_vals.append(float(df['value'].iloc[-1]))
                # Use last RUN_SERIES_WINDOW points to estimate mean/std for this run
                series_vals = df['value'].tail(RUN_SERIES_WINDOW).astype(float).values
                if series_vals.size > 0:
                    per_run_series_stats.append((float(np.mean(series_vals)), float(np.std(series_vals, ddof=0)), int(series_vals.size)))

        if len(final_vals) == 0:
            continue

        if len(final_vals) >= 2:
            # Use across-seed distribution of final values
            finals = np.array(final_vals, dtype=float)
            mean = float(np.mean(finals))
            std = float(np.std(finals, ddof=0))
            n = int(len(finals))
        else:
            # Single-run group: use that run's own time-series statistics
            if per_run_series_stats:
                mean, std, series_n = per_run_series_stats[0]
                n = 1  # one run; std reflects within-run variability
                # If std ends up zero (constant), set a tiny epsilon for visual non-degeneracy
                if std == 0.0:
                    std = 1e-9
            else:
                # Fallback to final value with tiny std
                mean = float(final_vals[0])
                std = 1e-9
                n = 1

        results.setdefault(env, []).append({
            'algo': algo,
            'mean': mean,
            'std': std,
            'n': n,
            'key': best_key,
        })
    return results


def plot_final_bell_curves(finals_by_env: Dict[str, List[Dict[str, object]]], num_points: int = 400):
    figs = {}
    for env, items in finals_by_env.items():
        if not items:
            continue
        # Determine x-range using ±4σ across algos; fallback if σ==0
        xs_min = []
        xs_max = []
        for it in items:
            mu, sigma = it['mean'], it['std']
            if sigma > 0:
                xs_min.append(mu - 4*sigma)
                xs_max.append(mu + 4*sigma)
        if xs_min and xs_max:
            x_min, x_max = min(xs_min), max(xs_max)
        else:
            # All zero std; create a narrow window around means
            mus = [it['mean'] for it in items]
            center = float(np.mean(mus))
            spread = max(1.0, 0.05*abs(center))
            x_min, x_max = center - spread, center + spread
        if x_min == x_max:
            x_min -= 1.0
            x_max += 1.0
        xs = np.linspace(x_min, x_max, num_points)

        fig = go.Figure()
        for it in sorted(items, key=lambda z: z['algo']):
            algo = it['algo']
            mu = float(it['mean'])
            sigma = float(it['std'])
            pdf = (1.0/(sigma*np.sqrt(2*np.pi))) * np.exp(-0.5*((xs - mu)/sigma)**2)
            fig.add_trace(go.Scatter(x=xs, y=pdf, mode='lines', name=f"{algo} (n={it['n']})", hovertemplate=f"{algo} μ={mu:.3f}, σ={sigma:.3f}, n={it['n']}<extra></extra>"))
            fig.add_vline(x=mu, line_width=1, line_dash='dash', line_color='gray')
        used_key = items[0]['key'] if items else ''
        fig.update_layout(
            title=f"{env}: final distribution across seeds (key: {used_key})",
            xaxis_title='final value', yaxis_title='density', template='plotly_white'
        )
        figs[env] = fig
    return figs


# Final curves for preferred return-like metrics (best-covered key)
finals = collect_final_values(RUNS_DIR, SCALAR_KEYS, key_preference=['charts/episodic_return','charts/episode_reward','eval/episodic_return','train/episodic_return'])
final_figs = plot_final_bell_curves(finals)
for env, fig in final_figs.items():
    fig.show()

# Final curves for actual performance metric
finals_actual = collect_final_values(RUNS_DIR, SCALAR_KEYS, key_preference=['charts/actual_performance'])
final_figs_actual = plot_final_bell_curves(finals_actual)
for env, fig in final_figs_actual.items():
    fig.show()
