# Load necessary packages

In [None]:
import pandas as pd 
import numpy as np
import yaml
from math import log
import plotly.express as px
import plotly.graph_objects as go
import pickle as pkl
from plotly.subplots import make_subplots
from collections import Counter
from sklearn.metrics import pairwise_distances
from scipy.stats import spearmanr, fisher_exact
import ast
import matplotlib.pyplot as plt
import seaborn as sns

# Load configuration and data

In [None]:
with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

cf_colors = config["cofactor_color"]
class_color = config["ligand_color"]
king_colors = config["kingdom_color"]
species_map = {species["id"]: species["kingdom"] for species in config["species"]}
species_colors = {species["id"]: species["color"] for species in config["species"]}

SPECIES = [species["id"] for species in config["species"]]
BIND_DICT_PATH = config["paths"]["bind_dict"]
POCKET_PATH = config["paths"]["pockets"]
PLOT_DIR = config["paths"]["plots"]

In [None]:
def plot_dim_red(df, x_col, y_col, scatter_configs, out_file):
    n_rows = len(scatter_configs) // 2 + (len(scatter_configs) % 2 > 0)
    n_cols = 2
    dpi = 300
    fig_width = 6.8
    fig_height = 8.75
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), dpi=dpi)
    plt.subplots_adjust(hspace=0.3, wspace=0.3)

    row = 0 
    col = 0
    # Define the plots
    for i, (color_col, plot_label, cbar_title, cmap) in enumerate(scatter_configs):
        sc = sns.scatterplot(
            x=x_col, y=y_col,
            hue=color_col,
            palette=cmap,
            data=df,
            ax=axes[row, col],
            s=2,
            legend=False
        )
        axes[row, col].set_title(plot_label, fontsize=12, weight='bold', loc='left')
        axes[row, col].set_xlabel(x_col.split("_")[0], fontsize=10)
        axes[row, col].set_ylabel(y_col.split("_")[0], fontsize=10)
    
        # Add colorbar manually
        norm = plt.Normalize(df[color_col].min(), df[color_col].max())
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])
        cbar = fig.colorbar(sm, ax=axes[row, col], orientation='vertical', fraction=0.046, pad=0.04)
        cbar.ax.tick_params(labelsize=10)
        cbar.set_label(cbar_title, fontsize=10)

        col += 1
        if col >= n_cols:
            col = 0
            row += 1
    
    # Save the figure
    fig.tight_layout()
    plt.xticks(fontsize=10)
    plt.yticks(fontsize=10)
    plt.savefig(out_file, dpi=dpi, bbox_inches='tight')
    plt.close()
    return 0

In [None]:
bs_all = {}
dom_all = {}

for species in SPECIES:
    bind_dict = pkl.load(open(f"{BIND_DICT_PATH}/{species}_pockets.pkl", "rb")) 
    bs_all.update(bind_dict)

    dom_dict = pkl.load(open(f"{POCKET_PATH}/{species}/dom_dict.pkl", "rb"))
    dom_all.update(dom_dict)
bs_all = {pocket["pocket_id"]: pocket for pockets in bs_all.values() for pocket in pockets}  

In [None]:
df = pd.read_csv(f"output/dim_red_all.csv", keep_default_na=True)
df["res"] = [ast.literal_eval(pocket) for pocket in df["res"]]
df["seq"] = [ast.literal_eval(pocket) for pocket in df["seq"]]
df["#res"] = [len(pocket) for pocket in df["res"]]
df["domain"] = df["pocket_id"].map(dom_all)
df["num_dom"] = [len(dom["domains"]) for dom in df["domain"]]
df["domain"] = [str(dom["domains"]) for dom in df["domain"]]
df["kingdom"] = df["species"].map(species_map)
df["is_tm"] = [bs_all.get(pocket_id)["is_tm"] for pocket_id in df["pocket_id"]]
df.head()

# Dimensionality redudction plots

