In [28]:
import wandb
import os
import numpy as np
import pandas as pd
from argparse import ArgumentParser
from pathlib import Path

def fetch_data(groups, sweeps):
    api =  wandb.Api(
        overrides={"project": "diffusion_policy_debug", "entity": "pfrommerd"}
    )

    filters = []
    filters.extend([{"group": gid} for gid in groups])
    filters.extend([{"sweep": sid} for sid in sweeps])
    runs = api.runs(
        filters={
            "$or" : filters
        }
    )
    run_datas = []
    print(f"Fetching {len(runs)} runs")
    for r in runs:
        num_data = r.config['task']['dataset']['max_train_episodes']
        sigma = r.config['global_cond_noise']
        print(f"Fetching {num_data}, {sigma} {r.id}")
        replica = r.history(keys=["test/replica_mean_score"])
        replica = np.array(replica["test/replica_mean_score"])
        deconv = r.history(keys=["test/deconv_mean_score"])
        deconv = np.array(deconv["test/deconv_mean_score"])
        run_datas.append({
            "num": num_data,
            "sigma": sigma,
            "type": "replica",
            "max_score": replica.max(),
        })
        run_datas.append({
            "num": num_data,
            "sigma": sigma,
            "type": "deconv",
            "max_score": deconv.max(),
        })
    data = pd.DataFrame(run_datas)
    return data

In [88]:
data = fetch_data(["3bc2gckb", "b0sbgoqc"],[])

Fetching 21 runs
Fetching 90, 3 b0sbgoqc_train_n90_s3.0_0
Fetching 40, 3 b0sbgoqc_train_n40_s3.0_0
Fetching 200, 3 b0sbgoqc_train_n200_s3.0_0
Fetching 200, 0.001 3bc2gckb_train_n200_s0.001_0
Fetching 200, 0.005 3bc2gckb_train_n200_s0.005_0
Fetching 40, 0.005 3bc2gckb_train_n40_s0.005_0
Fetching 90, 0.005 3bc2gckb_train_n90_s0.005_0
Fetching 90, 0.01 3bc2gckb_train_n90_s0.01_0
Fetching 200, 0.01 3bc2gckb_train_n200_s0.01_0
Fetching 40, 0.01 3bc2gckb_train_n40_s0.01_0
Fetching 200, 0.05 3bc2gckb_train_n200_s0.05_0
Fetching 40, 0.05 3bc2gckb_train_n40_s0.05_0
Fetching 90, 0.05 3bc2gckb_train_n90_s0.05_0
Fetching 40, 0.1 3bc2gckb_train_n40_s0.1_0
Fetching 90, 0.1 3bc2gckb_train_n90_s0.1_0
Fetching 200, 0 3bc2gckb_train_n200_s0.0_0
Fetching 90, 0 3bc2gckb_train_n90_s0.0_0
Fetching 90, 0.001 3bc2gckb_train_n90_s0.001_0
Fetching 40, 0 3bc2gckb_train_n40_s0.0_0
Fetching 40, 0.001 3bc2gckb_train_n40_s0.001_0
Fetching 200, 0.1 3bc2gckb_train_n200_s0.1_0


In [108]:
colors = ['red','green','blue','orange']

In [112]:
import plotly.graph_objects as go

def plot_run(info, scores, color):
    data = scores.sort_values('sigma')
    label = f"{info['num']}_{info['type']}"
    yield go.Scatter(x=data['sigma'],y=data['max_score'],name=label, 
                     line_color=color,
                     opacity=0.8,
                    line=dict(dash="dash" if info['type']=='deconv' else None))

def plot_runs(data):
    data = data[data['sigma'] < 1]
    unique = data[['num']].drop_duplicates()
    for n, color in zip(unique.to_dict(orient="records"),colors):
        for t in ['replica','deconv']:
            n['type'] = t
            scores = data[(data['num'] == n['num']) & (data['type'] == n['type'])]
            yield from plot_run(n, scores, color)
fig = go.Figure(list(plot_runs(data)))
fig.update_layout(xaxis_title="Train Sigma",yaxis_title="Score")
fig

In [114]:
import plotly.graph_objects as go

def plot_run(info, scores,color):
    data = scores.sort_values('num')
    label = f"{info['sigma']}_{info['type']}"
    yield go.Scatter(x=data['num'],y=data['max_score'],name=label, opacity=0.8,
                     line_color=color,
                     line=dict(dash="dash" if info['type']=='deconv' else None))

def plot_runs(data):
    data = data[(data['sigma'] < 1)]# & (data['type']=='replica')]
    unique = data[['sigma']].drop_duplicates()
    for n,color in zip(unique.to_dict(orient="records"),colors):
        for t in ['replica','deconv']:
            n['type'] = t
            scores = data[(data['sigma'] == n['sigma']) & (data['type'] == n['type'])]
            yield from plot_run(n, scores,color)
fig = go.Figure(list(plot_runs(data)))
fig.update_layout(xaxis_title="Datapoints",yaxis_title="Score")
fig