In [None]:
r"""Heatmaps notebook.
 _____  _______  _    _
|  __ \|__   __|| |  | |
| |  | |  | |   | |  | |
| |  | |  | |   | |  | |
| |__| |  | |   | |__| |
|_____/   |_|   |______|

__authors__ = Marco Reverenna & Konstantinos Kalogeropoulus
__copyright__ = Copyright 2025-2026
__research-group__ = DTU Biosustain (Multi-omics Network Analytics) and DTU Bioengineering
__date__ = 25 Jun 2025
__maintainer__ = Marco Reverenna
__email__ = marcor@dtu.dk
__status__ = Dev
"""

' Full assembly workflow with a dbg approach.\n _____  _______  _    _ \n|  __ \\|__   __|| |  | |\n| |  | |  | |   | |  | |\n| |  | |  | |   | |  | |\n| |__| |  | |   | |__| |\n|_____/   |_|   |______|\n\n__authors__ = Marco Reverenna & Konstantinos Kalogeropoulus\n__copyright__ = Copyright 2025-2026\n__research-group__ = DTU Biosustain (Multi-omics Network Analytics) and DTU Bioengineering\n__date__ = 25 Jun 2025\n__maintainer__ = Marco Reverenna\n__email__ = marcor@dtu.dk\n__status__ = Dev\n'

In [None]:
import os
import json
import pandas as pd
import plotly.graph_objects as go

In [None]:
abbrev = {
    "max_mismatches": "mm",
    "min_identity": "id",
    "size_threshold": "st",
    "conf": "c",
    "kmer_size": "k",
    "min_overlap": "mo",
}

In [None]:
def parse_theme(theme_json):

    return json.loads(theme_json)

In [None]:
def plot_grid_search_clustermap(
    df,
    index_cols,
    column_cols,
    theme,
    value_col,
    title="",
    aggfunc="max",
    output_file=None,
):
    pivot = df.pivot_table(
        values=value_col, index=index_cols, columns=column_cols, aggfunc=aggfunc
    )
    pivot = pivot.sort_index(level=index_cols).sort_index(axis=1, level=column_cols)

    row_labels = [
        ", ".join(f"{abbrev.get(col, col)}={val}" for col, val in zip(index_cols, idx))
        for idx in pivot.index
    ]
    col_labels = [
        ", ".join(f"{abbrev.get(col, col)}={val}" for col, val in zip(column_cols, col))
        for col in pivot.columns
    ]

    global_mean = pivot.stack(future_stack=True).mean()
    pivot = pivot.fillna(global_mean)

    heatmap = go.Heatmap(
        z=pivot.values,
        x=col_labels,
        y=row_labels,
        colorscale=theme,
        zmin=0,
        zmax=1,
        showscale=True,
        colorbar=dict(title=value_col, len=0.75, thickness=20),
    )

    fig = go.Figure(data=[heatmap])

    # Aggiunta dei bordi bianchi: linee orizzontali e verticali tra le celle
    n_rows, n_cols = pivot.shape
    shapes = []
    for i in range(n_rows + 1):
        shapes.append(
            dict(
                type="line",
                x0=-0.5,
                x1=n_cols - 0.5,
                y0=i - 0.5,
                y1=i - 0.5,
                line=dict(color="white", width=2),
            )
        )
    for j in range(n_cols + 1):
        shapes.append(
            dict(
                type="line",
                x0=j - 0.5,
                x1=j - 0.5,
                y0=-0.5,
                y1=n_rows - 0.5,
                line=dict(color="white", width=2),
            )
        )

    fig.update_layout(
        width=950,
        height=850,
        title=title,
        xaxis=dict(tickangle=-45, showgrid=False, zeroline=False),
        yaxis=dict(showgrid=False, zeroline=False, autorange="reversed"),
        plot_bgcolor="white",
        paper_bgcolor="white",
        shapes=shapes,
    )

    if output_file:
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        fig.write_image(output_file, format="svg", scale=2)

    return fig