In [None]:
dim_red_types = [("pca", "PC1", "PC2"), ("ica", "IC1", "IC2"),
                 ("umap_n10", "UMAP1_n10", "UMAP2_n10"), ("umap_n20", "UMAP1_n20", "UMAP2_n20"),
                 ("umap_n50", "UMAP1_n50", "UMAP2_n50"), ("umap_n100", "UMAP1_n100", "UMAP2_n100"),
                 ("umap_n200", "UMAP1_n200", "UMAP2_n200"), ("tsne_p10", "tSNE1_p10", "tSNE2_p10"),
                 ("tsne_p20", "tSNE1_p20", "tSNE2_p20"), ("tsne_p30", "tSNE1_p30", "tSNE2_p30"),
                 ("tsne_p40", "tSNE1_p40", "tSNE2_p40"), ("tsne_p50", "tSNE1_p50", "tSNE2_p50"),
                 ("tsne_p100", "tSNE1_p100", "tSNE2_p100")]

scatter_configs = [
        ("hydro", "A","Hydrophobicity", "RdBu"),
        ("aro", "B", "Aromaticity", "viridis"),
        ("net_charge", "C","Net charge", "RdBu"),
        ("sasa", "D", "SASA [\u212B\u00b2]", "viridis"),
        ("plddt", "E", "pLDDT", "viridis"),
        ("pae", "F", "PAE [\u212B]", "viridis"),
        ("#res", "G", "# residues", "viridis"),
        ("prob", "H", "Probability", "viridis")]

for red_type, red1, red2 in dim_red_types:
    out_file = f"{PLOT_DIR}/{red_type}_desc.tif"
    plot_dim_red(df, red1, red2, scatter_configs, out_file)

In [None]:
red = ("tSNE1_p50", "tSNE2_p50")
data = {
    "x": df[red[0]],
    "y": df[red[1]],
    "category": df["kingdom"]
}

tsne_df = pd.DataFrame(data)

# Parameters for the grid
grid_size = (30, 30)
cell_th = 20
x_min, x_max = tsne_df["x"].min(), tsne_df["x"].max()
y_min, y_max = tsne_df["y"].min(), tsne_df["y"].max()
x_bins = np.linspace(x_min, x_max, grid_size[0] + 1)
y_bins = np.linspace(y_min, y_max, grid_size[1] + 1)

# Assign each point to a grid cell
def get_cell(x, y):
    x_idx = np.digitize(x, x_bins) - 1
    y_idx = np.digitize(y, y_bins) - 1
    return x_idx, y_idx

tsne_df["cell"] = tsne_df.apply(lambda row: get_cell(row["x"], row["y"]), axis=1)
tsne_df["color"] = tsne_df["category"].map(king_colors)
# Compute cell-specific entropy
entropy_dict = {}

for cell, group in tsne_df.groupby("cell"):
    total_points = len(group)
    category_counts = group["category"].value_counts(normalize=True)  # Fraction per category
    entropy = -sum(p * log(p) for p in category_counts if p > 0)
    entropy_dict[cell] = entropy

cell_count = Counter(tsne_df["cell"])
# Create a heatmap of entropy values
entropy_grid = np.full(grid_size, np.nan)
cell_grid = np.full(grid_size, np.nan)
for (x_idx, y_idx), entropy in entropy_dict.items():
    if 0 <= x_idx < grid_size[0] and 0 <= y_idx < grid_size[1]:
        entropy_grid[x_idx, y_idx] = entropy
    if 0 <= x_idx < grid_size[0] and 0 <= y_idx < grid_size[1] and cell_count[(x_idx, y_idx)] < cell_th:
        cell_grid[x_idx, y_idx] = 1

dpi = 300
fig_width_px = 2250
fig_height_px = 1150
figsize_in = (fig_width_px / dpi, fig_height_px / dpi)
# Set up plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize_in, dpi=dpi)
plt.subplots_adjust(wspace=0.2)

# Plot scatter plot with colors by kingdom
for king in king_colors.keys():
    filtered_df = df[df["kingdom"] == king]

    sc = sns.scatterplot(
            x=red[0], y=red[1],
            hue="kingdom",
            palette=king_colors,
            data=filtered_df,
            ax=ax1,
            s=2,
            )

ax1.set_title("A", loc='left', fontsize=16, weight='bold')
ax1.set_xlabel(red[0].split("_")[0], fontsize=10)
ax1.set_ylabel(red[1].split("_")[0], fontsize=10)
ax1.legend(loc='lower center', ncol=2, bbox_to_anchor=(0.5, -0.45), fontsize=10, frameon=False, markerscale=4)

