In [None]:
%env CUDA_VISIBLE_DEVICES=0

In [None]:
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import os
import pandas as pd
import seaborn as sns

from atlas.envs.xminigrid.level import XMinigridLevel, get_num_objects
from atlas.hrm.ops import get_num_rm_states, path_info
from atlas.hrm.types import HRM
from atlas.ued.buffer import BufferManager
from atlas.utils.checkpointing import load_runner_state
from atlas.utils.logging import download_checkpoints

# Helpers

In [None]:
BUFFER_CAPACITY = 50000
STALENESS_COEFF = 0.1
NOTEBOOK_PATH = os.path.join(os.getcwd(), "experiments", "evaluation", "curriculum")

WANDB_ENTITY = None   # TODO: replace with your W&B entity
WANDB_PROJECT = None  # TODO: replace with your W&B project

# Fix the get_num_rooms function to work with JAX by using jnp.where instead of if statements
def get_num_rooms_jax(level: XMinigridLevel) -> int:
    # Use jnp.where for conditional logic that works with JAX transformations
    return jnp.where(
        level.height == 7,
        jnp.where(level.width == 7, 1, jnp.where(level.width == 13, 2, -1)),
        jnp.where(
            level.height == 13,
            jnp.where(level.width == 13, 4, jnp.where(level.width == 19, 6, -1)),
            -1
        )
    )

def download_all_checkpoints(runs, steps, folder):
    for experiment in runs:
        for run_id in runs[experiment]:
            run_path = os.path.join(NOTEBOOK_PATH, "ckpts", folder, experiment, run_id)
            if not Path(run_path).exists(): # Only download if the folder doesn't exist
                download_checkpoints(WANDB_ENTITY, WANDB_PROJECT, run_id, steps, run_path)

# Sequential Problem Sampler

## Download Checkpoints and Dump Data

In [None]:
RUNS = {  # Replace with the run ids from the `01-core` experiments 
    "plr-i": [None] * 5,
    "plr-c": [None] * 5,
    "accel_full-i": [None] * 5,
    "accel_full-c": [None] * 5,
    "accel_scratch-i": [None] * 5,
    "accel_scratch-c": [None] * 5,
}
STEPS = [*range(10, 2000, 100), 2000]
CHECKPOINT_FOLDER = "seq"
EXPORT_FILENAME = "seq_buffer_data.csv"

# Load state from checkpoint
def get_stats_from_saved_checkpoint(run_path, step):
    runner_state = load_runner_state(path=Path(run_path).resolve(), target=None, step=step)

    hrm = HRM(**runner_state['buffer']['problems'][1])
    num_hrm_states = jax.vmap(get_num_rm_states, in_axes=(0, None))(hrm, 0)

    level = XMinigridLevel(**runner_state['buffer']['problems'][0])
    num_rooms = jax.vmap(get_num_rooms_jax)(level)

    num_objects = get_num_objects(level)
    
    # mask with capacity/size too
    buffer_size = runner_state['buffer']['size']
    mask = jnp.arange(BUFFER_CAPACITY) < buffer_size
    
    buffer_manager = BufferManager(capacity=BUFFER_CAPACITY, staleness_coeff=STALENESS_COEFF)
    probs = buffer_manager.problem_weights(runner_state['buffer'])
    
    return mask, probs, num_hrm_states, num_rooms, num_objects

def export_checkpoint_data():
    rows = []
    for experiment in RUNS:
        for run_id in RUNS[experiment]:
            run_path =  os.path.join(NOTEBOOK_PATH, "ckpts", CHECKPOINT_FOLDER, experiment, run_id)
            for step in STEPS:
                mask, probs, num_hrm_states, num_rooms, num_objects = get_stats_from_saved_checkpoint(run_path, step)
                avg_w_states = jnp.sum(mask * probs * num_hrm_states)
                avg_w_rooms = jnp.sum(mask * probs * num_rooms)
                avg_w_objs = jnp.sum(mask * probs * num_objects)
                rows.append({"experiment": experiment, "run_id": run_id, "step": step, "states": avg_w_states, "rooms": avg_w_rooms, "objects": avg_w_objs})

    df = pd.DataFrame(rows)
    df.to_csv(os.path.join(NOTEBOOK_PATH, EXPORT_FILENAME), index=False)


In [None]:
download_all_checkpoints(RUNS, STEPS, CHECKPOINT_FOLDER)

In [None]:
export_checkpoint_data()

## Plot Data

