two heatmaps for comparison

1. true heatmap: 4 columns for pici/cfpici/p4/phage, and have one seq in each row; color by pcat in the dataset; order: reverse (descending order of 'start') if more -1 strand than 1 strand;  
2. predicted heatmap: 4 columns for pici/cfpici/p4/phage, and have one seq in each row; color by predicted functions (top_function in the prediction_df); order: reverse (descending order of 'start') if more -1 strand than 1 strand;  

In [89]:
import pandas as pd
import numpy as np
from PIL import Image, ImageColor, ImageDraw, ImageFont

### create heatmap

In [90]:
def make_segment_heatmap(
    samples,
    function_col,
    colors,
    segment_types,
    block_size=10,
    col_gap=3,
    title_height=30,
    font_path=None,
):
    from PIL import Image, ImageColor, ImageDraw, ImageFont
    import numpy as np

    acc_lists = [samples[samples["what"] == t]["acc"].unique() for t in segment_types]
    max_rows = max(len(accs) for accs in acc_lists)
    max_genes_per_type = [
        max(samples[samples["what"] == t].groupby("acc").size().max() or 0, 1)
        for t in segment_types
    ]
    width = (
        sum(max_genes_per_type) * block_size
        + (len(segment_types) - 1) * col_gap * block_size
    )
    height = max_rows * block_size + title_height

    im = Image.new("RGBA", (width, height), "white")
    draw = ImageDraw.Draw(im)
    font = ImageFont.truetype(font_path, 18) if font_path else None

    x_offset = 0
    for col, (t, accs, max_genes) in enumerate(
        zip(segment_types, acc_lists, max_genes_per_type)
    ):
        # Draw column title
        title_x = x_offset + (max_genes * block_size) // 2
        draw.text((title_x, 5), t, fill="black", anchor="ma", font=font)
        for row, acc in enumerate(accs):
            seg = samples[(samples["what"] == t) & (samples["acc"] == acc)].copy()
            if seg.empty:
                continue
            seg = seg.sort_values("start")
            if seg["strand"].sum() < 0:
                seg = seg.sort_values("start", ascending=False)
            labels = seg[function_col].tolist()
            for i, label in enumerate(labels):
                # Only skip if label is missing
                if label is None or (isinstance(label, float) and np.isnan(label)):
                    continue
                color = ImageColor.getcolor(colors.get(label, "#F5F5F5"), "RGBA")
                for dx in range(block_size):
                    for dy in range(block_size):
                        x = x_offset + i * block_size + dx
                        y = title_height + row * block_size + dy
                        if x < width and y < height:
                            im.putpixel((x, y), color)
        x_offset += max_genes * block_size + col_gap * block_size
    return im

In [91]:
samples = pd.read_csv(
    "../dataset/demonstration_samples/known_segments/annotation_200.csv"
)
predictions_df = pd.read_csv(
    "../results/demonstration/prediction_known_segments_200_best_thresholds.csv"
)

# rename values in the pcat column in samples
samples["pcat"] = samples["pcat"].replace(
    "DNA, RNA and nucleotide metabolism", "dna_rna_and_nucleotide_metabolism"
)
samples["pcat"] = samples["pcat"].replace("unknown function", "unknown_function")
samples["pcat"] = samples["pcat"].replace("head and packaging", "head_and_packaging")
samples["pcat"] = samples["pcat"].replace(
    "transcription regulation", "transcription_regulation"
)
samples["pcat"] = samples["pcat"].replace(
    "integration and excision", "integration_and_excision"
)
samples["pcat"] = samples["pcat"].replace(
    "moron, auxiliary metabolic gene and host takeover",
    "moron_auxiliary_metabolic_gene_and_host_takeover",
)
samples["pcat"] = samples["pcat"].replace("unknown_no_hit", "no_hit")

# merge
samples_pred = samples.merge(
    predictions_df[["id", "top_function"]], left_on="name", right_on="id", how="left"
)
samples_pred_filtered_true = samples_pred[
    ~samples_pred["pcat"].isin(["no_hit", "unknown_function"])
].copy()
samples_pred_filtered_pred = samples_pred[
    ~samples_pred["top_function"].isin(["no_hit", "unknown_function"])
].copy()


In [92]:
colors = {
    "lysis": "#f35f49",
    "tail": "#07e9a2",
    "connector": "#35d7ff",
    "dna_rna_and_nucleotide_metabolism": "#ffdf59",
    "head_and_packaging": "#3e83f6",
    "transcription_regulation": "#a861e3",
    "moron_auxiliary_metabolic_gene_and_host_takeover": "#ff59f5",
    "integration_and_excision": "#fea328",
    "other": "#838383",
    "unknown_function": "#313131",
    "no_hit": "#f5f5f5",
}

