In [1]:
# %% [markdown]
# # Single Plot of Best Fixed Learning Rate Test Loss
# This notebook processes machine learning run data from the `fixed_search-4` group to generate a single plot of test loss by `ts` (first element) for the best inner learning rate based on validation loss.

# %%
import os
import sys
sys.path.append(os.getcwd())
sys.path.append(os.getcwd() + '/../src')



In [2]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import logging
import re
from concurrent.futures import ThreadPoolExecutor
import multiprocessing
from recurrent.parameters import AllLogs
import jax
import jax.numpy as jnp
jax.config.update("jax_platform_name", "cpu")



In [3]:
# %%
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('fixed_lr_analysis.log')
    ]
)
logger = logging.getLogger(__name__)

# Set multiprocessing start method to 'spawn' to avoid JAX fork issues
multiprocessing.set_start_method("spawn", force=True)



In [4]:
# %%
# Configuration
download_dir = "/scratch/downloaded_artifacts"
results_dir = "/scratch/results"
fixed_group_name = "batched_bptt_fixed-1_bc4b77390a40422781b8eebf7f6c819c"
max_process_workers = 10
lr_failure_threshold = 0.0
y_limits = (0.45, 0.7)  # Fixed y-axis limits for the plot
config_keys = ["tr_avg_per"]  # Config keys for grouping

os.makedirs(results_dir, exist_ok=True)



In [5]:
# %%
def sanitize_folder_name(name):
    return re.sub(r'[^\w\-]', '_', str(name))

# %%
def process_run(run_result):
    run_id = run_result["run_id"]
    artifact_dir = run_result["artifact_dir"]
    config = run_result["config"]
    
    if run_result["status"] != "success" or not artifact_dir or not config:
        logger.warning(f"Skipping run {run_id}: download failed or no config")
        return {
            "run_id": run_id,
            "status": "skipped",
            "data": None
        }
    
    log_file = os.path.join(artifact_dir, "logs.pkl")
    if not os.path.exists(log_file):
        logger.error(f"Logs file not found for run {run_id}")
        return {
            "run_id": run_id,
            "status": "no_log_file",
            "data": None
        }
    
    try:
        with open(log_file, "rb") as f:
            logs = pickle.load(f)
        
        if not isinstance(logs, AllLogs):
            logger.error(f"Logs for run {run_id} is not an AllLogs instance")
            return {
                "run_id": run_id,
                "status": "invalid_logs",
                "data": None
            }
        
        is_success = not np.any(logs.inner_learning_rate <= lr_failure_threshold)
        final_test_loss = float(logs.test_loss[-1]) if logs.test_loss is not None else None
        final_validation_loss = float(logs.validation_loss[-1]) if logs.validation_loss is not None else None
        
        logger.info(f"Processed run {run_id}: success={is_success}, final_test_loss={final_test_loss}")
        
        return {
            "run_id": run_id,
            "status": "success",
            "data": {
                "config": config,
                "is_success": is_success,
                "final_test_loss": final_test_loss,
                "final_validation_loss": final_validation_loss,
                "ts": tuple(config.get("ts", ())),
                "inner_learning_rate": config.get("inner_learning_rate")
            }
        }
    except Exception as e:
        logger.error(f"Error processing logs for run {run_id}: {str(e)}")
        return {
            "run_id": run_id,
            "status": f"error: {str(e)}",
            "data": None
        }

