In [68]:
"""Compute chrY z-scores with regards to different distributions, and plot the results."""
# pylint: disable=line-too-long, redefined-outer-name, import-error, pointless-statement, use-dict-literal, expression-not-assigned, unused-import, too-many-lines, unreachable
from __future__ import annotations

from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import display
from plotly.subplots import make_subplots

from epi_ml.core.metadata import Metadata
from epi_ml.utils.general_utility import get_valid_filename

## Setup

In [None]:
ASSAY = "assay_epiclass"
SEX = "harmonized_donor_sex"
TRACK = "track_type"

In [None]:
epilap_project = Path.home() / "Projects/epilap"

metadata = (
    epilap_project
    / "input"
    / "metadata"
    / "dfreeze-v2"
    / "hg38_2023-epiatlas-dfreeze_v2.1_w_encode_noncore_2.json"
)

logs = epilap_project / "output/logs/epiatlas-dfreeze-v2.1"
chrY_coverage_file = logs / "chrY_coverage_results" / "chrXY_coverage_all.csv"

In [None]:
meta = Metadata(metadata)

In [None]:
df_chrY = pd.read_csv(chrY_coverage_file, index_col=0, header=0)

In [None]:
print(df_chrY.columns)

## Merge metadata and chrY coverage info

In [None]:
meta_df = pd.DataFrame.from_records(list(meta.datasets)).set_index("md5sum")

In [None]:
meta_df[ASSAY].value_counts()

In [None]:
# remove non-core tracks
meta_df = meta_df[~meta_df[ASSAY].str.contains("non-core|CTCF")]

In [None]:
merged_df = df_chrY.merge(meta_df, left_index=True, right_index=True)

In [None]:
print(meta_df.shape, df_chrY.shape, merged_df.shape)

In [None]:
# Check the merge
assert (merged_df.shape[0] == meta_df.shape[0]) and (
    df_chrY.shape[1] + meta_df.shape[1] == merged_df.shape[1]
), "not right shape"

In [None]:
# for item in set(meta_df.index.values) - set(df_chrY.index.values):
#     print(item)

In [None]:
merged_df

## Compute chrY zscore coverage by assay and sex

In [None]:
# groupby_columns = [ASSAY, SEX, TRACK]
# groupby_columns = [ASSAY, SEX]
groupby_columns = [ASSAY, TRACK]

chrY_dists = merged_df.groupby(groupby_columns).agg({"chrY": ["mean", "std", "count"]})
display(chrY_dists)
# display(chrY_dists.index)

In [None]:
partial_name = "assay_track"

In [None]:
chrY_dists.to_csv(
    logs / "chrY_coverage_results" / f"chrY_coverage_distributions_{partial_name}.csv"
)

In [None]:
score_name = f"expected_{partial_name}_chrY_z-score"

In [None]:
new_data = []
for row in chrY_dists.iterrows():
    index, vals = row
    # assay, sex, track = index # type: ignore
    # assay, sex = index # type: ignore
    # assay = index  # type: ignore
    assay, track = index  # type: ignore

    mean = vals["chrY"]["mean"]
    std = vals["chrY"]["std"]
    count = vals["chrY"]["count"]

    # partial_df = merged_df[(merged_df[ASSAY] == assay) & (merged_df[SEX] == sex) & (merged_df[TRACK] == track)]
    # partial_df = merged_df[(merged_df[ASSAY] == assay) & (merged_df[SEX] == sex)]
    # partial_df = merged_df[(merged_df[ASSAY] == assay)]
    partial_df = merged_df[(merged_df[ASSAY] == assay) & (merged_df[TRACK] == track)]
    partial_df = partial_df[["chrY"]].copy()

    partial_df.loc[:, score_name] = (partial_df["chrY"] - mean) / std
    partial_df.loc[:, f"count_{score_name}"] = count

    new_data.append(partial_df)

In [None]:
# print(new_data)

In [None]:
all_zscores_df = pd.concat(new_data)

In [None]:
final_df = df_chrY.merge(
    all_zscores_df, how="left", left_index=True, right_index=True, suffixes=("", "_DROP")
).sort_index()
final_df = final_df.drop(columns=["chrY_DROP"])

In [None]:
final_df_path = (
    logs / "chrY_coverage_results" / f"chrY_coverage_zscores_{partial_name}.csv"
)
final_df.to_csv(final_df_path)

## Map predictions to chrY z-scores

In [None]:
sex_pred_dir = Path.home() / "downloads" / "temp"
sex_pred_file = (
    sex_pred_dir / "sex3_oversample_full-10fold-validation_prediction_augmented-all.csv"
)

