In [None]:
"""Collect from comet-ml results important for the paper, e.g. epiatlas dfreeze '2.1' data."""
# pylint: disable=import-error, redefined-outer-name, too-many-branches, unnecessary-lambda

In [None]:
from __future__ import annotations

import json
import re
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Set

import pandas as pd
from comet_ml.api import API
import numpy as np

from epi_ml.utils.time import seconds_to_str

In [None]:
api = API()

In [None]:
def collect_run_information(api: API) -> Dict[str, Dict[str, Any]]:
    """Collect NN training metadata+metrics from comet.ml.

    Returns:
        - results (Dict[str, Dict[str, Any]])
    """
    results = {}
    for experiment in api.get("rabyj/epilap"):  # type: ignore
        hyperparams = experiment.get_parameters_summary()
        for hparam_dict in hyperparams:
            if hparam_dict["name"] == "run_arg_0":
                label_category = hparam_dict["valueMax"]
                break

        meta = experiment.get_metadata()
        exp_key = meta["experimentKey"]

        exp_dict = {
            "name": experiment.name,
            "label_category": label_category,
            "metadata": meta,
            "hyperparameters": hyperparams,
            "metrics": experiment.get_metrics_summary(),
            "other_metadata": experiment.get_others_summary(),
            "tags": experiment.get_tags(),
        }

        results[exp_key] = exp_dict

    return results

In [None]:
# all_results_dict = collect_run_information(api)

In [None]:
# json_dir = Path().home() / "projects" / "epiclass" / "output"
# json_path = json_dir / "all_results_cometml.json"
# with open(json_path, "w", encoding="utf8") as f:
#     json.dump(all_results_dict, f, indent=2)

In [None]:
def remove_unwanted_info(all_results_dict: Dict[str, Dict[str, Any]]):
    """Remove unwanted info from all_results_dict. Mostly system metrics info."""
    all_info_names = set()
    for experiment_dict in all_results_dict.values():
        for info_name, an_info_item in experiment_dict.items():
            if isinstance(an_info_item, list) and info_name != "tags":
                for info_dict in an_info_item:
                    all_info_names.add(info_dict["name"])

    undesired_info_names = set(name for name in all_info_names if "sys" in name)
    undesired_info_names.update(
        ["offline_experiment", "storage_size_bytes", "throttled_by_params"]
    )

    for exp_key, experiment_dict in list(all_results_dict.items()):
        for info_name, info_item in experiment_dict.items():
            if isinstance(info_item, list) and info_name != "tags":
                info_item = [
                    info_dict
                    for info_dict in info_item
                    if info_dict["name"] not in undesired_info_names
                ]
                all_results_dict[exp_key][info_name] = info_item

    return all_results_dict

In [None]:
# all_results_dict = remove_unwanted_info(all_results_dict)

In [None]:
# json_path = json_dir / "all_results_cometml_filtered.json"
# with open(json_path, "w", encoding="utf8") as f:
#     json.dump(all_results_dict, f, indent=2)

In [None]:
def select_time_slice(experiments: Dict[str, Any], date1: str, date2: str) -> List[str]:
    """Select experiments within a time slice.

    Args:
        experiments (): object loaded from custom cometml json file
        date1 (str): start date, ISO format
        date2 (str): end date, ISO format
    Returns:
        List[str]: List of experiment keys
    """
    valid_list = []
    for exp_key, experiment in experiments.items():  # type: ignore
        meta = experiment["metadata"]
        time = int(meta["startTimeMillis"]) / 1000
        time = datetime.utcfromtimestamp(time)
        is_within_date = (
            datetime.fromisoformat(date1) < time < datetime.fromisoformat(date2)
        )
        if is_within_date:
            valid_list.append(exp_key)

    return valid_list

In [None]:
experiments_path = (
    Path().home()
    / "projects"
    / "epiclass"
    / "output"
    / "all_results_cometml_filtered.json"
)
with open(experiments_path, "r", encoding="utf8") as f:
    experiments = json.load(f)

In [None]:
def find_info_keys(experiments: Dict[str, Any]) -> Dict[str, Set[str]]:
    """Find the label for many information categories."""
    names = defaultdict(set)
    for experiment in experiments.values():  # type: ignore
        for k, v in experiment.items():
            if isinstance(v, list) and k not in ["tags"]:
                for item in v:
                    try:
                        name = item["name"]
                        names[k].add(name)
                    except TypeError:
                        print(k, v)

    return names

