In [None]:
# %% [markdown]
# # Grid Analysis with Whisker Plots for Test Loss and Accuracy
# This notebook processes machine learning run data from two groups (`mlr_search` and `fixed_search`) to generate a grid plot of inner learning rate by outer learning rate. Each grid cell contains four whisker plots: test loss and accuracy for the main group, and test loss and accuracy for the fixed group using the best inner learning rate based on lowest average validation loss.

# %%
import os
import sys
sys.path.append(os.getcwd())
sys.path.append(os.path.join(os.getcwd(), '../src'))



In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import logging
import re
from concurrent.futures import ThreadPoolExecutor
import multiprocessing
from recurrent.parameters import AllLogs
import jax
import jax.numpy as jnp
jax.config.update("jax_platform_name", "cpu")



In [None]:
# %%
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('grid_analysis_whisker.log')
    ]
)
logger = logging.getLogger(__name__)

# Set multiprocessing start method to 'spawn' to avoid JAX fork issues
multiprocessing.set_start_method("spawn", force=True)



In [None]:
# %%
# Configuration
download_dir = "/scratch/downloaded_artifacts"
results_dir = "/scratch/results"
group_name = "mlr_search-5_abcdef123456"
fixed_group_name = "fixed_search-5_ghijk789012"
max_process_workers = 10
success_threshold = 0.95
y_limits_loss = (0.4, 0.8)  # Y-axis limits for test loss
y_limits_acc = (0.0, 1.0)  # Y-axis limits for accuracy
lr_failure_threshold = 1e-4
config_keys = ["inner_optimizer", "inner_learner"]  # Config keys for grouping

os.makedirs(results_dir, exist_ok=True)

# %%


In [None]:
def sanitize_folder_name(name):
    return re.sub(r'[^\w\-]', '_', str(name))

# %%
def process_run(run_result):
    run_id = run_result["run_id"]
    artifact_dir = run_result["artifact_dir"]
    config = run_result["config"]
    summary = run_result.get("summary", {})
    
    if run_result["status"] != "success" or not artifact_dir or not config:
        logger.warning(f"Skipping run {run_id}: download failed or no config")
        return {
            "run_id": run_id,
            "status": "skipped",
            "data": None
        }
    
    log_file = os.path.join(artifact_dir, "logs.pkl")
    if not os.path.exists(log_file):
        logger.error(f"Logs file not found for run {run_id}")
        return {
            "run_id": run_id,
            "status": "no_log_file",
            "data": None
        }
    
    try:
        with open(log_file, "rb") as f:
            logs = pickle.load(f)
        
        if not isinstance(logs, AllLogs):
            logger.error(f"Logs for run {run_id} is not an AllLogs instance")
            return {
                "run_id": run_id,
                "status": "invalid_logs",
                "data": None
            }
        
        is_success = not np.any(logs.inner_learning_rate <= lr_failure_threshold)
        final_test_loss = float(logs.test_loss[-1]) if logs.test_loss is not None else None
        final_validation_loss = float(logs.validation_loss[-1]) if logs.validation_loss is not None else None
        test_statistic = float(summary.get("test_statistic")) if summary.get("test_statistic") is not None else None
        
        logger.info(f"Processed run {run_id}: success={is_success}, final_test_loss={final_test_loss}, accuracy={test_statistic}")
        
        return {
            "run_id": run_id,
            "status": "success",
            "data": {
                "config": config,
                "is_success": is_success,
                "final_test_loss": final_test_loss,
                "final_validation_loss": final_validation_loss,
                "test_statistic": test_statistic,
                "inner_optimizer": config.get("inner_optimizer", "unknown"),
                "inner_learner": config.get("inner_learner", "unknown"),
                "inner_learning_rate": config.get("inner_learning_rate"),
                "outer_learning_rate": config.get("outer_learning_rate")
            }
        }
    except Exception as e:
        logger.error(f"Error processing logs for run {run_id}: {str(e)}")
        return {
            "run_id": run_id,
            "status": f"error: {str(e)}",
            "data": None
        }