full_chrY_df = pd.read_csv(final_df_path, index_col=0, header=0)

In [None]:
pred_df = pd.read_csv(sex_pred_file, index_col=0, header=0)
print(pred_df.shape)

In [None]:
merged_pred_df = pred_df.merge(
    full_chrY_df, how="inner", left_index=True, right_index=True, suffixes=("", "_DROP")
).sort_index()
merged_pred_df = merged_pred_df.drop(
    columns=[column for column in pred_df.columns if column.endswith("_DROP")]
)
print(merged_pred_df.shape)

In [None]:
raise UserWarning("stop here")

## chrY z-scores confusion matrix for sex3 preds

In [None]:
COLORS_DICT = {"female": "red", "male": "blue", "mixed": "purple"}

In [None]:
# merged_pred_df.to_csv(sex_pred_dir / "sex3_oversample_full-10fold-validation_prediction_augmented-all_chrY_zscores.csv")

In [None]:
print([col for col in merged_pred_df.columns if "z-score" in col])

In [None]:
# merged_pred_df = merged_pred_df[
#     ~merged_pred_df[ASSAY].str.contains(pat="wgb", case=False)
# ]

### graphs per assay

In [None]:
coverage_label = "expected_assay_chrY_z-score"
count_label = f"count_{coverage_label}"

classes = merged_pred_df["True class"].unique()
assays = merged_pred_df[ASSAY].unique()

matrix_logdir = chrY_coverage_file = (
    logs
    / "chrY_coverage_results"
    / "10fold_valid"
    / "z-score"
    / "per_assay"
    / "w-unknown"
    / "per_assay_graph"
)

for assay_label in assays:
    assay_df = merged_pred_df[merged_pred_df[ASSAY] == assay_label]

    # confusion matrix for chrY z-score
    for threshold in [0, 0.7, 0.9]:
        row = 1
        col = 1
        fig = make_subplots(
            rows=3,
            cols=3,
            shared_yaxes=True,
            x_title="Predicted class (nb of predictions)",
            y_title="z-score vs expected assay",
            row_titles=list(classes),
            column_titles=list(classes),
            vertical_spacing=0.08,
            horizontal_spacing=0.01,
        )
        threshold_df = assay_df[assay_df["Max pred"] >= threshold]

        for label in classes:
            df_label = threshold_df[threshold_df["True class"] == label]

            # Iterate over each target and add a violin plot for it
            for target in classes:
                sub_df = df_label[df_label["Predicted class"] == target]

                if sub_df.shape[0] == 0:
                    y_values = [
                        threshold_df[coverage_label].mean()
                    ]  # Minimal synthetic data
                    sample_count = 0
                    hovertext = ["PLACEHOLDER - NO DATA"]
                else:
                    y_values = sub_df[coverage_label]
                    hovertext = [
                        f"{md5sum, assay}:(z-score={z_score:.3f} (n={int(count)}), pred={pred:.3f})"
                        for md5sum, pred, z_score, count, assay in zip(
                            sub_df.index,
                            sub_df["Max pred"],
                            sub_df[coverage_label],
                            sub_df[count_label],
                            sub_df[ASSAY],
                        )
                    ]

                fig.add_trace(
                    go.Violin(
                        y=y_values,
                        name=f"{target} ({sub_df.shape[0]})",
                        box_visible=True,
                        meanline_visible=True,
                        points="all",
                        text=hovertext,
                        line_color=COLORS_DICT[target],
                        hovertemplate="%{text}",
                    ),
                    row=row,
                    col=col,
                )

                # Move to the next subplot position
                col += 1
                if col > 3:
                    col = 1
                    row += 1

        # Update global layout and traces
        fig.update_traces(marker=dict(size=1))
        fig.update_yaxes(
            range=[
                min(assay_df[coverage_label]) - 0.01,
                max(assay_df[coverage_label]) + 0.01,
            ]
        )

        # Directly using annotations param does not work with make_subplots
        existing_annotations = fig.layout.annotations
        new_annotation = dict(
            x=1.01,  # Position on the x-axis
            y=0.5,  # Position on the y-axis
            showarrow=False,  # Do not show arrow
            text="Reference class",  # The text you want to display
            xref="paper",  # 'x' coordinate is set in relative coordinates
            yref="paper",  # 'y' coordinate is set in relative coordinates
            xanchor="left",  # Text starts from the left of the x-coordinate
            yanchor="middle",  # Middle aligned vertically
            font=dict(size=16),
            textangle=90,
        )
        updated_annotations = list(existing_annotations) + [new_annotation]

        title = f"z-score(mean chrY coverage per file):{assay_label} (pred>{threshold})"

        fig.update_layout(
            title_text=f"{title} (n={threshold_df.shape[0]})",
            showlegend=False,
            annotations=updated_annotations,
        )

        # fig.show()

        title = get_valid_filename(title).replace("_br_", "_")
        html_file = matrix_logdir / f"{title}.html"
        png_file = matrix_logdir / f"{title}.png"
        if not png_file.exists():
            fig.write_image(png_file, scale=2)
        if not html_file.exists():
            fig.write_html(html_file)

