# Grid Analysis with Fixed Search Overlay
This notebook processes machine learning run data from two groups (`mlr_search-1` and `fixed_search-1`) to generate grid plots of test loss by `ts` (first element), with an overlay of test loss from the `fixed_search-1` group for the best inner learning rate based on validation loss.

In [23]:
import os
import sys
sys.path.append(os.getcwd())
sys.path.append(os.getcwd() + '/../src')

import pickle
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import logging
import re
from concurrent.futures import ThreadPoolExecutor
import multiprocessing
from recurrent.parameters import AllLogs
import jax
import jax.numpy as jnp
jax.config.update("jax_platform_name", "cpu")

In [24]:
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('grid_analysis.log')
    ]
)
logger = logging.getLogger(__name__)

# Set multiprocessing start method to 'spawn' to avoid JAX fork issues
multiprocessing.set_start_method("spawn", force=True)

In [25]:
# Configuration
download_dir = "/scratch/downloaded_artifacts"
results_dir = "/scratch/results"
group_name = "parametrized_lr-1_4574d04339a645f6bc69a3022f2b316a"
fixed_group_name = "fixed_search-1_209e70fddac44ed4a229a4267c6f1976"
max_process_workers = 10
success_threshold = 0.95
y_limits = (0.45, 0.7)  # Fixed y-axis limits for all subplots
lr_failure_threshold = 0.0  # 1e-4

os.makedirs(results_dir, exist_ok=True)

In [26]:
def sanitize_folder_name(name):
    return re.sub(r'[^\w\-]', '_', str(name))

In [27]:
def process_run(run_result):
    run_id = run_result["run_id"]
    artifact_dir = run_result["artifact_dir"]
    config = run_result["config"]
    
    if run_result["status"] != "success" or not artifact_dir or not config:
        logger.warning(f"Skipping run {run_id}: download failed or no config")
        return {
            "run_id": run_id,
            "status": "skipped",
            "data": None
        }
    
    log_file = os.path.join(artifact_dir, "logs.pkl")
    if not os.path.exists(log_file):
        logger.error(f"Logs file not found for run {run_id}")
        return {
            "run_id": run_id,
            "status": "no_log_file",
            "data": None
        }
    
    try:
        with open(log_file, "rb") as f:
            logs = pickle.load(f)
        
        if not isinstance(logs, AllLogs):
            logger.error(f"Logs for run {run_id} is not an AllLogs instance")
            return {
                "run_id": run_id,
                "status": "invalid_logs",
                "data": None
            }
        
        is_success = not np.any(logs.inner_learning_rate <= lr_failure_threshold)
        final_test_loss = float(logs.test_loss[-1]) if logs.test_loss is not None else None
        final_validation_loss = float(logs.validation_loss[-1]) if logs.validation_loss is not None else None
        
        logger.info(f"Processed run {run_id}: success={is_success}, final_test_loss={final_test_loss}")
        
        return {
            "run_id": run_id,
            "status": "success",
            "data": {
                "config": config,
                "is_success": is_success,
                "final_test_loss": final_test_loss,
                "final_validation_loss": final_validation_loss,
                "ts": tuple(config.get("ts", ())),
                "inner_optimizer": config.get("inner_optimizer", "unknown"),
                "inner_learner": config.get("inner_learner", "unknown"),
                "inner_learning_rate": config.get("inner_learning_rate")
            }
        }
    except Exception as e:
        logger.error(f"Error processing logs for run {run_id}: {str(e)}")
        return {
            "run_id": run_id,
            "status": f"error: {str(e)}",
            "data": None
        }

