In [1]:
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.patches as mpatches
import wandb
from rliable import library as rly
from rliable import metrics
from rliable import plot_utils
import seaborn as sns
import numpy as np
import os
from collections import OrderedDict
import matplotlib.ticker as ticker

from utils import *

In [2]:
# plt.rcParams["legend.title_fontsize"] = "large"
plt.rcParams.update({'font.size': 25})

matplotlib.rc("xtick", labelsize=20)
matplotlib.rc("ytick", labelsize=20)

In [3]:
# Initialize the wandb API
api = wandb.Api()

# Define your project and entity (replace with your specific values)
entity = "cl-probing"  # e.g., your username or team name
figures_path = "./figures"
data_path = "./data"
os.makedirs(figures_path, exist_ok=True)
os.makedirs(data_path, exist_ok=True)
take_x_seeds = 10
window_size = 10
use_se = True
all_configs = []

## Generalized Stitching Configs

In [4]:
def generalized_grouping_func(config):
    result = ''
    if config['agent']['agent_name'] == 'gcdqn':
        result += 'GCDQN'
    elif config['agent']['agent_name'] == 'clearn_search':
        result += "C-LEARN"
    elif config['agent']['agent_name'] == 'gciql_search':
        result += "GCIQL"
    else:
        result += "CRL"

    if config['agent']['is_td']:
        result += ' TD'
    else:
        result += ' MC'

    if config['agent']['net_arch'] =='mlp':
        result += ' SMALL'
    else:
        result += ' BIG'


    return result

agents = ['GCDQN', 'C-LEARN', 'CRL', 'GCIQL']
grid_sizes = [4]


for grid_size in grid_sizes:

    config_1 = {
        "directory": "generalized",
        "title": f"scaling_generalized_train_3_boxes_2_movable_{grid_size}_grid",
        "projects":["paper_generalized"],
        "possible_names":[' '.join([agent, s]) for s in ['MC BIG', 'MC SMALL', 'TD BIG', 'TD SMALL'] for agent in agents],
        "filters":{
            # "config.agent.is_td": False,
            "config.env.grid_size": grid_size,
            "config.env.number_of_boxes_max":3,
            "config.env.number_of_moving_boxes_max":2,
            "tags":{"$nin":["mc_old_gamma"]},
        },
        "grouping_func": generalized_grouping_func,
        "metrics": OrderedDict([
            ("eval/mean_success", "Training success on 3 boxes \n(1 on target)"), 
        ]),
    }

    config_2 = {
        "directory": "generalized",
        "title": f"scaling_generalized_test_3_boxes_2_movable_{grid_size}_grid",
        "projects":["paper_generalized"],
        "possible_names":[' '.join([agent, s]) for s in ['MC BIG', 'MC SMALL', 'TD BIG', 'TD SMALL'] for agent in agents],
        "filters":{
            # "config.agent.is_td": False,
            "config.env.grid_size": grid_size,
            "config.env.number_of_boxes_max":3,
            "config.env.number_of_moving_boxes_max":2,
            "tags":{"$nin":["mc_old_gamma"]},
        },
        "grouping_func": generalized_grouping_func,
        "metrics": OrderedDict([
            ("eval_3/mean_success", "Test success on 3 boxes \n(0 on target)"),
        ]),
    }

    all_configs.extend([config_1, config_2])


# Plotting

In [9]:


runtimes = {}


for config, ax in zip(all_configs, axs):
    ax.set_title(next(iter(config['metrics'].values())))
    runs = []
    for project in config['projects']:
        project_runs = api.runs(path=f"{entity}/{project}", filters=config['filters'])
        runs.extend([r for r in project_runs])

    for r in runs:
        name = config['grouping_func'](r.config)


        if name not in runtimes:
            runtimes[name] = []
        runtimes[name].append(r.summary["_wandb"]["runtime"])

for k, v in runtimes.items():
    runtimes[k] = np.array(v).mean() / 3600  # convert to hours

for k in runtimes:
    print(f"{k}: {runtimes[k]:.2f} hours")

CRL MC BIG: 8.77 hours
CRL MC SMALL: 1.32 hours
C-LEARN TD BIG: 8.80 hours
C-LEARN TD SMALL: 1.27 hours
GCDQN TD BIG: 7.26 hours
GCDQN TD SMALL: 1.03 hours
GCIQL TD SMALL: 1.25 hours
GCIQL TD BIG: 8.41 hours
GCDQN MC SMALL: 1.07 hours
GCIQL MC SMALL: 1.19 hours
GCIQL MC BIG: 6.56 hours
GCDQN MC BIG: 4.77 hours
C-LEARN MC SMALL: 0.83 hours
C-LEARN MC BIG: 5.49 hours
