In [None]:
import wandb
import pandas as pd
import re
import plotly.express as px
import plotly.graph_objects as go
import json

In [None]:
wandb.login()

In [None]:
ENTITY = "p-j-c-michielsen-eindhoven-university-of-technology"
PROJECT = "thesis_experiments"
REWARD_SUFFIX = "reward/reward"

# Setup
api = wandb.Api()
runs = api.runs(f"{ENTITY}/{PROJECT}")

print(f"Found {len(runs)} runs.")
tagged_dfs = {}

for run in runs:
    print(f"🔍 Run: {run.name}")

    key = "_".join(run.tags) if run.tags else run.name

    try:
        df = run.history(samples=10000, keys=None)
    except Exception as e:
        print(f"⚠️ Failed to fetch history for {run.name}: {e}")
        continue

    reward_cols = [col for col in df.columns if col.endswith("reward/reward")]
    if not reward_cols:
        print(f"⚠️ No reward/reward columns in {run.name}")
        continue

    cols_to_keep = ["global_step"] if "global_step" in df.columns else []
    cols_to_keep += reward_cols
    df_filtered = df[cols_to_keep].dropna(subset=reward_cols, how="all").copy()

    tagged_dfs[key] = df_filtered
    print(f"✅ Stored: {key} with {len(df_filtered)} rows")

print(f"\n✅ Collected {len(tagged_dfs)} tagged DataFrames.")

og_dfs = tagged_dfs

In [None]:
df_chain = tagged_dfs["CHAINSAW_BASELINE"]
fig = px.line(df_chain, x="global_step", y="reward/reward")
fig.show()

In [None]:
tagged_dfs["CHAINSAW_BASELINE"]["reward/reward"]

In [None]:
tagged_dfs.keys()

In [None]:
scaled_tagged = {}
for run_name, df in tagged_dfs.items():
    df = df.rename(columns={'reward/reward': 'reward'}).copy()

    if run_name.endswith('_BASELINE'):
        print(f"{run_name} (scale based on full data)")
        max_r = df['reward'].max()
        scale = 100.0 / max_r if max_r > 0 else 1.0
        print(f"Scale factor: {scale}")
    
        df = df[df['global_step'] <= 2e6].copy()
    else:
        scale = 1.0

    df['reward'] = df['reward'] * scale
    df['scale_factor'] = scale
    scaled_tagged[run_name] = df

tagged_dfs = scaled_tagged

In [None]:
tagged_dfs["HEALTH_GATHERING_BASELINE"]["reward"]

In [None]:


for df_name in tagged_dfs.keys():
    
    df = tagged_dfs[df_name].copy()
    df["color_group"] = df["global_step"] // 2_000_000
    fig = go.Figure()
    
    for group, group_df in df.groupby("color_group"):
        fig.add_trace(go.Scatter(
            x=group_df["global_step"],
            y=group_df["reward"],
            mode='lines',
            name=f'Group {group}'
        ))
    
    fig.update_layout(
        title=df_name,
        xaxis_title="global_step",
        yaxis_title="reward/reward",
        showlegend=False,
        plot_bgcolor='rgba(0,0,0,0)',
        paper_bgcolor='rgba(0,0,0,0)'
    )
    
    fig.show()

In [None]:
TOTAL_STEPS = 2_000_000

RUN_TASKS = {
    'ARMS_DEALER_BASELINE': ['arms_dealer'],
    'CHAINSAW_BASELINE':   ['chainsaw'],
    'HIDE_AND_SEEK_BASELINE': ['hide_and_seek'],
    'RAISE_THE_ROOF_BASELINE': ['raise_the_roof'],
    'RUN_AND_GUN_BASELINE': ['run_and_gun'],
    'PITFALL_BASELINE':    ['pitfall'],
    'FLOOR_IS_LAVA_BASELINE': ['floor_is_lava'],
    'HEALTH_GATHERING_BASELINE': ['health_gathering'],
    'ALL_SCENARIOS_BASELINE_SCALED': [
        "arms_dealer","chainsaw","floor_is_lava",
        "health_gathering","hide_and_seek","pitfall",
        "raise_the_roof","run_and_gun"
    ],
    'TARGETING':        ['health_gathering','arms_dealer'],
    'TARGETING_THEN_KILL':   ['health_gathering','arms_dealer','chainsaw','run_and_gun'],
    'TARGETING_THEN_SURVIVE': ['health_gathering','arms_dealer','floor_is_lava',
                               'hide_and_seek','raise_the_roof'],
    'KILL_THEN_PITFALL': ['chainsaw','run_and_gun','pitfall'],
    'TARGETING_THEN_PITFALL': ['health_gathering','arms_dealer','pitfall'],
    'TARGETING_FLIPPED': ['arms_dealer','health_gathering'],
    'KILL_FLIPPED':      ['run_and_gun','chainsaw'],
    'SURVIVE_FRH': ['floor_is_lava','raise_the_roof','hide_and_seek'],
    'SURVIVE_FHR': ['floor_is_lava','hide_and_seek','raise_the_roof'],
    'SURVIVE_HRF': ['hide_and_seek','raise_the_roof','floor_is_lava'],
    'SURVIVE_HFR': ['hide_and_seek','floor_is_lava','raise_the_roof'],
    'SURVIVE_RFH': ['raise_the_roof','floor_is_lava','hide_and_seek'],
    'SURVIVE_RHF': ['raise_the_roof','hide_and_seek','floor_is_lava'],
}

TOTAL_STEPS = 2_000_000

scenario_segments = {t: [] for t in all_tasks}

for run_name, df in tagged_dfs.items():
    tasks = RUN_TASKS.get(run_name, [])
    if not tasks:
        continue
        
    run_max = df['global_step'].max()

    n_phases = len(tasks)
    edges = np.linspace(0, run_max, n_phases + 1)

    for phase_idx, task in enumerate(tasks):
        start, end = edges[phase_idx], edges[phase_idx + 1]
        seg = df[(df.global_step >= start) & (df.global_step < end)].copy()
        if seg.empty:
            continue

        seg['step_norm'] = (
            (seg['global_step'] - start)
            / (end - start)
        ) * TOTAL_STEPS

        # tag it
        seg['run']      = run_name
        seg['scenario'] = task

        scenario_segments[task].append(seg)

scenario_dfs = {
    task: pd.concat(frames, ignore_index=True) if frames else pd.DataFrame()
    for task, frames in scenario_segments.items()
}


In [None]:
df = scenario_dfs['health_gathering']
fig = px.line(
    df, 
    x='step_norm', 
    # x='global_step',
    y='reward', 
    color='run',
    labels={
      'step_norm':'Steps',
      'reward_scaled':'Reward'
    }
)
fig.update_layout(title='Chainsaw: Baseline vs. Composite Runs')
fig.show()

In [None]:
for task, df in scenario_dfs.items():
    if df.empty:
        continue
    pretty = task.replace('_', ' ').title()
    fig = px.line(
        df,
        x='step_norm',
        y='reward',
        color='run',
        labels={
            'step_norm': 'Steps',
            'reward':    'Reward',
            'run':       'Run'
        }
    )
    fig.update_layout(
        title_text=f"{pretty}: Baseline vs. Composites",
        title_x=0.5,
        width=500,
        height=400,
        legend=dict(
            title_text="Run",
            orientation="h",
            yanchor="top",
            y=-0.2,
            xanchor="center",
            x=0.5,
            font=dict(size=10)
        ),
        margin=dict(l=20, r=20, t=60, b=80)
    )

    fig.show()