In [None]:
def make_plot(df, experiments, colors, names, filename):
    common_args = dict(
        x="step",
        marker='o',
        markeredgecolor='auto',
        hue="name",
        linewidth=2,
        err_kws={"alpha": 0.2}
    )

    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(27, 5))
    for experiment, color, name in zip(experiments, colors, names):
        df_algorithm = df[df["experiment"] == experiment]
        df_algorithm["name"] = name
        sns.lineplot(df_algorithm, y="states", palette=[color], ax=ax1, **common_args)
        sns.lineplot(df_algorithm, y="rooms", palette=[color], ax=ax2, **common_args)
        sns.lineplot(df_algorithm, y="objects", palette=[color], ax=ax3, **common_args)
        

    # States
    ax1.set_ylabel("Average #RM States", fontsize="xx-large")
    ax1.legend().set_visible(False)
    ax1.set_ylim(1.6, 6.4)

    # Rooms
    ax2.set_ylabel("Average #Rooms", fontsize="xx-large")
    ax2.legend(title=None, fontsize="x-large", loc='lower right')
    ax2.set_ylim(0.6, 6.4)

    # Objects
    ax3.set_ylabel("Average #Objects", fontsize="xx-large")
    ax3.legend().set_visible(False)
    ax3.set_ylim(0.6, 20.4)

    # Common
    for ax in [ax1, ax2, ax3]:
        ax.set_xlabel('Number of Environment Steps (in millions)', fontsize="xx-large")
        ax.set_xticks([0, 500, 1000, 1500, 2000])
        ax.set_xticklabels([0, 1000, 2000, 3000, 4000])
        ax.grid(True, alpha=0.2)
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.spines['left'].set_linewidth(2)
        ax.spines['bottom'].set_linewidth(2)
        ax.tick_params(length=0.1, width=0.1, labelsize="xx-large")
        ax.spines['left'].set_position(('outward', 10))
        ax.spines['bottom'].set_position(('outward', 10))

    plt.savefig(filename, bbox_inches="tight", pad_inches=0.0)
    plt.show()


In [None]:
df = pd.read_csv(os.path.join(NOTEBOOK_PATH, EXPORT_FILENAME))
set2 = sns.color_palette("Set2", 8)
for sampling, sampling_full in [("i", "indep"), ("c", "cond")]:
    make_plot(
        df, 
        experiments=[f"plr-{sampling}", f"accel_full-{sampling}", f"accel_scratch-{sampling}"], 
        colors=[set2[0], set2[1], set2[2]], 
        names=["PLR$^\\bot$", "ACCEL", "ACCEL-0"],
        filename=f"seq-{sampling_full}.pdf"
    )

# Random Walk Sampler

## Download Checkpoints and Dump Data

In [None]:
RUNS_NONSEQ = {  # Replace with the run ids from the `07-dag-sampling` experiments 
    "plr-i": [None] * 5,
    "plr-c": [None] * 5
}
STEPS_NON_SEQ = [*range(10, 3000, 100), 3000]
CHECKPOINT_FOLDER_NON_SEQ = "non_seq"
EXPORT_FILENAME_NON_SEQ = "rw_buffer_data.csv"

def get_stats_from_saved_checkpoint(run_path, step):
    runner_state = load_runner_state(
        path=Path(run_path).resolve(),
        target=None,
        step=step,
    )

    hrm = HRM(**runner_state['buffer']['problems'][1])
    num_hrm_states = jax.vmap(get_num_rm_states, in_axes=(0, None))(hrm, 0)
    all_num_paths, all_avg_path_lengths = jax.vmap(path_info, in_axes=(0))(hrm)

    level = XMinigridLevel(**runner_state['buffer']['problems'][0])
    num_rooms = jax.vmap(get_num_rooms_jax)(level)

    num_objects = get_num_objects(level)
    
    buffer_size = runner_state['buffer']['size']
    mask = jnp.arange(BUFFER_CAPACITY) < buffer_size

    buffer_manager = BufferManager(capacity=BUFFER_CAPACITY, staleness_coeff=STALENESS_COEFF)
    probs = buffer_manager.problem_weights(runner_state['buffer'])
    
    return mask, probs, num_hrm_states, all_num_paths, all_avg_path_lengths, num_rooms, num_objects

