In [1]:
import os
import sys
sys.path.append(os.getcwd())
sys.path.append(os.getcwd() + '/../src')


In [2]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import logging
import re
from concurrent.futures import ThreadPoolExecutor
import multiprocessing
from recurrent.parameters import AllLogs
import jax
import jax.numpy as jnp
jax.config.update("jax_platform_name", "cpu")

In [3]:
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('processing.log')
    ]
)
logger = logging.getLogger(__name__)

# Set multiprocessing start method to 'spawn' to avoid JAX fork issues
multiprocessing.set_start_method("spawn", force=True)

In [4]:
# Configuration
download_dir = "/scratch/downloaded_artifacts"
results_dir = "/scratch/results"
group_name = "mlr_search-1_aa9c06652fb34624bebe972b1fe7292f"
max_process_workers = 10

In [5]:
# Ensure results directory exists
os.makedirs(results_dir, exist_ok=True)

In [6]:
# Function to sanitize folder names
def sanitize_folder_name(name):
    return re.sub(r'[^\w\-]', '_', str(name))

In [7]:
# Function to process a single run's logs
def process_run(run_result):
    run_id = run_result["run_id"]
    artifact_dir = run_result["artifact_dir"]
    config = run_result["config"]
    
    if run_result["status"] != "success" or not artifact_dir or not config:
        logger.warning(f"Skipping run {run_id}: download failed or no config")
        return {
            "run_id": run_id,
            "status": "skipped",
            "data": None
        }
    
    log_file = os.path.join(artifact_dir, "logs.pkl")
    if not os.path.exists(log_file):
        logger.error(f"Logs file not found for run {run_id}")
        return {
            "run_id": run_id,
            "status": "no_log_file",
            "data": None
        }
    
    try:
        with open(log_file, "rb") as f:
            logs = pickle.load(f)
        
        if not isinstance(logs, AllLogs):
            logger.error(f"Logs for run {run_id} is not an AllLogs instance")
            return {
                "run_id": run_id,
                "status": "invalid_logs",
                "data": None
            }
        
        is_success = not np.any(logs.hyperparameters == 1e-4)
        logger.info(f"Processed run {run_id}: success={is_success}")
        
        return {
            "run_id": run_id,
            "status": "success",
            "data": {
                "config": config,
                "is_success": is_success
            }
        }
    except Exception as e:
        logger.error(f"Error processing logs for run {run_id}: {str(e)}")
        return {
            "run_id": run_id,
            "status": f"error: {str(e)}",
            "data": None
        }

In [8]:
download_results_file = os.path.join(download_dir, f'download_results_{group_name}.pkl')
if not os.path.exists(download_results_file):
    logger.error(f"Download results file not found at {download_results_file}")

with open(download_results_file, 'rb') as f:
    download_results = pickle.load(f)

# Process downloaded artifacts
with ThreadPoolExecutor(max_workers=max_process_workers) as executor:
    process_results = list(executor.map(process_run, download_results))


INFO:2025-04-20 13:18:00,326:jax._src.xla_bridge:925: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'


2025-04-20 13:18:00,326 - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'


INFO:2025-04-20 13:18:00,334:jax._src.xla_bridge:925: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


2025-04-20 13:18:00,334 - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-04-20 13:18:00,668 - INFO - Processed run g1mnjrxk: success=True
2025-04-20 13:18:00,669 - INFO - Processed run 9ob4c701: success=False
2025-04-20 13:18:00,669 - INFO - Processed run 79ygnuzq: success=True
2025-04-20 13:18:00,669 - INFO - Processed run e0fh6y0e: success=False
2025-04-20 13:18:00,669 - INFO - Processed run gvfaz9pa: success=True
2025-04-20 13:18:00,669 - INFO - Processed run 0lfnlsx6: success=True
2025-04-20 13:18:00,669 - INFO - Processed run di860t9p: success=True
2025-04-20 13:18:00,669 - INFO - Processed run abwe9znm: success=True
2025-04-20 13:18:00,669 - INFO - Processed run 1nr7lnlj: success=True
2025-04-20 13:18:00,670 - INFO - Processed run 7u6yxcmw: success=True
2025-04-20 13:18:00,937 - INFO - Processed run 47pw2620: success=True
2025-04-20 13:18:00,940 - INFO - Processed run mu7o991

