In [1]:
import wandb
import os
import numpy as np
import pandas as pd
from argparse import ArgumentParser
from pathlib import Path
import itertools

def fetch_data(groups, sweeps):
    api =  wandb.Api()
    sweep_runs = [api.sweep(f"pfrommerd/diffusion_policy_sweeps/{s}").runs for s in sweeps]
    runs = itertools.chain(*sweep_runs)
    run_datas = []
    for r in runs:
        num_data = r.config['task']['dataset']['max_train_episodes']
        sigma = r.config['global_cond_noise']
        seed = r.config['training']['seed']
        print(f"Fetching {num_data}, {sigma} {r.id} {seed}")
        replica = r.history(keys=["test/replica_mean_score"])
        replica = np.array(replica["test/replica_mean_score"])
        deconv = r.history(keys=["test/deconv_mean_score"])
        deconv = np.array(deconv["test/deconv_mean_score"])
        run_datas.append({
            "num": num_data,
            "sigma": sigma,
            "type": "replica",
            "seed": seed,
            "max_score": replica.max(),
        })
        run_datas.append({
            "num": num_data,
            "sigma": sigma,
            "type": "deconv",
            "seed": seed,
            "max_score": deconv.max(),
        })
    data = pd.DataFrame(run_datas)
    return data

In [2]:
data = fetch_data([],["54txftng"])
print(data)

Fetching 40, 1 361895iw 44
Fetching 40, 1 9k0l0e4o 43
Fetching 40, 1 3f4xcj1r 42
Fetching 20, 1 md5x0099 44
Fetching 20, 1 wmp5ou67 43
Fetching 20, 1 scgbzxyc 42
Fetching 10, 1 0sgahlsv 44
Fetching 10, 1 ucteewkj 43
Fetching 10, 1 8gi9aosc 42
Fetching 40, 0.5 7lh8l3lr 44
Fetching 40, 0.5 a6qlibb0 43
Fetching 40, 0.5 oxa7ufnr 42
Fetching 20, 0.5 1x4i7gr8 44
Fetching 20, 0.5 saq5qr6d 43
Fetching 20, 0.5 pof8g7c4 42
Fetching 10, 0.5 fb69sal5 44
Fetching 10, 0.5 hcu6ez0q 43
Fetching 40, 0.1 5v7q5eu6 44
Fetching 40, 0.1 bph5kdmw 43
Fetching 40, 0.1 pg8px1yv 42
Fetching 20, 0.1 7sobeawt 44
Fetching 20, 0.1 c61rh13u 43
Fetching 20, 0.1 u12sjghd 42
Fetching 10, 0.1 t3766egy 44
Fetching 10, 0.1 shgzyeyg 43
Fetching 10, 0.1 qhzxxppf 42
Fetching 40, 0.05 pd2pgnzq 44
Fetching 40, 0.05 zh09osw2 43
Fetching 40, 0.05 84e7k14i 42
Fetching 20, 0.05 tgj07uqz 44
Fetching 20, 0.05 t3w4mg4o 43
Fetching 20, 0.05 6w5y2r6a 42
Fetching 10, 0.05 c12a29cu 44
Fetching 10, 0.05 e4rhj63l 43
Fetching 10, 0.05 b8jd9e

In [6]:
colors = [['red', 'darkred'],
          ['green', 'darkgreen'],
          ['blue', 'darkblue'],
          ['orange', 'darkorange']]

In [7]:
import plotly.graph_objects as go

def select(data, group):
    group_keys, group_values = zip(*group._asdict().items())
    group_keys, group_values = list(group_keys), list(group_values)
    group_elems = (data[group_keys] == group_values).all(1)
    group_data = data[group_elems]
    return group_data

def data_groups(data, l1_groups, l2_groups, items):
    unique = data[l1_groups].drop_duplicates()
    for group, item in zip(unique.itertuples(index=False),items):
        sub_data = select(data, group)
        if l2_groups:
            sub_unique = sub_data[l2_groups].drop_duplicates()
            for sub_group, sub_item in zip(sub_unique.itertuples(index=False), item):
                series_data = select(sub_data, sub_group)
                series_group = group._asdict()
                series_group.update(sub_group._asdict())
                yield sub_item, series_group, series_data
        else:
            series_group = group._asdict()
            yield item, series_group, sub_data

def scatter(data, x_label, y_label, l1_groups=[], l2_groups=[], is_dashed=lambda x: False):
    data = data.sort_values(x_label)
    for color, series_group, series_data in data_groups(
                data, l1_groups, l2_groups, colors
            ):
        xs = list(series_data[x_label].drop_duplicates())
        ys = []
        ys_upper = []
        ys_lower = []
        for v in xs:
            x_data = series_data[series_data[x_label] == v][y_label]
            ys.append(x_data.median())
            ys_lower.append(x_data.min())
            ys_upper.append(x_data.max())
        dashed = is_dashed(series_group)
        label = ','.join([f'{k}={v}' for (k,v) in series_group.items()])
        yield go.Scatter(x=xs, y=ys, name=label, 
                 line_color=color, opacity=0.8,
                 line=dict(dash="dash" if dashed else None))
        yield go.Scatter(
            x=xs+xs[::-1], # x, then x reversed
            y=ys_upper+ys_lower[::-1], # upper, then lower reversed
            fill='toself',
            fillcolor=color,
            marker=dict(opacity=0),
            line=dict(color=color),
            hoverinfo="skip",
            showlegend=False,
            opacity=0.2
        )

fig = go.Figure(list(scatter(data, 'sigma', 'max_score',
                             l1_groups=['num'], l2_groups=['type'],
                             is_dashed=lambda x: x['type'] == 'deconv')))
fig.update_layout(xaxis_title="Train Sigma",yaxis_title="Score")
fig

In [9]:
reduced_data = data[(data['type'] == "replica") & (data['sigma'] != 0.0001) & (data['sigma'] != 0.001) & (data['sigma'] != 0.5)]
fig = go.Figure(list(scatter(reduced_data, 'num', 'max_score',
                             l1_groups=['sigma'], l2_groups=['type'],
                             is_dashed=lambda x: x['type'] != 'replica')))
fig.update_layout(xaxis_title="Num Datapoints",yaxis_title="Score")
fig