In [None]:
def combine_json_to_csv(run, type_sequence):
    """
    Walks through directories, reads JSON files, and combines them into separate CSV files per method.
    type_sequence: can be contigs or scaffolds, for example.
    """
    base_path = "../outputs/" + run
    dataframes_by_method = {}
    files_added = {}

    for root, dirs, _ in os.walk(base_path):
        for dir_name in dirs:
            if dir_name.startswith("comb_dbg"):
                method = "dbg"
            elif dir_name.startswith("comb_greedy"):
                method = "greedy"
            else:
                method = "other"

            json_path = os.path.join(
                root, dir_name, "statistics", f"{type_sequence}_stats.json"
            )

            if os.path.exists(json_path):
                try:
                    with open(json_path, "r") as f:
                        data = json.load(f)

                    df = pd.json_normalize(data)
                    df["source"] = dir_name

                    if method not in dataframes_by_method:
                        dataframes_by_method[method] = []
                        files_added[method] = 0

                    dataframes_by_method[method].append(df)
                    files_added[method] += 1
                except Exception as e:
                    print(f"Error loading {json_path}: {e}")

    for method, dfs in dataframes_by_method.items():
        if dfs:
            combined_df = pd.concat(dfs, ignore_index=True)
            combined_df["sequence_type"] = type_sequence
            combined_df["run"] = run
            combined_df["ass_method"] = method

            output_file = os.path.join(
                base_path, f"{type_sequence}_combined_stats_{method}.csv"
            )
            combined_df.to_csv(output_file, index=False, sep=",", header=True)
            print(f"[{method}] Combined JSON saved to CSV: {output_file}")
            print(f"[{method}] Files successfully added: {files_added[method]}")
        else:
            print(f"[{method}] No dataframes to concatenate.")

In [None]:
def get_category(run_name):
    if run_name == "bsa":
        return "bsa"
    if run_name.startswith("NB"):
        return "nanobodies"
    if run_name.startswith("BIND"):
        return "binders"
    return "antibodies"

In [None]:
output_base = "../outputs"

In [None]:
runs = [
    d for d in os.listdir(output_base) if os.path.isdir(os.path.join(output_base, d))
]

In [None]:
type_sequences = ["contigs", "scaffolds"]

In [None]:
for r in runs:
    for seq in type_sequences:
        combine_json_to_csv(run=r, type_sequence=seq)

In [None]:
all_dfs = []

for r in runs:
    category = get_category(r)
    for seq in type_sequences:
        for method in ["dbg", "greedy"]:
            csv_path = os.path.join(
                "../outputs", r, f"{seq}_combined_stats_{method}.csv"
            )
            if os.path.exists(csv_path):
                df = pd.read_csv(csv_path)
                df["run"] = r
                df["sequence_type"] = seq
                df["ass_method"] = method
                df["category"] = category
                all_dfs.append(df)

In [None]:
df_all = pd.concat(all_dfs, ignore_index=True) if all_dfs else pd.DataFrame()

theme_map = {
    "bsa": [
        [0.0, "#fee8c8"],
        [0.8, "#fdbb84"],
        [0.9, "#ef6548"],
        [0.95, "#b30000"],
        [1.0, "#7f0000"],
    ],
    "antibodies": [
        [0.0, "#c7e9c0"],
        [0.7, "#a1d99b"],
        [0.8, "#74c476"],
        [0.9, "#41ab5d"],
        [1.0, "#238b45"],
    ],
    "nanobodies": [
        [0.0, "#deebf7"],
        [0.7, "#9ecae1"],
        [0.8, "#6baed6"],
        [0.9, "#3182bd"],
        [1.0, "#08519c"],
    ],
    "binders": [
        [0.0, "#f2f0f7"],
        [0.7, "#cbc9e2"],
        [0.8, "#9e9ac8"],
        [0.9, "#756bb1"],
        [1.0, "#54278f"],
    ],
}

In [None]:
index_cols = ["max_mismatches", "min_identity"]
column_cols = ["conf", "size_threshold"]
value_col = "coverage"
aggfunc = "max"
base_out = "heatmaps"

In [None]:
for category in df_all["category"].unique():
    cat_df = df_all[df_all["category"] == category]
    for r in cat_df["run"].unique():
        for seq in type_sequences:
            for method in ["dbg", "greedy"]:
                subset = cat_df[
                    (cat_df["run"] == r)
                    & (cat_df["sequence_type"] == seq)
                    & (cat_df["ass_method"] == method)
                ]

                if subset.empty:
                    continue

                theme = theme_map[category]

                title = f"{r} - {seq} - {method} coverage"
                out_dir = os.path.join(base_out, category)
                os.makedirs(out_dir, exist_ok=True)
                output_file = os.path.join(
                    out_dir, f"{r}_{seq}_{method}_coverage_clustermap.svg"
                )

                plot_grid_search_clustermap(
                    subset,
                    index_cols,
                    column_cols,
                    theme,
                    value_col,
                    title=title,
                    aggfunc=aggfunc,
                    output_file=output_file,
                )