In [None]:
# %env CUDA_VISIBLE_DEVICES=0

In [None]:
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import wandb

from atlas.envs.xminigrid.mutators.common import Mutations
from atlas.ued.buffer import BufferManager
from atlas.utils.checkpointing import load_runner_state
from atlas.utils.logging import download_checkpoints


# Set plotting style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12

# Helpers

In [None]:
BUFFER_CAPACITY = 50000
STALENESS_COEFF = 0.1
NOTEBOOK_PATH = Path(os.getcwd()).parent
PAD_VAL = -1   # sentinel for "empty" slots
MUTATION_PADDING_ID = max(Mutations) + 10

WANDB_ENTITY = None  # TODO: replace with your W&B entity and projects
WANDB_PROJECT = None

# Download Checkpoints and Dump Data

## Download

In [None]:
RUNS = {  # TODO: replace with W&B identifiers
    "accel_full-i": [None] * 5,
    "accel_full-c": [None] * 5,
    "accel_scratch-i": [None] * 5,
    "accel_scratch-c": [None] * 5,
}

# NOTE WE SKIP THE FIRST 100 STEPS BECAUSE THEY ARE NOT RELEVANT
STEPS = list(range(110, 2000, 100)) + [2000]
SEQUENTIAL = True

def download_all_checkpoints(runs, steps, sequential=True):
    for experiment in runs:
        for run_id in runs[experiment]:
            run_path = os.path.join(NOTEBOOK_PATH, "ckpts", 'sequential' if sequential else 'non_sequential', experiment, run_id)
            if not Path(run_path).exists(): # Only download if the folder doesn't exist
                download_checkpoints(WANDB_ENTITY, WANDB_PROJECT, run_id, steps, run_path)

download_all_checkpoints(RUNS, STEPS, SEQUENTIAL)

## Dump Data

### Functions

In [None]:
def get_row_probabilities(row, row_probability, unique_mutations, pad_value=PAD_VAL):
    """
    Returns a vector of length len(unique_mutations) whose i‑th entry is:
        (# occurrences of unique_mutations[i] in `row`) * row_probability / (# valid entries)

    Now duplicates (e.g. 2, 5 in your example) receive proportionally more weight.
    """
    # mask out the padding
    valid_mask = row != pad_value                   # (K,)

    # broadcast comparison: (K,1) vs (1,M) → (K,M) bool
    match = (row[:, None] == unique_mutations[None, :]) & valid_mask[:, None]

    # count how many times each mutation appears in this row
    counts = jnp.sum(match, axis=0)                 # (M,) int

    # total real entries (counting duplicates)
    n_real = jnp.sum(counts)                        # scalar

    # probability per single occurrence
    per_occurrence_mass = jnp.where(n_real > 0,
                                    row_probability / n_real,
                                    0.0)

    # each mutation's probability = occurrences * per‑occurrence mass
    return counts * per_occurrence_mass

def get_mutation_probs_from_checkpoint(run_path, step, include_no_mutation=False):
    """
    Returns mutation probabilities from a saved checkpoint.
    
    Returns:
        unique_mutations: Array of unique mutation IDs (including pad_mutation_id for -1s)
        mutation_prob_mass: Array of probability masses for each mutation
    """
    runner_state = load_runner_state(
        # resolve absolute path
        path=Path(run_path).resolve(),
        target=None,
        step=step,
    )
    
    buffer_manager = BufferManager(capacity=BUFFER_CAPACITY, staleness_coeff=STALENESS_COEFF)
    probs = buffer_manager.problem_weights(runner_state['buffer'])
    mutation_ids = runner_state['buffer']['extra']['mutation_ids']
    
    # Get unique mutations (excluding -1 padding)
    unique_mutations = jnp.unique(mutation_ids[mutation_ids != -1])
    
    # Map -1 to the next available mutation ID after the max enum value
    
    # print("Padding ID is", MUTATION_PADDING_ID)
    unique_mutations = jnp.concatenate([unique_mutations, jnp.array([MUTATION_PADDING_ID])])

    if include_no_mutation:
        mutation_ids = jnp.where(mutation_ids == -1, MUTATION_PADDING_ID, mutation_ids)
    else:
        # Set probs to 0 for rows which are all -1 -> those which sum to -10, and then renormalize the other probs
        no_mutations_mask = jnp.sum(mutation_ids, axis=-1) == -10
        probs = probs.at[no_mutations_mask].set(0)
        probs = probs / probs.sum()
    
    # vmap over the rows
    get_probs_batched = jax.vmap(
        get_row_probabilities,
        in_axes=(0, 0, None)        # 1st arg row → axis 0, 2nd arg prob → axis 0, 3rd constant
    )
    all_row_probs = get_probs_batched(mutation_ids, probs, unique_mutations)
    
    # Sum across all rows to get total probability mass per mutation
    mutation_prob_mass = all_row_probs.sum(axis=0)
    
    return unique_mutations, mutation_prob_mass


### Test