In [93]:
# im_true_satellites = make_segment_heatmap(
#     samples_pred_filtered_true,
#     # samples_pred,
#     function_col="pcat",
#     colors=colors,
#     segment_types=["PICI", "CFPICI", "P4"],
#     block_size=10,
#     col_gap=3,
#     title_height=30,
#     font_path=None,  # Or provide a .ttf path for custom font
# )
# im_true_satellites.save(
#     "../results/demonstration/known_segment_heatmaps/im_satellites_true_simple.png"
# )

im_pred_satellites = make_segment_heatmap(
    samples_pred_filtered_pred,
    # samples_pred,
    function_col="top_function",
    colors=colors,
    segment_types=["PICI", "CFPICI", "P4"],
    block_size=10,
    col_gap=3,
    title_height=30,
    font_path=None,
)
im_pred_satellites.save(
    "../results/demonstration/known_segment_heatmaps/im_satellites_pred_best_thresholds_simple.png"
)

# im_true_phage = make_segment_heatmap(
#     # samples_pred_filtered_true,
#     samples_pred,
#     function_col="pcat",
#     colors=colors,
#     segment_types=["phage"],
#     block_size=10,
#     col_gap=3,
#     title_height=30,
#     font_path=None,
# )
# im_true_phage.save("../results/demonstration/known_segment_heatmaps/im_phage_true.png")

# im_pred_phage = make_segment_heatmap(
#     # samples_pred_filtered_pred,
#     samples_pred,
#     function_col="top_function",
#     colors=colors,
#     segment_types=["phage"],
#     block_size=10,
#     col_gap=3,
#     title_height=30,
#     font_path=None,
# )
# im_pred_phage.save("../results/demonstration/known_segment_heatmaps/im_phage_pred.png")

### check

In [94]:
# check precision
num_total = len(samples_pred)
num_correct = len(samples_pred[samples_pred["pcat"] == samples_pred["top_function"]])
print(f"precision: {num_correct / num_total}")

precision: 0.5232916440094254


In [95]:
def calculate_recall(samples_pred, function_name):
    true_positive = samples_pred[samples_pred["pcat"] == function_name]
    predicted_positive = samples_pred[samples_pred["top_function"] == function_name]
    predicted_true_positive = predicted_positive[
        predicted_positive["pcat"] == function_name
    ]
    precision = len(predicted_true_positive) / len(predicted_positive)
    recall = len(predicted_true_positive) / len(true_positive)
    print(
        f"precision: {precision}, recall: {recall}, intersection: {len(predicted_true_positive)}, real positive: {len(true_positive)}, predicted positive: {len(predicted_positive)}"
    )


calculate_recall(samples_pred, "head_and_packaging")

precision: 0.7892772310824567, recall: 0.8750940556809631, intersection: 2326, real positive: 2658, predicted positive: 2947


In [96]:
samples_pred["pcat"].value_counts()

pcat
no_hit                                              11096
unknown_function                                     4225
dna_rna_and_nucleotide_metabolism                    3338
head_and_packaging                                   2658
tail                                                 1516
transcription_regulation                             1267
integration_and_excision                             1068
connector                                             766
other                                                 761
moron_auxiliary_metabolic_gene_and_host_takeover      481
lysis                                                 409
Name: count, dtype: int64

In [97]:
samples_pred["top_function"].value_counts()

top_function
unknown_function                                    6998
dna_rna_and_nucleotide_metabolism                   5275
head_and_packaging                                  2947
transcription_regulation                            2124
tail                                                2080
moron_auxiliary_metabolic_gene_and_host_takeover    1809
connector                                           1552
other                                               1382
no_hit                                              1362
integration_and_excision                            1155
lysis                                                901
Name: count, dtype: int64

In [98]:
set(samples_pred["pcat"].unique()) == set(predictions_df["top_function"].unique())

True

### legend

In [99]:
# import matplotlib.pyplot as plt


# def plot_legend_matplotlib(
#     colors, square_size=0.5, text_size=16, row_gap=0.2, out_path=None, dpi=150
# ):
#     n = len(colors)
#     fig_height = n * (square_size + row_gap)
#     fig, ax = plt.subplots(figsize=(5, fig_height))
#     for i, (func, color) in enumerate(colors.items()):
#         y = n - i - 1  # So the first color is at the top
#         label = func.replace("_", " ")
#         # Draw square
#         ax.add_patch(
#             plt.Rectangle((0, y), square_size, square_size, color=color, ec="black")
#         )
#         # Draw text with space to the right of the square
#         ax.text(
#             square_size + 0.2,
#             y + square_size / 2,
#             label,
#             va="center",
#             fontsize=text_size,
#         )
#     ax.set_xlim(0, 4)
#     ax.set_ylim(0, n)
#     ax.axis("off")
#     plt.tight_layout()
#     if out_path:
#         plt.savefig(out_path, bbox_inches="tight", dpi=dpi)
#     plt.show()


# # Usage:
# plot_legend_matplotlib(
#     colors,
#     square_size=0.4,
#     text_size=15,
#     row_gap=0.15,
#     out_path="../results/demonstration/known_segment_heatmaps/legend.png",
# )