In [9]:
def create_heatmap(ts, runs_data, group_idx, group_name, subfolder=None, inner_optimizer=None):
    if inner_optimizer:
        runs_data = [run for run in runs_data if inner_optimizer in (run["config"].get("inner_optimizer", ["unknown"]) if isinstance(run["config"].get("inner_optimizer"), list) else [run["config"].get("inner_optimizer", "unknown")])]
    runs_data = [run for run in runs_data if tuple(run["config"].get("ts", ())) == ts]
    outer_lrs = sorted(set(run["config"].get("outer_learning_rate") for run in runs_data))
    inner_lrs = sorted(set(run["config"].get("inner_learning_rate") for run in runs_data), reverse=True)
    
    if not outer_lrs or not inner_lrs:
        logger.warning(f"No valid learning rates for ts group {ts} in {group_name}{f'/{subfolder}' if subfolder else ''}{f' (inner_optimizer={inner_optimizer})' if inner_optimizer else ''}")
        return
    
    grid = np.zeros((len(inner_lrs), len(outer_lrs)))
    success_counts = defaultdict(int)
    total_counts = defaultdict(int)
    
    for run in runs_data:
        outer_idx = outer_lrs.index(run["config"].get("outer_learning_rate"))
        inner_idx = inner_lrs.index(run["config"].get("inner_learning_rate"))
        key = (inner_idx, outer_idx)
        total_counts[key] += 1
        if run["is_success"]:
            success_counts[key] += 1
    
    for (inner_idx, outer_idx), total in total_counts.items():
        successes = success_counts.get((inner_idx, outer_idx), 0)
        grid[inner_idx, outer_idx] = successes / total if total > 0 else 0
    
    plt.figure(figsize=(12, 10))
    im = plt.imshow(grid, origin='lower', cmap='viridis', interpolation='nearest', vmin=0, vmax=1)
    plt.colorbar(im, label='Fraction of Successful Runs')
    
    for i in range(len(inner_lrs)):
        for j in range(len(outer_lrs)):
            fraction = grid[i, j]
            if total_counts.get((i, j), 0) > 0:
                successes = success_counts.get((i, j), 0)
                total = total_counts.get((i, j), 0)
                text = f"{successes}/{total}"
            else:
                text = "N/A"
            plt.text(j, i, text, ha='center', va='center', color='white' if fraction < 0.5 else 'black')
    
    plt.xticks(np.arange(len(outer_lrs)), [f"{lr:.1e}" for lr in outer_lrs], rotation=45)
    plt.yticks(np.arange(len(inner_lrs)), [f"{lr:.1e}" for lr in inner_lrs])
    plt.xlabel('Outer Learning Rate')
    plt.ylabel('Inner Learning Rate')
    title = f'Success Fraction Heatmap for ts={ts} in {group_name}'
    if subfolder:
        title += f' ({subfolder})'
    if inner_optimizer:
        title += f' (inner_optimizer={inner_optimizer})'
    plt.title(title)
    
    group_results_dir = os.path.join(results_dir, group_name, subfolder or '')
    os.makedirs(group_results_dir, exist_ok=True)
    output_file = os.path.join(group_results_dir, f'heatmap_ts_group_{group_idx}.png')
    plt.savefig(output_file, bbox_inches='tight')
    plt.close()
    logger.info(f"Saved heatmap for ts={ts} in {group_name}{f'/{subfolder}' if subfolder else ''}{f' (inner_optimizer={inner_optimizer})' if inner_optimizer else ''} to {output_file}")

