In [2]:
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
# import umap

from histopatseg.visualization.visualization import plot_embeddings
from histopatseg.evaluation.utils import aggregate_tile_embeddings

In [3]:
project_dir = Path(".").resolve().parent
print(f"Project Directory: {project_dir}")

Project Directory: /home/val/workspaces/histopatseg


In [6]:
embedding_file = project_dir / "data/processed/embeddings/lunghist700_20x_UNI2_embeddings.npz"
metadata  = pd.read_csv(project_dir / "data/processed/LungHist700_tiled/LungHist700_20x/metadata.csv").set_index("tile_id")
metadata.head()

Unnamed: 0_level_0,patient_id,superclass,subclass,resolution,image_id,class_name,label,original_filename,tile_path
tile_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
aca_bd_20x_0_tile_0_0,5,aca,bd,20x,0,aca_bd,0,aca_bd_20x_0,data/processed/LungHist700_tiled/LungHist700_2...
aca_bd_20x_0_tile_0_1,5,aca,bd,20x,0,aca_bd,0,aca_bd_20x_0,data/processed/LungHist700_tiled/LungHist700_2...
aca_bd_20x_0_tile_0_2,5,aca,bd,20x,0,aca_bd,0,aca_bd_20x_0,data/processed/LungHist700_tiled/LungHist700_2...
aca_bd_20x_0_tile_0_3,5,aca,bd,20x,0,aca_bd,0,aca_bd_20x_0,data/processed/LungHist700_tiled/LungHist700_2...
aca_bd_20x_0_tile_0_4,5,aca,bd,20x,0,aca_bd,0,aca_bd_20x_0,data/processed/LungHist700_tiled/LungHist700_2...


In [7]:
# Load the embeddings
data = np.load(embedding_file)
embeddings = data["embeddings"]
tile_ids = data["tile_ids"]
embedding_dim = data["embedding_dim"]

# Print basic information
print(f"Loaded {len(embeddings)} embeddings with dimensionality {embeddings.shape[1]}")
print(f"Embedding dimension from model: {embedding_dim}")

Loaded 21216 embeddings with dimensionality 1536
Embedding dimension from model: 1536


In [8]:
# Check if all embedding tile_ids are in the metadata index
missing_ids = [id for id in tile_ids if id not in metadata.index]
if missing_ids:
    print(f"Warning: {len(missing_ids)} tile_ids from embeddings are not in metadata")
    print(f"First few missing IDs: {missing_ids[:5]}")
aligned_metadata = metadata.reindex(tile_ids)
aligned_metadata['subclass'] = aligned_metadata.apply(
    lambda row: row['superclass'] if pd.isna(row['subclass']) and row['superclass'] == 'nor' else row['subclass'], 
    axis=1
)

In [9]:
aggregated_embeddings, aggregated_metadata = aggregate_tile_embeddings(
    embeddings=embeddings,
    tile_ids=tile_ids,
    metadata=aligned_metadata,
    group_by="original_filename",
)

Aggregated 21216 individual tile embeddings into 691 original_filename-level embeddings


In [10]:
def visualize_embeddings(embeddings, metadata, method="t-SNE", aggregated=False):
    """Generate visualizations for embeddings using specified dimensionality reduction.
    
    Args:
        embeddings: The embedding vectors
        metadata: Associated metadata
        method: Dimensionality reduction method ("t-SNE", "UMAP", or "PCA")
        aggregated: Whether these are aggregated embeddings
    """
    suffix = "with Mean aggregation" if aggregated else "without Aggregation"
    
    # Perform dimensionality reduction
    if method == "t-SNE":
        reducer = TSNE(
            n_components=2,
            perplexity=15 if aggregated else 30,
            n_iter=1000,
            random_state=42,
            init='pca'
        )
    elif method == "UMAP":
        reducer = umap.UMAP(
            n_neighbors=10 if aggregated else 15,
            min_dist=0.2 if aggregated else 0.1,
            n_components=2,
            metric='euclidean',
            random_state=42
        )
    elif method == "PCA":
        reducer = PCA(n_components=2, random_state=42)
    
    reduced_data = reducer.fit_transform(embeddings)
    
    # Plot with different colorings
    for color_by in ['class_name', 'superclass', 'subclass', 'resolution', 'patient_id']:
        fig = plot_embeddings(
            reduced_data=reduced_data,
            metadata=metadata,
            color_by=color_by,
            method_name=method,
            title=f'{method} Projection of LungHist700 Embeddings {suffix}',
            palette_name='tab10'
        )
        plt.show()
    
    return reduced_data

In [11]:
def remove_image_pcs_for_normalized(embeddings, groups, n_components=2):
    """
    Remove principal components that capture image-level variation,
    specially adapted for L2-normalized embeddings.
    
    Parameters:
    -----------
    embeddings : numpy array of shape (n_samples, n_features)
        L2-normalized embedding vectors
    groups : numpy array of shape (n_samples,)
        groups ID for each tile to remove the effect of
    n_components : int
        Number of principal components to remove
    
    Returns:
    --------
    corrected_embeddings : numpy array of shape (n_samples, n_features)
        Embeddings with image-level PCs removed
    """
    unique_group_ids = np.unique(groups)
    
    # Compute image means
    group_means = np.zeros((len(unique_group_ids), embeddings.shape[1]))
    for i, group_id in enumerate(unique_group_ids):
        mask = groups == group_id
        # For L2-normalized vectors, take the mean and re-normalize
        mean_vector = embeddings[mask].mean(axis=0)
        group_means[i] = mean_vector / np.linalg.norm(mean_vector)
    
    # Compute PCA on image means to identify image-specific directions
    pca = PCA(n_components=n_components)
    pca.fit(group_means)
    image_pcs = pca.components_
    
    # For each embedding, remove projection onto image PCs
    corrected_embeddings = embeddings.copy()
    for i in range(len(embeddings)):
        embedding = embeddings[i]
        
        # Remove projections onto image PCs
        for pc in image_pcs:
            # Calculate projection
            proj = np.dot(embedding, pc) * pc
            # Subtract projection
            embedding = embedding - proj
            
        # Renormalize to unit length
        corrected_embeddings[i] = embedding / np.linalg.norm(embedding)
    
    return corrected_embeddings

In [12]:
corrected_embeddings = remove_image_pcs_for_normalized(embeddings, aligned_metadata['patient_id'].values, n_components=3)

In [13]:
tsne_embedding = visualize_embeddings(embeddings, aligned_metadata, "t-SNE")



KeyboardInterrupt: 

In [None]:
tsne_embedding = visualize_embeddings(embeddings, aligned_metadata, "t-SNE")

In [None]:
umap_embedding = visualize_embeddings(embeddings, aligned_metadata, "UMAP")

In [None]:
pca_embedding = visualize_embeddings(embeddings, aligned_metadata, "PCA")

In [None]:
tsne_agg_embedding = visualize_embeddings(aggregated_embeddings, aggregated_metadata, "t-SNE", aggregated=True)

In [None]:
umap_agg_embedding = visualize_embeddings(aggregated_embeddings, aggregated_metadata, "UMAP", aggregated=True)

In [None]:
pca_agg_embedding = visualize_embeddings(aggregated_embeddings, aggregated_metadata, "PCA", aggregated=True)