# %%
def process_fixed_group_data(fixed_download_results):
    """Process fixed_search data to get test loss and accuracy for best inner learning rate per outer learning rate."""
    with ThreadPoolExecutor(max_workers=max_process_workers) as executor:
        fixed_process_results = list(executor.map(process_run, fixed_download_results))
    
    # Include all runs with status=success and valid data
    fixed_runs_data = [
        result["data"] for result in fixed_process_results
        if result["status"] == "success" and result["data"]
    ]
    
    # Organize by outer learning rate
    outer_lr_runs = defaultdict(list)
    for run in fixed_runs_data:
        outer_lr = run["outer_learning_rate"]
        outer_lr_runs[outer_lr].append(run)
    
    # For each outer learning rate, find the inner learning rate with lowest average validation loss
    outer_lr_stats = {}
    for outer_lr in outer_lr_runs:
        runs = outer_lr_runs[outer_lr]
        # Group by inner learning rate
        lr_val_losses = defaultdict(list)
        for run in runs:
            if run["is_success"] and run["final_validation_loss"] is not None:
                lr_val_losses[run["inner_learning_rate"]].append(run["final_validation_loss"])
        
        # Find the inner learning rate with lowest average validation loss
        best_lr = None
        lowest_avg_val_loss = float('inf')
        for lr, val_losses in lr_val_losses.items():
            avg_val_loss = np.mean(val_losses)
            if avg_val_loss < lowest_avg_val_loss:
                lowest_avg_val_loss = avg_val_loss
                best_lr = lr
        
        if best_lr is not None:
            # Collect test losses and accuracies for runs with the best inner learning rate
            test_losses = [
                run["final_test_loss"] for run in runs
                if run["inner_learning_rate"] == best_lr and run["is_success"] and run["final_test_loss"] is not None
            ]
            accuracies = [
                run["test_statistic"] for run in runs
                if run["inner_learning_rate"] == best_lr and run["is_success"] and run["test_statistic"] is not None
            ]
            if test_losses or accuracies:
                outer_lr_stats[outer_lr] = {
                    "test_losses": test_losses,
                    "accuracies": accuracies,
                    "best_lr": best_lr
                }
    
    return outer_lr_stats

# %%
def save_best_lr_figure(fixed_outer_lr_stats, group_name, config_combination):
    """Save a figure listing the best inner learning rates for each outer learning rate."""
    fig = plt.figure(figsize=(4, len(fixed_outer_lr_stats) * 0.5 + 1))
    lr_text = "Best Inner Learning Rates (Fixed Search):\n"
    for outer_lr in sorted(fixed_outer_lr_stats.keys()):
        best_lr = fixed_outer_lr_stats[outer_lr].get("best_lr")
        if best_lr is not None:
            lr_text += f"Outer LR={outer_lr:.1e}: {best_lr:.1e}\n"
    
    plt.text(0.1, 0.9, lr_text, fontsize=10, verticalalignment='top')
    plt.axis('off')
    
    # Create subfolder based on config combination
    subfolder = "_".join(
        f"{sanitize_folder_name(key)}_{sanitize_folder_name(value)}"
        for key, value in config_combination
    )
    group_results_dir = os.path.join(results_dir, group_name, subfolder)
    os.makedirs(group_results_dir, exist_ok=True)
    output_file = os.path.join(group_results_dir, 'best_learning_rates.png')
    plt.savefig(output_file, bbox_inches='tight')
    plt.close()
    logger.info(f"Saved best learning rates figure for {subfolder} to {output_file}")

