In [1]:
%%time
%load_ext jupyter_black
%load_ext autoreload
%autoreload 2

import pandas as pd
import wandb
import pandas as pd
from collections import defaultdict

# Table processing
def process_line(means, highlight, highlight_index, highlight_max, ignore_std):
    if highlight:
        if highlight_max:
            tops = set(means.groupby(highlight_index).idxmax())
        else:
            tops = set(means.groupby(highlight_index).idxmin())
    else:
        tops = set()

    def process_line(x):
        if ignore_std:
            if x.name in tops:
                return rf"\textbf{{{x['mean']:0.3f}}}"
            return rf"{x['mean']:0.3f}"
        if x.name in tops:
            return rf"\textbf{{{x['mean']:0.3f} $\pm$ {x['std']:0.3f}}}"
        return rf"{x['mean']:0.3f} $\pm$ {x['std']:0.3f}"

    return process_line


def mean_pm_std(
    data,
    index,
    columns,
    value,
    highlight=True,
    highlight_cols=True,
    highlight_max=True,
    ignore_std=False,
):
    assert len(data) > 0
    groupby = data.groupby([*index, *columns])
    means = groupby.mean()[value].rename("mean")
    stds = groupby.std()[value].rename("std")
    ddf = pd.concat([means, stds], axis=1).T
    highlight_index = columns if highlight_cols else index
    ddf = ddf.apply(
        process_line(means, highlight, highlight_index, highlight_max, ignore_std)
    )
    ddf = ddf.reset_index().pivot(index=index, columns=columns)
    ddf.columns = ddf.columns.droplevel(level=0)
    return ddf

    
def flatten_dict(d, parent_key="", sep="/"):
    items = []
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)


def prepare_data(data):
    flattened_data = [flatten_dict(item) for item in data]
    return pd.DataFrame(flattened_data)



CPU times: user 1.08 s, sys: 504 ms, total: 1.59 s
Wall time: 3.09 s


In [83]:
api = wandb.Api(timeout=30)

# Project is specified by <entity/project-name>
runs = api.runs(
    "openproblems-comp/DEM-2",
    filters={
        "$and": [
            {
                "tags": {
                    "$in": [
                        "fixed_resampling_var_steps",
                        "start_resampling_gaussian",
                        "fixed_resampling_50steps",
                        "fixed_resampling_v2",
                        "start_resampling_gaussian_v2",
                    ]
                },
                #'group': {'$in': ['5_vars']},
                # "config.data.n_particles": {"$eq": 22},
                #'config.model': {'$eq': model},
                #'config.lr': {'$lt': 1.01 * lr, '$gt': 0.99 * lr},
            }
        ]
    },
)

summary_list, config_list, name_list, tag_list = [], [], [], []
for run in runs:
    tag_list.append(run.tags)
    # .summary contains the output keys/values for metrics like accuracy.
    #  We call ._json_dict to omit large files
    summary_list.append(run.summary._json_dict)
    # .config contains the hyperparameters.
    #  We remove special values that start with _.
    config_list.append({k: v for k, v in run.config.items() if not k.startswith("_")})
    # .name is the human-readable name of the run.
    name_list.append(run.name)
df_summary = prepare_data(summary_list)
df_config = prepare_data(config_list)
tag_list = [str(t) for t in tag_list]
df = pd.concat(
    [
        pd.DataFrame(name_list, columns=["name"]),
        pd.DataFrame(tag_list, columns=["Tags"]),
        df_summary,
        df_config,
    ],
    axis=1,
)

In [92]:
import math


def filterer(x):
    if isinstance(x, float) and not math.isfinite(x):
        return False
    return "table" in list(x)


filtered_df = df[~df["model/annealed_energy/temperature"].isin([0.5, 0.6])][
    [
        # "tags",
        "test/temp_annealed/energy_w2",
        "test/temp_annealed/energy_w1",
        "test/temp_annealed/dist_w2",
        "model/annealed_energy/temperature",
        "model/resampling_interval",
        "model/scale_diffusion",
        "model/start_resampling_step",
        "model/annealed_clipper/max_score_norm",
        "model/num_eval_samples",
        # "test/cropped_energy_w1",
        # "test/resampled/cropped_energy_w1",
        # # "val/effective_sample_size",
        # "test/effective_sample_size",
        # "data/n_particles",
        # # "val/rama/torus_wasserstein",
        # "test/rama/torus_wasserstein",
        # "test/resampled/rama/torus_wasserstein",
        # "model/sampling_config/num_proposal_samples",
        # "model/sampling_config/num_test_proposal_samples",
    ]
]

# filtered_df.sort_values("data/n_particles")

