two heatmaps for comparison

1. true heatmap: 4 columns for pici/cfpici/p4/phage, and have one seq in each row; color by pcat in the dataset; order: reverse (descending order of 'start') if more -1 strand than 1 strand;  
2. predicted heatmap: 4 columns for pici/cfpici/p4/phage, and have one seq in each row; color by predicted functions (top_function in the prediction_df); order: reverse (descending order of 'start') if more -1 strand than 1 strand;  

In [16]:
import pandas as pd
import numpy as np
from PIL import Image, ImageColor, ImageDraw, ImageFont

### generate samples

In [17]:
# annotation = pd.read_parquet("../dataset/Phage_and_Satellites_Pann_Pcat_Pcol.pa")

# # id 200 random samples for each type
# pici_acc_200 = list(
#     np.random.choice(
#         annotation[annotation["what"] == "PICI"]["acc"].unique(), 200, replace=False
#     )
# )
# cf_acc_200 = list(
#     np.random.choice(
#         annotation[annotation["what"] == "CFPICI"]["acc"].unique(), 200, replace=False
#     )
# )
# p4_acc_200 = list(
#     np.random.choice(
#         annotation[annotation["what"] == "P4"]["acc"].unique(), 200, replace=False
#     )
# )
# phage_acc_200 = list(
#     np.random.choice(
#         annotation[annotation["what"] == "phage"]["acc"].unique(), 200, replace=False
#     )
# )
# all_acc_200 = pici_acc_200 + cf_acc_200 + p4_acc_200 + phage_acc_200

# # merge all the samples
# samples = annotation[annotation["acc"].isin(all_acc_200)]
# print(samples["acc"].nunique())

# samples.to_csv(
#     "../dataset/demonstration_samples/known_segments/annotation_200.csv", index=False
# )

### parse seqs

In [18]:
# samples = pd.read_csv(
#     "../dataset/demonstration_samples/known_segments/annotation_200.csv"
# )
# samples_acc = samples["acc"].unique()
# samples_protein_id = samples["name"].unique()

# print(len(samples_protein_id))
# print(len(samples_acc))
# print(samples.shape)

# with open("../dataset/demonstration_samples/known_segments/proteins_200.faa", "w") as f:
#     for idx, row in samples.iterrows():
#         f.write(f">{row['name']}\n")
#         f.write(f"{row['translation']}\n")

### create heatmap

In [19]:
def make_segment_heatmap(
    samples,
    function_col,
    colors,
    segment_types,
    block_size=10,
    col_gap=3,
    title_height=30,
    font_path=None,
):
    from PIL import Image, ImageColor, ImageDraw, ImageFont
    import numpy as np

    acc_lists = [samples[samples["what"] == t]["acc"].unique() for t in segment_types]
    max_rows = max(len(accs) for accs in acc_lists)
    max_genes_per_type = [
        max(samples[samples["what"] == t].groupby("acc").size().max() or 0, 1)
        for t in segment_types
    ]
    width = (
        sum(max_genes_per_type) * block_size
        + (len(segment_types) - 1) * col_gap * block_size
    )
    height = max_rows * block_size + title_height

    im = Image.new("RGBA", (width, height), "white")
    draw = ImageDraw.Draw(im)
    font = ImageFont.truetype(font_path, 18) if font_path else None

    x_offset = 0
    for col, (t, accs, max_genes) in enumerate(
        zip(segment_types, acc_lists, max_genes_per_type)
    ):
        # Draw column title
        title_x = x_offset + (max_genes * block_size) // 2
        draw.text((title_x, 5), t, fill="black", anchor="ma", font=font)
        for row, acc in enumerate(accs):
            seg = samples[(samples["what"] == t) & (samples["acc"] == acc)].copy()
            if seg.empty:
                continue
            seg = seg.sort_values("start")
            if seg["strand"].sum() < 0:
                seg = seg.sort_values("start", ascending=False)
            labels = seg[function_col].tolist()
            for i, label in enumerate(labels):
                # Only skip if label is missing
                if label is None or (isinstance(label, float) and np.isnan(label)):
                    continue
                color = ImageColor.getcolor(colors.get(label, "#F5F5F5"), "RGBA")
                for dx in range(block_size):
                    for dy in range(block_size):
                        x = x_offset + i * block_size + dx
                        y = title_height + row * block_size + dy
                        if x < width and y < height:
                            im.putpixel((x, y), color)
        x_offset += max_genes * block_size + col_gap * block_size
    return im

In [20]:
samples = pd.read_csv(
    "../dataset/demonstration_samples/known_segments/annotation_200.csv"
)
predictions_df = pd.read_csv(
    "../results/demonstration/prediction_known_segments_200_threshold_0.8.csv"
)

# rename values in the pcat column in samples
samples["pcat"] = samples["pcat"].replace(
    "DNA, RNA and nucleotide metabolism", "dna_rna_and_nucleotide_metabolism"
)
samples["pcat"] = samples["pcat"].replace("unknown function", "unknown_function")
samples["pcat"] = samples["pcat"].replace("head and packaging", "head_and_packaging")
samples["pcat"] = samples["pcat"].replace(
    "transcription regulation", "transcription_regulation"
)
samples["pcat"] = samples["pcat"].replace(
    "integration and excision", "integration_and_excision"
)
samples["pcat"] = samples["pcat"].replace(
    "moron, auxiliary metabolic gene and host takeover",
    "moron_auxiliary_metabolic_gene_and_host_takeover",
)
samples["pcat"] = samples["pcat"].replace("unknown_no_hit", "no_hit")

