In [42]:
import os
from glob import glob
from typing import List

import pandas as pd
import wandb

pd.options.plotting.backend = "plotly"
api = wandb.Api()

In [43]:
# Define WandB Parameters
entity = "phys_inversion"
project_prefix = "FINALLY__"
filters = {"state": {"$eq": "finished"}}

# Fetch all projects for the authenticated user
projects = api.projects(entity=entity)
projects = [project for project in projects if project.name.startswith(project_prefix)]

rename_metrics = {
    "env/Env reach_success": "reach_success",
    "env/Env reach_distance": "reach_distance",
    "core/Episode Reward": "episode_reward"
}

In [44]:
GENERATE_RAW_CSV = True
GENERATE_GROUPED_CSV = True
GENERATE_TABLES = True

In [45]:
# Fetch Data from WandB

def get_runs(entity, project, filters, keys=None):
    runs = api.runs(f"{entity}/{project}", filters=filters)
    data = []
    for run in runs:
        history = run.history(keys=keys, pandas=True)  # Adjust as needed

        # Fetch hyperparameters
        config = run.config
        history = pd.concat([history, pd.json_normalize(config, sep='.')], axis=1)

        history["run_name"] = run.name
        history["run_id"] = run.id

        data.append(history)
    return pd.concat(data).reset_index() if data else pd.DataFrame()


In [46]:

# ALL TOGETHER PER PROJECT
from tqdm import tqdm


def generate_raw_csvs():
    os.makedirs(f"{project_prefix}/raw", exist_ok=True)
    for project in tqdm(projects):
        df = get_runs(entity, project.name, filters)
        df.to_csv(f"{project_prefix}/raw/{project.name}.csv", index=False)

In [47]:
if GENERATE_RAW_CSV:
    generate_raw_csvs()

100%|██████████| 5/5 [04:56<00:00, 59.39s/it]


In [48]:
raw_csvs_paths = os.listdir(f"{project_prefix}/raw")
grouped_dir = f"{project_prefix}/grouped"
os.makedirs(grouped_dir, exist_ok=True)


def generate_grouped_csvs(raw_csvs_paths=raw_csvs_paths):
    for project_file in raw_csvs_paths:
        df = pd.read_csv(f"{project_prefix}/raw/{project_file}")
        df.rename(columns=rename_metrics, inplace=True)
        df[df['reach_success'] == "NaN"] = 0

        # multiply by 100 to get percentage
        df['reach_success'] = df['reach_success'] * 100

        # check if algo type is null, if so then if env.config.disable_discriminator is True then it's PPO else it's AMP
        df['algo_type'] = df.apply(lambda x:
                                   x['algo_type'] if pd.notna(x['algo_type'])
                                   else ('PPO' if x['env.config.disable_discriminator'] else 'AMP'), axis=1)
        df['prior'] = df['prior'].fillna(False)

        df = df[df.any(axis=1)]

        df_grouped = df.groupby(["algo_type", "prior", "use_perturbations"], dropna=False).agg(
            {"reach_success": ["mean", "std"]})
        df_grouped.columns = ["_".join(col).strip() for col in df_grouped.columns.to_flat_index()]
        df_grouped.reset_index(inplace=True)

        os.makedirs(project_prefix, exist_ok=True)
        df_grouped.to_csv(f"{grouped_dir}/{project_file}", float_format="%.4f")


In [49]:
if GENERATE_GROUPED_CSV:
    generate_grouped_csvs()