def export_checkpoint_data():
    rows = []
    for experiment in RUNS_NONSEQ:
        for run_id in RUNS_NONSEQ[experiment]:
            run_path = os.path.join(NOTEBOOK_PATH, "ckpts", CHECKPOINT_FOLDER_NON_SEQ, experiment, run_id)
            for step in STEPS_NON_SEQ:
                mask, probs, num_hrm_states, num_paths, path_lengths, num_rooms, num_objects = get_stats_from_saved_checkpoint(run_path, step)
                avg_w_states = jnp.sum(mask * probs * num_hrm_states)
                avg_w_paths = jnp.sum(mask * probs * num_paths)
                avg_w_lengths = jnp.sum(mask * probs * path_lengths)
                avg_w_rooms = jnp.sum(mask * probs * num_rooms)
                avg_w_objs = jnp.sum(mask * probs * num_objects)
                rows.append({"experiment": experiment, "run_id": run_id, "step": step, "states": avg_w_states, "paths": avg_w_paths, "path_lengths": avg_w_lengths, "rooms": avg_w_rooms, "objects": avg_w_objs})

    df = pd.DataFrame(rows)
    df.to_csv(os.path.join(NOTEBOOK_PATH, EXPORT_FILENAME_NON_SEQ), index=False)


In [None]:
download_all_checkpoints(RUNS_NONSEQ, STEPS_NON_SEQ, CHECKPOINT_FOLDER)

In [None]:
export_checkpoint_data()

## Plot Data

In [None]:
def make_plot(df, experiments, colors, names, filename):
    common_args = dict(
        x="step",
        marker='o',
        markeredgecolor='auto',
        hue="name",
        linewidth=2,
        err_kws={"alpha": 0.2}
    )

    fig, axes = plt.subplots(1, 5, figsize=(35, 5))
    for experiment, color, name in zip(experiments, colors, names):
        df_algorithm = df[df["experiment"] == experiment]
        df_algorithm["name"] = name
        sns.lineplot(df_algorithm, y="states", palette=[color], ax=axes[0], **common_args)
        sns.lineplot(df_algorithm, y="paths", palette=[color], ax=axes[1], **common_args)
        sns.lineplot(df_algorithm, y="path_lengths", palette=[color], ax=axes[2], **common_args)
        sns.lineplot(df_algorithm, y="rooms", palette=[color], ax=axes[3], **common_args)
        sns.lineplot(df_algorithm, y="objects", palette=[color], ax=axes[4], **common_args)
        

    # States
    axes[0].set_ylabel("Average #RM States", fontsize="xx-large")
    axes[0].legend().set_visible(False)
    axes[0].set_ylim(1.6, 6.4)
 
    # Number of Paths
    axes[1].set_ylabel("Average #Paths", fontsize="xx-large")
    axes[1].legend().set_visible(False)
    # axes[1].set_ylim(0.6, 2.4)

    # Average path length
    axes[2].set_ylabel("Average Path Length", fontsize="xx-large")
    axes[2].legend(title=None, fontsize="x-large", loc='lower right')
    #axes[2].set_ylim(0.6, 6.4)

    # Rooms
    axes[3].set_ylabel("Average #Rooms", fontsize="xx-large")
    axes[3].legend().set_visible(False)
    axes[3].set_ylim(0.6, 6.4)

    # Objects
    axes[4].set_ylabel("Average #Objects", fontsize="xx-large")
    axes[4].legend().set_visible(False)
    axes[4].set_ylim(0.6, 20.4)
 
    # Common
    for ax in axes:
        ax.set_xlabel('Number of Environment Steps (in millions)', fontsize="xx-large")
        ax.set_xticks([0, 1000, 2000, 3000])
        ax.set_xticklabels([0, 2000, 4000, 6000])
        ax.grid(True, alpha=0.2)
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.spines['left'].set_linewidth(2)
        ax.spines['bottom'].set_linewidth(2)
        ax.tick_params(length=0.1, width=0.1, labelsize="xx-large")
        ax.spines['left'].set_position(('outward', 10))
        ax.spines['bottom'].set_position(('outward', 10))

    plt.savefig(filename, bbox_inches="tight", pad_inches=0.0)
    plt.show()

In [None]:
df = pd.read_csv(os.path.join(NOTEBOOK_PATH, "rw_buffer_data.csv"))
set2 = sns.color_palette("Set2", 8)
base_color = set2[0]
darker_color = tuple(0.7 * c for c in base_color)
make_plot(
    df, 
    experiments=["plr-i", "plr-c"], 
    colors=[base_color, darker_color], 
    names=["PLR$^\\bot_\\text{indep}$", "PLR$^\\bot_\\text{cond}$"],
    filename=f"rw.pdf"
)