In [1]:
import os

import pandas as pd
import plotly.express as px
from analysis.analysis_utils import get_runs, fetch_projects

pd.options.plotting.backend = "plotly"

%load_ext autoreload
%autoreload 2

In [2]:
# Define WandB Parameters
entity = "phys_inversion"
env_name = 'direction_facing'
project_name = f"chens_runs_{env_name}"
filters = {"state": {"$eq": "finished"}}
output_dir = f"convergence/{env_name}"


In [3]:
pm_rename_metrics = {
    "info/episode_reward": "Episode Reward",
}
pulse_rename_metrics = {
    "mean_rewards": "Episode Reward",
}

order = ['MaskedMimic_Inversion_prior_False',
         'MaskedMimic_Inversion_prior_True',
         'MaskedMimic_Prior_Only_prior_True',
         'MaskedMimic_Finetune_prior_False',
         'MaskedMimic_Finetune_prior_True',
         'PULSE_prior_False',
         'AMP_prior_False',
         'PureRL_prior_False',
         'PPO_prior_False', ]

env_pretty_names = {
    'reach': 'Reach',
    'steering': 'Direction',
    'direction_facing': 'Steering',
    'long_jump': 'Long Jump',
    'strike': 'Strike'
}

algo_pretty_names = {
    'MaskedMimic_Inversion_prior_False': 'Task Tokens (ours)',
    'MaskedMimic_Inversion_prior_True': 'Task Tokens (ours) + J.C.',
    'MaskedMimic_Prior_Only_prior_True': 'MaskedMimic (J.C. only)',
    'MaskedMimic_Finetune_prior_False': 'MaskedMimic Fine-Tune',
    'MaskedMimic_Finetune_prior_True': 'MaskedMimic Fine-Tune + J.C.',
    'PULSE_prior_False': 'PULSE',
    'AMP_prior_False': 'AMP',
    'PureRL_prior_False': 'PureRL',
    'PPO_prior_False': 'PPO',
}

# Consistent color palette for algorithms
pretty_names_order = [algo_pretty_names[algo] for algo in order]
color_discrete_map = {algo_pretty_names[algo]: px.colors.qualitative.Set2[i % len(px.colors.qualitative.Set2)] for
                      i, algo in
                      enumerate(order)}


In [4]:
def try_fill_columns(df, verbose=False):
    for col in tqdm(df.columns, desc="Filling columns", position=0, leave=False):
        # Check if all values are NaN except the first one
        if df[col].notna().sum() == 1:
            # Fill the entire column with the first value
            try:
                df[col].fillna(df[col].iloc[0], inplace=True)
            except Exception as e:
                if verbose:
                    print(f"Error filling column {col}: {e}")
    return df

In [5]:
from tqdm.notebook import tqdm
import wandb


def get_runs(entity, project, filters=None, keys=None, samples=40e3):
    if filters is None:
        filters = {"state": {"$eq": "finished"}}
    api = wandb.Api()
    runs = api.runs(f"{entity}/{project}", filters=filters)
    data = []
    for run in tqdm(runs, desc=f"Fetching runs from {project}", position=1):
        history = run.history(keys=keys, pandas=True, samples=samples)  # Adjust as needed

        # Fetch hyperparameters
        config = run.config
        history = pd.concat([history, pd.json_normalize(config, sep='.')], axis=1)

        history["run_name"] = run.name
        history["run_id"] = run.id

        data.append(history)
    return pd.concat(data).reset_index() if data else pd.DataFrame()


In [6]:
import plotly.graph_objects as go
import plotly.express as px


def create_beam_plot(df, x_col, y_col, std_col, title, x_label, y_label, order, algo_pretty_names, color_discrete_map,
                     fig=None):
    if fig is None:
        show = True
        fig = go.Figure()
    else:
        show = False

    traces = []

    for i, algo in enumerate(order):
        algo_df = df[df['algo_str'] == algo]
        if algo_df.empty:
            continue
        pretty_name = algo_pretty_names.get(algo, algo)

        # Add the upper bound of the std
        fig.add_trace(go.Scatter(
            x=algo_df[x_col],
            y=algo_df[y_col] + algo_df[std_col],
            fill=None,
            mode='lines',
            line=dict(width=0),
            showlegend=False,
        ))
        
        # Add the lower bound of the std
        color = color_discrete_map.get(pretty_name, 'black')
        color_alpha = color.replace(')', ', 0.2)').replace('rgb', 'rgba')
        fig.add_trace(go.Scatter(
            x=algo_df[x_col],
            y=algo_df[y_col] - algo_df[std_col],
            fillcolor=color_alpha,
            fill='tonexty',
            mode='lines',
            line=dict(width=0),
            showlegend=False
        ))
        
        # Add the mean line
        trace = go.Scatter(
            x=algo_df[x_col],
            y=algo_df[y_col],
            mode='lines',
            name=pretty_name,
            line=dict(width=2, color=color)
        )
        traces.append(trace)
        fig.add_trace(trace)

    # Reorder legend based on the order list
    fig.update_layout(
        title=title,
        xaxis_title=x_label,
        yaxis_title=y_label,
        autosize=False,
        width=1200,
        height=800,
        plot_bgcolor='white',
        paper_bgcolor='white',
        font=dict(color='black'),
        legend=dict(
            orientation='h',
            yanchor='bottom',
            y=-0.2,
            xanchor='center',
            x=0.5,
            itemsizing='constant',
        )
    )
    if show:
        fig.show()


