In [None]:
%load_ext autoreload
%autoreload 2

import sys
REPO_ROOT = "/mnt/STORAGE3/sebastian2/DiffSBDD"
if REPO_ROOT not in sys.path:
    sys.path.append(REPO_ROOT)

In [None]:
import numpy as np
from rdkit import Chem
import umap
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import umap.umap_ as umap
from typing import List, Optional, Tuple
from sklearn.decomposition import PCA
import math

# Set a professional plotting style
sns.set_theme(context="talk")

In [None]:
DATA_SET_PATH = "/mnt/STORAGE3/sebastian2/data_hackathon/data/crossdocked_pocket10/"
EMBEDDING_PWS = [
    "/mnt/STORAGE3/sebastian2/data_hackathon/data/text_embeddings.npz",
    "/mnt/STORAGE3/sebastian2/data_hackathon/data/all_embed_tinyllama.npz",
]



In [None]:
data = np.load("/mnt/STORAGE3/sebastian2/data_hackathon/data/text_embeddings.npz")

In [None]:
def load_mols_from_names(names):
    mols = []
    for name in names:
        mol = Chem.MolFromMolFile(DATA_SET_PATH + name.split("pdb_")[-1])
        mols.append(mol)
    return mols

def reduce_embeddings_umap(embeddings):
    reducer = umap.UMAP(n_components=2, random_state=42)
    reduced_embeddings = reducer.fit_transform(embeddings)
    return reduced_embeddings

def load_embeddings(embedding_pw):
    with np.load(embedding_pw, allow_pickle=True) as f:
        embeddings = f['embeddings']
        names = f['names']
    return embeddings, names

def plot_embeddings_separate_pca(
    embeddings_list: List[np.ndarray],
    highlight_idxs: List[List[int]],
    labels: Optional[List[str]] = None,
    suptitle: str = "Internal Structure of Embedding Groups via PCA",
) -> None:
    """
    Visualizes the internal structure of embedding groups that have different
    feature dimensions by running PCA on each group separately.

    NOTE: The axes of each subplot are NOT comparable to other subplots.
    This visualization is for analyzing the shape and variance of each
    group independently.

    Args:
        embeddings_list (List[np.ndarray]): A list of numpy arrays. Each array
            can have a different number of features (shape[1]).
        highlight_idxs (List[List[int]]): Indices to highlight in each group.
        labels (Optional[List[str]], optional): Labels for each subplot title.
        suptitle (str, optional): The main title for the entire figure.
    """
    sns.set_theme(style="ticks", context="talk")

    if labels is None:
        labels = [f'Group {i+1} (D={E.shape[1]})' for i, E in enumerate(embeddings_list)]

    # --- FIX: Validate that all input lists have the same length ---
    if not (len(embeddings_list) == len(highlight_idxs) == len(labels)):
        raise ValueError(
            "The lists for embeddings, highlight indices, and labels must all be the same length."
        )

    # --- Set up the subplot grid ---
    n_groups = len(embeddings_list)
    ncols = min(3, n_groups)
    nrows = math.ceil(n_groups / ncols)
    
    fig, axes = plt.subplots(
        nrows, ncols,
        figsize=(ncols * 5.5, nrows * 5)
    )
    # Ensure axes is always an array, even with one plot
    if n_groups == 1:
        axes_flat = [axes]
    else:
        axes_flat = axes.flatten()
        
    palette = sns.color_palette("viridis", n_colors=n_groups)

    # --- Loop, running PCA on each group individually ---
    for i, (embeddings, ax) in enumerate(zip(embeddings_list, axes_flat)):
        ax.set_title(labels[i], fontweight='bold')

        if embeddings.shape[0] < 2:
            ax.text(0.5, 0.5, 'Not enough data for PCA', ha='center', va='center')
            continue

        # --- Perform PCA on the current group ---
        pca = PCA(n_components=2, random_state=42)
        reduced_embeddings = pca.fit_transform(embeddings)
        explained_variance = pca.explained_variance_ratio_

        # Plot the points for this group
        ax.scatter(
            reduced_embeddings[:, 0], reduced_embeddings[:, 1],
            color=palette[i], s=20, alpha=0.7
        )

        # Overlay highlights
        current_highlights = highlight_idxs[i]
        if current_highlights:
            ax.scatter(
                reduced_embeddings[current_highlights, 0],
                reduced_embeddings[current_highlights, 1],
                facecolors='none', edgecolors='#222222',
                s=50, linewidths=2.0
            )
        
        # --- IMPROVEMENT: Label only the outer axes ---
        # Label x-axis only for plots on the bottom row
        if i >= n_groups - ncols:
            ax.set_xlabel(f"PC 1 ({explained_variance[0]:.1%})")
        # Label y-axis only for plots in the first column
        if i % ncols == 0:
            ax.set_ylabel(f"PC 2 ({explained_variance[1]:.1%})")

    # --- Final Touches ---
    # Hide any unused subplots
    for j in range(n_groups, len(axes_flat)):
        axes_flat[j].axis('off')

    fig.suptitle(suptitle, fontsize=22, y=1.02)
    sns.despine(trim=True)
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.show()

In [None]:
embeddings, names = load_embeddings(EMBEDDING_PWS[0])
embeddings_llama, names_llama = load_embeddings(EMBEDDING_PWS[1])

In [None]:
plot_embeddings_separate_pca([embeddings[:5000], embeddings_llama[:5000]], [[0],[]])