def process_fixed_group_data(fixed_download_results, config_combination=None):
    """Process fixed_search-4 data to get test loss for best inner learning rate per ts."""
    with ThreadPoolExecutor(max_workers=max_process_workers) as executor:
        fixed_process_results = list(executor.map(process_run, fixed_download_results))
    
    # Include all runs with status=success and valid data
    fixed_runs_data = [
        result["data"] for result in fixed_process_results
        if result["status"] == "success" and result["data"]
    ]
    
    # Filter by config combination if provided
    if config_combination:
        fixed_runs_data = [
            run for run in fixed_runs_data
            if all(run["config"].get(key, "unknown") == value for key, value in config_combination)
        ]
    
    # Organize by ts
    ts_runs = defaultdict(list)
    for run in fixed_runs_data:
        ts = run["ts"]
        ts_runs[ts].append(run)
    
    # For each ts, find the inner learning rate with lowest average validation loss
    ts_stats = {}
    for ts in ts_runs:
        runs = ts_runs[ts]
        # Group by inner learning rate
        lr_val_losses = defaultdict(list)
        for run in runs:
            if run["is_success"] and run["final_validation_loss"] is not None:
                lr_val_losses[run["inner_learning_rate"]].append(run["final_validation_loss"])
        
        # Find the inner learning rate with lowest average validation loss
        best_lr = None
        lowest_avg_val_loss = float('inf')
        for lr, val_losses in lr_val_losses.items():
            avg_val_loss = np.mean(val_losses)
            if avg_val_loss < lowest_avg_val_loss:
                lowest_avg_val_loss = avg_val_loss
                best_lr = lr
        
        if best_lr is not None:
            # Collect test losses for runs with the best inner learning rate
            test_losses = [
                run["final_test_loss"] for run in runs
                if run["inner_learning_rate"] == best_lr and run["is_success"] and run["final_test_loss"] is not None
            ]
            if test_losses:
                mean_test_loss = np.mean(test_losses)
                std_test_loss = np.std(test_losses) / np.sqrt(len(test_losses)) if len(test_losses) > 1 else 0
                ts_stats[ts] = {
                    "mean_test_loss": mean_test_loss,
                    "std_error": std_test_loss,
                    "num_runs": len(test_losses),
                    "best_lr": best_lr
                }
    
    return ts_stats

def create_fixed_lr_plot(fixed_ts_stats, group_name, config_combination):
    """Create a single plot of test loss for the best inner learning rate per ts for a config combination."""
    if not fixed_ts_stats:
        logger.warning(f"No valid data to plot for config combination: {config_combination}")
        return
    
    # Prepare data for plotting
    ts_list = sorted(fixed_ts_stats.keys(), key=lambda x: x[0])
    means = [fixed_ts_stats[ts]["mean_test_loss"] for ts in ts_list]
    errors = [fixed_ts_stats[ts]["std_error"] for ts in ts_list]
    
    # Create the plot
    plt.figure(figsize=(10, 6))
    x = range(len(ts_list))
    plt.errorbar(x, means, yerr=errors, fmt='s--', color='red', capsize=5, capthick=2, elinewidth=2, label='Fixed Search')
    
    # Add best learning rates text
    lr_text = "Best Inner Learning Rates:\n"
    for ts in ts_list:
        best_lr = fixed_ts_stats[ts].get("best_lr")
        if best_lr is not None:
            lr_text += f"ts={int(ts[0])}: {best_lr:.1e}\n"
    plt.text(0.02, 0.98, lr_text, transform=plt.gca().transAxes, fontsize=10, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Customize the plot
    plt.xticks(x, [f"{int(ts[0])}" for ts in ts_list])
    plt.xlabel('Task Difficulty')
    plt.ylabel('Average Final Test Loss')
    config_title = ", ".join(f"{key}: {value}" for key, value in config_combination)
    plt.title(f'Average Final Test Loss for Best Inner Learning Rate\n(Group: {group_name}, {config_title})')
    plt.ylim(y_limits)
    plt.grid(True)
    plt.legend()
    
    # Save the plot
    subfolder = "_".join(
        f"{sanitize_folder_name(key)}_{sanitize_folder_name(value)}"
        for key, value in config_combination
    )
    output_dir = os.path.join(results_dir, group_name, subfolder)
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, 'best_fixed_lr_test_loss.png')
    plt.savefig(output_file, bbox_inches='tight')
    plt.close()
    logger.info(f"Saved best fixed learning rate test loss plot for {subfolder} to {output_file}")

