In [1]:
import os
import sys
sys.path.append(os.getcwd())
sys.path.append(os.getcwd() + '/../src')

In [2]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import logging
import re
from concurrent.futures import ThreadPoolExecutor
import multiprocessing
from recurrent.parameters import AllLogs
import jax
import jax.numpy as jnp
jax.config.update("jax_platform_name", "cpu")

In [3]:
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('grid_analysis.log')
    ]
)
logger = logging.getLogger(__name__)

# Set multiprocessing start method to 'spawn' to avoid JAX fork issues
multiprocessing.set_start_method("spawn", force=True)

In [4]:
# Configuration
download_dir = "/scratch/downloaded_artifacts"
results_dir = "/scratch/results"
group_name = "mlr_search-1_aa9c06652fb34624bebe972b1fe7292f"
max_process_workers = 10
success_threshold = 0.95

In [5]:
os.makedirs(results_dir, exist_ok=True)

In [6]:
def sanitize_folder_name(name):
    return re.sub(r'[^\w\-]', '_', str(name))

In [7]:
# Function to process a single run's logs
def process_run(run_result):
    run_id = run_result["run_id"]
    artifact_dir = run_result["artifact_dir"]
    config = run_result["config"]
    
    if run_result["status"] != "success" or not artifact_dir or not config:
        logger.warning(f"Skipping run {run_id}: download failed or no config")
        return {
            "run_id": run_id,
            "status": "skipped",
            "data": None
        }
    
    log_file = os.path.join(artifact_dir, "logs.pkl")
    if not os.path.exists(log_file):
        logger.error(f"Logs file not found for run {run_id}")
        return {
            "run_id": run_id,
            "status": "no_log_file",
            "data": None
        }
    
    try:
        with open(log_file, "rb") as f:
            logs = pickle.load(f)
        
        if not isinstance(logs, AllLogs):
            logger.error(f"Logs for run {run_id} is not an AllLogs instance")
            return {
                "run_id": run_id,
                "status": "invalid_logs",
                "data": None
            }
        
        is_success = not np.any(logs.hyperparameters == 1e-4)
        final_test_loss = float(logs.test_loss[-1]) if is_success and logs.test_loss is not None else None
        
        logger.info(f"Processed run {run_id}: success={is_success}, final_test_loss={final_test_loss}")
        
        return {
            "run_id": run_id,
            "status": "success",
            "data": {
                "config": config,
                "is_success": is_success,
                "final_test_loss": final_test_loss,
                "ts": tuple(config.get("ts", ())),
                "inner_optimizer": config.get("inner_optimizer", "unknown"),
                "inner_learner": config.get("inner_learner", "unknown")
            }
        }
    except Exception as e:
        logger.error(f"Error processing logs for run {run_id}: {str(e)}")
        return {
            "run_id": run_id,
            "status": f"error: {str(e)}",
            "data": None
        }