# merge
samples_pred = samples.merge(
    predictions_df[["id", "top_function"]], left_on="name", right_on="id", how="left"
)
samples_pred_filtered_true = samples_pred[
    ~samples_pred["pcat"].isin(["no_hit", "unknown_function"])
].copy()
samples_pred_filtered_pred = samples_pred[
    ~samples_pred["top_function"].isin(["no_hit", "unknown_function"])
].copy()


In [21]:
colors = {
    "lysis": "#f35f49",
    "tail": "#07e9a2",
    "connector": "#35d7ff",
    "dna_rna_and_nucleotide_metabolism": "#ffdf59",
    "head_and_packaging": "#3e83f6",
    "transcription_regulation": "#a861e3",
    "moron_auxiliary_metabolic_gene_and_host_takeover": "#ff59f5",
    "integration_and_excision": "#fea328",
    "other": "#838383",
    "unknown_function": "#313131",
    "no_hit": "#f5f5f5",
}

In [22]:
im_true_satellites = make_segment_heatmap(
    samples_pred_filtered_true,
    function_col="pcat",
    colors=colors,
    segment_types=["PICI", "CFPICI", "P4"],
    block_size=10,
    col_gap=3,
    title_height=30,
    font_path=None,  # Or provide a .ttf path for custom font
)
im_true_satellites.save(
    "../results/demonstration/known_segment_heatmaps/im_satellites_true.png"
)

im_true_phage = make_segment_heatmap(
    samples_pred_filtered_true,
    function_col="pcat",
    colors=colors,
    segment_types=["phage"],
    block_size=10,
    col_gap=3,
    title_height=30,
    font_path=None,
)
im_true_phage.save("../results/demonstration/known_segment_heatmaps/im_phage_true.png")

im_pred_satellites = make_segment_heatmap(
    samples_pred_filtered_pred,
    function_col="top_function",
    colors=colors,
    segment_types=["PICI", "CFPICI", "P4"],
    block_size=10,
    col_gap=3,
    title_height=30,
    font_path=None,
)
im_pred_satellites.save(
    "../results/demonstration/known_segment_heatmaps/im_satellites_pred.png"
)

im_pred_phage = make_segment_heatmap(
    samples_pred_filtered_pred,
    function_col="top_function",
    colors=colors,
    segment_types=["phage"],
    block_size=10,
    col_gap=3,
    title_height=30,
    font_path=None,
)
im_pred_phage.save("../results/demonstration/known_segment_heatmaps/im_phage_pred.png")

### check

In [32]:
cfpici = samples_pred[samples_pred["what"] == "CFPICI"]
cfpici_head_packaging = cfpici[cfpici["pcat"] == "head_and_packaging"]
print(len(cfpici_head_packaging))
cfpici_head_packaging_pred = cfpici_head_packaging[
    cfpici_head_packaging["top_function"] == "head_and_packaging"
]
print(len(cfpici_head_packaging_pred))

942
355


In [23]:
samples_pred["pcat"].value_counts()

pcat
no_hit                                              11096
unknown_function                                     4225
dna_rna_and_nucleotide_metabolism                    3338
head_and_packaging                                   2658
tail                                                 1516
transcription_regulation                             1267
integration_and_excision                             1068
connector                                             766
other                                                 761
moron_auxiliary_metabolic_gene_and_host_takeover      481
lysis                                                 409
Name: count, dtype: int64

In [24]:
samples_pred["top_function"].value_counts()

top_function
no_hit                                              12793
unknown_function                                     4022
dna_rna_and_nucleotide_metabolism                    3042
head_and_packaging                                   1543
tail                                                 1398
transcription_regulation                             1164
moron_auxiliary_metabolic_gene_and_host_takeover      966
integration_and_excision                              876
connector                                             711
other                                                 606
lysis                                                 464
Name: count, dtype: int64

In [26]:
set(samples_pred["pcat"].unique()) == set(predictions_df["top_function"].unique())

True

### legend

In [27]:
# import matplotlib.pyplot as plt


# def plot_legend_matplotlib(
#     colors, square_size=0.5, text_size=16, row_gap=0.2, out_path=None, dpi=150
# ):
#     n = len(colors)
#     fig_height = n * (square_size + row_gap)
#     fig, ax = plt.subplots(figsize=(5, fig_height))
#     for i, (func, color) in enumerate(colors.items()):
#         y = n - i - 1  # So the first color is at the top
#         label = func.replace("_", " ")
#         # Draw square
#         ax.add_patch(
#             plt.Rectangle((0, y), square_size, square_size, color=color, ec="black")
#         )
#         # Draw text with space to the right of the square
#         ax.text(
#             square_size + 0.2,
#             y + square_size / 2,
#             label,
#             va="center",
#             fontsize=text_size,
#         )
#     ax.set_xlim(0, 4)
#     ax.set_ylim(0, n)
#     ax.axis("off")
#     plt.tight_layout()
#     if out_path:
#         plt.savefig(out_path, bbox_inches="tight", dpi=dpi)
#     plt.show()


# # Usage:
# plot_legend_matplotlib(
#     colors,
#     square_size=0.4,
#     text_size=15,
#     row_gap=0.15,
#     out_path="../results/demonstration/known_segment_heatmaps/legend.png",
# )