In [None]:
import os
import sys
sys.path.append(os.getcwd())
sys.path.append(os.getcwd() + '/../src')

In [None]:
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import logging
from concurrent.futures import ThreadPoolExecutor
import multiprocessing
from recurrent.parameters import AllLogs
import jax
jax.config.update("jax_platform_name", "cpu")

In [None]:
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('plot_seeds.log')
    ]
)
logger = logging.getLogger(__name__)

# Set multiprocessing start method to 'spawn' to avoid JAX fork issues
multiprocessing.set_start_method("spawn", force=True)


In [None]:
# Configuration
download_dir = "/scratch/downloaded_artifacts"
results_dir = "/scratch/results"
group_name = "fixed_weight_seed-1_cf21e420f71a4529bed03b4c48fda84c"
max_process_workers = 10

# Ensure results directory exists
os.makedirs(results_dir, exist_ok=True)



In [None]:
# Function to process a single run's logs
def process_run(run_result):
    run_id = run_result["run_id"]
    artifact_dir = run_result["artifact_dir"]
    config = run_result["config"]
    
    if run_result["status"] != "success" or not artifact_dir or not config:
        logger.warning(f"Skipping run {run_id}: download failed or no config")
        return {
            "run_id": run_id,
            "status": "skipped",
            "data": None
        }
    
    log_file = os.path.join(artifact_dir, "logs.pkl")
    if not os.path.exists(log_file):
        logger.error(f"Logs file not found for run {run_id}")
        return {
            "run_id": run_id,
            "status": "no_log_file",
            "data": None
        }
    
    try:
        with open(log_file, "rb") as f:
            logs = pickle.load(f)
        
        if not isinstance(logs, AllLogs):
            logger.error(f"Logs for run {run_id} is not an AllLogs instance")
            return {
                "run_id": run_id,
                "status": "invalid_logs",
                "data": None
            }
        
        # Assuming logs contains inner_learning_rate and validation_loss as arrays
        inner_lr = np.array(logs.inner_learning_rate) if hasattr(logs, 'inner_learning_rate') else np.array([])
        val_loss = np.array(logs.validation_loss) if hasattr(logs, 'validation_loss') else np.array([])
        
        if len(inner_lr) == 0 or len(val_loss) == 0:
            logger.warning(f"No inner_learning_rate or validation_loss data for run {run_id}")
            return {
                "run_id": run_id,
                "status": "no_data",
                "data": None
            }
        
        logger.info(f"Processed run {run_id}")
        
        return {
            "run_id": run_id,
            "status": "success",
            "data": {
                "config": config,
                "inner_learning_rate": inner_lr,
                "validation_loss": val_loss
            }
        }
    except Exception as e:
        logger.error(f"Error processing logs for run {run_id}: {str(e)}")
        return {
            "run_id": run_id,
            "status": f"error: {str(e)}",
            "data": None
        }


In [None]:
# Load download results
download_results_file = os.path.join(download_dir, f'download_results_{group_name}.pkl')
if not os.path.exists(download_results_file):
    logger.error(f"Download results file not found at {download_results_file}")
    raise FileNotFoundError(f"Download results file not found at {download_results_file}")

with open(download_results_file, 'rb') as f:
    download_results = pickle.load(f)


In [None]:
# Process downloaded artifacts
with ThreadPoolExecutor(max_workers=max_process_workers) as executor:
    process_results = list(executor.map(process_run, download_results))



In [None]:
# Collect successful runs
all_runs_data = []
for result in process_results:
    if result["status"] == "success" and result["data"]:
        all_runs_data.append(result["data"])



In [None]:
# Function to create seed-based plots
def create_seed_plots(seed_type, seed_key):
    # Group runs by seed
    seed_groups = defaultdict(list)
    for run in all_runs_data:
        seed_value = run["config"].get(seed_key, "unknown")
        seed_groups[seed_value].append(run)
    
    # Sort seeds for consistent plotting
    seeds = sorted([s for s in seed_groups.keys() if s != "unknown"])
    if not seeds:
        logger.warning(f"No valid {seed_type} seeds found")
        return
    
    # Create figure with n x 1 grid, each cell containing two subplots
    n_seeds = len(seeds)
    fig, axes = plt.subplots(n_seeds, 2, figsize=(12, 4 * n_seeds), sharex=True)
    if n_seeds == 1:
        axes = [axes]  # Ensure axes is iterable for a single seed
    
    for idx, seed in enumerate(seeds):
        runs = seed_groups[seed]
        ax_lr = axes[idx][0]
        ax_loss = axes[idx][1]
        
        # Plot all runs for this seed
        for run in runs:
            epochs = np.arange(len(run["inner_learning_rate"]))
            ax_lr.plot(epochs, run["inner_learning_rate"], alpha=0.5, label=run["config"].get("run_id", "unknown"))
            ax_loss.plot(epochs, run["validation_loss"], alpha=0.5, label=run["config"].get("run_id", "unknown"))
        
        # Configure inner_learning_rate subplot
        ax_lr.set_ylabel(f'Inner Learning Rate (Seed {seed})')
        ax_lr.set_xlabel('Epoch')
        ax_lr.grid(True)
        if idx == 0:
            ax_lr.set_title('Inner Learning Rate vs Epoch')
        if len(runs) <= 10:  # Add legend only if not too many lines
            ax_lr.legend()
        
        # Configure validation_loss subplot
        ax_loss.set_ylabel(f'Validation Loss (Seed {seed})')
        ax_loss.set_xlabel('Epoch')
        ax_loss.grid(True)
        if idx == 0:
            ax_loss.set_title('Validation Loss vs Epoch')
        if len(runs) <= 10:
            ax_loss.legend()
    
    # Adjust layout and save
    plt.suptitle(f'{seed_type} Seed Plots for {group_name}', fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    
    # Save to file
    output_dir = os.path.join(results_dir, group_name)
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, f'{seed_type.lower()}_seed_plots.png')
    plt.savefig(output_file, bbox_inches='tight', dpi=300)
    plt.close()
    logger.info(f"Saved {seed_type} seed plots to {output_file}")



In [None]:
# Generate plots for parameter_seed
logger.info(f"Generating plots for parameter_seed in {group_name}")
create_seed_plots("Parameter", "parameter_seed")

# Generate plots for data_seed
logger.info(f"Generating plots for data_seed in {group_name}")
create_seed_plots("Data", "data_seed")