# Import necessary packages

In [None]:
import os
import pandas as pd 
import numpy as np
import yaml
import pickle as pkl
from collections import Counter

from scipy.stats import pearsonr, ttest_ind 

import plotly.graph_objects as go
from plotly.subplots import make_subplots

import matplotlib.pyplot as plt
import seaborn as sns

# Load configuration and data 

In [None]:
with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)
cf_colors = config["cofactor_color"]
class_color = config["ligand_color"]
king_colors = config["kingdom_color"]

species_map = {species["id"]: species["kingdom"] for species in config["species"]}
species_colors = {species["id"]: species["color"] for species in config["species"]}

SPECIES = [species["id"] for species in config["species"]]
PLOT_DIR = config["paths"]["plots"]
KNOWN_POCKETS_PATH = config["paths"]["known_pockets"]
POCKET_PATH = config["paths"]["pockets"]
PROBIS_PATH = config["paths"]["probis"]
RESULTS_PATH = config["paths"]["results"]

In [None]:
bs_all = []
dom_all = {}
for species in SPECIES:
    bind_dict = pkl.load(open(f"{KNOWN_POCKETS_PATH}/{species}_pockets.pkl", "rb")) 
    dom_dict = pkl.load(open(f"{POCKET_PATH}/{species}/dom_dict.pkl", "rb"))
    dom_all.update(dom_dict)

    for prot, pockets in list(bind_dict.items()):
        for pocket in pockets: 
            bs_all.append(pocket)
    
df = pd.DataFrame.from_records(bs_all)
df["prot"] = [pocket_id.split("_")[0][3:] for pocket_id in df["pocket_id"]]
df["color_species"] = [species_colors.get(org_id, "Other species") for org_id in df["species"]]
df["domain"] = df["pocket_id"].map(dom_all)
df["num_dom"] = [len(dom["domains"]) for dom in df["domain"]]
df["kingdom"] = df["species"].map(species_map)
df.head()

In [None]:
df["num_res"] = [len(res) for res in df["res"]]
df["num_res"].median()

# pLDDT and PAE per compound class

In [None]:
plddt_med = {}
for compound_class in set(df["lig_class"]):
    tmp = df[df["lig_class"] == compound_class]
    med = tmp["plddt"].median()
    plddt_med[compound_class] = med
plddt_med = {k: v for k, v in sorted(plddt_med.items(), key=lambda item: item[1], reverse=True)}

fig = go.Figure()
for compound_class in plddt_med.keys():
    fig.add_trace(go.Violin(x=df["lig_class"][df["lig_class"] == compound_class],
                            y=df["plddt"][df["lig_class"] == compound_class],
                            name=compound_class,
                            box_visible=True,
                            line_color="black",
                            legendgroup="group1",
                            showlegend=True,
                            fillcolor = class_color.get(compound_class)
                           ))

fig.update_layout(width=1200, height=700,
                  plot_bgcolor='white',
                  font=dict(size=20, color="black"),
                  yaxis = dict(title="plddt"),
                  yaxis2 = dict(title="PAE"),
                 )
fig.update_xaxes(linecolor='grey',
                 tickangle=45
                )
fig.update_yaxes(linecolor='grey')

In [None]:
pae_med = {}
for compound_class in set(df["lig_class"]):
    tmp = df[df["lig_class"] == compound_class]
    med = tmp["pae"].median()
    pae_med[compound_class] = med
pae_med = {k: v for k, v in sorted(pae_med.items(), key=lambda item: item[1])}

fig = go.Figure()
for compound_class in pae_med.keys():
    fig.add_trace(go.Violin(x=df["lig_class"][df["lig_class"] == compound_class],
                            y=df["pae"][df["lig_class"] == compound_class],
                            name=compound_class,
                            box_visible=True,
                            legendgroup="group2",
                            showlegend=True,
                            line_color="black",
                            fillcolor = class_color.get(compound_class),
                           ))

fig.update_layout(width=1200, height=700,
                  plot_bgcolor='white',
                  font=dict(size=20, color="black"),
                  yaxis = dict(title="PAE"),
                 )
fig.update_xaxes(linecolor='grey',
                 tickangle=45
                )
fig.update_yaxes(linecolor='grey')

# Domain information 