# Plotting the entropy heatmap
x_centers = (x_bins[:-1] + x_bins[1:]) / 2
y_centers = (y_bins[:-1] + y_bins[1:]) / 2

entropy_img = ax2.imshow(
    entropy_grid.T,
    cmap="viridis",
    origin="lower",
    aspect="auto",
    extent=[x_bins[0], x_bins[-1], y_bins[0], y_bins[-1]],
    alpha=0.6
)

# Overlay cell heatmap (binary mask)
masked = np.ma.masked_where(cell_grid.T == 0, cell_grid.T)
cell_img = ax2.imshow(
    masked,
    cmap=plt.cm.gray,
    origin="lower",
    aspect="auto",
    extent=[x_bins[0], x_bins[-1], y_bins[0], y_bins[-1]],
    alpha=1
)

# Colorbar for entropy
cbar = fig.colorbar(entropy_img, ax=ax2, fraction=0.046, pad=0.04)
cbar.set_label("Entropy", fontsize=10)

ax2.set_title("B", loc='left', fontsize=16, weight='bold')
ax2.set_xlabel(red[0].split("_")[0], fontsize=10)
ax2.set_ylabel(red[1].split("_")[0], fontsize=10)

for ax in [ax1, ax2]:
    ax.tick_params(labelsize=10)
    ax.set_facecolor("white")

fig.tight_layout()
# Save figure
plt.savefig(f"{PLOT_DIR}/tsne_king_scatt_entro.tif", dpi=dpi, bbox_inches='tight', format="tif")
plt.show()

In [None]:

red = ("tSNE1_p50", "tSNE2_p50")
data = {
    "x": df[red[0]],
    "y": df[red[1]],
    "category": df["species"]
}

tsne_df = pd.DataFrame(data)

# Parameters for the grid
grid_size = (30, 30)  # 10x10 grid
cell_th = 20
x_min, x_max = tsne_df["x"].min(), tsne_df["x"].max()
y_min, y_max = tsne_df["y"].min(), tsne_df["y"].max()
x_bins = np.linspace(x_min, x_max, grid_size[0] + 1)
y_bins = np.linspace(y_min, y_max, grid_size[1] + 1)

# Assign each point to a grid cell
def get_cell(x, y):
    x_idx = np.digitize(x, x_bins) - 1
    y_idx = np.digitize(y, y_bins) - 1
    return x_idx, y_idx

tsne_df["cell"] = tsne_df.apply(lambda row: get_cell(row["x"], row["y"]), axis=1)
tsne_df["color"] = tsne_df["category"].map(species_colors)
# Compute cell-specific entropy
entropy_dict = {}

for cell, group in tsne_df.groupby("cell"):
    total_points = len(group)
    category_counts = group["category"].value_counts(normalize=True)  # Fraction per category
    entropy = -sum(p * log(p) for p in category_counts if p > 0)
    entropy_dict[cell] = entropy

cell_count = Counter(tsne_df["cell"])
# Create a heatmap of entropy values
entropy_grid = np.full(grid_size, np.nan)
cell_grid = np.full(grid_size, np.nan)
for (x_idx, y_idx), entropy in entropy_dict.items():
    if 0 <= x_idx < grid_size[0] and 0 <= y_idx < grid_size[1]:
        entropy_grid[x_idx, y_idx] = entropy
    if 0 <= x_idx < grid_size[0] and 0 <= y_idx < grid_size[1] and cell_count[(x_idx, y_idx)] < cell_th:
        cell_grid[x_idx, y_idx] = 1

dpi = 300
fig_width_px = 2250
fig_height_px = 850
figsize_in = (fig_width_px / dpi, fig_height_px / dpi)
# Set up plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize_in, dpi=dpi)
plt.subplots_adjust(wspace=0.2)

# Reversed organism order (may be useful for future extensions like annotations)
organisms = SPECIES.copy()
organisms.reverse()

for org in organisms:
    filtered_df = df[df["species"] == org]
    sc = sns.scatterplot(
            x=red[0], y=red[1],
            hue="species",
            palette=species_colors,
            data=filtered_df,
            ax=ax1,
            s=2,
            )