In [28]:
def process_fixed_group_data(fixed_download_results):
    """Process fixed_search-1 data to get test loss for best inner learning rate per ts."""
    with ThreadPoolExecutor(max_workers=max_process_workers) as executor:
        fixed_process_results = list(executor.map(process_run, fixed_download_results))
    
    # Include all runs with status=success and valid data
    fixed_runs_data = [
        result["data"] for result in fixed_process_results
        if result["status"] == "success" and result["data"]
    ]
    
    # Organize by ts
    ts_runs = defaultdict(list)
    for run in fixed_runs_data:
        ts = run["ts"]
        ts_runs[ts].append(run)
    
    # For each ts, find the inner learning rate with lowest average validation loss
    ts_stats = {}
    for ts in ts_runs:
        runs = ts_runs[ts]
        # Group by inner learning rate
        lr_val_losses = defaultdict(list)
        for run in runs:
            if run["is_success"] and run["final_validation_loss"] is not None:
                lr_val_losses[run["inner_learning_rate"]].append(run["final_validation_loss"])
        
        # Find the inner learning rate with lowest average validation loss
        best_lr = None
        lowest_avg_val_loss = float('inf')
        for lr, val_losses in lr_val_losses.items():
            avg_val_loss = np.mean(val_losses)
            if avg_val_loss < lowest_avg_val_loss:
                lowest_avg_val_loss = avg_val_loss
                best_lr = lr
        
        if best_lr is not None:
            # Collect test losses for runs with the best inner learning rate
            test_losses = [
                run["final_test_loss"] for run in runs
                if run["inner_learning_rate"] == best_lr and run["is_success"] and run["final_test_loss"] is not None
            ]
            if test_losses:
                mean_test_loss = np.mean(test_losses)
                std_test_loss = np.std(test_losses) / np.sqrt(len(test_losses)) if len(test_losses) > 1 else 0
                ts_stats[ts] = {
                    "mean_test_loss": mean_test_loss,
                    "std_error": std_test_loss,
                    "num_runs": len(test_losses),
                    "best_lr": best_lr
                }
    
    return ts_stats

In [29]:
def save_best_lr_figure(fixed_ts_stats, group_name, optimizer, learner):
    """Save a figure listing the best inner learning rates for each ts."""
    fig = plt.figure(figsize=(4, len(fixed_ts_stats) * 0.5 + 1))
    lr_text = "Best Inner Learning Rates (Fixed Search):\n"
    for ts in sorted(fixed_ts_stats.keys(), key=lambda x: x[0]):
        best_lr = fixed_ts_stats[ts].get("best_lr")
        if best_lr is not None:
            lr_text += f"ts={int(ts[0])}: {best_lr:.1e}\n"
    
    plt.text(0.1, 0.9, lr_text, fontsize=10, verticalalignment='top')
    plt.axis('off')
    
    # Save figure
    subfolder = f"opt_{sanitize_folder_name(optimizer)}_learner_{sanitize_folder_name(learner)}"
    group_results_dir = os.path.join(results_dir, group_name, subfolder)
    os.makedirs(group_results_dir, exist_ok=True)
    output_file = os.path.join(group_results_dir, 'best_learning_rates.png')
    plt.savefig(output_file, bbox_inches='tight')
    plt.close()
    logger.info(f"Saved best learning rates figure for inner_optimizer={optimizer}, inner_learner={learner} to {output_file}")