def create_aggregated_heatmap(all_runs_data, group_name, subfolder=None, inner_optimizer=None):
    if inner_optimizer:
        all_runs_data = [run for run in all_runs_data if inner_optimizer in (run["config"].get("inner_optimizer", ["unknown"]) if isinstance(run["config"].get("inner_optimizer"), list) else [run["config"].get("inner_optimizer", "unknown")])]
    outer_lrs = sorted(set(run["config"].get("outer_learning_rate") for run in all_runs_data))
    inner_lrs = sorted(set(run["config"].get("inner_learning_rate") for run in all_runs_data), reverse=True)
    
    if not outer_lrs or not inner_lrs:
        logger.warning(f"No valid learning rates for aggregated heatmap in {group_name}{f'/{subfolder}' if subfolder else ''}{f' (inner_optimizer={inner_optimizer})' if inner_optimizer else ''}")
        return
    
    grid = np.zeros((len(inner_lrs), len(outer_lrs)))
    success_counts = defaultdict(int)
    total_counts = defaultdict(int)
    
    for run in all_runs_data:
        outer_idx = outer_lrs.index(run["config"].get("outer_learning_rate"))
        inner_idx = inner_lrs.index(run["config"].get("inner_learning_rate"))
        key = (inner_idx, outer_idx)
        total_counts[key] += 1
        if run["is_success"]:
            success_counts[key] += 1
    
    for (inner_idx, outer_idx), total in total_counts.items():
        successes = success_counts.get((inner_idx, outer_idx), 0)
        grid[inner_idx, outer_idx] = successes / total if total > 0 else 0
    
    plt.figure(figsize=(12, 10))
    im = plt.imshow(grid, origin='lower', cmap='viridis', interpolation='nearest', vmin=0, vmax=1)
    plt.colorbar(im, label='Fraction of Successful Runs')
    
    for i in range(len(inner_lrs)):
        for j in range(len(outer_lrs)):
            fraction = grid[i, j]
            if total_counts.get((i, j), 0) > 0:
                successes = success_counts.get((i, j), 0)
                total = total_counts.get((i, j), 0)
                text = f"{successes}/{total}"
            else:
                text = "N/A"
            plt.text(j, i, text, ha='center', va='center', color='white' if fraction < 0.5 else 'black')
    
    plt.xticks(np.arange(len(outer_lrs)), [f"{lr:.1e}" for lr in outer_lrs], rotation=45)
    plt.yticks(np.arange(len(inner_lrs)), [f"{lr:.1e}" for lr in inner_lrs])
    plt.xlabel('Outer Learning Rate')
    plt.ylabel('Inner Learning Rate')
    title = f'Aggregated Success Fraction Heatmap in {group_name}'
    if subfolder:
        title += f' ({subfolder})'
    if inner_optimizer:
        title += f' (inner_optimizer={inner_optimizer})'
    plt.title(title)
    
    group_results_dir = os.path.join(results_dir, group_name, subfolder or '')
    os.makedirs(group_results_dir, exist_ok=True)
    output_file = os.path.join(group_results_dir, 'heatmap_aggregated.png')
    plt.savefig(output_file, bbox_inches='tight')
    plt.close()
    logger.info(f"Saved aggregated heatmap in {group_name}{f'/{subfolder}' if subfolder else ''}{f' (inner_optimizer={inner_optimizer})' if inner_optimizer else ''} to {output_file}")

def create_ts_grid_heatmap(ts_groups, group_name, optimizer):
    subfolder = f"inner_optimizer_{sanitize_folder_name(optimizer)}"
    group_results_dir = os.path.join(results_dir, group_name, subfolder)
    os.makedirs(group_results_dir, exist_ok=True)
    
    fig, axes = plt.subplots(4, 5, figsize=(25, 20))
    axes = axes.flatten()
    
    ts_list = sorted(ts_groups.keys())[:20]
    
    for idx, ax in enumerate(axes):
        if idx < len(ts_list):
            ts = ts_list[idx]
            runs_data = [run for run in ts_groups[ts] if optimizer in (run["config"].get("inner_optimizer", ["unknown"]) if isinstance(run["config"].get("inner_optimizer"), list) else [run["config"].get("inner_optimizer", "unknown")])]
            outer_lrs = sorted(set(run["config"].get("outer_learning_rate") for run in runs_data))
            inner_lrs = sorted(set(run["config"].get("inner_learning_rate") for run in runs_data), reverse=True)
            
            if not outer_lrs or not inner_lrs:
                logger.warning(f"No valid learning rates for ts group {ts} in {group_name}/{subfolder}")
                ax.axis('off')
                continue
            
            grid = np.zeros((len(inner_lrs), len(outer_lrs)))
            success_counts = defaultdict(int)
            total_counts = defaultdict(int)
            
            for run in runs_data:
                outer_idx = outer_lrs.index(run["config"].get("outer_learning_rate"))
                inner_idx = inner_lrs.index(run["config"].get("inner_learning_rate"))
                key = (inner_idx, outer_idx)
                total_counts[key] += 1
                if run["is_success"]:
                    success_counts[key] += 1
            
            for (inner_idx, outer_idx), total in total_counts.items():
                successes = success_counts.get((inner_idx, outer_idx), 0)
                grid[inner_idx, outer_idx] = successes / total if total > 0 else 0
            
            im = ax.imshow(grid, origin='lower', cmap='viridis', interpolation='nearest', vmin=0, vmax=1)
            
            for i in range(len(inner_lrs)):
                for j in range(len(outer_lrs)):
                    fraction = grid[i, j]
                    if total_counts.get((i, j), 0) > 0:
                        successes = success_counts.get((i, j), 0)
                        total = total_counts.get((i, j), 0)
                        text = f"{successes}/{total}"
                    else:
                        text = "N/A"
                    ax.text(j, i, text, ha='center', va='center', color='white' if fraction < 0.5 else 'black', fontsize=10)
            
            ax.set_xticks(np.arange(len(outer_lrs)))
            ax.set_yticks(np.arange(len(inner_lrs)))
            ax.set_xticklabels([f"{lr:.1e}" for lr in outer_lrs], rotation=45, fontsize=8)
            ax.set_yticklabels([f"{lr:.1e}" for lr in inner_lrs], fontsize=8)
            
            ax.set_xlabel('Outer LR', fontsize=10)
            ax.set_ylabel('Inner LR', fontsize=10)
            ax.set_title(f"ts={ts}", fontsize=12)
        else:
            ax.axis('off')
    
    plt.suptitle(f'Success Fraction Heatmaps for inner_optimizer={optimizer}', fontsize=20)
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    plt.colorbar(im, cax=cbar_ax, label='Fraction of Successful Runs')
    
    plt.tight_layout(rect=[0, 0, 0.9, 0.95])
    output_file = os.path.join(group_results_dir, 'grid_4x5_ts_heatmaps.png')
    plt.savefig(output_file, bbox_inches='tight')
    plt.close()
    logger.info(f"Saved 4x5 grid heatmap for inner_optimizer={optimizer} in {group_name}/{subfolder} to {output_file}")