In [50]:
def generate_combined_df(grouped_dir):
    # Read all CSV files
    csv_files = glob(os.path.join(grouped_dir, "*.csv"))
    keep_cols = ['algo_type', 'prior', 'use_perturbations', 'reach_success_mean', 'reach_success_std']
    renamed_keep_cols = ['algo_type', 'prior', 'use_perturbations', 'mean', 'std']
    # Dictionary to store data
    data = {}
    for file in csv_files:
        env_name = os.path.splitext(os.path.basename(file))[0].replace(project_prefix, '')  # Remove project prefix
        df = pd.read_csv(file)
        df = df.dropna(axis=1, how='all')  # Drop empty columns
        df = df[keep_cols]
        df.columns = renamed_keep_cols
        df['mean'] = pd.to_numeric(df['mean'], errors='coerce')  # Convert to numeric, coerce errors to NaN
        df['std'] = pd.to_numeric(df['std'], errors='coerce')  # Convert to numeric, coerce errors to NaN
        df['value'] = df.apply(
            lambda row: f"{row['mean']:.2f} ± {row['std']:.2f}" if pd.notna(row['std']) else f"{row['mean']:.2f}",
            axis=1)
        df['algo_str'] = df['algo_type'] + '_prior_' + df['prior'].map(str)
        df['env_perturb'] = env_name + '_perturb_' + df['use_perturbations'].map(str)

        for ep in df['env_perturb'].unique():
            data[ep] = df[df['env_perturb'] == ep].set_index(['algo_str', 'algo_type', 'prior'])['value']
    # Combine data into a single DataFrame
    combined_df = pd.concat(data, axis=1).reset_index()
    return combined_df


combined_df = generate_combined_df(grouped_dir)

In [51]:
def reorder_rows(df):
    # Define order of rows
    order = ['MaskedMimic_Inversion_prior_False',
             'MaskedMimic_Inversion_prior_True',
             'MaskedMimic_Prior_Only_prior_True',
             'MaskedMimic_Finetune_prior_False',
             'MaskedMimic_Finetune_prior_True',
             'PULSE_prior_False',
             'AMP_prior_False',
             'PureRL_prior_False',
             'PPO_prior_False', ]

    # Set order of rows
    df = df.sort_values('algo_str', key=lambda x: x.map(order.index))

    return df


def merge_rows(combined_df):
    # Merge 'PureRL_prior_False' and 'PPO_prior_False' rows
    df = combined_df.copy()
    if 'PureRL' in df['algo_type'].values and 'PPO' in df['algo_type'].values:
        pure_rl_row = df[df['algo_type'] == 'PureRL'].iloc[0]
        ppo_row = df[df['algo_type'] == 'PPO'].iloc[0]
        merged_row = ppo_row.combine_first(pure_rl_row)
        df = df[~df['algo_type'].isin(['PureRL', 'PPO'])]
        df = pd.concat([df, merged_row.to_frame().T], ignore_index=True)
    return df


def post_process_df(df):
    df = reorder_rows(df)
    df = merge_rows(df)
    return df

In [52]:
final_df = post_process_df(combined_df)
final_df

Unnamed: 0,algo_str,algo_type,prior,reach_perturb_False,reach_perturb_True,long_jump_perturb_False,long_jump_perturb_True,strike_perturb_False,strike_perturb_True,direction_facing_perturb_False,steering_perturb_False,BUG_steering_perturb_False,BUG_steering_perturb_True,BUG_direction_facing_perturb_False,BUG_direction_facing_perturb_True
0,MaskedMimic_Inversion_prior_False,MaskedMimic_Inversion,False,95.37 ± 1.80,61.67 ± 38.73,99.81 ± 0.42,45.02 ± 48.36,70.27 ± 4.56,45.19 ± 11.00,83.66 ± 5.66,96.89 ± 4.33,97.05 ± 2.98,59.82 ± 43.96,79.17 ± 4.11,43.34 ± 34.04
1,MaskedMimic_Inversion_prior_True,MaskedMimic_Inversion,True,94.88 ± 1.99,65.06 ± 37.77,,,,,88.69 ± 4.04,99.26 ± 0.79,99.47 ± 0.46,52.02 ± 37.87,85.96 ± 3.67,55.98 ± 35.67
2,MaskedMimic_Prior_Only_prior_True,MaskedMimic_Prior_Only,True,24.77,23.99 ± 2.85,,,,,3.83,2.19,0.48,0.22 ± 0.37,3.77,3.53 ± 1.55
3,MaskedMimic_Finetune_prior_False,MaskedMimic_Finetune,False,93.70 ± 4.59,52.40 ± 33.50,47.38 ± 54.74,22.39 ± 34.85,79.61 ± 7.01,28.35 ± 23.19,87.44 ± 6.79,99.10 ± 1.29,99.23 ± 0.89,50.12 ± 43.89,90.67 ± 7.19,30.96 ± 32.58
4,MaskedMimic_Finetune_prior_True,MaskedMimic_Finetune,True,92.88 ± 3.42,66.31 ± 18.70,,,,,96.41 ± 4.94,98.86 ± 0.32,98.59 ± 1.66,45.91 ± 37.75,94.01 ± 7.44,30.14 ± 35.07
5,AMP_prior_False,AMP,False,57.14 ± 4.80,38.23 ± 16.67,76.26 ± 43.36,7.39 ± 21.94,30.92 ± 38.38,26.03 ± 31.65,4.28 ± 1.42,5.14 ± 0.68,5.83 ± 2.17,3.27 ± 2.34,4.30 ± 1.24,2.81 ± 2.36
6,PPO_prior_False,PPO,False,89.90 ± 3.25,66.83 ± 19.07,60.59 ± 53.98,11.81 ± 30.56,42.36 ± 35.51,41.11 ± 23.68,32.64 ± 40.21,97.74 ± 1.40,94.69 ± 10.67,12.91 ± 22.32,30.77 ± 38.23,6.61 ± 19.08


