In [None]:
"""Notebook to analyze the values in an HDF5 file."""
# %pip list | grep "ka"
# pylint: disable=redefined-outer-name, expression-not-assigned, import-error, not-callable, pointless-statement, no-value-for-parameter, undefined-variable, unused-argument

In [None]:
import copy
from pathlib import Path
from typing import Dict, List, Tuple

import h5py
import numpy as np
import pandas as pd
import plotly.express as px  # type: ignore
import plotly.graph_objects as go  # type: ignore

from epi_ml.core.data_source import EpiDataSource
from epi_ml.core.epiatlas_treatment import ACCEPTED_TRACKS
from epi_ml.core.hdf5_loader import Hdf5Loader
from epi_ml.core.metadata import Metadata

ASSAY = "assay_epiclass"
TRACK_TYPE = "track_type"

In [None]:
%matplotlib inline

In [None]:
# base = Path("/lustre06/project/6007017/rabyj/epilap/input/")
base = Path.home() / "Projects/epilap"
input_base = base / "input"
output_base = base / "output"

chromsize_path = input_base / "chromsizes" / "hg38.noy.chrom.sizes"
metadata_path = input_base / "metadata/hg38_2023_epiatlas_dfreeze_formatted_JR.json"

base_logdir = output_base / "logs"
logdir = base_logdir / "hg38_2022-epiatlas/hdf5_stats"

In [None]:
# hdf5_list_path = base / "hdf5_list" / "100kb_all_none.list"
hdf5_list_path = (
    input_base
    / "hdf5_list"
    / "hg38_2023-01-epiatlas-freeze"
    / "100kb_all_none_0blklst.list"
)

In [None]:
datasource = EpiDataSource(hdf5_list_path, chromsize_path, metadata_path)
my_meta = Metadata(datasource.metadata_file)
my_meta.display_labels("track_type")

my_meta.select_category_subsets("track_type", ACCEPTED_TRACKS)
my_meta.display_labels("track_type")

In [25]:
with open(hdf5_list_path, "r", encoding="utf8") as f:
    paths = f.read().splitlines()

In [None]:
shap_md5s_path = input_base / "hdf5_list" / "md5_shap_assay_explain.list"
with open(shap_md5s_path, "r", encoding="utf8") as f:
    shap_md5s = set(f.read().splitlines())

In [None]:
# fmt:off
regions = sorted(
   set([29956, 28774, 28775, 16809, 26345, 15821, 15888, 7249, 5651, 15219, 28889, 11325, 8574]) | set([29956, 28774, 28775, 16809, 26345, 29551, 15888, 5651, 15219, 28889, 11325, 8574])
)
# fmt:on

In [None]:
from epi_ml.utils.bed_utils import bins_to_bed_ranges

resolution = 1000 * 100
regions_bed = bins_to_bed_ranges(regions, datasource.load_chrom_sizes(), resolution)
regions_dict = {regions[i]: region_bed for i, region_bed in enumerate(regions_bed)}

In [None]:
regions_dict

In [None]:
def analyze_feature_vals(
    regions_dict: Dict[int, Tuple],
    md5s: List[str],
    hdf5_list: Path,
    logdir: Path,
    name: str,
):
    """
    Generate and save a violin plot of provided feature values for the provided md5s, with some md5s highlighted.

    This function takes as input a list of md5s and a dictionary of regions, and generates a violin plot
    of the feature values for these md5s. It also highlights specific md5s by adding lines+markers for them.
    The function saves the plot as an HTML file and a PNG file in the provided log directory.

    Args:
        regions_dict (Dict[int, Tuple]): A dictionary mapping region indices to their respective genomic coordinates.
        md5s (List[str]): A list of md5s to analyze.
        hdf5_list (Path): Path to the list of hdf5 files to be used.
        logdir (Path): Directory where the resulting plot should be saved.
        name (str): Name used to save the resulting plot (will be part of the filename).
    """
    hdf5_loader = Hdf5Loader(chrom_file=chromsize_path, normalization=True)
    hdf5_loader.load_hdf5s(hdf5_list, md5s, strict=True)
    N = len(hdf5_loader.signals)

    nb_highlight = 3
    highlight_md5s = list(set(md5s) & shap_md5s)[0:nb_highlight]

    traces = []
    highlight_values = {highlight_md5: [] for highlight_md5 in highlight_md5s}
    for region, region_bed in regions_dict.items():
        values = [signal[region] for signal in hdf5_loader.signals.values()]
        region_str = f"{region_bed[0]}:{region_bed[1]}-{region_bed[2]}"

        trace = go.Violin(
            y=values,
            name=region_str,
            points="all",
            box_visible=True,
            meanline_visible=True,
        )
        traces.append(trace)

        for highlight_md5 in highlight_md5s:
            highlight_value = hdf5_loader.signals[highlight_md5][region]
            highlight_values[highlight_md5].append((region_str, highlight_value))

    for (highlight_md5, highlight_value), marker_format in zip(
        highlight_values.items(),
        [["cross", "black"], ["circle", "blue"], ["diamond", "red"]],
    ):
        x, y = zip(*highlight_value)
        symbol, color = marker_format
        highlight_trace = go.Scatter(
            x=x,
            y=y,
            mode="lines+markers",
            name=f"{highlight_md5}",
            marker={"size": 6, "symbol": symbol, "color": color},
        )
        traces.append(highlight_trace)

    # Create the layout
    layout = go.Layout(
        title=f"Feature values distributions for {N} {name} samples (0blklst)",
        yaxis={"title": "z-score"},
        xaxis={"title": "Region"},
        showlegend=False,
    )

    # Create the figure with the data and layout
    fig = go.Figure(data=traces, layout=layout)
    fig.write_html(logdir / f"feature_values_{name}.html")

    width = 1200
    fig.write_image(
        logdir / f"feature_values_{name}.png", width=width, height=width * 3 / 4
    )
    # fig.show()

