In [None]:
import wandb
import pandas as pd
import numpy as np

ENTITY = "confopt-team"
PROJECT = "ConfoptAutoML25"

In [None]:
def pull_run_data(run_id, entity="confopt-team", project="ConfoptAutoML25"):
    """
    Pull data from a W&B run given the entity, project, and run_id.
    
    Args:
        entity (str): The W&B entity (username or team).
        project (str): The name of the W&B project.
        run_id (str): The run identifier.
    
    Returns:
        dict: A dictionary containing the run summary, config, and history (as a pandas DataFrame).
    """
    api = wandb.Api()
    run_path = f"{entity}/{project}/{run_id}"
    try:
        run = api.run(run_path)
    except Exception as e:
        print(f"Error retrieving run: {e}")
        return None

    # Retrieve summary metrics (as a dict)
    summary = run.summary._json_dict

    # Retrieve configuration used for the run
    config = run.config

    # Retrieve run history (logged metrics) as a pandas DataFrame
    history = run.history(pandas=True)

    return {"summary": summary, "config": config, "history": history}

def get_run_ids_with_filter(filter_dict, entity="confopt-team", project="ConfoptAutoML25"):
    """
    Retrieve the run IDs for all runs in a project that match the given filter.

    Args:
        entity (str): The W&B entity (username or team).
        project (str): The name of the W&B project.
        filter_dict (dict): A dictionary defining the filter criteria.
            For example: {"config.learning_rate": 0.001}
    
    Returns:
        list: A list of run IDs that match the filter.
    """
    api = wandb.Api()
    runs = api.runs(f"{entity}/{project}", filters=filter_dict)
    run_ids = [run.id for run in runs]
    run_names = [run.name for run in runs]
    return run_ids, run_names

def print_wandb_links(run_ids, run_names, entity=ENTITY, project=PROJECT):
    for run_id, run_name in zip(run_ids, run_names):
        print(f"https://wandb.ai/{entity}/{project}/runs/{run_id}/overview ({run_name} {run_id})")

def get_best_run_data(run_ids, last_epoch):
    all_data = {}
    best_run = None
    best_genotype = None
    best_loss = np.inf
    losses = []
    seeds = []

    for run_id in run_ids:
        data = pull_run_data(run_id)
        all_data[run_id] = data

        seed = data["config"]["trainer"]["seed"]

        assert seed not in seeds, f"Duplicate seed {seed} found in run {run_id}"
        last_step = data["summary"]["_step"] 
        assert last_step == last_epoch, f"Last step {last_step} is not {last_epoch}"

        eval_loss = data["summary"]["eval/loss"]
        losses.append(eval_loss)
    
        if eval_loss < best_loss:
            best_loss = eval_loss
            best_run = run_id

        seeds.append(seed)

    print(losses, run_ids)
    
    return all_data[best_run]

def get_run_ids_and_names(sampler, subspace, opset, dataset, batch_size, other=None, tag="first-full-run"):
    filter_dict = {
        "config.benchmark": f"{subspace}-{opset}",
        "config.dataset": dataset,
        "config.sampler_type": sampler,
        "config.trainer.batch_size": batch_size,
        "config.is_debug_run": False,
        "config.oles.oles": False,
        "config.partial_connector": {"$in": [None]},
        "config.perturbator": {"$in": [None]},
        "config.sampler.arch_combine_fn": "default",
        "config.tag": tag,
        "state": "finished",
    }

    if other is not None:
        if other == "oles":
            filter_dict["config.oles.oles"] = True
        elif other == "pcdarts":
            filter_dict["config.partial_connector"] = {"$nin": [None]}
        elif other == "sdarts":
            filter_dict["config.perturbator"] = {"$nin": [None]}
        elif other == "fairdarts":
            filter_dict["config.sampler.arch_combine_fn"] = "sigmoid"

    print(filter_dict)
    
    return get_run_ids_with_filter(filter_dict)

batch_sizes = {
    "darts": {
        "deep": 64,
        "wide": 96,
        "single_cell": 96,
    },
    "drnas": {
        "deep": 64,
        "wide": 96,
        "single_cell": 96,
    },
    "gdas": {
        "deep": 320,
        "wide": 480,
        "single_cell": 480,
    },
}

In [None]:
dataset = "cifar10_supernet"

best_genotypes = {}
runs_without_results = []
incomplete_runs = []

samplers = ("darts", "drnas", "gdas")
subspaces = ("deep", "wide", "single_cell")
opsets = ("regular", "all_skip", "no_skip")
darts_others = (None, "oles", "pcdarts", "sdarts", "fairdarts")

for sampler in samplers:
    for subspace in subspaces:
        for opset in opsets:

            batch_size = batch_sizes[sampler][subspace]
            others = darts_others if sampler == "darts" else (None,)

            for other in others:
                run_ids, run_names = get_run_ids_and_names(
                    sampler,
                    subspace,
                    opset,
                    dataset,
                    batch_size=batch_size,
                    other=other,
                    tag="first-full-run"
                )

                if len(run_ids) == 0:
                    print("No results for: ", (sampler, subspace, opset, other))
                    runs_without_results.append((sampler, subspace, opset, other))
                    continue
                elif len(run_ids) < 3:
                    print("Fewer than 3 runs for: ", (sampler, subspace, opset, other))
                    incomplete_runs.append((sampler, subspace, opset, other))
                    continue

                print("\n", sampler, subspace, opset, "" if other is None else other)
                print_wandb_links(run_ids, run_names)
                last_epoch = 99 if sampler in ("darts", "drnas") else 299

                try:
                    best_run_data = get_best_run_data(run_ids, last_epoch)
                    best_genotype = best_run_data["summary"]["genotype"]
                    other_str = "-baseline" if other is None else f"-{other}"
                    best_genotypes[f"{sampler}{other_str}-{subspace}-{opset}"] = best_genotype
                except AssertionError as e:
                    print(f"Assertion error: {e}")

In [None]:
# !mkdir genotypes

for exp, genotype in best_genotypes.items():
    with open(f"genotypes/{exp}.txt", "w") as f:
        f.write(genotype)

In [None]:
runs_without_results

In [None]:
incomplete_runs

In [None]:
best_genotypes.keys()

In [None]:
import os

def list_files(folder_path):
    return sorted([f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))])

# Example usage
folder_path = "./genotypes"
files = list_files(folder_path)
# print(files)
# print(len(files))


pieces = [(*f.split(".")[0].split("-"),) for f in files]
pieces

hpsets = "0,1,2,3,4"
epochs = 300


print("#!/bin/bash\n")
for opt, other, subspace, opset in pieces:
    print(f"python launch_model_train.py --optimizer {opt} --subspace {subspace} --opset {opset} --dataset cifar10_model --hpsets {hpsets} --seed 0 --epochs {epochs} --other {other} --tag models-train --genotypes_folder exp/genotypes & sleep 5")