def create_grid_plot(runs_data, fixed_ts_stats, optimizer, learner, group_name):
    # Filter runs by optimizer and learner
    filtered_runs = [
        run for run in runs_data
        if run["config"].get("inner_optimizer", "unknown") == optimizer
        and run["config"].get("inner_learner", "unknown") == learner
    ]
    
    if not filtered_runs:
        logger.warning(f"No runs found for inner_optimizer={optimizer}, inner_learner={learner}")
        return
    
    # Organize data by outer and inner learning rates
    lr_grid = defaultdict(list)
    for run in filtered_runs:
        outer_lr = run["config"].get("outer_learning_rate")
        inner_lr = run["config"].get("inner_learning_rate")
        lr_grid[(outer_lr, inner_lr)].append(run)
    
    # Compute success fractions and collect test losses
    outer_lrs = sorted(set(run["config"].get("outer_learning_rate") for run in filtered_runs))
    inner_lrs = sorted(set(run["config"].get("inner_learning_rate") for run in filtered_runs))
    ts_list = sorted(set(run["ts"] for run in filtered_runs), key=lambda x: x[0])
    
    if not outer_lrs or not inner_lrs:
        logger.warning(f"No valid learning rates for inner_optimizer={optimizer}, inner_learner={learner}")
        return
    
    # Create grid plot
    fig, axes = plt.subplots(len(inner_lrs), len(outer_lrs), figsize=(len(outer_lrs) * 4, len(inner_lrs) * 4), squeeze=False)
    
    for i, inner_lr in enumerate(inner_lrs):
        for j, outer_lr in enumerate(outer_lrs):
            ax = axes[i, j]
            runs = lr_grid.get((outer_lr, inner_lr), [])
            
            # Calculate success fraction
            total_runs = len(runs)
            successful_runs = len([run for run in runs if run["is_success"]])
            success_fraction = successful_runs / total_runs if total_runs > 0 else 0
            
            if total_runs > 0:
                # Collect test losses by ts
                ts_test_losses = defaultdict(list)
                for run in runs:
                    if run["is_success"] and run["final_test_loss"] is not None:
                        ts_test_losses[run["ts"]].append(run["final_test_loss"])
                
                # Compute averages and standard errors
                means = []
                errors = []
                valid_ts = []
                for ts in ts_list:
                    losses = ts_test_losses.get(ts, [])
                    if losses:
                        mean_loss = np.mean(losses)
                        std_loss = np.std(losses) / np.sqrt(len(losses)) if len(losses) > 1 else 0
                        means.append(mean_loss)
                        errors.append(std_loss)
                        valid_ts.append(ts)
                
                if means:
                    # Plot original data (blue)
                    x = range(len(valid_ts))
                    ax.errorbar(x, means, yerr=errors, fmt='o-', color='blue', capsize=5, capthick=2, elinewidth=2, label='MLR Search')
                    ax.set_xticks(x)
                    ax.set_xticklabels([f"{int(ts[0])}" for ts in valid_ts], rotation=45)
                    ax.set_xlabel('Task Difficulty')
                    ax.set_ylabel('Average Final Test Loss')
                    ax.set_title(f'Success: {successful_runs}/{total_runs}')
                    ax.set_ylim(y_limits)  # Fixed y-scale
                    ax.grid(True)
                    
                    # Overlay fixed_search-1 data (red)
                    fixed_means = []
                    fixed_errors = []
                    for ts in valid_ts:
                        stats = fixed_ts_stats.get(ts, {})
                        fixed_means.append(stats.get("mean_test_loss", np.nan))
                        fixed_errors.append(stats.get("std_error", 0))
                    
                    # Plot only if there are valid fixed search data points
                    if any(np.isfinite(fixed_means)):
                        ax.errorbar(x, fixed_means, yerr=fixed_errors, fmt='s--', color='red', capsize=5, capthick=2, elinewidth=2, label='Fixed Search')
                    
                    # Add legend
                    ax.legend()
                else:
                    ax.axis('off')
            else:
                ax.axis('off')
    
    # Add hyperaxis labels for outer and inner learning rates
    for j, outer_lr in enumerate(outer_lrs):
        fig.text(
            (j + 0.5) / len(outer_lrs), 1.01,
            f'Outer LR: {outer_lr:.1e}',
            ha='center', va='bottom', fontsize=10, transform=fig.transFigure
        )
    
    for i, inner_lr in enumerate(inner_lrs):
        fig.text(
            -0.01, (len(inner_lrs) - i - 0.5) / len(inner_lrs),
            f'Inner LR: {inner_lr:.1e}',
            ha='right', va='center', fontsize=10, transform=fig.transFigure
        )
    
    plt.suptitle(f'Average Final Test Loss by ts\n(Group: {group_name}, Optimizer: {optimizer}, Learner: {learner}, Success Threshold: {success_threshold})', fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    
    # Save plot
    subfolder = f"opt_{sanitize_folder_name(optimizer)}_learner_{sanitize_folder_name(learner)}"
    group_results_dir = os.path.join(results_dir, group_name, subfolder)
    os.makedirs(group_results_dir, exist_ok=True)
    output_file = os.path.join(group_results_dir, 'test_loss_grid_with_fixed.png')
    plt.savefig(output_file, bbox_inches='tight')
    plt.close()
    logger.info(f"Saved test loss grid plot with fixed search overlay for inner_optimizer={optimizer}, inner_learner={learner} to {output_file}")

In [30]:
# Load and process main group data
download_results_file = os.path.join(download_dir, f'download_results_{group_name}.pkl')
if not os.path.exists(download_results_file):
    logger.error(f"Download results file not found at {download_results_file}")
    sys.exit(1)

with open(download_results_file, 'rb') as f:
    download_results = pickle.load(f)

with ThreadPoolExecutor(max_workers=max_process_workers) as executor:
    process_results = list(executor.map(process_run, download_results))


2025-04-20 14:39:22,917 - INFO - Processed run g4y3ilcy: success=True, final_test_loss=0.4663831293582916
2025-04-20 14:39:22,918 - INFO - Processed run go6xvdx8: success=False, final_test_loss=0.5529025197029114
2025-04-20 14:39:22,918 - INFO - Processed run kz5maoum: success=True, final_test_loss=0.45876312255859375
2025-04-20 14:39:22,919 - INFO - Processed run m8i6e6s9: success=True, final_test_loss=0.46501439809799194
2025-04-20 14:39:22,919 - INFO - Processed run ri6v61bd: success=True, final_test_loss=0.5662210583686829
2025-04-20 14:39:22,920 - INFO - Processed run i9ui9frf: success=True, final_test_loss=0.458822101354599
2025-04-20 14:39:22,920 - INFO - Processed run t42w9ab1: success=True, final_test_loss=0.4944952428340912
2025-04-20 14:39:22,920 - INFO - Processed run duoaifgw: success=True, final_test_loss=0.45891308784484863
2025-04-20 14:39:22,921 - INFO - Processed run 0me2ttuh: success=True, final_test_loss=0.4589952528476715
2025-04-20 14:39:22,922 - INFO - Processed 

In [31]:
all_runs_data = []
for result in process_results:
    if result["status"] == "success" and result["data"]:
        all_runs_data.append(result["data"])

In [32]:
# Load and process fixed_search-1 data
fixed_download_results_file = os.path.join(download_dir, f'download_results_{fixed_group_name}.pkl')
if not os.path.exists(fixed_download_results_file):
    logger.error(f"Fixed search results file not found at {fixed_download_results_file}")
    sys.exit(1)

with open(fixed_download_results_file, 'rb') as f:
    fixed_download_results = pickle.load(f)

fixed_ts_stats = process_fixed_group_data(fixed_download_results)

2025-04-20 14:43:43,584 - INFO - Processed run 1zuko15r: success=True, final_test_loss=0.6171742081642151
2025-04-20 14:43:43,586 - INFO - Processed run 8dsff0nn: success=True, final_test_loss=0.6753073334693909
2025-04-20 14:43:43,606 - INFO - Processed run be0cvx1e: success=True, final_test_loss=0.7680957913398743
2025-04-20 14:43:43,608 - INFO - Processed run inu643zs: success=True, final_test_loss=0.6999733448028564
2025-04-20 14:43:43,608 - INFO - Processed run 8vmysoek: success=True, final_test_loss=0.671892523765564
2025-04-20 14:43:43,610 - INFO - Processed run 064xikfd: success=True, final_test_loss=0.6653053760528564
2025-04-20 14:43:43,610 - INFO - Processed run 13ddsopj: success=True, final_test_loss=0.6741185784339905
2025-04-20 14:43:43,612 - INFO - Processed run s7kv66jm: success=True, final_test_loss=0.6947146058082581
2025-04-20 14:43:43,613 - INFO - Processed run nowocf97: success=True, final_test_loss=0.5980479121208191
2025-04-20 14:43:43,613 - INFO - Processed run 

In [33]:
# Generate plots
optimizer_learner_pairs = set()
for run in all_runs_data:
    optimizer = run["config"].get("inner_optimizer", "unknown")
    learner = run["config"].get("inner_learner", "unknown")
    optimizer_learner_pairs.add((optimizer, learner))

for optimizer, learner in sorted(optimizer_learner_pairs):
    logger.info(f"Generating grid plot for inner_optimizer={optimizer}, inner_learner={learner}")
    create_grid_plot(all_runs_data, fixed_ts_stats, optimizer, learner, group_name)
    save_best_lr_figure(fixed_ts_stats, group_name, optimizer, learner)

2025-04-20 14:44:41,016 - INFO - Generating grid plot for inner_optimizer=sgd, inner_learner=rtrl


2025-04-20 14:44:47,205 - INFO - Saved test loss grid plot with fixed search overlay for inner_optimizer=sgd, inner_learner=rtrl to /scratch/results/parametrized_lr-1_4574d04339a645f6bc69a3022f2b316a/opt_sgd_learner_rtrl/test_loss_grid_with_fixed.png
2025-04-20 14:44:47,259 - INFO - Saved best learning rates figure for inner_optimizer=sgd, inner_learner=rtrl to /scratch/results/parametrized_lr-1_4574d04339a645f6bc69a3022f2b316a/opt_sgd_learner_rtrl/best_learning_rates.png
2025-04-20 14:44:47,260 - INFO - Generating grid plot for inner_optimizer=sgd_normalized, inner_learner=rtrl
2025-04-20 14:44:53,142 - INFO - Saved test loss grid plot with fixed search overlay for inner_optimizer=sgd_normalized, inner_learner=rtrl to /scratch/results/parametrized_lr-1_4574d04339a645f6bc69a3022f2b316a/opt_sgd_normalized_learner_rtrl/test_loss_grid_with_fixed.png
2025-04-20 14:44:53,197 - INFO - Saved best learning rates figure for inner_optimizer=sgd_normalized, inner_learner=rtrl to /scratch/results/