# Generate Tables

In [53]:
latex_output_dir = f"{project_prefix}/latex_tables"
os.makedirs(latex_output_dir, exist_ok=True)

In [54]:
env_pretty_names = {
    'reach': 'Reach',
    'steering': 'Direction',
    'direction_facing': 'Steering',
    'long_jump': 'Long Jump',
    'strike': 'Strike'
}

algo_pretty_names = {
    'MaskedMimic_Inversion_prior_False': 'Task Tokens (ours)',
    'MaskedMimic_Inversion_prior_True': 'Task Tokens (ours) + J.C.',
    'MaskedMimic_Prior_Only_prior_True': 'MaskedMimic (J.C. only)',
    'MaskedMimic_Finetune_prior_False': 'MaskedMimic Fine-Tune',
    'MaskedMimic_Finetune_prior_True': 'MaskedMimic Fine-Tune + J.C.',
    'PULSE_prior_False': 'PULSE',
    'AMP_prior_False': 'AMP',
    'PureRL_prior_False': 'PureRL',
    'PPO_prior_False': 'PPO',
}


In [94]:
def max_values_mask(df, ignore_cols=None):
    if ignore_cols is None:
        ignore_cols = []
        # create false mask like df
    mask = pd.DataFrame(False, index=df.index, columns=df.columns)
    for col in df.columns:
        if col in ignore_cols:
            continue
        mask[col] = df[col] == df[col].max()
    return mask

In [61]:
def generate_latex(df, output_file_path, caption=None, label=None, bold_mask=None, **kwargs):
    """
    Generate a LaTeX table from a DataFrame with optional bold values.

    Parameters:
    - df: DataFrame
    - output_file_path: str, path to save the LaTeX file
    - caption: str, optional caption for the table
    - label: str, optional label for the table
    - bold_mask: DataFrame-like, optional, same shape as df with True for values to be bold
    - **kwargs: additional arguments passed to DataFrame.to_latex
    """
    if bold_mask is not None:
        df = df.mask(bold_mask, df.applymap(lambda x: f"\\textbf{{{x}}}"))

    # Auto-generate column format: 'l' for the first column, 'c' for the rest
    column_format = 'l' + 'c' * (len(df.columns) - 1)
    df.to_latex(
        output_file_path,
        index=False,
        float_format="%.2f",
        column_format=column_format,
        caption=caption,
        label=label,
        escape=False,
        **kwargs
    )


## Table 1 - Main Results