In [None]:
import re
def mutation_to_str(mutation_id: Mutations) -> str:
    match mutation_id:
        case Mutations.ADD_NON_DOOR_OBJ:
            string = "AddObject"
        case Mutations.RM_NON_DOOR_OBJ:
            string = "RemoveObject"
        case Mutations.MOVE_NON_DOOR_OBJ:
            string = "MoveObject"
        case Mutations.MOVE_AGENT:
            string = "MoveAgent"
        case Mutations.REPLACE_DOOR:
            string = "ReplaceDoor"
        case Mutations.REPLACE_NON_DOOR:
            string = "ReplaceNonDoor"
        case Mutations.ADD_ROOMS:
            string = "AddRooms"
        case Mutations.RM_ROOMS:
            string = "RemoveRooms"
        case Mutations.SWITCH_PROP:
            string = "SwitchProposition"
        case Mutations.ADD_TRANSITION:
            string = "AddState"
        case Mutations.RM_TRANSITION:
            string = "RemoveState"
        case Mutations.HINDSIGHT_LVL_ONLY:
            string = "hindsight_lvl"
        case Mutations.HINDSIGHT_PRED:
            string = "ExtractPreceding"
        case Mutations.HINDSIGHT_SUCC:
            string = "ExtractSucceeding"
    return rf"\textsc{{{string}}}"

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import wandb
import itertools

_suffix_re = re.compile(r"(-[a-zA-Z0-9_]+)$")      # strip final "-…"
def base_name(raw: str) -> str:
    return _suffix_re.sub("", raw).lower()

# ── 1. global style ────────────────────────────────────────────────────────────
BIG, BIGGER = "x-large", "xx-large"

set2                    = sns.color_palette("Set2", 8)   # plenty of spares
base_to_pretty          = {
    'dr':             "DR",
    "plr":            r"PLR$^{\bot}$",
    "accel_full":     "ACCEL",
    "accel_scratch":  "ACCEL-0",
}
base_to_color           = {
    "plr":            set2[0],
    "accel_full":     set2[1],
    "accel_scratch":  set2[2],  
    "dr":             set2[3],
}
spares                  = itertools.cycle(set2[4:])      # nice, unused colours



In [None]:


# Use the new function
run_path = NOTEBOOK_PATH / 'ckpts' / 'sequential' / 'accel_full-i' / RUNS["accel_full-i"][0]
step = 1900
unique_mutations, mutation_prob_mass = get_mutation_probs_from_checkpoint(run_path, step, include_no_mutation=False)

print(unique_mutations)
print(mutation_prob_mass)
print(f"Total probability mass: {mutation_prob_mass.sum()}")

for mutation_id, prob_mass in zip(unique_mutations.tolist(), mutation_prob_mass.tolist()):
    print(f"Mutation {mutation_to_str(mutation_id) if mutation_id != MUTATION_PADDING_ID else 'None'}: {prob_mass}")

mutation_labels = [mutation_to_str(mutation_id) if mutation_id != MUTATION_PADDING_ID else 'None' for mutation_id in unique_mutations.tolist()]

plt.figure(figsize=(12, 6))
sns.barplot(x=mutation_labels, y=mutation_prob_mass, palette="crest")
plt.xticks(rotation=45, ha='right')
plt.xlabel('Mutation Type')
plt.ylabel('Probability Mass')
plt.title('Mutation Probability Distribution')
plt.tight_layout()
plt.show()

### Dump

In [None]:
from collections import defaultdict
import pickle
import os

# Check if cached data exists
cache_file = NOTEBOOK_PATH / "mutation_data_cache.pkl"

if cache_file.exists():
    print("Loading cached mutation data...")
    with open(cache_file, 'rb') as f:
        run_data = pickle.load(f)
    
    # Verify the cache has the expected structure
    expected_keys = set(RUNS.keys())
    cached_keys = set(run_data.keys())
    
    if cached_keys == expected_keys:
        print("Cache is valid, using cached data.")
    else:
        print(f"Cache keys mismatch. Expected: {expected_keys}, Got: {cached_keys}")
        print("Regenerating data...")
        run_data = None
else:
    print("No cache found, generating data...")
    run_data = None

if run_data is None:
    # Collect data for all runs
    run_data = defaultdict(lambda: {"mutation_ids": [], "mutation_prob_mass": []})

    rows = []

    for experiment in RUNS:
        for run_id in RUNS[experiment]:
            mutation_ids_list = []
            mutation_prob_mass_list = []
            for step in STEPS:
                mutation_ids, mutation_prob_mass = get_mutation_probs_from_checkpoint(
                    os.path.join(NOTEBOOK_PATH, "ckpts", 'sequential' if SEQUENTIAL else 'non_sequential', experiment, run_id),
                    step, 
                    include_no_mutation=False
                )

                mutation_ids_list.append(mutation_ids)
                mutation_prob_mass_list.append(mutation_prob_mass)
            run_data[experiment]["mutation_ids"].append(mutation_ids_list)
            run_data[experiment]["mutation_prob_mass"].append(mutation_prob_mass_list)
    
    # Save to cache
    print("Saving data to cache...")
    with open(cache_file, 'wb') as f:
        pickle.dump(dict(run_data), f)
    print("Data cached successfully.")


