In [None]:
"""Workbook to create figures destined for the paper."""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines, unused-import

In [None]:
from __future__ import annotations

import itertools
from collections import defaultdict
from pathlib import Path
from typing import Dict

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import display
from plotly.subplots import make_subplots
from sklearn.metrics import confusion_matrix as sk_cm

from epi_ml.core.confusion_matrix import ConfusionMatrixWriter
from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    CELL_TYPE,
    IHECColorMap,
    MetadataHandler,
    SplitResultsHandler,
    merge_similar_assays,
)

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

if not base_fig_dir.exists():
    raise FileNotFoundError(f"Directory {base_fig_dir} does not exist.")

In [None]:
IHECColorMap = IHECColorMap(base_fig_dir)
assay_colors = IHECColorMap.assay_color_map
cell_type_colors = IHECColorMap.cell_type_color_map

In [None]:
split_results_handler = SplitResultsHandler()

## Fig 2 - EpiClass results on EpiAtlas other metadata

For all sub-figures 2+3 use v1.1 of sample metadata (called v2.1 internally)

A) Histogram of performance (accuracy and F1 scores) for each category (using metadata v1)  
B) Violin plot of average z-score on chrY per sex, black dots for pred same class and red for pred different class.  
- Do the split male female violin per assay (only FC, merge 2xwgbs and 2xrna, no rna unique_raw). 
- Use scatter for points on each side, agree same color as violin, disagree other.
- Point labels: uuid, epirr  

C) ---  
D) ---  
E) SHAP cell-types GO  


### Fig 2.A

Check if all training runs were done with oversampling on.

In [None]:
v1_results_dir = base_data_dir / "dfreeze_v1"
if not v1_results_dir.exists():
    raise FileNotFoundError(f"Directory {v1_results_dir} does not exist.")

In [None]:
def check_for_oversampling(base_data_dir: Path):
    """Check for oversampling status in the results."""
    # Identify experiments
    exp_key_line = "The current experiment key is"
    exp_keys_dict = defaultdict(list)
    for category in v1_results_dir.iterdir():
        for stdout_file in category.glob("*/output_job*.o"):
            with open(stdout_file, "r", encoding="utf8") as f:
                lines = [l.rstrip() for l in f if exp_key_line in l]
            exp_keys = [l.split(exp_key_line)[1].strip() for l in lines]
            exp_keys_dict[category.name].extend(exp_keys)

    # Get all hparam values
    gen_run_metadata = (
        base_data_dir / "all_results_cometml_filtered_oversampling-fixed.csv"
    )
    run_metadata = pd.read_csv(gen_run_metadata, header=0)

    # Check oversampling values
    all_exp_keys = set()
    for exp_keys in exp_keys_dict.values():
        all_exp_keys.update(exp_keys)

    df = run_metadata[run_metadata["experimentKey"].isin(all_exp_keys)]
    display(df["hparams/oversampling"].value_counts())

In [None]:
check_for_oversampling(base_data_dir)