In [7]:
GENERATE_RAW_DATA = False
if GENERATE_RAW_DATA:
    project = fetch_projects(entity, project_prefix=project_name)[0]
    os.makedirs(f"{output_dir}/raw", exist_ok=True)
    df = get_runs(entity, project.name, filters, keys=['info/episode_reward'], samples=4000)
    # split into runs and save
    df_dict = {}
    for run_name, run_df in df.groupby('run_name'):
        run_df.to_csv(f"{output_dir}/raw/{run_name}.csv", index=False)
        df_dict[run_name] = run_df
else:
    # load dfs into dictionary
    df_dict = {}
    for filename in os.listdir(f"{output_dir}/raw"):
        run_name = os.path.basename(filename)
        df_dict[run_name] = pd.read_csv(f"{output_dir}/raw/{filename}", low_memory=False)

In [8]:
df_dict = {run_name: try_fill_columns(run_df) for run_name, run_df in df_dict.items()}
pm_df = pd.concat(df_dict).reset_index()
pm_df.loc[pm_df['algo_type'] == 'AMP', 'info/episode_reward'] = pm_df[pm_df['algo_type'] == 'AMP'][
                                                                    'info/episode_reward'] * 2
pm_df['algo_str'] = pm_df['algo_type'] + '_prior_' + pm_df['prior'].fillna(False).map(str)
pm_df['algo_str'] = pm_df['algo_type'] + '_prior_' + pm_df['prior'].fillna(False).map(str)


Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

Filling columns:   0%|          | 0/1275 [00:00<?, ?it/s]

In [9]:
GENERATE_RAW_DATA_PULSE = True
if GENERATE_RAW_DATA_PULSE:
    project = [project for project in fetch_projects(entity, project_prefix="pulse") if "debug" not in project.name][0]
    filters = {"state": {"$eq": "finished"}, "display_name": {"$regex": f".*{env_name}.*"}}
    p_df = get_runs(entity=entity, project=project.name, filters=filters, keys=['mean_rewards'], samples=30e3)
    p_df.to_csv(f"{output_dir}/raw/pulse.csv", index=False)
else:
    p_df = pd.read_csv(f"{output_dir}/raw/pulse.csv", low_memory=False)

Fetching runs from pulse:   0%|          | 0/5 [00:00<?, ?it/s]

[34m[1mwandb[0m: Network error (HTTPError), entering retry loop.


KeyboardInterrupt: 

In [None]:
df_pulse_plot = p_df.copy()
df_pulse_plot = df_pulse_plot.groupby('index')['mean_rewards'].agg(['mean', 'std']).reset_index()
df_pulse_plot['algo_type'] = 'PULSE'
df_pulse_plot['prior'] = False
df_pulse_plot['algo_str'] = df_pulse_plot['algo_type'] + '_prior_' + df_pulse_plot['prior'].map(str)

In [None]:
# contatenate pm and pulse
pm_df_plot = pd.concat([pm_df, df_pulse_plot])

In [None]:
fig = go.Figure()

for algo_str in reversed(order):
    algo_df = pm_df_plot[pm_df_plot['algo_str'] == algo_str]
    algo_df = algo_df.groupby('index')['info/episode_reward'].agg(['mean', 'std']).reset_index()
    algo_df['algo_str'] = algo_str
    create_beam_plot(algo_df, x_col='index', y_col='mean', std_col='std', title='', x_label='Step',
                     y_label='Mean Reward', fig=fig, order=order, algo_pretty_names=algo_pretty_names,
                     color_discrete_map=color_discrete_map)
fig.update_layout(
    legend_traceorder='reversed',
)
fig.show()