# Test Functional Coherence of GO Terms in Clusters

In [47]:
# @title Download and Install PHILHARMONIC

try:
    import importlib.util

    importlib.util.find_spec("google.colab")
    IN_COLAB = True
except ModuleNotFoundError:
    IN_COLAB = False

if IN_COLAB:
    !pip install philharmonic
    !curl https://current.geneontology.org/ontology/go.obo -o go.obo
    !curl https://current.geneontology.org/ontology/subsets/goslim_generic.obo -o goslim_generic.obo

from itertools import combinations
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from loguru import logger
from scipy.spatial.distance import jaccard
from scipy.stats import ttest_ind
from tqdm import tqdm

from philharmonic.utils import (
    add_GO_function,
    load_cluster_json,
    parse_GO_database,
    parse_GO_map,
)

In [48]:
# @title Loading PHILHARMONIC results
# @markdown Upload your zipped results file using the "Files" tab to the left.
# @markdown Set the `RUN_NAME` variable to the name of the run from your config file you want to analyze.
# @markdown If you are running this locally, edit the `RESULTS_DIR` variable to point to your results directory, and point to a locally downloaded go.obo file. Otherwise, you can leave it as is.

ZIP_FILE_NAME = ""  # @param {"type":"string","placeholder":"philharmonic_results.zip"}
RUN_NAME = ""  # @param {"type":"string","placeholder":"philharmonic_run"}
RESULTS_DIR = "."  # @param {"type":"string","placeholder":"."}
USE_GO_SLIM = False  # @param {"type":"boolean"}

if IN_COLAB:
    !unzip -o $ZIP_FILE_NAME
    results_dir = Path(".")
    GO_OBO_PATH = "go.obo"
    GO_SLIM_PATH = "goslim_generic.obo"
else:
    results_dir = Path(RESULTS_DIR)
    GO_OBO_PATH = results_dir / "go.obo"
    GO_SLIM_PATH = results_dir / "goslim_generic.obo"

CLUSTER_FILE_PATH = results_dir / f"{RUN_NAME}_clusters.json"
CLUSTER_FUNC_PATH = results_dir / f"{RUN_NAME}_cluster_graph_functions.tsv"
GO_MAP_PATH = results_dir / f"{RUN_NAME}_GO_map.csv"
IMG_DIR = results_dir / "img"
!mkdir -p $IMG_DIR

clusters = load_cluster_json(CLUSTER_FILE_PATH)
go_map = parse_GO_map(GO_MAP_PATH)

if USE_GO_SLIM:
    go_db = parse_GO_database(GO_SLIM_PATH)
else:
    go_db = parse_GO_database(GO_OBO_PATH)

for clust in clusters.values():
    clust["GO_terms"] = add_GO_function(clust, go_map, go_db=go_db)
    for gt in clust["GO_terms"].keys():
        assert gt in go_db.keys()

---

## Native Cluster Coherence

In [None]:
go_assigned = set()
for clust in tqdm(clusters.values()):
    for m in clust["members"]:
        go_assigned.update(go_map.get(m, []))
go_assigned = sorted(list(go_assigned.intersection(go_db.keys())))

logger.info(f"{len(go_assigned)} GO terms assigned")

In [None]:
proteins_in_clusters = []
for clust in tqdm(clusters.values()):
    proteins_in_clusters.extend(clust["members"])
    proteins_in_clusters.extend(list(clust["recipe"]["degree"]["0.75"]))
logger.info(f"{len(proteins_in_clusters)} proteins in clusters")

In [None]:
def protein_GO_bit_vector(
    protein_id, go_map, full_go_list, id_col="seq", go_col="GO_ids"
):
    go_bv = np.zeros(len(full_go_list))
    prot_go = go_map.get(protein_id)
    if prot_go is not None:
        for gid in prot_go:
            if gid in full_go_list:
                go_bv[full_go_list.index(gid)] = 1
    return go_bv


# Compute GO bit vectors for each protein
protein_GO_bvs = {}
for pid in tqdm(proteins_in_clusters):
    protein_GO_bvs[pid] = protein_GO_bit_vector(pid, go_map, go_assigned)

