In [6]:
import pandas as pd
import numpy as np
from PIL import Image, ImageColor
import matplotlib.pyplot as plt
from tqdm import tqdm
import os

In [7]:
colors = {
    "lysis": "#f35f49",
    "tail": "#07e9a2",
    "connector": "#35d7ff",
    "dna_rna_and_nucleotide_metabolism": "#ffdf59",
    "head_and_packaging": "#3e83f6",
    "other": "#838383",
    "transcription_regulation": "#a861e3",
    "moron_auxiliary_metabolic_gene_and_host_takeover": "#ff59f5",
    "unknown_function": "#313131",  # maybe should be white or not to be shown
    "integration_and_excision": "#fea328",
    "no_hit": "#F5F5F5",  # maybe not to be shown
}

In [8]:
def create_genome_heatmap(gff_df, predictions_df, colors, block_size=10):
    """
    Create a single heatmap image for all contigs, each contig is one row.
    """
    contigs = gff_df["contig"].unique()
    # Prepare data for all contigs
    contig_gene_lists = []
    for contig_id in contigs:
        contig_data = gff_df[gff_df["contig"] == contig_id].copy()
        contig_data = contig_data.sort_values("start")
        contig_data = contig_data.merge(
            predictions_df[["id", "top_function"]],
            left_on="protein_id",
            right_on="id",
            how="left",
        )
        contig_data = contig_data[
            ~contig_data["top_function"].isin(["no_hit", "unknown_function"])
        ]
        contig_gene_lists.append(contig_data["top_function"].tolist())
    # Determine image size
    max_genes = max(len(genes) for genes in contig_gene_lists)
    width = max_genes * block_size
    height = len(contigs) * block_size
    im = Image.new("RGBA", (width, height), "white")
    # Draw each contig as a row
    for row, gene_labels in enumerate(contig_gene_lists):
        for i, label in enumerate(gene_labels):
            color = ImageColor.getcolor(colors.get(label, colors["no_hit"]), "RGBA")
            for dx in range(block_size):
                for dy in range(block_size):
                    x = i * block_size + dx
                    y = row * block_size + dy
                    if x < width and y < height:
                        im.putpixel((x, y), color)
    return im

In [9]:
gff_df = pd.read_csv("../dataset/demonstration_samples/GCF_000175755.1/gff_df.csv")
predictions_df = pd.read_csv("../results/demonstration/prediction_GCF_000175755.1.csv")

In [10]:
predictions_df["id"] = predictions_df["id"].str.split().str[0]

In [11]:
im = create_genome_heatmap(gff_df, predictions_df, colors, block_size=10)
im.save("../results/demonstration/FRIK2000_contig_heatmaps/genome_heatmap.png")