In [None]:
dom_of_interest = 'Protein kinase domain'
for species in SPECIES:
    clust_path = os.path.join(PROBIS_PATH, f"{species}/{species}_scores.csv_discretized_TRUE_rewired_FALSE_randomShuffle_FALSE_t0.1_r0.01_membership_vec.txt")
    mem_df = pd.read_csv(clust_path, sep=" ", header=None, names=["pocket_id", "community"])
    mem_df["pocket_id"] = [pocket_id.replace(".", "-") for pocket_id in mem_df["pocket_id"]]
    grouped_mem = mem_df.groupby("community")
    clust_count = Counter(mem_df["community"])
    i = 0
    for com_id, size in clust_count.most_common(2):
        is_dom = 0
        community = grouped_mem.get_group(com_id)
        for pocket_id in community["pocket_id"]:
            domains = dom_all.get(pocket_id)["domains"]
            for dom in domains:
                if dom == dom_of_interest:
                    is_dom += 1
        print(f"Percent of {dom_of_interest} pockets in community {i} (n={len(community)}) of {species}: {round(is_dom/len(community)*100, 2)} %")
        i += 1

# Known pockets 

In [None]:
known_pockets = pkl.load(open(f"{KNOWN_POCKETS_PATH}/known_pockets_up.pkl", "rb"))
new = {}
for key, val in known_pockets.items(): 
    if val:
        for idx, pocket in enumerate(val):
            new[f"{key}_{idx}"] = pocket

known_pockets = pd.DataFrame.from_dict(new, orient="index")
known_pockets["num_res"] = [len(res) for res in known_pockets["res"]]
known_pockets.head()

In [None]:
med = known_pockets["num_res"].median()
print(f"Median #res for known pockets: {med}")

In [None]:
counter = Counter(known_pockets["class_name"])
num_dict = {lig_class: f"{lig_class} ({num})" for lig_class,num in counter.items()}
class_color2 = {num_dict.get(lig_class): color for lig_class, color in class_color.items() if lig_class in num_dict}

In [None]:
fig = make_subplots(rows=1, cols=2, specs=[[{"type": "pie"}, {"type": "pie"}]], subplot_titles=["<b>Found</b>", "<b>Not found</b>"])

tmp = known_pockets[(known_pockets["found"] == True) & (known_pockets["prob"] >= 0.5)]
count_dict = Counter(tmp["class_name"])
count_df = pd.DataFrame(list(zip(count_dict.keys(), count_dict.values())), columns = ["lig_class", "count"])
count_df["colors"] = count_df["lig_class"].map(class_color)
count_df["legend"] = count_df["lig_class"].map(num_dict)
labels = [lig_class for lig_class in class_color2.keys()]
vals = [count_dict.get(lig_class, 0) for lig_class in class_color.keys() if lig_class in num_dict]
fig.add_trace(go.Pie(labels=labels, values=vals), row=1, col =1)

tmp = known_pockets[(known_pockets["found"] == False) | (known_pockets["prob"] < 0.5)]
count_dict = Counter(tmp["class_name"])
labels = [lig_class for lig_class in class_color2.keys()]
vals = [count_dict.get(lig_class, 0) for lig_class in class_color.keys() if lig_class in num_dict]
fig.add_trace(go.Pie(labels=labels,
                     values=vals,
                     marker=dict(colors=list(class_color2.values()))), row=1, col =2)

fig.update_layout(width = 900, height=350,
                 font=dict(size=12, color="black", family="Arial"),
                 margin=dict(l=20, r=20, t=20, b=10))
fig.update_traces(texttemplate='%{percent:.2%}')
fig.write_image(f"{PLOT_DIR}/ligands_found_pie.png", scale=1, width=900, height=350)
fig.show()

# FoldSeek cluster vs communities

In [None]:
cluster_df = pd.read_csv(f"{PROBIS_PATH}/clusterStats_discretized_TRUE_rewired_FALSE_randomShuffle_FALSE_t0.1_r0.01.txt", delimiter=" ")
cluster_df["species"] = [species.split("_")[0] for species in cluster_df["species"]]
cluster_df["n_pockets"] = [org_df.at[spec, "n_pockets"] for spec in cluster_df["species"]]
cluster_df["n_unique_bs"] = cluster_df["n_communities"] + cluster_df["n_singletons"]
cluster_df["n_cluster_n_fs"] = cluster_df["n_unique_bs"] / cluster_df["n_fs_clust"]
cluster_df["n_singletons_n_fs"] = cluster_df["n_singletons"] / cluster_df["n_fs_clust"]
cluster_df["n_real_n_fs"] = cluster_df["n_communities"]/cluster_df["n_fs_clust"] 
cluster_df["log(n_fs_cluster)"] = np.log(cluster_df["n_fs_clust"].to_numpy())
cluster_df["log(n_unique_bs)"] = np.log(cluster_df["n_unique_bs"].to_numpy())
cluster_df["log(n_communities)"] = np.log(cluster_df["n_communities"].to_numpy())
cluster_df["n_singletons/n_pockets"] = cluster_df["n_singletons"] / cluster_df["n_pockets"]
cluster_df.index = cluster_df["species"]
cluster_df

