In [58]:
import os
import sys
sys.path.append(os.getcwd())
sys.path.append(os.getcwd() + '/../src')

In [59]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import os
import wandb
from concurrent.futures import ThreadPoolExecutor
import multiprocessing
import logging
import sys
from recurrent.parameters import AllLogs

import jax 
import jax.numpy as jnp
jax.config.update("jax_platform_name", "cpu")


In [60]:
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('processing.log')
    ]
)
logger = logging.getLogger(__name__)

# Set multiprocessing start method to 'spawn' to avoid JAX fork issues
multiprocessing.set_start_method("spawn", force=True)

In [61]:
# Configuration
download_dir = "downloaded_artifacts"
results_dir = "results"
entity = "wlp9800-new-york-university"
project_name = "oho_exps"
group_name = "time_test_oho_d1efc05e0903463ca4e95a52714389a0"
max_download_workers = 20
max_process_workers = 10  # Reduced for ThreadPoolExecutor stability

# Ensure download directory exists
os.makedirs(download_dir, exist_ok=True)

In [62]:
# Function to download a single run's artifact
def download_artifact(run_data):
    run_id = run_data["id"]
    config = run_data["config"]
    try:
        api = wandb.Api()
        artifact = api.artifact(f'{entity}/{project_name}/logs_{run_id}:v0')
        artifact_dir = os.path.join(download_dir, artifact.name)
        artifact.download(root=artifact_dir)
        logger.info(f"Downloaded {artifact.name} to {artifact_dir}")
        return {
            "run_id": run_id,
            "artifact_dir": artifact_dir,
            "config": config,
            "status": "success"
        }
    except Exception as e:
        logger.error(f"Error downloading artifact for run {run_id}: {str(e)}")
        return {
            "run_id": run_id,
            "artifact_dir": None,
            "config": config,
            "status": f"error: {str(e)}"
        }

# Function to process a single run's logs
def process_run(run_result):
    run_id = run_result["run_id"]
    artifact_dir = run_result["artifact_dir"]
    config = run_result["config"]
    
    if run_result["status"] != "success" or not artifact_dir:
        logger.warning(f"Skipping run {run_id}: download failed")
        return {
            "run_id": run_id,
            "status": "skipped",
            "data": None
        }
    
    log_file = os.path.join(artifact_dir, "logs.pkl")
    if not os.path.exists(log_file):
        logger.error(f"Logs file not found for run {run_id}")
        return {
            "run_id": run_id,
            "status": "no_log_file",
            "data": None
        }
    
    try:
        with open(log_file, "rb") as f:
            logs = pickle.load(f)
        
        if not isinstance(logs, AllLogs):
            logger.error(f"Logs for run {run_id} is not an AllLogs instance")
            return {
                "run_id": run_id,
                "status": "invalid_logs",
                "data": None
            }
        
        ts = tuple(config.get("ts", ()))
        outer_lr = config.get("outer_learning_rate")
        inner_lr = config.get("inner_learning_rate")
        
        if outer_lr is None or inner_lr is None or not ts:
            logger.error(f"Missing config values for run {run_id}")
            return {
                "run_id": run_id,
                "status": "missing_config",
                "data": None
            }
        
        is_success = not jnp.any(logs.hyperparameters == 1e-4)
        logger.info(f"Processed run {run_id}: success={is_success}")
        
        return {
            "run_id": run_id,
            "status": "success",
            "data": {
                "ts": ts,
                "outer_lr": outer_lr,
                "inner_lr": inner_lr,
                "is_success": is_success
            }
        }
    except Exception as e:
        logger.error(f"Error processing logs for run {run_id}: {str(e)}")
        return {
            "run_id": run_id,
            "status": f"error: {str(e)}",
            "data": None
        }

In [63]:
def create_heatmap(ts, runs_data, group_idx, group_name):
    outer_lrs = sorted(set(run["outer_lr"] for run in runs_data))
    inner_lrs = sorted(set(run["inner_lr"] for run in runs_data))
    
    if not outer_lrs or not inner_lrs:
        logger.warning(f"No valid learning rates for ts group {ts} in {group_name}")
        return
    
    grid = np.zeros((len(inner_lrs), len(outer_lrs)))
    success_counts = defaultdict(int)
    total_counts = defaultdict(int)
    
    for run in runs_data:
        outer_idx = outer_lrs.index(run["outer_lr"])
        inner_idx = inner_lrs.index(run["inner_lr"])
        key = (inner_idx, outer_idx)
        total_counts[key] += 1
        if run["is_success"]:
            success_counts[key] += 1
    
    for (inner_idx, outer_idx), total in total_counts.items():
        successes = success_counts.get((inner_idx, outer_idx), 0)
        grid[inner_idx, outer_idx] = successes / total if total > 0 else 0
    
    plt.figure(figsize=(12, 10))
    im = plt.imshow(grid, origin='lower', cmap='viridis', interpolation='nearest')
    plt.colorbar(im, label='Fraction of Successful Runs')
    
    for i in range(len(inner_lrs)):
        for j in range(len(outer_lrs)):
            fraction = grid[i, j]
            text = f"{fraction:.2f}" if total_counts.get((i, j), 0) > 0 else "N/A"
            plt.text(j, i, text, ha='center', va='center', color='white' if fraction < 0.5 else 'black')
    
    plt.xticks(np.arange(len(outer_lrs)), [f"{lr:.1e}" for lr in outer_lrs], rotation=45)
    plt.yticks(np.arange(len(inner_lrs)), [f"{lr:.1e}" for lr in inner_lrs])
    plt.xlabel('Outer Learning Rate')
    plt.ylabel('Inner Learning Rate')
    plt.title(f'Success Fraction Heatmap for ts={ts} in {group_name}')
    
    # Save to group-specific results directory
    group_results_dir = os.path.join(results_dir, group_name)
    os.makedirs(group_results_dir, exist_ok=True)
    output_file = os.path.join(group_results_dir, f'heatmap_ts_group_{group_idx}.png')
    plt.savefig(output_file, bbox_inches='tight')
    plt.close()
    logger.info(f"Saved heatmap for ts={ts} in {group_name} to {output_file}")