In [10]:
all_runs_data = []
for result in process_results:
    if result["status"] == "success" and result["data"]:
        all_runs_data.append(result["data"])

# Organize runs by ts groups
ts_groups = defaultdict(list)
for run in all_runs_data:
    ts = tuple(run["config"].get("ts", ()))
    ts_groups[ts].append(run)


In [11]:

# Generate heatmaps for each ts group
for idx, (ts, runs_data) in enumerate(ts_groups.items()):
    logger.info(f"Generating heatmap for ts group {ts} in {group_name}")
    create_heatmap(ts, runs_data, idx, group_name)

# Generate aggregated heatmap
logger.info(f"Generating aggregated heatmap in {group_name}")
create_aggregated_heatmap(all_runs_data, group_name)

2025-04-20 13:22:19,259 - INFO - Generating heatmap for ts group (0, 2) in mlr_search-1_aa9c06652fb34624bebe972b1fe7292f
2025-04-20 13:22:20,140 - INFO - Saved heatmap for ts=(0, 2) in mlr_search-1_aa9c06652fb34624bebe972b1fe7292f to /scratch/results/mlr_search-1_aa9c06652fb34624bebe972b1fe7292f/heatmap_ts_group_0.png
2025-04-20 13:22:20,141 - INFO - Generating heatmap for ts group (1, 3) in mlr_search-1_aa9c06652fb34624bebe972b1fe7292f
2025-04-20 13:22:20,374 - INFO - Saved heatmap for ts=(1, 3) in mlr_search-1_aa9c06652fb34624bebe972b1fe7292f to /scratch/results/mlr_search-1_aa9c06652fb34624bebe972b1fe7292f/heatmap_ts_group_1.png
2025-04-20 13:22:20,374 - INFO - Generating heatmap for ts group (2, 4) in mlr_search-1_aa9c06652fb34624bebe972b1fe7292f
2025-04-20 13:22:20,608 - INFO - Saved heatmap for ts=(2, 4) in mlr_search-1_aa9c06652fb34624bebe972b1fe7292f to /scratch/results/mlr_search-1_aa9c06652fb34624bebe972b1fe7292f/heatmap_ts_group_2.png
2025-04-20 13:22:20,608 - INFO - Generat

In [12]:
# Generate heatmaps for each optimizer
optimizer_groups = set()
for run_data in all_runs_data:
    inner_optimizer = run_data["config"].get("inner_optimizer", "unknown")
    inner_optimizer = inner_optimizer if isinstance(inner_optimizer, list) else [inner_optimizer]
    for optimizer in inner_optimizer:
        optimizer_groups.add(optimizer)