In [None]:
my_meta.remove_missing_labels("harmonized_donor_sex")

In [None]:
for label in ["male", "female"]:
    assay = "h3k27ac"
    meta = copy.deepcopy(my_meta)
    meta.select_category_subsets("harmonized_donor_sex", [label])
    meta.select_category_subsets("assay_epiclass", [assay])
    md5s = list(meta.md5s)
    print(f"Number of {label} samples: {len(md5s)}")
    analyze_feature_vals(
        regions_dict, md5s, hdf5_list_path, logdir, name=f"{label}-{assay}"
    )

In [None]:
def plot_single_file(md5, zscore: bool = True):
    """Produce a violin plot (save to html) of all feature values for a single sample."""
    if zscore:
        mode = "z-scores"
    else:
        mode = "raw values"

    hdf5_loader = Hdf5Loader(chrom_file=chromsize_path, normalization=zscore)
    signals = hdf5_loader.load_hdf5s(hdf5_list_path, [md5], strict=True).signals

    fig = px.violin(
        data_frame=list(signals.values())[0],
        box=True,
        points="all",
        title=f"Violin plot for {md5} {mode}",
    )
    fig.write_html(f"{md5}-{mode}.html")
    fig.show()

In [None]:
# md5 = "402a78740e46888266209a5b7c3ece4c"
# plot_single_file(md5, zscore=True)

In [None]:
# N_SAMPLES = 100
# md5s = set(list(Hdf5Loader.read_list(hdf5_list_path).keys())[0:N_SAMPLES])
md5s = set(list(Hdf5Loader.read_list(hdf5_list_path).keys()))
# for md5 in md5s:
#     if md5 not in my_meta:
#         raise IndexError(f"Missing metadata for {md5}")

df_md5_metadata = pd.DataFrame([my_meta[md5] for md5 in md5s if md5 in my_meta])
df_md5_metadata.set_index("md5sum", inplace=True)

In [None]:
print(f"{df_md5_metadata.shape[0]} files to analyze.")
print(df_md5_metadata[TRACK_TYPE].value_counts())
print(df_md5_metadata[ASSAY].value_counts())

In [None]:
traces = []
for filepath in paths:
    with h5py.File(filepath, "r+") as f:
        for _, group in f.items():
            for dataset_name, dataset in list(group.items()):
                # Extract the values from the dataset
                values = dataset[:]

                # Create a violin trace
                trace = go.Violin(y=values, name=dataset_name)

                # Add the trace to the data list
                traces.append(trace)

                # # Cast to float32 and compare max diff
                # casted_dataset = dataset.astype(np.float32)[:]
                # diff = np.abs(casted_dataset - values)
                # max_diff = np.max(diff)
                # # print(f"Max diff when casting: {max_diff}")
                # if max_diff > 1e-4:
                #     print("Induced casting error")
                #     print(f"Max value: {np.max(values)}")
                #     print(f"Filepath: {filepath}")
                #     print(f"Dataset name: {dataset_name}")

    # Create the layout
    layout = go.Layout(title="Violin Plots", yaxis={"title": "Values"})

    # Create the figure with the data and layout
    fig = go.Figure(data=traces, layout=layout)

    # Show the violin plot
    fig.show()
    traces = []

In [None]:
# Assuming you have a list of arrays
hdf5_loader = Hdf5Loader(chrom_file=chromsize_path, normalization=True)
signals = hdf5_loader.load_hdf5s(hdf5_list_path, md5s, strict=True).signals
df = pd.DataFrame.from_dict(signals, orient="index")
# df.head()

In [None]:
# Descriptive statistics
percentiles = [0.01] + list(np.arange(0.05, 1, 0.05)) + [0.99] + [0.999]
stats_df = df.apply(pd.DataFrame.describe, percentiles=percentiles, axis=1)  # type: ignore
metrics = set(stats_df.columns.values)

In [None]:
# print(sorted(metrics))

In [None]:
stats_df = stats_df.join(df_md5_metadata)

In [None]:
# Create violin plots, one plot for each metric, and a violin for each assay (per plot)
allowed_metrics = metrics - set(["count", "mean", "std"])
category_orders = {ASSAY: sorted(my_meta.label_counter(ASSAY, verbose=False).keys())}
for column in stats_df:
    if column not in allowed_metrics:
        continue
    fig = px.violin(
        data_frame=stats_df,
        x=column,
        y=ASSAY,
        box=True,
        points="all",
        title=f"Violin plot for {column}",
        color=ASSAY,
        category_orders=category_orders,
        height=800,
        hover_data={"md5sum": (df.index)},
    )
    fig.write_image(logdir / f"100kb_all_none_hdf5_{column}.png")
    fig.write_html(logdir / f"100kb_all_none_hdf5_{column}.html")

In [None]:
# do same plots but only keep "raw" files.
df_filter = stats_df["track_type"].isin(["fc", "pval"])

display(stats_df[df_filter].head(10))
display(stats_df[~df_filter].head(10))