In [None]:
"""Collect from comet-ml results using a list of experimental keys."""

# pylint: disable=import-error, redefined-outer-name, too-many-branches, unnecessary-lambda

## SETUP

In [None]:
from __future__ import annotations

from pathlib import Path
from typing import Any, Dict

import pandas as pd
from comet_ml.api import API, APIExperiment
from tqdm import tqdm

In [97]:
api = API()

## Initial download

In [None]:
paper_dir = Path.home() / "Projects" / "epiclass" / "output" / "paper"
exp_list_path = paper_dir / "tables" / "training_experiment_keys.csv"
exp_df = pd.read_csv(exp_list_path, low_memory=False)
print(exp_df.shape)
print(exp_df.columns)

In [None]:
exp_keys = set(exp_df["exp_key"].tolist())
print(len(exp_keys))

In [None]:
# get_metadata(), direct
needed_meta = ["startTimeMillis"]

# get_git_metadata(), direct
needed_git_meta = ["parent"]

# get_others_summary(), elem
needed_others = [
    "category",
    "Code version / commit",
    "Name",
    "SLURM_JOB_ID",
    "Total nb of files",
    "Training size",
    "Validation size",
    "train size",
    "validation size",
]

# get_parameters_summary(), elem
needed_parameters = [
    "output_size",
]

# get_metrics_summary, elem
needed_metrics = ["tra_Accuracy", "tra_F1Score", "val_Accuracy", "val_F1Score"]

In [None]:
def collect_exp_info(exp_key: str) -> Dict[str, Any]:
    """Collect experiment info from comet-ml using the API.

    Args:
        exp_key (str): Experiment key.

    Returns:
        Dict[str, Any]: Collected info.
    """
    exp_data: APIExperiment = api.get_experiment_by_key(exp_key)  # type: ignore
    collected_info: Dict[str, Any] = {}
    collected_info["Experiment key"] = exp_key

    # - Direct attributes -

    # Metadata
    metadata = exp_data.get_metadata()
    for attr in needed_meta:
        collected_info[attr] = metadata.get(attr, None)

    # Git metadata
    collected_info["Exact commit"] = exp_data.get_git_metadata()["parent"]

    # - Nested attributes -

    # Others summary
    others_summary = exp_data.get_others_summary()
    for elem in others_summary:
        name = elem["name"]
        if name in needed_others:
            collected_info[name] = elem.get("valueCurrent", None)

    # Parameters summary
    parameters_summary = exp_data.get_parameters_summary()
    for elem in parameters_summary:
        name = elem["name"]
        if name in needed_parameters:
            collected_info[name] = elem.get("valueCurrent", None)

    # Metrics summary
    metrics_summary = exp_data.get_metrics_summary()
    for elem in metrics_summary:
        name = elem["name"]
        if name in needed_metrics:
            collected_info[name] = elem.get("valueCurrent", None)

    return collected_info

In [None]:
all_collected_info = {}
for exp_key in tqdm(exp_keys, desc="Experiments", unit="exp"):
    all_collected_info[exp_key] = collect_exp_info(exp_key)

## Formatting

In [None]:
collected_df = pd.DataFrame.from_dict(all_collected_info, orient="index")
print(collected_df.columns)

In [None]:
# 'Training size' exists when 'train size' doesn't (and vice-versa)
collected_df.loc[collected_df["train size"].isnull(), "train size"] = collected_df[
    "Training size"
]
collected_df.drop(columns=["Training size"], inplace=True)

In [None]:
# Get date

# Server start time is in milliseconds, convert to seconds
unix_time = round(collected_df["startTimeMillis"] / 1000)

# Convert to date
datetime = pd.to_datetime(unix_time, unit="s")
collected_df["Date (YYYY-MM-DD)"] = datetime.dt.date

collected_df.rename(
    columns={"startTimeMillis": "Server start time (Unix, ms)"}, inplace=True
)

In [None]:
column_order = [
    "Name",
    "category",
    "output_size",
    "Total nb of files",
    "train size",
    "validation size",
    "tra_Accuracy",
    "tra_F1Score",
    "val_Accuracy",
    "val_F1Score",
    "Experiment key",
    "SLURM_JOB_ID",
    "Code version / commit",
    "Exact commit",
    "Server start time (Unix, ms)",
    "Date (YYYY-MM-DD)",
]

collected_df = collected_df[column_order]

collected_df.to_csv(
    paper_dir / "tables" / "collected_experiments_info_cometml.csv", index=False
)