ax1.set_title("A", loc='left', fontsize=16, weight='bold')
ax1.set_xlabel(red[0].split("_")[0], fontsize=10)
ax1.set_ylabel(red[1].split("_")[0], fontsize=10)
ax1.legend(loc='upper right', ncol=1, fontsize=10, frameon=False, markerscale=4, bbox_to_anchor=(1.5, 1.05)) 

# Plotting the entropy heatmap
x_centers = (x_bins[:-1] + x_bins[1:]) / 2
y_centers = (y_bins[:-1] + y_bins[1:]) / 2

# Plot entropy heatmap
entropy_img = ax2.imshow(
    entropy_grid.T,
    cmap="viridis",
    origin="lower",
    aspect="auto",
    extent=[x_bins[0], x_bins[-1], y_bins[0], y_bins[-1]],
    alpha=0.6
)

# Overlay cell heatmap (binary mask)
masked = np.ma.masked_where(cell_grid.T == 0, cell_grid.T)
cell_img = ax2.imshow(
    masked,
    cmap=plt.cm.gray,
    origin="lower",
    aspect="auto",
    extent=[x_bins[0], x_bins[-1], y_bins[0], y_bins[-1]],
    alpha=1
)

# Colorbar for entropy
cbar = fig.colorbar(entropy_img, ax=ax2, fraction=0.046, pad=0.04)
cbar.set_label("Entropy", fontsize=12)

ax2.set_title("B", loc='left', fontsize=16, weight='bold')
ax2.set_xlabel(red[0].split("_")[0], fontsize=10)
ax2.set_ylabel(red[1].split("_")[0], fontsize=10)

# Global font settings
for ax in [ax1, ax2]:
    ax.tick_params(labelsize=10)
    ax.set_facecolor("white")

fig.tight_layout(pad=0.1)
# Save figure
plt.savefig(f"{PLOT_DIR}/tsne_org_scatt_entro.tif", dpi=dpi, bbox_inches='tight', format="tif")
plt.show()

In [None]:
red = ("tSNE1_p50", "tSNE2_p50")
dpi = 300
fig_width = 3.7
fig_height = 3
fig = plt.figure(figsize=(fig_width, fig_height), dpi=dpi)
ax = fig.add_subplot(111)
lig_classes = df["lig_class"].value_counts().sort_values(ascending=False)
df["lig_class_num"] = [f"{lig} ({lig_classes.get(lig, 0)})" for lig in df["lig_class"]]
legend_labels = []
for lig, count in lig_classes.items(): 
        tmp = df[df["lig_class"] == lig]
        sns.scatterplot(
                x=red[0], y=red[1],
                hue="lig_class",
                palette=class_color,
                data=tmp,
                s=2,
                )
        legend_labels.append(f"{lig} ({count})")
ax.legend(loc='upper right', ncol=1, fontsize=10, frameon=False, markerscale=4, bbox_to_anchor=(2, 1.1)) 
ax.set_xlabel(red[0].split("_")[0], fontsize=10)
ax.set_ylabel(red[1].split("_")[0], fontsize=10)

plt.savefig(f"{PLOT_DIR}/tsne_lig_scatt.tif", dpi=dpi, bbox_inches='tight', format="tif")

In [None]:
fig = px.scatter(df, x=red[0], y=red[1], color="species",color_discrete_map=species_colors,
                category_orders = {"species": list(species_colors.keys())},hover_data=["pocket_id", "species", "lig_class"])
fig.update_traces(marker=dict(size=4))
fig.update_layout(width=2000, height=1500,
                  font=dict(size=24, color="black"),
                  plot_bgcolor='white',
                  legend=go.layout.Legend(itemsizing='constant'),
                  margin=dict(l=20, r=20, t=20, b=20)
                 )
fig.update_xaxes(title = red[0].split("_")[0])
fig.update_yaxes(title = red[1].split("_")[0])
fig.write_html(f"{PLOT_DIR}/tSNE_species.html")

# Check unknown regions

In [None]:
df["low_prob"] = [True if (tsne1 < -75 and tsne1 > -200 and tsne2 < 10 and tsne2 > -90) or 
                    (tsne1 < -90 and tsne1 > -140 and tsne2 > 10 and tsne2 < 50) or
                    (tsne1 < -30 and tsne1 > -90 and tsne2 > -60 and tsne2 < -10) or
                    (tsne1 < -10 and tsne1 > -30 and tsne2 > -60 and tsne2 <-20)
                    else False for tsne1, tsne2 in zip(df["tSNE1_p50"], df["tSNE2_p50"])]