In [None]:
for species in cluster_df["species"]: 
    n_comm = cluster_df.at[species, "n_communities"]
    n_single = cluster_df.at[species, "n_singletons"]
    single_frac = cluster_df.at[species, "n_singletons/n_pockets"]
    print(f"{species}: {n_comm} communities, {n_single} singletons, {single_frac:.2f} fraction of singletons")

In [None]:
fig_width = 7.5 
fig_height = 3.3
dpi=300
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(fig_width, fig_height), dpi=dpi)
# Plot scatter plots for n_fs clusters vs. n_communities + n_singletons and n_communities
sns.scatterplot(data=cluster_df, x="n_fs_clust", y="n_cluster_n_fs", hue="species", 
                palette=species_colors, ax=ax1, s=50, legend=False)
sns.scatterplot(data=cluster_df, x="n_fs_clust", y="n_real_n_fs", hue="species",
                palette=species_colors, ax=ax2, s=50, legend = True)

# Customizing plots
ax1.set_facecolor("white")
ax1.set_xlabel("$N_{FS cluster}$", fontsize=10)
ax1.set_ylabel("$N_{Pocket cluster}$/$N_{FS cluster}$", fontsize=10)
ax1.set_ylim(0, 0.5)
ax1.set_title("A", loc='left', fontsize=16, fontweight='bold')
ax1.tick_params(axis='both', which='major', labelsize=10)

ax2.set_facecolor("white")
ax2.set_xlabel("$N_{FS cluster}$", fontsize=10)
ax2.set_ylabel("$N_{Communities}$/$N_{FS cluster}$", fontsize=10)
ax2.set_title("B", loc='left', fontsize=16, fontweight='bold')
ax2.set_ylim(0, 0.1)
ax2.tick_params(axis='both', which='major', labelsize=10)
ax2.legend(fontsize=10, title="Species", bbox_to_anchor=(1.05, 1.0), loc='upper left')

# Saving and showing the plot
plt.tight_layout()
plt.savefig(f"{PLOT_DIR}/cluster_fs_pocket_probis_comm.tif", dpi=300)


# Line of best fit (slope)
a, b = np.polyfit(cluster_df["log(N_FS_cluster)"], cluster_df["log(N_uniqe_bs)"], 1)
pearson_r, p_val = pearsonr(cluster_df["log(N_FS_cluster)"], cluster_df["log(N_uniqe_bs)"])
print(f"Log-log linear model for #FS cluster and # unique BS: y = {round(a, 3)} * x + {round(b, 3)}, Pearson r: {round(pearson_r, 3)} with p-value {round(p_val, 3)}.")

# Line of best fit (slope)
a, b = np.polyfit(cluster_df["log(N_FS_cluster)"], cluster_df["log(N_communities)"], 1)
pearson_r, p_val = pearsonr(cluster_df["log(N_FS_cluster)"], cluster_df["log(N_communities)"])
print(f"Log-log linear model for #FS cluster and #communities: y = {round(a, 3)} * x + {round(b, 3)}, Pearson r: {round(pearson_r, 3)} with p-value {round(p_val, 3)}.")

In [None]:
com_len_all = pd.DataFrame(columns=["Species", "Cluster size", "# Cluster"])
for species in SPECIES:
    clust_path = os.path.join(PROBIS_PATH, f"{species}/{species}_scores.csv_discretized_TRUE_rewired_FALSE_randomShuffle_FALSE_t0.1_r0.01_membership_vec.txt")
    mem_df = pd.read_csv(clust_path, sep=" ", header=None, names=["pocket_id", "community"])
    mem_df["pocket_id"] = [pocket_id.replace(".", "-") for pocket_id in mem_df["pocket_id"]]
    grouped_mem = mem_df.groupby("community")
    com_len = [len(group) for name, group in grouped_mem]
    com_len = list(Counter(com_len).items())
    com_len_df = pd.DataFrame(com_len, columns=["Cluster size", "# Cluster"]) 
    com_len_df["Species"] = species
    com_len_all = pd.concat([com_len_all, com_len_df], ignore_index=True)
    
com_len_all["Cluster size"] = com_len_all["Cluster size"].astype(int)  
com_len_all["log(Cluster size)"] = np.log10(com_len_all["Cluster size"].to_list())
com_len_all["log(# Cluster)"] = np.log10(com_len_all["# Cluster"].to_list())
 
# Create the line plot
dpi = 300
fig_width = 4
fig_height = 4
 
fig = plt.figure(figsize=(fig_width, fig_height), dpi=dpi)
sns.lineplot(
    data=com_len_all,
    x="log(Cluster size)",
    y="log(# Cluster)",
    hue="Species",
    palette=species_colors,
    marker="o",
    )
 