In [64]:
api = wandb.Api()
runs = api.runs(
    path=f"{entity}/{project_name}",
    filters={"group": group_name}
)

# Prepare run data
run_data = [{"id": run.id, "config": run.config} for run in runs]
logger.info(f"Found {len(run_data)} runs to process")

2025-04-18 23:07:05,448 - INFO - Found 10 runs to process


In [65]:
with ThreadPoolExecutor(max_workers=max_download_workers) as executor:
    download_results = list(executor.map(download_artifact, run_data))

[34m[1mwandb[0m:   1 of 1 files downloaded.  


2025-04-18 23:07:06,036 - INFO - Downloaded logs_gicdoysn:v0 to downloaded_artifacts/logs_gicdoysn:v0


[34m[1mwandb[0m:   1 of 1 files downloaded.  


2025-04-18 23:07:06,041 - INFO - Downloaded logs_y6klzrn2:v0 to downloaded_artifacts/logs_y6klzrn2:v0


[34m[1mwandb[0m:   1 of 1 files downloaded.  


2025-04-18 23:07:06,050 - INFO - Downloaded logs_mgwak6ta:v0 to downloaded_artifacts/logs_mgwak6ta:v0


[34m[1mwandb[0m:   1 of 1 files downloaded.  


2025-04-18 23:07:06,057 - INFO - Downloaded logs_83hjf1k8:v0 to downloaded_artifacts/logs_83hjf1k8:v0


[34m[1mwandb[0m:   1 of 1 files downloaded.  


2025-04-18 23:07:06,064 - INFO - Downloaded logs_wgeu3744:v0 to downloaded_artifacts/logs_wgeu3744:v0


[34m[1mwandb[0m:   1 of 1 files downloaded.  


2025-04-18 23:07:06,065 - INFO - Downloaded logs_drme1azv:v0 to downloaded_artifacts/logs_drme1azv:v0


[34m[1mwandb[0m:   1 of 1 files downloaded.  


2025-04-18 23:07:06,078 - INFO - Downloaded logs_lky8uxxr:v0 to downloaded_artifacts/logs_lky8uxxr:v0


[34m[1mwandb[0m:   1 of 1 files downloaded.  


2025-04-18 23:07:06,082 - INFO - Downloaded logs_270lhq5v:v0 to downloaded_artifacts/logs_270lhq5v:v0


[34m[1mwandb[0m:   1 of 1 files downloaded.  


2025-04-18 23:07:06,083 - INFO - Downloaded logs_9f568pgc:v0 to downloaded_artifacts/logs_9f568pgc:v0


[34m[1mwandb[0m:   1 of 1 files downloaded.  


2025-04-18 23:07:06,098 - INFO - Downloaded logs_vuae6fly:v0 to downloaded_artifacts/logs_vuae6fly:v0


In [66]:
with ThreadPoolExecutor(max_workers=max_process_workers) as executor:
    process_results = list(executor.map(process_run, download_results))

2025-04-18 23:07:06,297 - INFO - Processed run 9f568pgc: success=False
2025-04-18 23:07:06,301 - INFO - Processed run lky8uxxr: success=False
2025-04-18 23:07:06,302 - INFO - Processed run vuae6fly: success=False
2025-04-18 23:07:06,303 - INFO - Processed run wgeu3744: success=True
2025-04-18 23:07:06,303 - INFO - Processed run 270lhq5v: success=False
2025-04-18 23:07:06,303 - INFO - Processed run mgwak6ta: success=True
2025-04-18 23:07:06,306 - INFO - Processed run drme1azv: success=True
2025-04-18 23:07:06,306 - INFO - Processed run y6klzrn2: success=True
2025-04-18 23:07:06,308 - INFO - Processed run gicdoysn: success=True
2025-04-18 23:07:06,310 - INFO - Processed run 83hjf1k8: success=True


In [67]:
# Group successful runs by ts
run_groups = defaultdict(list)
for result in process_results:
    if result["status"] == "success" and result["data"]:
        run_groups[result["data"]["ts"]].append(result["data"])

In [68]:
for idx, (ts, runs_data) in enumerate(run_groups.items()):
    logger.info(f"Generating heatmap for ts group {ts} in {group_name}")
    create_heatmap(ts, runs_data, idx, group_name)

2025-04-18 23:07:06,328 - INFO - Generating heatmap for ts group (14, 16) in time_test_oho_d1efc05e0903463ca4e95a52714389a0
2025-04-18 23:07:06,564 - INFO - Saved heatmap for ts=(14, 16) in time_test_oho_d1efc05e0903463ca4e95a52714389a0 to results/time_test_oho_d1efc05e0903463ca4e95a52714389a0/heatmap_ts_group_0.png