dpi = 300
fig_width = 2.7
fig_height = 2.7
fig = plt.figure(figsize=(fig_width, fig_height), dpi=dpi)
ax = fig.add_subplot(111)
sns.scatterplot(
                x=red[0], y=red[1],
                hue="low_prob",
                palette={True: "orange", False: "lightgray"},
                data=df,
                s=1,
                legend=False
                )
# ax.legend(loc='upper right', ncol=1, fontsize=10, frameon=False, markerscale=4, bbox_to_anchor=(2, 1.1)) 
ax.set_xlabel(red[0].split("_")[0], fontsize=10)
ax.set_ylabel(red[1].split("_")[0], fontsize=10)
# plt.tight_layout(pad=0.2)
plt.savefig(f"{PLOT_DIR}/tsne_low_prob_scatt.tif", dpi=dpi, bbox_inches='tight', format="tif")
plt.show()

In [None]:
low_prob = df[df["low_prob"] == True]
num_tm = len(low_prob[low_prob["is_tm"] == True])
tmp = df[df["low_prob"] == False]
num_tm_rest = len(tmp[tmp["is_tm"] == True])
print(f"Number of pockets found in low probability area: {len(low_prob)}")
print(f"Median prob in low probability area: {low_prob['prob'].median()} vs. all: {tmp['prob'].median()}, difference = {round(low_prob['prob'].median() - tmp['prob'].median(), 2)}")
print(f"Median aro in low probability area: {low_prob['aro'].median()} vs. all: {tmp['aro'].median()}")
print(f"Median hydro in low probability area: {low_prob['hydro'].median()} vs. all: {tmp['hydro'].median()}")
print(f"Number of TM pockets in low probability area: {num_tm} vs. {num_tm_rest} in the rest")
count = Counter(low_prob["domain"])
print(f"Number of domains in low probability area: {len(count)}")
print(f"Most common domain in low probability area: {count.most_common(10)}")


In [None]:
df["no_mans"] = [True if (tsne1 < -148 and tsne2 < -91) or 
                    (tsne1 < -200 and tsne2 < -20)
                    else False for tsne1, tsne2 in zip(df["tSNE1_p50"], df["tSNE2_p50"])]
fig = px.scatter(df, x=red[0], y=red[1], color="no_mans", color_discrete_map={True: "orange", False: "lightgray"},
                hover_data=["pocket_id", "organism", "lig_class"])
fig.update_traces(marker=dict(size=4))
fig.update_layout(width=800, height=700,
                  font=dict(size=20),
                  plot_bgcolor='white',
                  legend=go.layout.Legend(itemsizing='constant',
                                         traceorder="reversed"),
                  showlegend=True
                 )
fig.show()

In [None]:
no_mans = df[df["no_mans"] == True]
num_tm = no_mans[no_mans["is_tm"] == True]
tmp = df[df["no_mans"] == False]
num_tm_rest = tmp[tmp["is_tm"] == True]

high_hydro = no_mans[no_mans["hydro"] > 2]
num_tm_hydro = len(high_hydro[high_hydro["is_tm"] == True])
print(f"Number of pockets found in nomans land: {len(no_mans)}")
print(f"Median prob in nomans land: {no_mans['prob'].median()} vs. all: {tmp['prob'].median()}, difference = {round(no_mans['prob'].median() - tmp['prob'].median(), 2)}")
print(f"Median aro in nomans land: {no_mans['aro'].median()} vs. all: {tmp['aro'].median()}")
print(f"Median hydro in nomans land: {no_mans['hydro'].median()} vs. all: {tmp['hydro'].median()}")
print(f"Number of TM pockets in nomans land: {len(num_tm)} vs. {len(num_tm_rest)} in the rest")
count = Counter(no_mans["domain"])
print(f"Number of domains in nomans land: {len(count)}")
print(f"Most common domain in nomans land: {count.most_common(10)}")
print(f"Number of pockets with high hydrophobicity in nomans land: {len(high_hydro)} with {num_tm_hydro} TM pockets")
