In [None]:
import os
import sys
sys.path.append(os.getcwd())
sys.path.append(os.getcwd() + '/../src')


In [None]:
import logging
import wandb
from concurrent.futures import ThreadPoolExecutor
import multiprocessing


In [None]:
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('download.log')
    ]
)
logger = logging.getLogger(__name__)

# Set multiprocessing start method to 'spawn'
multiprocessing.set_start_method("spawn", force=True)

In [None]:
# Configuration
download_dir = "/scratch/downloaded_artifacts"
entity = "wlp9800-new-york-university"
project_name = "oho_exps"
group_name = "mlr_search-1_aa9c06652fb34624bebe972b1fe7292f"
max_download_workers = 20

In [None]:
# Ensure download directory exists
os.makedirs(download_dir, exist_ok=True)

In [None]:
# Function to download a single run's artifact
def download_artifact(run_data):
    run_id = run_data["id"]
    config = run_data["config"]
    try:
        api = wandb.Api()
        artifact = api.artifact(f'{entity}/{project_name}/logs_{run_id}:v0')
        artifact_dir = os.path.join(download_dir, artifact.name)
        artifact.download(root=artifact_dir)
        logger.info(f"Downloaded {artifact.name} to {artifact_dir}")
        return {
            "run_id": run_id,
            "artifact_dir": artifact_dir,
            "config": config,
            "status": "success"
        }
    except Exception as e:
        logger.error(f"Error downloading artifact for run {run_id}: {str(e)}")
        return {
            "run_id": run_id,
            "artifact_dir": None,
            "config": config,
            "status": f"error: {str(e)}"
        }

In [None]:
api = wandb.Api()
runs = api.runs(
    path=f"{entity}/{project_name}",
    filters={"group": group_name}
)

# Prepare run data
run_data = [{"id": run.id, "config": run.config} for run in runs]
logger.info(f"Found {len(run_data)} runs to download")

# Download artifacts
with ThreadPoolExecutor(max_workers=max_download_workers) as executor:
    download_results = list(executor.map(download_artifact, run_data))

# Save download results for the processing script, including group name
import pickle
results_file = os.path.join(download_dir, f'download_results_{group_name}.pkl')
with open(results_file, 'wb') as f:
    pickle.dump(download_results, f)
logger.info(f"Saved download results to {results_file}")