In [105]:
renamed_df = filtered_df
# renamed_df = filtered_df.replace(
#     {
#         "src.models.components.tbg.egnn_dynamics_ad2_cat.EGNN_dynamics_AD2_cat": "EQ-CFM",
#         "src.models.components.dit.DIT3D": "DiT-CFM",
#     }
# ).rename(columns={"model/net/_target_": "Model", "data/n_particles": "n_particles"})

In [106]:
# renamed_df = renamed_df.groupby(
#     [
#         "model/annealed_energy/temperature",
#         "model/resampling_interval",
#         "model/scale_diffusion",
#         "model/start_resampling_step",
#     ]
# ).mean()

In [107]:
mylist = [
    "model/start_resampling_step",
    "model/annealed_energy/temperature",
    "model/resampling_interval",
    "model/scale_diffusion",
    "model/num_eval_samples",
]

In [108]:
metrics = [
    "test/temp_annealed/energy_w2",
    "test/temp_annealed/dist_w2",
    "test/temp_annealed/energy_w1",
]
df_melt = renamed_df.melt(  # .dropna()
    value_vars=metrics,
    id_vars=mylist,
    var_name="Metric",
)


df_melt.loc[
    df_melt["model/resampling_interval"] == -1, "model/start_resampling_step"
] = 0

In [109]:
pd.set_option("display.max_rows", 500)

In [110]:
results = mean_pm_std(
    df_melt, index=mylist, columns=["Metric"], value="value", highlight=False
)
results

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Metric,test/temp_annealed/dist_w2,test/temp_annealed/energy_w1,test/temp_annealed/energy_w2
model/start_resampling_step,model/annealed_energy/temperature,model/resampling_interval,model/scale_diffusion,model/num_eval_samples,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
0.0,0.7,-1.0,False,22000.0,1.042 $\pm$ nan,nan $\pm$ nan,292.037 $\pm$ nan
0.0,0.7,-1.0,True,22000.0,1.048 $\pm$ 0.012,nan $\pm$ nan,84.635 $\pm$ 0.677
0.0,0.8,-1.0,False,22000.0,0.937 $\pm$ nan,nan $\pm$ nan,204.991 $\pm$ nan
0.0,0.8,-1.0,True,22000.0,0.928 $\pm$ 0.033,nan $\pm$ nan,38.751 $\pm$ 0.788
0.0,0.9,-1.0,False,22000.0,0.848 $\pm$ nan,nan $\pm$ nan,175.877 $\pm$ nan
0.0,0.9,-1.0,True,22000.0,0.859 $\pm$ 0.012,nan $\pm$ nan,27.607 $\pm$ 0.290
0.0,1.0,-1.0,False,22000.0,0.792 $\pm$ nan,nan $\pm$ nan,170.923 $\pm$ nan
0.0,1.0,-1.0,True,22000.0,0.779 $\pm$ 0.009,nan $\pm$ nan,27.665 $\pm$ 0.854
10.0,0.7,1.0,True,10000.0,1.044 $\pm$ nan,nan $\pm$ nan,170.760 $\pm$ nan
10.0,0.7,1.0,True,22000.0,1.046 $\pm$ 0.026,nan $\pm$ nan,131.713 $\pm$ 87.975


In [21]:
print(
    results.to_latex(
        float_format="{:.3f}".format,
    )
)

\begin{tabular}{llllllllll}
\toprule
n_particles & \multicolumn{3}{r}{22.000000} & \multicolumn{3}{r}{33.000000} & \multicolumn{3}{r}{42.000000} \\
Metric & test/effective_sample_size & test/resampled/cropped_energy_w1 & test/resampled/rama/torus_wasserstein & test/effective_sample_size & test/resampled/cropped_energy_w1 & test/resampled/rama/torus_wasserstein & test/effective_sample_size & test/resampled/cropped_energy_w1 & test/resampled/rama/torus_wasserstein \\
Model &  &  &  &  &  &  &  &  &  \\
\midrule
DiT-CFM & 0.025 $\pm$ nan & 0.473 $\pm$ nan & 0.618 $\pm$ nan & 0.027 $\pm$ 0.007 & 3.042 $\pm$ 0.594 & 1.755 $\pm$ 0.102 & 0.023 $\pm$ 0.001 & 5.970 $\pm$ 0.782 & 2.915 $\pm$ 0.058 \\
EQ-CFM & 0.233 $\pm$ 0.042 & 0.825 $\pm$ 0.038 & 0.349 $\pm$ 0.050 & 0.036 $\pm$ 0.027 & 1.759 $\pm$ 0.788 & 1.967 $\pm$ 0.062 & 0.059 $\pm$ 0.010 & 1.694 $\pm$ 0.058 & 2.737 $\pm$ 0.035 \\
\bottomrule
\end{tabular}