# Customize appearance
plt.ylabel("log($N_{\mathrm{Clusters}}$)", fontsize=10)
plt.xlabel("log(Cluster size)", fontsize=10)
ax = plt.gca()
plt.grid(False)
#Adjust legend
ax.legend(title="Species")
plt.savefig(os.path.join(PLOT_DIR, "cluster_sizes.tif"), bbox_inches="tight", dpi=dpi)

# Singletons vs. Non-singletons 

In [None]:

def get_cohen_d(x, y):
    """
    Calculates Cohen's d effect size for two independent samples.

    Args:
        x: A list or NumPy array containing the first sample.
        y: A list or NumPy array containing the second sample.

    Returns:
        The calculated Cohen's d value.
    """

    nx = len(x)
    ny = len(y)

    dof = nx + ny - 2

    pooled_std = np.sqrt(((nx-1)*np.std(x, ddof=1)**2 + (ny-1)*np.std(y, ddof=1)**2) / dof)
    std_x = np.std(x, ddof=1)
    std_y = np.std(y, ddof=1)
    return (np.mean(x) - np.mean(y)) / pooled_std, std_x, std_y

# Load clustering results and assign singletons
singleton_dict = {}
for species in SPECIES:
    clust_path = os.path.join(PROBIS_PATH, f"{species}/{species}_scores.csv_discretized_TRUE_rewired_FALSE_randomShuffle_FALSE_t0.1_r0.01_membership_vec.txt")
    mem_df = pd.read_csv(clust_path, sep=" ", header=None, names=["pocket_id", "community"])
    mem_df["pocket_id"] = [pocket_id.replace(".", "-") for pocket_id in mem_df["pocket_id"]]
    grouped_mem = mem_df.groupby("community")
    for com_id, group in grouped_mem:
        size = len(group)
        if size == 1:
            for pocket_id in group["pocket_id"]:
                singleton_dict[pocket_id] = "Singletons"
        else:   
            for pocket_id in group["pocket_id"]:
                singleton_dict[pocket_id] = "Non-singletons"
df["is_singleton"] = df["pocket_id"].map(singleton_dict)
df["num_res"] = [len(res) for res in df["res"]]

singles_all = df[df["is_singleton"] == "Singletons"]
probis_all = df[df["is_singleton"] == "Non-singletons"]

prop_dict = {"hydro": "Hydrophobicity", "aro": "Aromaticity", "net_charge": "Net charge",
            "sasa": "SASA [\u212B\u00b2]","plddt": "pLDDT", "pae": "PAE [\u212B]",
            "prob": "Probability", "num_res": "# Residues",
            }
titles = ["A", "B", "C", "D", "E", "F", "G", "H"]
  
dpi = 300
fig_width = 6.2
fig_height = 8.75

n_rows = 4
n_cols = 2
row = 0
col = 0

fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(fig_width, fig_height), dpi=dpi)

ttest_results = []
for idx, (prop, label) in enumerate(prop_dict.items()):
    ax = axes[row, col]
    # Create violin plot for each property
    sns.violinplot(x='is_singleton', y=prop, data=df, hue='is_singleton',
               palette={'Non-singletons': 'lightseagreen', 'Singletons': 'darkorange'},    
               inner='box', linewidth=1.5, ax=ax)
    ax.set_title(titles[idx], fontsize=12, weight='bold', loc='left')
    ax.set_ylabel(label, fontsize=10)
    ax.set_xlabel('')
    ax.tick_params("x", rotation=0)

    # Perform t-test 
    t_stat, p_val = ttest_ind(list(probis_all[prop]), list(singles_all[prop]))
    cohen_d, std_single, std_non_single = get_cohen_d(singles_all[prop], probis_all[prop])
    ttest_results.append([prop, t_stat, p_val, probis_all[prop].median(), singles_all[prop].median(), round(std_single, 3), round(std_non_single, 3), round(cohen_d, 3)])

    col += 1
    if col == n_cols:
        col = 0
        row += 1

plt.tight_layout()
plt.savefig(f"{PLOT_DIR}/single_vs_non_single.tif",dpi=dpi, format="tif")

# Save t-test results and DataFrame
ttest_results = pd.DataFrame(ttest_results, columns=["Property", "t_stat", "p_val", "non_single_med", "single_med", "std_single", "std_non_single", "cohen_d"])
ttest_results.to_csv(f"{RESULTS_PATH}/ttest_single_vs_non_singles.csv")
df.to_csv(f"{RESULTS_PATH}/pocket_df.tsv.gz", index=False, sep="\t")