In [None]:
cluster_jaccards = {}

for k, clust in tqdm(clusters.items()):
    cjaccard = []
    for p1, p2 in combinations(
        clust["members"] + list(clust["recipe"]["degree"]["0.75"]), 2
    ):
        jc = 1 - jaccard(protein_GO_bvs[p1], protein_GO_bvs[p2])
        cjaccard.append(jc)
    cluster_jaccards[k] = np.array(cjaccard)

---

## Shuffled Cluster Coherence

In [53]:
rng = np.random.default_rng(seed=42)
shuffled_bit_vectors = {
    k: v
    for k, v in zip(
        protein_GO_bvs.keys(), rng.permutation(list(protein_GO_bvs.values()))
    )
}

In [None]:
cluster_jaccards_perm = {}

for k, clust in tqdm(clusters.items()):
    cjaccard = []
    for p1, p2 in combinations(
        clust["members"] + list(clust["recipe"]["degree"]["0.75"]), 2
    ):
        jc = 1 - jaccard(shuffled_bit_vectors[p1], shuffled_bit_vectors[p2])
        cjaccard.append(jc)
    cluster_jaccards_perm[k] = np.array(cjaccard)

--- 


## Compare shuffled and original coherence

In [None]:
phil_mean = [np.mean(i) for i in cluster_jaccards.values()]
permute_mean = [np.mean(i) for i in cluster_jaccards_perm.values()]
coherence_df = (
    pd.DataFrame(
        {
            "cluster": list(cluster_jaccards.keys()),
            "PHILHARMONIC": phil_mean,
            "Random Clustering": permute_mean,
        }
    )
    .melt("cluster")
    .rename(
        {"variable": "Clustering Method", "value": "Mean Jaccard Similarity"}, axis=1
    )
)
coherence_df.head()

In [56]:
from scipy import interpolate


def find_histogram_intersection(data1, data2, bins=50, return_curves=False):
    """
    Find the intersection point of two histograms.

    Parameters:
    data1: array-like, first dataset
    data2: array-like, second dataset
    bins: int, number of bins for histogram
    return_curves: bool, whether to return the interpolated curves

    Returns:
    float: x-value where histograms intersect
    """
    # Create histograms
    hist1, bin_edges = np.histogram(data1, bins=bins, density=True)
    hist2, _ = np.histogram(data2, bins=bin_edges, density=True)

    # Get bin centers
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    # Create interpolation functions
    f1 = interpolate.interp1d(
        bin_centers, hist1, kind="linear", fill_value="extrapolate"
    )
    f2 = interpolate.interp1d(
        bin_centers, hist2, kind="linear", fill_value="extrapolate"
    )

    # Find intersection points
    x_range = np.linspace(bin_centers[0], bin_centers[-1], 1000)
    y1 = f1(x_range)
    y2 = f2(x_range)

    # Find where the difference changes sign
    diff = y1 - y2
    intersection_indices = np.where(np.diff(np.signbit(diff)))[0]

    if len(intersection_indices) == 0:
        raise ValueError("No intersection found")

    # Get x values at intersections
    intersection_points = x_range[intersection_indices]

    if return_curves:
        return max(intersection_points), f1, f2, x_range

    return max(intersection_points)

In [None]:
sns.set_palette("colorblind")
sns.set_theme(style="white", palette="pastel", font_scale=1)

# Create figure and gridspec
fig = plt.figure(figsize=(10, 8))
gs = fig.add_gridspec(2, 1, height_ratios=[3, 2], hspace=0.05)

# Create top subplot for histogram
ax0 = fig.add_subplot(gs[0])
ax1 = fig.add_subplot(gs[1])

sns.histplot(
    data=coherence_df,
    x="Mean Jaccard Similarity",
    hue="Clustering Method",
    alpha=0.3,
    bins=np.arange(0, 1.05, 0.01),
    kde=True,
    palette=["blue", "red"],
    common_norm=False,
    ec="white",
    ax=ax0,
)

sns.boxplot(
    data=coherence_df,
    x="Mean Jaccard Similarity",
    hue="Clustering Method",
    palette=["blue", "red"],
    ax=ax1,
)