# %%
def create_grid_plot(runs_data, fixed_outer_lr_stats, config_combination, group_name):
    # Filter runs by the specified configuration combination
    filtered_runs = [
        run for run in runs_data
        if all(run["config"].get(key, "unknown") == value for key, value in config_combination)
    ]
    
    if not filtered_runs:
        logger.warning(f"No runs found for config combination: {config_combination}")
        return
    
    # Organize data by outer and inner learning rates
    lr_grid = defaultdict(list)
    for run in filtered_runs:
        outer_lr = run["config"].get("outer_learning_rate")
        inner_lr = run["config"].get("inner_learning_rate")
        lr_grid[(outer_lr, inner_lr)].append(run)
    
    # Collect outer and inner learning rates
    outer_lrs = sorted(set(run["config"].get("outer_learning_rate") for run in filtered_runs))
    inner_lrs = sorted(set(run["config"].get("inner_learning_rate") for run in filtered_runs))
    
    if not outer_lrs or not inner_lrs:
        logger.warning(f"No valid learning rates for config combination: {config_combination}")
        return
    
    # Create grid plot
    fig, axes = plt.subplots(len(inner_lrs), len(outer_lrs), figsize=(len(outer_lrs) * 4, len(inner_lrs) * 4), squeeze=False)
    
    for i, inner_lr in enumerate(inner_lrs):
        for j, outer_lr in enumerate(outer_lrs):
            ax = axes[i, j]
            runs = lr_grid.get((outer_lr, inner_lr), [])
            
            if not runs:
                ax.axis('off')
                continue
            
            # Collect test losses and accuracies for main group
            test_losses = [
                run["final_test_loss"] for run in runs
                if run["is_success"] and run["final_test_loss"] is not None
            ]
            accuracies = [
                run["test_statistic"] for run in runs
                if run["is_success"] and run["test_statistic"] is not None
            ]
            
            # Get test losses and accuracies for fixed group
            fixed_test_losses = fixed_outer_lr_stats.get(outer_lr, {}).get("test_losses", [])
            fixed_accuracies = fixed_outer_lr_stats.get(outer_lr, {}).get("accuracies", [])
            
            # Prepare data for box plots
            data = []
            labels = []
            colors = []
            if test_losses:
                data.append(test_losses)
                labels.append('MLR Loss')
                colors.append('blue')
            if accuracies:
                data.append(accuracies)
                labels.append('MLR Acc')
                colors.append('green')
            if fixed_test_losses:
                data.append(fixed_test_losses)
                labels.append('Fixed Loss')
                colors.append('red')
            if fixed_accuracies:
                data.append(fixed_accuracies)
                labels.append('Fixed Acc')
                colors.append('purple')
            
            if data:
                # Create box plots with asymmetric whiskers
                box = ax.boxplot(
                    data,
                    labels=labels,
                    patch_artist=True,
                    showmeans=True,
                    whis=[5, 95],  # Show 5th and 95th percentiles for asymmetric whiskers
                    widths=0.2 if len(data) == 4 else (0.25 if len(data) == 3 else (0.33 if len(data) == 2 else 0.5))
                )
                
                # Customize box plot colors
                for patch, color in zip(box['boxes'], colors):
                    patch.set_facecolor(color)
                    patch.set_alpha(0.6)
                
                # Set y-limits based on data type (loss or accuracy)
                y_min = min(y_limits_loss[0], y_limits_acc[0])
                y_max = max(y_limits_loss[1], y_limits_acc[1])
                ax.set_ylim(y_min, y_max)
                
                # Add secondary y-axis for accuracy
                ax2 = ax.twinx()
                ax2.set_ylim(y_limits_acc)
                ax2.set_ylabel('Accuracy', color='green')
                ax2.tick_params(axis='y', labelcolor='green')
                
                # Set primary y-axis for loss
                ax.set_ylim(y_limits_loss)
                ax.set_ylabel('Test Loss', color='blue')
                ax.tick_params(axis='y', labelcolor='blue')
                
                ax.grid(True, linestyle='--', alpha=0.7)
                
                # Calculate success fraction for main group
                total_runs = len(runs)
                successful_runs = len([run for run in runs if run["is_success"]])
                ax.set_title(f'Success: {successful_runs}/{total_runs}', fontsize=8)
            else:
                ax.axis('off')
    
    # Add hyperaxis labels for outer and inner learning rates
    for j, outer_lr in enumerate(outer_lrs):
        fig.text(
            (j + 0.5) / len(outer_lrs), 1.01,
            f'Outer LR: {outer_lr:.1e}',
            ha='center', va='bottom', fontsize=10, transform=fig.transFigure
        )
    
    for i, inner_lr in enumerate(inner_lrs):
        fig.text(
            -0.01, (len(inner_lrs) - i - 0.5) / len(inner_lrs),
            f'Inner LR: {inner_lr:.1e}',
            ha='right', va='center', fontsize=10, transform=fig.transFigure
        )
    
    # Create title based on config combination
    config_title = ", ".join(f"{key}: {value}" for key, value in config_combination)
    plt.suptitle(f'Test Loss and Accuracy by Learning Rates\n(Group: {group_name}, {config_title})', fontsize=14)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    
    # Save plot
    subfolder = "_".join(
        f"{sanitize_folder_name(key)}_{sanitize_folder_name(value)}"
        for key, value in config_combination
    )
    group_results_dir = os.path.join(results_dir, group_name, subfolder)
    os.makedirs(group_results_dir, exist_ok=True)
    output_file = os.path.join(group_results_dir, 'test_loss_accuracy_whisker_grid.png')
    plt.savefig(output_filecents, bbox_inches='tight')
    plt.close()
    logger.info(f"Saved whisker grid plot for {subfolder} to {output_file}")



In [None]:
# %%
# Load and process main group data
download_results_file = os.path.join(download_dir, f'download_results_{group_name}.pkl')
if not os.path.exists(download_results_file):
    logger.error(f"Download results file not found at {download_results_file}")
    sys.exit(1)

with open(download_results_file, 'rb') as f:
    download_results = pickle.load(f)

with ThreadPoolExecutor(max_workers=max_process_workers) as executor:
    process_results = list(executor.map(process_run, download_results))



In [None]:
all_runs_data = [
    result["data"] for result in process_results
    if result["status"] == "success" and result["data"]
]



In [None]:
# %%
# Load and process fixed_search data
fixed_download_results_file = os.path.join(download_dir, f'download_results_{fixed_group_name}.pkl')
if not os.path.exists(fixed_download_results_file):
    logger.error(f"Fixed search results file not found at {fixed_download_results_file}")
    sys.exit(1)

with open(fixed_download_results_file, 'rb') as f:
    fixed_download_results = pickle.load(f)

fixed_outer_lr_stats = process_fixed_group_data(fixed_download_results)



In [None]:
# %%
# Generate plots for each config combination
config_combinations = set()
for run in all_runs_data:
    combination = tuple((key, run["config"].get(key, "unknown")) for key in config_keys)
    config_combinations.add(combination)

for config_combination in sorted(config_combinations):
    logger.info(f"Generating whisker plot for config combination: {config_combination}")
    create_grid_plot(all_runs_data, fixed_outer_lr_stats, config_combination, group_name)
    save_best_lr_figure(fixed_outer_lr_stats, group_name, config_combination)