In [25]:
"""
Analyze non-core predictions from 9n-nc classifier and co.
"""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines, unused-import, unused-argument, too-many-branches

'\nAnalyze non-core predictions from 9n-nc classifier and co.\n'

In [26]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [27]:
from __future__ import annotations

from pathlib import Path
from typing import List

import numpy as np
import pandas as pd
from IPython.display import display
from sklearn.metrics import confusion_matrix as sk_cm

from epi_ml.core.confusion_matrix import ConfusionMatrixWriter
from epi_ml.utils.notebooks.paper.paper_utilities import (
    ASSAY,
    IHECColorMap,
    MetadataHandler,
    SplitResultsHandler,
)

# import plotly.graph_objects as go
# from plotly.subplots import make_subplots

In [28]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

if not base_fig_dir.exists():
    raise FileNotFoundError(f"Directory {base_fig_dir} does not exist.")

In [29]:
IHECColorMap = IHECColorMap(base_fig_dir)
assay_colors = IHECColorMap.assay_color_map

In [30]:
split_results_handler = SplitResultsHandler()
metadata_handler = MetadataHandler(paper_dir)

In [31]:
metadata_v2_df = metadata_handler.load_metadata_df("v2-encode", merge_assays=False)

## assay epiclass 9c-nc

### Create informative dataframe

In [32]:
results_dir = (
    base_data_dir / "training_results/dfreeze_v2/hg38_100kb_all_none_w_encode_noncore"
)
results_dir = results_dir / f"{ASSAY}_1l_3000n" / "9c-nc" / "10fold-oversampling"
if not results_dir.exists():
    raise FileNotFoundError(f"Directory {results_dir} does not exist.")

In [33]:
results = split_results_handler.read_split_results(results_dir)
concat_results = split_results_handler.concatenate_split_results(
    {"9c-nc": results}, concat_first_level=True
)["9c-nc"]

In [34]:
# concat_results.columns

In [35]:
def add_second_highest_prediction(df: pd.DataFrame, pred_cols: List[str]) -> pd.DataFrame:
    """Return the index of the second highest prediction in a row."""
    # Convert the relevant columns to a numpy array
    predictions = concat_results[pred_cols].values

    # Get the indices of the sorted values
    sorted_indices = np.argsort(predictions, axis=1)

    # The second highest will be at position -2 (second to last) in the sorted order
    second_highest_indices = sorted_indices[:, -2]

    # Map indices to column names
    second_highest_columns = np.array(pred_cols)[second_highest_indices]

    # Add the second highest prediction column to the DataFrame
    df["2nd_pred_class"] = second_highest_columns
    return df

In [36]:
pred_cols = [col for col in concat_results.columns if "class" not in col]
# pred_cols

In [37]:
concat_results = split_results_handler.add_max_pred(concat_results)
concat_results = add_second_highest_prediction(concat_results, pred_cols)

### Analyze non-core pred that are "mislabels"

In [38]:
min_pred = 0.6
pred_mask = concat_results["Max pred"] >= min_pred
nb_pred = pred_mask.sum()
print(
    f"Nb pred (pred score >= {min_pred:.02f}): {nb_pred/len(concat_results) * 100:.02f}% ({nb_pred}/{len(concat_results)})"
)

Nb pred (pred score >= 0.60): 99.17% (20682/20855)


In [39]:
# # save a confusion matrix
# df = concat_results[pred_mask]
# cm = sk_cm(df["True class"], df["Predicted class"])
# cm_writer = ConfusionMatrixWriter(labels=pred_cols, confusion_matrix=cm)

# name = f"full-10fold-validation_prediction-confusion-matrix-threshold-{min_pred:.02f}"
# cm_writer.to_all_formats(logdir=results_dir, name=name)

In [40]:
display(concat_results["True class"].value_counts())