### Graphs per assay and track

In [None]:
coverage_label = f"expected_{partial_name}_chrY_z-score"
count_label = f"count_{coverage_label}"

classes = merged_pred_df["True class"].unique()
assays = merged_pred_df[ASSAY].unique()

In [None]:
matrix_logdir = chrY_coverage_file = (
    logs
    / "chrY_coverage_results"
    / "10fold_valid"
    / "z-score"
    / "per_assay_track"
    / "per_assay_track_confusion_graphs"
)
matrix_logdir.mkdir(parents=False, exist_ok=True)

#### Confusion matrices style

In [None]:
for assay_label in assays:
    assay_df = merged_pred_df[merged_pred_df[ASSAY] == assay_label]

    for track_type in assay_df[TRACK].unique():
        assay_track_df = assay_df[assay_df[TRACK] == track_type]

        # confusion matrix for chrY z-score
        for threshold in [0, 0.7, 0.9]:
            row = 1
            col = 1
            fig = make_subplots(
                rows=3,
                cols=3,
                shared_yaxes=True,
                x_title="Predicted class (nb of predictions)",
                y_title="z-score vs expected assay",
                row_titles=list(classes),
                column_titles=list(classes),
                vertical_spacing=0.08,
                horizontal_spacing=0.01,
            )
            threshold_df = assay_track_df[assay_track_df["Max pred"] >= threshold]

            title = f"z-score(mean chrY coverage per file):{assay_label},{track_type} (pred>{threshold})"
            filename = get_valid_filename(title).replace("_br_", "_")

            html_file = matrix_logdir / f"{filename}.html"
            png_file = matrix_logdir / f"{filename}.png"
            if png_file.exists() or html_file.exists():
                continue

            for label in classes:
                df_label = threshold_df[threshold_df["True class"] == label]

                # Iterate over each target and add a violin plot for it
                for target in classes:
                    sub_df = df_label[df_label["Predicted class"] == target]

                    if sub_df.shape[0] == 0:
                        y_values = [
                            threshold_df[coverage_label].mean()
                        ]  # Minimal synthetic data
                        sample_count = 0
                        hovertext = ["PLACEHOLDER - NO DATA"]
                    else:
                        y_values = sub_df[coverage_label]
                        hovertext = [
                            f"{md5sum}:(z-score={z_score:.3f} (n={int(count)}), pred={pred:.3f})"
                            for md5sum, pred, z_score, count, assay in zip(
                                sub_df.index,
                                sub_df["Max pred"],
                                sub_df[coverage_label],
                                sub_df[count_label],
                                sub_df[ASSAY],
                            )
                        ]

                    fig.add_trace(
                        go.Violin(
                            y=y_values,
                            name=f"{target} ({sub_df.shape[0]})",
                            box_visible=True,
                            meanline_visible=True,
                            points="all",
                            text=hovertext,
                            line_color=COLORS_DICT[target],
                            hovertemplate="%{text}",
                        ),
                        row=row,
                        col=col,
                    )

                    # Move to the next subplot position
                    col += 1
                    if col > 3:
                        col = 1
                        row += 1

            # Update global layout and traces
            fig.update_traces(marker=dict(size=1))
            fig.update_yaxes(
                range=[
                    min(assay_track_df[coverage_label]) - 0.01,
                    max(assay_track_df[coverage_label]) + 0.01,
                ]
            )

            # Directly using annotations param does not work with make_subplots
            existing_annotations = fig.layout.annotations
            new_annotation = dict(
                x=1.01,  # Position on the x-axis
                y=0.5,  # Position on the y-axis
                showarrow=False,  # Do not show arrow
                text="Reference class",  # The text you want to display
                xref="paper",  # 'x' coordinate is set in relative coordinates
                yref="paper",  # 'y' coordinate is set in relative coordinates
                xanchor="left",  # Text starts from the left of the x-coordinate
                yanchor="middle",  # Middle aligned vertically
                font=dict(size=16),
                textangle=90,
            )
            updated_annotations = list(existing_annotations) + [new_annotation]

            fig.update_layout(
                title_text=f"{title} (n={threshold_df.shape[0]})",
                showlegend=False,
                annotations=updated_annotations,
            )

            # fig.show()
            # break
            fig.write_image(png_file, scale=2)
            fig.write_html(html_file)