In [17]:
def create_grid_plot(runs_data, optimizer, learner, group_name):
    # Filter runs by optimizer and learner
    filtered_runs = [
        run for run in runs_data
        if run["config"].get("inner_optimizer", "unknown") == optimizer
        and run["config"].get("inner_learner", "unknown") == learner
    ]
    
    if not filtered_runs:
        logger.warning(f"No runs found for inner_optimizer={optimizer}, inner_learner={learner}")
        return
    
    # Organize data by outer and inner learning rates
    lr_grid = defaultdict(list)
    for run in filtered_runs:
        outer_lr = run["config"].get("outer_learning_rate")
        inner_lr = run["config"].get("inner_learning_rate")
        lr_grid[(outer_lr, inner_lr)].append(run)
    
    # Compute success fractions and collect test losses
    outer_lrs = sorted(set(run["config"].get("outer_learning_rate") for run in filtered_runs))
    inner_lrs = sorted(set(run["config"].get("inner_learning_rate") for run in filtered_runs))
    ts_list = sorted(set(run["ts"] for run in filtered_runs), key=lambda x: x[0])
    
    if not outer_lrs or not inner_lrs:
        logger.warning(f"No valid learning rates for inner_optimizer={optimizer}, inner_learner={learner}")
        return
    
    # Create grid plot
    fig, axes = plt.subplots(len(inner_lrs), len(outer_lrs), figsize=(len(outer_lrs) * 4, len(inner_lrs) * 4), squeeze=False)
    
    for i, inner_lr in enumerate(inner_lrs):
        for j, outer_lr in enumerate(outer_lrs):
            ax = axes[i, j]
            runs = lr_grid.get((outer_lr, inner_lr), [])
            
            # Calculate success fraction
            total_runs = len(runs)
            successful_runs = len([run for run in runs if run["is_success"]])
            success_fraction = successful_runs / total_runs if total_runs > 0 else 0
            
            if success_fraction >= success_threshold and total_runs > 0:
                # Collect test losses by ts
                ts_test_losses = defaultdict(list)
                for run in runs:
                    if run["is_success"] and run["final_test_loss"] is not None:
                        ts_test_losses[run["ts"]].append(run["final_test_loss"])
                
                # Compute averages and standard errors
                means = []
                errors = []
                valid_ts = []
                for ts in ts_list:
                    losses = ts_test_losses.get(ts, [])
                    if losses:
                        mean_loss = np.mean(losses)
                        std_loss = np.std(losses) / np.sqrt(len(losses)) if len(losses) > 1 else 0
                        means.append(mean_loss)
                        errors.append(std_loss)
                        valid_ts.append(ts)
                
                if means:
                    # Plot means with error bars
                    x = range(len(valid_ts))
                    ax.errorbar(x, means, yerr=errors, fmt='o-', capsize=5, capthick=2, elinewidth=2)
                    ax.set_xticks(x)
                    ax.set_xticklabels([f"{int(ts[0])}" for ts in valid_ts], rotation=45)
                    ax.set_xlabel('ts (first element)')
                    ax.set_ylabel('Average Final Test Loss')
                    ax.set_title(f'Success: {successful_runs}/{total_runs}')
                    ax.set_ylim(y_limits)  # Set fixed y-scale
                    ax.grid(True)
                else:
                    ax.axis('off')
            else:
                ax.axis('off')
    
    # Add hyperaxis labels for outer and inner learning rates
    for j, outer_lr in enumerate(outer_lrs):
        fig.text(
            (j + 0.5) / len(outer_lrs), 1.01,
            f'Outer LR: {outer_lr:.1e}',
            ha='center', va='bottom', fontsize=10, transform=fig.transFigure
        )
    
    for i, inner_lr in enumerate(inner_lrs):
        fig.text(
            -0.01, (len(inner_lrs) - i - 0.5) / len(inner_lrs),
            f'Inner LR: {inner_lr:.1e}',
            ha='right', va='center', fontsize=10, transform=fig.transFigure
        )
    
    plt.suptitle(f'Average Final Test Loss by ts\n(Group: {group_name}, Optimizer: {optimizer}, Learner: {learner}, Success Threshold: {success_threshold})', fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    
    # Save plot
    subfolder = f"opt_{sanitize_folder_name(optimizer)}_learner_{sanitize_folder_name(learner)}"
    group_results_dir = os.path.join(results_dir, group_name, subfolder)
    os.makedirs(group_results_dir, exist_ok=True)
    output_file = os.path.join(group_results_dir, 'test_loss_grid.png')
    plt.savefig(output_file, bbox_inches='tight')
    plt.close()
    logger.info(f"Saved test loss grid plot for inner_optimizer={optimizer}, inner_learner={learner} to {output_file}")

In [9]:
download_results_file = os.path.join(download_dir, f'download_results_{group_name}.pkl')
if not os.path.exists(download_results_file):
    logger.error(f"Download results file not found at {download_results_file}")

In [10]:
with open(download_results_file, 'rb') as f:
    download_results = pickle.load(f)

In [11]:
with ThreadPoolExecutor(max_workers=max_process_workers) as executor:
    process_results = list(executor.map(process_run, download_results))

INFO:2025-04-19 22:39:34,394:jax._src.xla_bridge:925: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'


2025-04-19 22:39:34,394 - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'


INFO:2025-04-19 22:39:34,396:jax._src.xla_bridge:925: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


2025-04-19 22:39:34,396 - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-04-19 22:39:34,741 - INFO - Processed run e0fh6y0e: success=False, final_test_loss=None
2025-04-19 22:39:34,743 - INFO - Processed run 9ob4c701: success=False, final_test_loss=None
2025-04-19 22:39:34,981 - INFO - Processed run isv73zvp: success=True, final_test_loss=0.46704187989234924
2025-04-19 22:39:34,981 - INFO - Processed run 79ygnuzq: success=True, final_test_loss=0.4624350368976593
2025-04-19 22:39:34,981 - INFO - Processed run gvfaz9pa: success=True, final_test_loss=0.4585985839366913
2025-04-19 22:39:34,981 - INFO - Processed run 1nr7lnlj: success=True, final_test_loss=0.4590325653553009
2025-04-19 22:39:34,982 - INFO - Processed run 7u6yxcmw: success=True, final_test_loss=0.45910096168518066
2025-04-19 22:39:34,982 - INFO - Processed run ixrtrjdn: success=True, final_test_loss=0.46439990401268005
2

In [12]:
all_runs_data = []
for result in process_results:
    if result["status"] == "success" and result["data"] and result["data"]["is_success"]:
        all_runs_data.append(result["data"])

In [13]:
optimizer_learner_pairs = set()
for run in all_runs_data:
    optimizer = run["config"].get("inner_optimizer", "unknown")
    learner = run["config"].get("inner_learner", "unknown")
    optimizer_learner_pairs.add((optimizer, learner))

In [19]:
y_limits = (0.45, 0.7)  # Fixed y-axis limits for all subplots
# Generate grid plots for each optimizer and learner combination
for optimizer, learner in sorted(optimizer_learner_pairs):
    logger.info(f"Generating grid plot for inner_optimizer={optimizer}, inner_learner={learner}")
    create_grid_plot(all_runs_data, optimizer, learner, group_name)

2025-04-19 22:57:57,540 - INFO - Generating grid plot for inner_optimizer=sgd, inner_learner=rtrl
2025-04-19 22:58:02,981 - INFO - Saved test loss grid plot for inner_optimizer=sgd, inner_learner=rtrl to /scratch/results/mlr_search-1_aa9c06652fb34624bebe972b1fe7292f/opt_sgd_learner_rtrl/test_loss_grid.png
2025-04-19 22:58:02,983 - INFO - Generating grid plot for inner_optimizer=sgd_normalized, inner_learner=rtrl
2025-04-19 22:58:07,615 - INFO - Saved test loss grid plot for inner_optimizer=sgd_normalized, inner_learner=rtrl to /scratch/results/mlr_search-1_aa9c06652fb34624bebe972b1fe7292f/opt_sgd_normalized_learner_rtrl/test_loss_grid.png