h3k27ac     4689
non-core    3599
h3k4me1     2889
h3k4me3     2397
h3k36me3    2085
h3k27me3    2025
h3k9me3     1926
input        777
CTCF         468
Name: True class, dtype: int64

In [41]:
nc_pred_df = concat_results[
    (concat_results["Predicted class"] == "non-core")
    & (concat_results["Predicted class"] != concat_results["True class"])
]
print(nc_pred_df.shape)

(29, 13)


In [42]:
second_pred_ok_mask = nc_pred_df["True class"] == nc_pred_df["2nd_pred_class"]
print(
    f"Number of non-core predictions mislabels where the second highest prediction is correct: {second_pred_ok_mask.sum()}/{nc_pred_df.shape[0]}"
)

Number of non-core predictions mislabels where the second highest prediction is correct: 24/29


In [43]:
non_pred_cols = [col for col in concat_results.columns if col not in pred_cols]

In [44]:
# pylint: disable=consider-using-f-string
with pd.option_context("display.float_format", "{:.3f}".format):
    display(nc_pred_df[~second_pred_ok_mask][non_pred_cols])

Unnamed: 0,True class,Predicted class,Max pred,2nd_pred_class
ENCFF994BAB,CTCF,non-core,0.376,input
635d949488a00a63742edf2491d6ceee,h3k4me1,non-core,0.398,h3k27ac
e46b5e67b91880a3add2d7983560ceb9,h3k4me1,non-core,0.589,h3k27ac
ENCFF422IST,CTCF,non-core,0.995,h3k4me3
ENCFF300LQE,CTCF,non-core,0.603,input


#### Summary
- Nb pred (pred score >= 0.60): 99.17% (20682/20855)
- Number of non-core predictions mislabels where the second highest prediction is correct: 24/29
- Incorrect 2nd_pred + min_pred >= 0.6: 2/5 (both ctcf)

### Analyze non-core files predicted as other

In [45]:
nc_pred_df = concat_results[
    (concat_results["Predicted class"] != "non-core")
    & (concat_results["True class"] == "non-core")
]

In [46]:
display(nc_pred_df["Predicted class"].value_counts())

h3k4me3     18
h3k27me3    17
input       13
h3k4me1     11
h3k27ac     11
h3k9me3     10
CTCF         3
h3k36me3     2
Name: Predicted class, dtype: int64

In [47]:
nc_pred_df = nc_pred_df.merge(metadata_v2_df, left_index=True, right_on="md5sum")

In [48]:
for predicted_class, group in nc_pred_df.groupby("Predicted class"):
    print(f"\nPredicted class: {predicted_class}")
    print(group["Assay"].value_counts())


Predicted class: CTCF
RAD21     1
ETV6      1
TRIM22    1
Name: Assay, dtype: int64

Predicted class: h3k27ac
H3K9ac     5
H3F3A      3
H2BK5ac    3
Name: Assay, dtype: int64

Predicted class: h3k27me3
EZH2               8
SUZ12              4
EZH2phosphoT487    4
CBX8               1
Name: Assay, dtype: int64

Predicted class: h3k36me3
ZZZ3        1
H3K79me1    1
Name: Assay, dtype: int64

Predicted class: h3k4me1
H3K9me1     3
H3K79me2    2
WHSC1       1
H4K20me1    1
GATA1       1
PRDM6       1
SPI1        1
H3F3A       1
Name: Assay, dtype: int64

Predicted class: h3k4me3
H3K4me2    10
H3K9ac      5
KDM4A       1
ZNF777      1
H2AFZ       1
Name: Assay, dtype: int64

Predicted class: h3k9me3
ZNF274     4
CBX3       2
CREB1      1
H3K9me2    1
CBX5       1
SETDB1     1
Name: Assay, dtype: int64

Predicted class: input
H3K9me2            2
POLR2A             2
SOX15              1
FOXA1              1
EZH2phosphoT487    1
GRHL2              1
NR2C2              1
GATA1              