In [6]:
# Load and process fixed_search-4 data
fixed_download_results_file = os.path.join(download_dir, f'download_results_{fixed_group_name}.pkl')
if not os.path.exists(fixed_download_results_file):
    logger.error(f"Fixed search results file not found at {fixed_download_results_file}")
    sys.exit(1)

with open(fixed_download_results_file, 'rb') as f:
    fixed_download_results = pickle.load(f)



In [7]:
# Get all runs to determine config combinations
with ThreadPoolExecutor(max_workers=max_process_workers) as executor:
    process_results = list(executor.map(process_run, fixed_download_results))



INFO:2025-04-26 12:23:06,057:jax._src.xla_bridge:925: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'


2025-04-26 12:23:06,057 - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'


INFO:2025-04-26 12:23:06,076:jax._src.xla_bridge:925: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


2025-04-26 12:23:06,076 - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-04-26 12:23:06,532 - INFO - Processed run ck3ahys1: success=True, final_test_loss=0.477783203125
2025-04-26 12:23:06,533 - INFO - Processed run ezlkygn5: success=True, final_test_loss=0.61962890625
2025-04-26 12:23:06,533 - INFO - Processed run k7jmee8w: success=True, final_test_loss=0.7802734375
2025-04-26 12:23:06,534 - INFO - Processed run v2y0eopc: success=True, final_test_loss=0.58203125
2025-04-26 12:23:06,534 - INFO - Processed run uf6brcme: success=True, final_test_loss=0.568359375
2025-04-26 12:23:06,534 - INFO - Processed run b23o5ie6: success=True, final_test_loss=0.474853515625
2025-04-26 12:23:06,534 - INFO - Processed run thyk73mc: success=True, final_test_loss=0.53369140625
2025-04-26 12:23:06,534 - INFO - Processed run rvdbrvt6: success=True, final_test_loss=0.55712890625
2025-04-26 12:23:06,53

In [8]:
all_runs_data = [
    result["data"] for result in process_results
    if result["status"] == "success" and result["data"]
]

In [9]:
# Determine unique config combinations
config_combinations = set()
for run in all_runs_data:
    combination = tuple((key, run["config"].get(key, "unknown")) for key in config_keys)
    config_combinations.add(combination)

# Generate a plot for each config combination
for config_combination in sorted(config_combinations):
    logger.info(f"Processing config combination: {config_combination}")
    fixed_ts_stats = process_fixed_group_data(fixed_download_results, config_combination)
    create_fixed_lr_plot(fixed_ts_stats, fixed_group_name, config_combination)

2025-04-26 12:23:27,549 - INFO - Processing config combination: (('tr_avg_per', 20),)
2025-04-26 12:23:27,630 - INFO - Processed run thyk73mc: success=True, final_test_loss=0.53369140625
2025-04-26 12:23:27,631 - INFO - Processed run k7jmee8w: success=True, final_test_loss=0.7802734375
2025-04-26 12:23:27,632 - INFO - Processed run dlag8l36: success=True, final_test_loss=0.537109375
2025-04-26 12:23:27,632 - INFO - Processed run 6qippb90: success=True, final_test_loss=0.477294921875
2025-04-26 12:23:27,632 - INFO - Processed run uf6brcme: success=True, final_test_loss=0.568359375
2025-04-26 12:23:27,633 - INFO - Processed run b23o5ie6: success=True, final_test_loss=0.474853515625
2025-04-26 12:23:27,633 - INFO - Processed run ezlkygn5: success=True, final_test_loss=0.61962890625
2025-04-26 12:23:27,634 - INFO - Processed run v2y0eopc: success=True, final_test_loss=0.58203125
2025-04-26 12:23:27,634 - INFO - Processed run ck3ahys1: success=True, final_test_loss=0.477783203125
2025-04-26