for optimizer in optimizer_groups:
    subfolder = f"inner_optimizer_{sanitize_folder_name(optimizer)}"
    logger.info(f"Generating heatmaps for inner_optimizer={optimizer} in {group_name}/{subfolder}")
    for idx, (ts, runs_data) in enumerate(ts_groups.items()):
        logger.info(f"Generating heatmap for ts group {ts} in {group_name}/{subfolder}")
        create_heatmap(ts, runs_data, idx, group_name, subfolder, inner_optimizer=optimizer)
    logger.info(f"Generating aggregated heatmap in {group_name}/{subfolder}")
    create_aggregated_heatmap(all_runs_data, group_name, subfolder, inner_optimizer=optimizer)
    
    # Generate 4x5 grid heatmap for SGD optimizer
    logger.info(f"Generating 4x5 grid heatmap for inner_optimizer={optimizer} in {group_name}/{subfolder}")
    create_ts_grid_heatmap(ts_groups, group_name, optimizer)

2025-04-20 13:22:25,086 - INFO - Generating heatmaps for inner_optimizer=sgd in mlr_search-1_aa9c06652fb34624bebe972b1fe7292f/inner_optimizer_sgd
2025-04-20 13:22:25,087 - INFO - Generating heatmap for ts group (0, 2) in mlr_search-1_aa9c06652fb34624bebe972b1fe7292f/inner_optimizer_sgd
2025-04-20 13:22:25,334 - INFO - Saved heatmap for ts=(0, 2) in mlr_search-1_aa9c06652fb34624bebe972b1fe7292f/inner_optimizer_sgd (inner_optimizer=sgd) to /scratch/results/mlr_search-1_aa9c06652fb34624bebe972b1fe7292f/inner_optimizer_sgd/heatmap_ts_group_0.png
2025-04-20 13:22:25,335 - INFO - Generating heatmap for ts group (1, 3) in mlr_search-1_aa9c06652fb34624bebe972b1fe7292f/inner_optimizer_sgd
2025-04-20 13:22:25,572 - INFO - Saved heatmap for ts=(1, 3) in mlr_search-1_aa9c06652fb34624bebe972b1fe7292f/inner_optimizer_sgd (inner_optimizer=sgd) to /scratch/results/mlr_search-1_aa9c06652fb34624bebe972b1fe7292f/inner_optimizer_sgd/heatmap_ts_group_1.png
2025-04-20 13:22:25,573 - INFO - Generating heatma

  plt.tight_layout(rect=[0, 0, 0.9, 0.95])


2025-04-20 13:22:33,392 - INFO - Saved 4x5 grid heatmap for inner_optimizer=sgd in mlr_search-1_aa9c06652fb34624bebe972b1fe7292f/inner_optimizer_sgd to /scratch/results/mlr_search-1_aa9c06652fb34624bebe972b1fe7292f/inner_optimizer_sgd/grid_4x5_ts_heatmaps.png
2025-04-20 13:22:33,393 - INFO - Generating heatmaps for inner_optimizer=sgd_normalized in mlr_search-1_aa9c06652fb34624bebe972b1fe7292f/inner_optimizer_sgd_normalized
2025-04-20 13:22:33,393 - INFO - Generating heatmap for ts group (0, 2) in mlr_search-1_aa9c06652fb34624bebe972b1fe7292f/inner_optimizer_sgd_normalized
2025-04-20 13:22:33,648 - INFO - Saved heatmap for ts=(0, 2) in mlr_search-1_aa9c06652fb34624bebe972b1fe7292f/inner_optimizer_sgd_normalized (inner_optimizer=sgd_normalized) to /scratch/results/mlr_search-1_aa9c06652fb34624bebe972b1fe7292f/inner_optimizer_sgd_normalized/heatmap_ts_group_0.png
2025-04-20 13:22:33,649 - INFO - Generating heatmap for ts group (1, 3) in mlr_search-1_aa9c06652fb34624bebe972b1fe7292f/inner

  plt.tight_layout(rect=[0, 0, 0.9, 0.95])


2025-04-20 13:22:41,881 - INFO - Saved 4x5 grid heatmap for inner_optimizer=sgd_normalized in mlr_search-1_aa9c06652fb34624bebe972b1fe7292f/inner_optimizer_sgd_normalized to /scratch/results/mlr_search-1_aa9c06652fb34624bebe972b1fe7292f/inner_optimizer_sgd_normalized/grid_4x5_ts_heatmaps.png