In [56]:
def generate_table(final_df: pd.DataFrame, algos: List[str], envs: List[str]):
    # different methods - PPO, AMP, PULSE, PRIOR_ONLY, INVERSION prior FALSE
    # 5 envs, no perturbations
    table_df = final_df.copy()
    table_df = table_df[table_df['algo_str'].isin(algos)]
    non_perturb_cols = [col for col in table_df.columns if 'perturb_False' in col]
    table_df = table_df[['algo_str'] + non_perturb_cols]

    renamed_cols = [col.replace('_perturb_False', '') for col in non_perturb_cols]
    # Rename columns
    table_df.columns = ['Method'] + renamed_cols
    # Rename algo names
    table_df['Method'] = table_df['Method'].replace(algo_pretty_names)
    table_df = table_df[['Method'] + envs]
    table_df = table_df.rename(columns=env_pretty_names)
    # table_df = table_df.set_index('Method')
    table_df = table_df.fillna('-')
    return table_df

In [57]:
 algos_1 = ['MaskedMimic_Inversion_prior_False',
            # 'MaskedMimic_Inversion_prior_True',
            'MaskedMimic_Prior_Only_prior_True',
            # 'MaskedMimic_Finetune_prior_False',
            # 'MaskedMimic_Finetune_prior_True',
            'PULSE_prior_False',
            'AMP_prior_False',
            'PureRL_prior_False',
            'PPO_prior_False', ]

In [63]:
table_1 = generate_table(final_df, algos_1, envs=['reach', 'steering', 'direction_facing', 'long_jump', 'strike'])
bold_mask = table_1.applymap(lambda x: x == 'MaskedMimic_Inversion_prior_False')
generate_latex(table_1, f"{latex_output_dir}/table_1.tex", caption="Main Results", label="tab:table_1",bold_mask=bold_mask)
table_1

Unnamed: 0,Method,Reach,Direction,Steering,Long Jump,Strike
0,Task Tokens (ours),95.37 ± 1.80,96.89 ± 4.33,83.66 ± 5.66,99.81 ± 0.42,70.27 ± 4.56
2,MaskedMimic (J.C. only),24.77,2.19,3.83,-,-
5,AMP,57.14 ± 4.80,5.14 ± 0.68,4.28 ± 1.42,76.26 ± 43.36,30.92 ± 38.38
6,PPO,89.90 ± 3.25,97.74 ± 1.40,32.64 ± 40.21,60.59 ± 53.98,42.36 ± 35.51



## Table 2 - Ablations

In [75]:
env_name = 'direction_facing'

algos_2 = ['MaskedMimic_Inversion_prior_False',
           'MaskedMimic_Inversion_prior_True',
           'MaskedMimic_Prior_Only_prior_True',
           'MaskedMimic_Finetune_prior_False',
           'MaskedMimic_Finetune_prior_True',
           # 'PULSE_prior_False',
           # 'AMP_prior_False',
           # 'PureRL_prior_False',
           # 'PPO_prior_False',
           ]
table_2 = generate_table(final_df, algos_2, envs=[env_name])
generate_latex(table_2, f"{latex_output_dir}/{env_name}_table_2.tex", caption="Ablation Study", label=f"tab:{env_name}_table_2")
table_2

Unnamed: 0,Method,Steering
0,Task Tokens (ours),83.66 ± 5.66
1,Task Tokens (ours) + J.C.,88.69 ± 4.04
2,MaskedMimic (J.C. only),3.83
3,MaskedMimic Fine-Tune,87.44 ± 6.79
4,MaskedMimic Fine-Tune + J.C.,96.41 ± 4.94


In [95]:
env_name = 'reach'

table_2_1 = generate_table(final_df, algos_2, envs=[env_name])
generate_latex(table_2_1, f"{latex_output_dir}/{env_name}_table_2.tex", caption="Ablation Study", label=f"tab:{env_name}_table_2", bold_mask=max_values_mask(table_2_1, ignore_cols=['Method']))
table_2_1

Unnamed: 0,Method,Reach
0,Task Tokens (ours),95.37 ± 1.80
1,Task Tokens (ours) + J.C.,94.88 ± 1.99
2,MaskedMimic (J.C. only),24.77
3,MaskedMimic Fine-Tune,93.70 ± 4.59
4,MaskedMimic Fine-Tune + J.C.,92.88 ± 3.42