In [None]:
def nested_json_to_flat_df(experiments: Dict[str, Any]) -> pd.DataFrame:
    """Convert nested json to flat DataFrame.

    Args:
        experiments (Dict[str, Any]): object loaded from custom cometml json file
    Returns:
        pd.DataFrame: flat DataFrame
    """
    flat_list = []
    for experiment in experiments.values():  # type: ignore
        flat_dict = {}
        for k, v in experiment.items():
            if k == "tags":
                if not isinstance(v, list):
                    v = [v]
                flat_dict[k] = v
            elif k == "metadata":
                for meta_key, meta_value in sorted(v.items()):
                    flat_dict[meta_key] = meta_value
            elif isinstance(v, list):
                for item in v:
                    try:
                        name = item["name"]
                        value = item["valueMax"]
                        flat_dict[name] = value
                    except TypeError:
                        print(k, v)

        flat_list.append(flat_dict)

    cols_to_cat = set()
    for flat_dict in flat_list:
        cols_to_cat.update(key for key in flat_dict.keys() if "mapping" in key)
    cols_to_cat = sorted(cols_to_cat, key=lambda x: int(x.split("/")[-1]))

    df = pd.DataFrame.from_records(flat_list, index="experimentKey")

    # Combine all mapping columns into one.
    df[cols_to_cat] = df[cols_to_cat].fillna("")
    try:
        df["mapping"] = df[cols_to_cat].apply(lambda x: ";".join(x), axis=1)
    except TypeError as e:
        df[cols_to_cat].apply(lambda x: print(x), axis=1)  # type: ignore
        raise e
    df["mapping"] = df["mapping"].str.replace("[;;]+", ";", regex=True)
    df["mapping"] = df["mapping"].str.replace(";$", "", regex=True)

    df = df.drop(columns=cols_to_cat)

    # Remove useless columns.
    to_drop = [
        "optimizationId",
        "userName",
        "projectId",
        "projectName",
        "workspaceName",
        "throttle",
        "throttleMessage",
        "throttlingReasons",
        "running",
        "error",
        "hasCrashed",
        "archived",
        "Category",
        "Data source",
        "Experience key",
    ]

    df = df.drop(columns=to_drop)

    # Combine oversampling status columns.
    oversampling_replace = {"TRUE": True, "FALSE": False}
    oversampling_cat = ["hparams/oversampling", "hparams/oversample"]
    df[oversampling_cat] = df[oversampling_cat].replace(oversampling_replace)
    df["hparams/oversampling"] = df["hparams/oversampling"].fillna(
        df["hparams/oversample"]
    )
    df = df.drop(columns="hparams/oversample")

    # Transform datetime.timedelta(seconds=X) into HH:MM:SS format
    time_cols = ["Training time", "Loop time"]
    for col in time_cols:
        for col in time_cols:
            df[col] = df[col].fillna("")
            for item in df[col].items():
                time_value = item[1]
                if isinstance(time_value, str):
                    re_search = re.search(
                        r"datetime\.timedelta\(seconds=(.*)\)", time_value
                    )
                    try:
                        seconds = int(re_search.group(1))  # type: ignore
                    except AttributeError:
                        continue
                    time_value_str = seconds_to_str(seconds)
                    df.at[item[0], col] = time_value_str

    # Combine train size cols
    df["train size"].fillna(df["Training size"], inplace=True)
    df = df.drop(columns="Training size")

    return df

In [None]:
df = nested_json_to_flat_df(experiments)
df = df.reindex(sorted(df.columns), axis=1)

In [None]:
# df.to_csv(experiments_path.parent / "all_results_cometml_filtered.csv")

### Verify value of train_size when oversampling false vs true, for a given nb_files, I think there were some mistakes maybe

In [None]:
# for col in df.columns:
#     print(col)

In [None]:
df["test size"].value_counts()

In [None]:
relevant_cols = ["Name", "Total nb of files", "train size", "validation size", "hparams/oversampling"]


# Is oversampling expected?
expected_no_oversampling = (df["hparams/oversampling"] == "false")
expected_oversampling = (df["hparams/oversampling"] == "true")
print(f"Expected no oversampling: {expected_no_oversampling.sum()}")
print(f"Expected oversampling: {expected_oversampling.sum()}")
# display(df[expected_no_oversampling][relevant_cols].value_counts())

# Is oversampling observed?
nb_real_files = df["Total nb of files"].fillna(0).astype(int)
max_files = (nb_real_files * 1.01).astype(float)

total_sample_size = df[["train size", "validation size"]].astype(float).sum(axis=1, skipna=True)

observed_oversampling = total_sample_size > max_files
observed_no_oversampling = total_sample_size == max_files
print(f"Observed oversampling: {observed_oversampling.sum()}")
print(f"Observed no oversampling: {observed_no_oversampling.sum()}")


should_be_true = expected_no_oversampling & observed_oversampling
should_be_false = expected_oversampling & observed_no_oversampling
print(f"Should be true: {should_be_true.sum()}")
print(f"Should be false: {should_be_false.sum()}")

# display(df[should_be_true][relevant_cols].value_counts())
# display(df[should_be_false][relevant_cols].value_counts())

df.loc[should_be_true, "hparams/oversampling"] = "true"

In [None]:
df.to_csv(experiments_path.parent / "all_results_cometml_filtered_oversampling-fixed.csv")

collect all general run parameters: fix oversampling when missing
i.e. create a new all_results_cometml_filtered_oversampling-fixed.csv
- get difference of content between different metadata groups (diff md5, create new meta obj with just diff, display labels the usual way)