ax0.set_xlabel("")  # Remove x-label from top plot
ax0.set_xticklabels([])  # Remove x-ticks from top plot

ax0.set_xlim(ax1.get_xlim())  # Align the x-axis of both subplots
ax1.get_legend().remove()  # Remove legend from bottom plot

tstat, p = ttest_ind(phil_mean, permute_mean, alternative="greater")
go_type = "GO Slim" if USE_GO_SLIM else "All GO Terms"
ax0.set_title(f"Functional Enrichment: PHILHARMONIC, {go_type} (p={p:.3})")

intersection_x = find_histogram_intersection(phil_mean, permute_mean, bins=50)
logger.info(f"Distributions cross at {intersection_x:.3f}")
ax0.axvline(intersection_x, linestyle="--", color="black")

# Show the plot
sns.despine()
finame = (
    f"{RUN_NAME}_function_enrichment_GOfull.svg"
    if not USE_GO_SLIM
    else f"{RUN_NAME}_function_enrichment_GOslim.svg"
)
plt.savefig(IMG_DIR / finame, bbox_inches="tight", dpi=300)
plt.show()

---
## Display Coherence by GO Slim Function

In [None]:
cluster_top_terms = pd.read_csv(CLUSTER_FUNC_PATH, sep="\t").set_index("key")

clens = [len(clust["members"]) for clust in clusters.values()]
clen_bin = pd.cut(clens, bins=[0, 5, 10, 15, 20, 25, 30])
cjacc = [np.mean(cluster_jaccards[k]) for k in clusters]
cfunc = [cluster_top_terms.loc[int(k), "go_fn"] for k in clusters]
func_df = pd.DataFrame(
    {
        "Function": cfunc,
        "Mean Jaccard Similarity": cjacc,
        "Cluster Size": clens,
    }
)
func_df["Function"] = func_df["Function"].str.lower()
func_df = func_df.sort_values("Function")
func_df["Function"].fillna("No dominant function", inplace=True)

logger.info(f"Function dataframe shape: {func_df.shape}")

In [59]:
import xml.etree.ElementTree as ET

# Define the path to the XML file
if IN_COLAB:
    xml_file_path = "/content/philharmonic/assets/philharmonic_styles.xml"
else:
    xml_file_path = "../assets/philharmonic_styles.xml"

# Parse the XML file
tree = ET.parse(xml_file_path)
root = tree.getroot()

# Extract colors from the "philharmonic" style
philharmonic_colors = {}
for style in root.findall(".//discreteMappingEntry"):
    name = style.get("attributeValue").lower()
    value = style.get("value")
    philharmonic_colors[name] = value

for f in func_df["Function"].unique():
    if f not in philharmonic_colors.keys():
        philharmonic_colors[f] = "#aaaaaa"

In [60]:
grouped_order = (
    func_df.loc[:, ["Function", "Mean Jaccard Similarity"]]
    .groupby(["Function"])
    .median()
    .sort_values(by="Mean Jaccard Similarity")
)

In [None]:
sns.set_theme(style="whitegrid", font_scale=1.3)
fig, ax = plt.subplots(figsize=(30, 10))

sns.stripplot(
    data=func_df,
    x="Mean Jaccard Similarity",
    y="Function",
    ax=ax,
    hue="Function",
    palette=philharmonic_colors,
    order=grouped_order.index,
    s=10,
    edgecolor="grey",
    linewidth=1,
)
sns.boxplot(
    data=func_df,
    x="Mean Jaccard Similarity",
    y="Function",
    ax=ax,
    hue="Function",
    palette={k: "#eeeeee" for k in func_df["Function"].unique()},
    order=grouped_order.index,
)
plt.xlabel("Cluster Coherence")
plt.ylabel("Cluster Function")

ax.axvline(intersection_x, linestyle="--", color="black")

sns.despine()
finame = (
    f"{RUN_NAME}_enrichment_by_function_swarmbox_GOfull.svg"
    if not USE_GO_SLIM
    else f"{RUN_NAME}_by_function_swarmbox_GOslim.svg"
)
plt.savefig(IMG_DIR / finame, bbox_inches="tight", dpi=300)
plt.show()