In [None]:
"""Workbook to analyse Chip-Atlas predictions, destined for the paper.
"""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines, unused-import, unused-argument, too-many-branches, pointless-statement

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import annotations

import copy
import itertools
from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, List, Set, Tuple

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import display
from plotly.subplots import make_subplots

from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    ASSAY_MERGE_DICT,
    ASSAY_ORDER,
    CELL_TYPE,
    LIFE_STAGE,
    SEX,
    IHECColorMap,
    merge_similar_assays,
)

In [None]:
ASSAY_ORDER

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

if not base_fig_dir.exists():
    raise FileNotFoundError(f"Directory {base_fig_dir} does not exist.")

In [None]:
ca_pred_path = (
    base_data_dir / "training_results" / "C-A" / "CA_metadata_4DB+all_pred_rabyj.tsv"
)
ca_pred_df = pd.read_csv(ca_pred_path, sep="\t", low_memory=False)

| Assay | Exp Key                               | Nb Files | Training Size | Oversampling | Expected Nb Files                      |
|-------|---------------------------------------|----------|---------------|--------------|---------------------------------------|
| 13c   | dd3710b73c0341af85a17ce1998362d0      | 24989    | 116550        | true         | 24989                                 |
| 11c   | 0f8e5eb996114868a17057bebe64f87c      | 20922    | 46128         | true         | 20922                                 |
| 7c    | 69488630801b4a05a53b5d9e572f0aaa      | 16788    | 34413         | true         | 16788 (contre-vérifié)                |

*using hg38_2023-epiatlas-dfreeze_v2.1_w_encode_noncore_2


In [None]:
ca_pred_df.columns

In [None]:
# Create input filename list (for umap)
# input_ids = ca_pred_df[ca_pred_df["manual_target_consensus"] == "input"]["Experimental-id"]
# input_ids["filename"] = input_ids + "_100kb_all_none.hdf5"
# input_ids["filename"].to_csv(base_data_dir / "training_results" / "C-A" / "C-A_100kb_all_none_input.list", sep="\t", index=False, header=False)

In [None]:
core_assays = ASSAY_ORDER[0:7]

FIX ca_pred_df["Max_pred_assay_11c"] commas!!!!

In [None]:
print(ca_pred_df.shape)
min_pred = 0.8
ca_pred_df = ca_pred_df[
    (ca_pred_df["Max_pred_assay_13c"].astype(float) > min_pred)
    | (ca_pred_df["Max_pred_assay_7c"].astype(float) > min_pred)
]
print(ca_pred_df.shape)

In [None]:
for assay in core_assays:
    print(f"{assay}")
    assay_df = ca_pred_df[ca_pred_df["manual_target_consensus"] == assay]
    for col in [
        "Predicted_class_assay_7c",
        "Predicted_class_assay_11c",
        "Predicted_class_assay_13c",
    ]:
        display(assay_df[col].value_counts() / len(assay_df) * 100)
        if col == "Predicted_class_assay_13c":
            wrong_pred = assay_df[assay_df[col] != assay]
            display(
                wrong_pred["2nd_pred_class_assay_13c"].value_counts()
                / len(wrong_pred)
                * 100
            )
    print("\n")

In [None]:
wgbs_dist = ca_pred_df[ca_pred_df["Predicted_class_assay_13c"] == "wgbs-standard"][
    "manual_target_consensus"
]
display(wgbs_dist.value_counts())
display(wgbs_dist.value_counts() / len(wgbs_dist) * 100)

In [None]:
print("What is the actual target when wgbs-standard is predicted?")
for col in ["Predicted_class_assay_11c", "Predicted_class_assay_13c"]:
    print(col)
    wgbs_dist = ca_pred_df[ca_pred_df[col] == "wgbs-standard"]["manual_target_consensus"]
    display(wgbs_dist.value_counts())
    display(wgbs_dist.value_counts() / len(wgbs_dist) * 100)

In [None]:
print("What is the actual target when non-core is predicted?")
col = "Predicted_class_assay_13c"
wgbs_dist = ca_pred_df[ca_pred_df[col] == "non-core"]["manual_target_consensus"]
display(wgbs_dist.value_counts())
display(wgbs_dist.value_counts() / len(wgbs_dist) * 100)