#### Global assay/track distributions

In [None]:
# remove non-core tracks
merged_pred_df = merged_pred_df[~merged_pred_df[ASSAY].str.contains("wgb")]
assays = merged_pred_df[ASSAY].unique()

In [None]:
merged_pred_df[ASSAY].value_counts()

In [None]:
matrix_logdir = chrY_coverage_file = (
    logs / "chrY_coverage_results" / "10fold_valid" / "z-score" / "per_assay_track"
)

# Prepare a subplot grid; one for each assay + track type combination
num_assay_track_combinations = sum(
    len(merged_pred_df[merged_pred_df[ASSAY] == assay_label][TRACK].unique())
    for assay_label in assays
)
total_cols = num_assay_track_combinations + len(merged_pred_df[TRACK].unique())

In [None]:
assay_colors_dict = dict(zip(assays, px.colors.qualitative.Dark24))
track_colors_dict = dict(
    zip(merged_pred_df[TRACK].unique(), px.colors.qualitative.Dark24_r)
)

In [None]:
col = 1
row = 1
x_title = "Assay+Track distribution"
fig = make_subplots(
    rows=row,
    cols=total_cols,
    shared_yaxes=True,
    x_title=x_title,
    y_title="z-score",
    horizontal_spacing=0.001,
)

for assay_label in sorted(assays):
    assay_df = merged_pred_df[merged_pred_df[ASSAY] == assay_label]

    for track_type in sorted(assay_df[TRACK].unique()):
        sub_df = assay_df[assay_df[TRACK] == track_type]

        y_values = sub_df[coverage_label]
        hovertext = [
            f"{md5sum}:z-score={z_score:.3f} (n={int(count)}), pred={pred:.3f}"
            for md5sum, pred, z_score, count in zip(
                sub_df.index,
                sub_df["Max pred"],
                sub_df[coverage_label],
                sub_df[count_label],
            )
        ]

        fig.add_trace(
            go.Violin(
                y=y_values,
                name=f"{assay_label},{track_type} ({sub_df.shape[0]})",
                box_visible=True,
                meanline_visible=True,
                points="all",
                text=hovertext,
                hovertemplate="%{text}",
                line_color=assay_colors_dict[assay_label],
            ),
            row=row,
            col=col,
        )

        col += 1


# global track type distribution (all assays)
for track_type in sorted(merged_pred_df[TRACK].unique()):
    sub_df = merged_pred_df[merged_pred_df[TRACK] == track_type]

    y_values = sub_df[coverage_label]
    hovertext = [
        f"{md5sum,assay}:z-score={z_score:.3f} (n={int(count)}), pred={pred:.3f}"
        for md5sum, pred, z_score, count, assay in zip(
            sub_df.index,
            sub_df["Max pred"],
            sub_df[coverage_label],
            sub_df[count_label],
            sub_df[ASSAY],
        )
    ]

    fig.add_trace(
        go.Violin(
            y=y_values,
            name=f"{track_type} ({sub_df.shape[0]})",
            box_visible=True,
            meanline_visible=True,
            points="all",
            text=hovertext,
            hovertemplate="%{text}",
            line_color=track_colors_dict[track_type],
        ),
        row=row,
        col=col,
    )

    col += 1

# Final graphing

# Update global layout and traces
fig.update_traces(marker=dict(size=1))
fig.update_yaxes(
    range=[
        min(merged_pred_df[coverage_label]) - 0.01,
        max(merged_pred_df[coverage_label]) + 0.01,
    ]
)

fig.update_xaxes(tickangle=50)

fig.update_annotations(y=1.5, selector={"text": x_title})

title = "z-score(mean chrY coverage per file) distribution per assay+track type"
fig.update_layout(
    title_text=f"{title}",
    showlegend=False,
    autosize=True,
    width=3000,
)

fig.show()

filename = get_valid_filename(title).replace("_br_", "_")
html_file = matrix_logdir / f"{filename}.html"
png_file = matrix_logdir / f"{filename}.png"

fig.write_image(png_file, scale=3)
fig.write_html(html_file)

### Graphs for all assays mixed

In [None]:
coverage_label = "expected_assay_chrY_z-score"
count_label = f"count_{coverage_label}"

classes = merged_pred_df["True class"].unique()