In [None]:
# Create stacked histogram plot showing mutation evolution over time
import re
import itertools
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.patches as mpatches

def create_mutation_distribution_plot(experiment, run_data, ax, show_legend=False):
    """Create a stacked area plot for mutation distribution over time."""
    # Average probability mass across runs for each step
    num_runs = len(run_data[experiment]["mutation_prob_mass"])
    num_steps = len(STEPS)

    # Stack all runs and steps to get consistent mutation ordering
    all_mutation_ids = set()
    for run_idx in range(num_runs):
        for step_mutations in run_data[experiment]["mutation_ids"][run_idx]:
            all_mutation_ids.update(step_mutations.tolist())
    all_mutation_ids = list(all_mutation_ids)

    # Convert to sorted list
    all_mutation_ids = sorted(all_mutation_ids)

    # Create matrix: steps x mutations x num_runs 
    prob_matrix = np.zeros((num_steps, len(all_mutation_ids), num_runs))
    for run_idx in range(num_runs):
        for step_idx in range(num_steps):
            mutation_ids = run_data[experiment]["mutation_ids"][run_idx][step_idx]
            mutation_probs = run_data[experiment]["mutation_prob_mass"][run_idx][step_idx]
            for mut_idx, mut_id in enumerate(mutation_ids):
                if mut_id in all_mutation_ids:
                    global_mut_idx = all_mutation_ids.index(mut_id)
                    prob_matrix[step_idx, global_mut_idx, run_idx] = mutation_probs[mut_idx]

    # Average across runs for each step
    prob_matrix = np.mean(prob_matrix, axis=2)

    # Create mutation labels
    include_no_mutation = False
    if include_no_mutation:
        mutation_labels = [mutation_to_str(mutation_id) if mutation_id != MUTATION_PADDING_ID else 'None' for mutation_id in all_mutation_ids]
    else:
        mutation_labels = [mutation_to_str(mutation_id) for mutation_id in all_mutation_ids if mutation_id != MUTATION_PADDING_ID]

    print(prob_matrix.sum(axis=1))
    
    # Get colors from crest palette
    colors = sns.color_palette("crest", len(mutation_labels))

    # Create hatching patterns - alternate between no hatch and various patterns
    hatch_patterns = ['', '///', '', '\\\\\\', '', '...', '', '+++', '', 'xxx', '', '|||', '', '---']
    hatches = [hatch_patterns[i % len(hatch_patterns)] for i in range(len(mutation_labels))]
        
    def training_steps_to_env_steps(training_steps):
        """Convert training steps to environment steps by multiplying by 2^21."""
        return training_steps * (2**21) // 10**6

    # Create stacked area plot with alternating hatching
    polys = ax.stackplot(np.array(STEPS) * (2**21) // 10**6, *prob_matrix.T, labels=mutation_labels, colors=colors, alpha=0.8)
    # Apply hatching to each polygon
    for poly, hatch in zip(polys, hatches):
        poly.set_hatch(hatch)

    base = base_name(experiment)
    pretty_name = base_to_pretty.get(base, experiment)

    ax.set_xlim(110 * (2**21) // 10**6, 2000 * (2**21) // 10**6)
    ax.set_ylim(0, 1)
    #! HACKAGE
    ax.set_xticklabels([str(x) for x in [0,1000,2000,3000,4000]])
    ax.set_yticklabels([str(f"{x:0.1f}") for x in ax.get_yticks()])
    ax.set_xlabel('Number of Environment Steps (in millions)', fontsize=BIGGER, labelpad=10, fontdict={'family': 'helvetica'})
    if not show_legend:
        ax.set_ylabel('Fractional Share', fontsize=BIGGER, labelpad=10, fontdict={'family': 'helvetica'})
    ax.set_title(f'{pretty_name}', fontsize=BIGGER)
    ax.tick_params(axis="both", which="major", labelsize=BIGGER)
    ax.spines['left'].set_linewidth(2)
    ax.spines['bottom'].set_linewidth(2)

    # Create legend with reversed order to match stacked regions (bottom to top)
    if show_legend:
        plt.rcParams.update({
            "text.usetex": True,
        })
        handles, labels = ax.get_legend_handles_labels()
        legend = ax.legend(reversed(handles), reversed(labels), bbox_to_anchor=(1.05, 1), loc='upper left', fontsize='large', title='Edit Type')
        legend.get_title().set_fontsize('x-large')

# Create subplot with 2 columns
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Plot both experiments
create_mutation_distribution_plot("accel_full-i", run_data, ax1, show_legend=False)
create_mutation_distribution_plot("accel_scratch-i", run_data, ax2, show_legend=True)

plt.tight_layout()
plt.savefig(f'plot_outputs/mutation_distribution_comparison.pdf', dpi=300, bbox_inches='tight')
plt.show()