matrix_logdir = chrY_coverage_file = (
    logs
    / "chrY_coverage_results"
    / "10fold_valid"
    / "z-score"
    / "per_assay"
    / "w-unknown"
)


# confusion matrix for chrY z-score
for threshold in [0, 0.7, 0.9]:
    row = 1
    col = 1
    fig = make_subplots(
        rows=3,
        cols=3,
        shared_yaxes=True,
        x_title="Predicted class (nb of predictions)",
        y_title="z-score for expected assay",
        row_titles=list(classes),
        column_titles=list(classes),
        vertical_spacing=0.08,
        horizontal_spacing=0.01,
    )
    threshold_df = merged_pred_df[merged_pred_df["Max pred"] >= threshold]

    title = (
        f"z-score(mean chrY coverage per file) - (pred>{threshold})<br>w fc/pval, no wgbs"
    )

    filename = get_valid_filename(title).replace("_br_", "_")
    html_file = matrix_logdir / f"{filename}.html"
    png_file = matrix_logdir / f"{filename}.png"
    if png_file.exists() or html_file.exists():
        continue

    for label in classes:
        df_label = threshold_df[threshold_df["True class"] == label]

        # Iterate over each target and add a violin plot for it
        for target in classes:
            sub_df = df_label[df_label["Predicted class"] == target]

            if sub_df.shape[0] == 0:
                y_values = [threshold_df[coverage_label].mean()]  # Minimal synthetic data
                sample_count = 0
                hovertext = ["PLACEHOLDER - NO DATA"]
            else:
                y_values = sub_df[coverage_label]
                hovertext = [
                    f"{md5sum, assay}:(z-score={z_score:.3f} (n={int(count)}), pred={pred:.3f})"
                    for md5sum, pred, z_score, count, assay in zip(
                        sub_df.index,
                        sub_df["Max pred"],
                        sub_df[coverage_label],
                        sub_df[count_label],
                        sub_df[ASSAY],
                    )
                ]

            fig.add_trace(
                go.Violin(
                    y=y_values,
                    name=f"{target} ({sub_df.shape[0]})",
                    box_visible=True,
                    meanline_visible=True,
                    points="all",
                    text=hovertext,
                    line_color=COLORS_DICT[target],
                    hovertemplate="%{text}",
                ),
                row=row,
                col=col,
            )

            # Move to the next subplot position
            col += 1
            if col > 3:
                col = 1
                row += 1

    # Update global layout and traces
    fig.update_traces(marker=dict(size=1))
    fig.update_yaxes(
        range=[
            min(merged_pred_df[coverage_label]) - 0.01,
            max(merged_pred_df[coverage_label]) + 0.01,
        ]
    )

    # Directly using annotations param does not work with make_subplots
    existing_annotations = fig.layout.annotations  # type: ignore
    new_annotation = dict(
        x=1.01,  # Position on the x-axis
        y=0.5,  # Position on the y-axis
        showarrow=False,  # Do not show arrow
        text="Reference class",  # The text you want to display
        xref="paper",  # 'x' coordinate is set in relative coordinates
        yref="paper",  # 'y' coordinate is set in relative coordinates
        xanchor="left",  # Text starts from the left of the x-coordinate
        yanchor="middle",  # Middle aligned vertically
        font=dict(size=16),
        textangle=90,
    )
    updated_annotations = list(existing_annotations) + [new_annotation]

    fig.update_layout(
        title_text=f"{title} (n={threshold_df.shape[0]})",
        showlegend=False,
        annotations=updated_annotations,
    )

    fig.show()

    fig.write_image(png_file, scale=2)
    fig.write_html(html_file)

## Merge with global results file

In [None]:
target_file_dir = logs / "merged_results"
target_file = target_file_dir / "merged_pred_results_all_2.1_chrY.csv"

target_df = pd.read_csv(target_file, index_col=0, header=0)

In [None]:
updated_target_df = target_df.merge(
    full_chrY_df, how="left", left_index=True, right_index=True, suffixes=("", "_DROP")
).sort_index()

In [None]:
same_cols = []
for column in updated_target_df.columns:
    if column.endswith("_DROP"):
        same_cols.append(column[:-5])

In [None]:
for column in same_cols:
    assert np.isclose(
        0, (updated_target_df[column] - updated_target_df[f"{column}_DROP"]).sum()
    ), f"{column} not equal"

In [None]:
updated_target_df = updated_target_df.drop(
    columns=[column for column in updated_target_df.columns if column.endswith("_DROP")]
)

In [None]:
updated_target_df.to_csv(target_file_dir / "merged_pred_results_all_2.1_chrY_zscores.csv")