In [None]:
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import umap

from histopatseg.visualization.visualization import plot_embeddings
from histopatseg.evaluation.utils import aggregate_tile_embeddings

In [None]:
project_dir = Path(".").resolve().parent
print(f"Project Directory: {project_dir}")

In [None]:
embedding_file = project_dir / "data/processed/embeddings/lunghist700_20x_UNI2_embeddings.npz"
metadata  = pd.read_csv(project_dir / "/home/valentin/workspaces/histopatseg/data/processed/LungHist700_tiled/LungHist700_20x/metadata.csv").set_index("tile_id")
metadata.head()

In [None]:
# Load the embeddings
data = np.load(embedding_file)
embeddings = data["embeddings"]
tile_ids = data["tile_ids"]
embedding_dim = data["embedding_dim"]

# Print basic information
print(f"Loaded {len(embeddings)} embeddings with dimensionality {embeddings.shape[1]}")
print(f"Embedding dimension from model: {embedding_dim}")

In [None]:
# Check if all embedding tile_ids are in the metadata index
missing_ids = [id for id in tile_ids if id not in metadata.index]
if missing_ids:
    print(f"Warning: {len(missing_ids)} tile_ids from embeddings are not in metadata")
    print(f"First few missing IDs: {missing_ids[:5]}")
aligned_metadata = metadata.reindex(tile_ids)
aligned_metadata['subclass'] = aligned_metadata.apply(
    lambda row: row['superclass'] if pd.isna(row['subclass']) and row['superclass'] == 'nor' else row['subclass'], 
    axis=1
)

In [None]:
# Compute t-SNE reduction 
print("Computing t-SNE projection (this may take a few minutes)...")
tsne = TSNE(
    n_components=2,          # Output dimensions
    perplexity=30,           # Balance between local and global structure
    n_iter=1000,             # Maximum number of iterations
    random_state=42,         # For reproducibility
    init='pca'               # Initialize with PCA (faster and more stable)
)
tsne_embedding = tsne.fit_transform(embeddings)

In [None]:
# Plot t-SNE results
fig = plot_embeddings(
    reduced_data=tsne_embedding,
    metadata=aligned_metadata,
    color_by='class_name',
    method_name='t-SNE',
    title='t-SNE Projection of LungHist700 Embeddings without Aggregation',
    palette_name='tab10'
)
plt.show()

In [None]:
# Plot t-SNE results
fig = plot_embeddings(
    reduced_data=tsne_embedding,
    metadata=aligned_metadata,
    color_by='superclass',
    method_name='t-SNE',
    title='t-SNE Projection of LungHist700 Embeddings without Aggregation',
    palette_name='tab10'
)
plt.show()

In [None]:
# Plot t-SNE results
fig = plot_embeddings(
    reduced_data=tsne_embedding,
    metadata=aligned_metadata,
    color_by='subclass',
    method_name='t-SNE',
    title='t-SNE Projection of LungHist700 Embeddings without Aggregation',
    palette_name='tab10'
)
plt.show()

In [None]:
# Plot t-SNE results
fig = plot_embeddings(
    reduced_data=tsne_embedding,
    metadata=aligned_metadata,
    color_by='resolution',
    method_name='t-SNE',
    title='t-SNE Projection of LungHist700 Embeddings without Aggregation',
    palette_name='tab10'
)
plt.show()

In [None]:
aggregated_embeddings, aggregated_metadata = aggregate_tile_embeddings(
    embeddings=embeddings,
    tile_ids=tile_ids,
    metadata=aligned_metadata,
    group_by="original_filename",
)

In [None]:
tsne_aggregated = TSNE(
    n_components=2,
    perplexity=15,  # Lower perplexity for fewer points
    n_iter=1000,
    random_state=42,
    init='pca'
)
tsne_embedding = tsne_aggregated.fit_transform(aggregated_embeddings)

In [None]:
# Plot t-SNE results
fig = plot_embeddings(
    reduced_data=tsne_embedding,
    metadata=aggregated_metadata,
    color_by='class_name',
    method_name='t-SNE',
    title='t-SNE Projection of LungHist700 Embeddings with Aggregation by Original File',
    palette_name='tab10'
)
plt.show()

In [None]:
# Plot t-SNE results
fig = plot_embeddings(
    reduced_data=tsne_embedding,
    metadata=aggregated_metadata,
    color_by='superclass',
    method_name='t-SNE',
    title='t-SNE Projection of LungHist700 Embeddings with Aggregation by Original File',
    palette_name='tab10'
)
plt.show()

In [None]:
fig = plot_embeddings(
    reduced_data=tsne_embedding,
    metadata=aggregated_metadata,
    color_by='resolution',
    method_name='t-SNE',
    title='t-SNE Projection of LungHist700 Embeddings with Aggregation by Original File',
    palette_name='tab10'
)
plt.show()

In [None]:
# Compute UMAP reduction
print("Computing UMAP projection...")
reducer = umap.UMAP(
    n_neighbors=15,          # Size of local neighborhood (higher: more global structure)
    min_dist=0.1,            # Minimum distance between points (lower: tighter clusters)
    n_components=2,          # Output dimensions
    metric='euclidean',      # Distance metric
    random_state=42          # For reproducibility
)
umap_embedding = reducer.fit_transform(embeddings)

In [None]:
# Plot with your visualization function
fig = plot_embeddings(
    reduced_data=umap_embedding,
    metadata=aligned_metadata,
    color_by='class_name',  # Or 'superclass' for higher-level grouping
    method_name='UMAP',
    title='UMAP Projection of LungHist700 Embeddings withtout Aggregation',
    palette_name='tab10'
)
plt.show()

In [None]:
# Plot with your visualization function
fig = plot_embeddings(
    reduced_data=umap_embedding,
    metadata=aligned_metadata,
    color_by='superclass',  # Or 'superclass' for higher-level grouping
    method_name='UMAP',
    title='UMAP Projection of LungHist700 Embeddings without Aggregation',
    palette_name='tab10'
)
plt.show()

In [None]:
# Plot with your visualization function
fig = plot_embeddings(
    reduced_data=umap_embedding,
    metadata=aligned_metadata,
    color_by='subclass',  # Or 'superclass' for higher-level grouping
    method_name='UMAP',
    title='UMAP Projection of LungHist700 Embeddings without Aggregation',
    palette_name='tab10'
)
plt.show()

In [None]:
# Plot with your visualization function
fig = plot_embeddings(
    reduced_data=umap_embedding,
    metadata=aligned_metadata,
    color_by='resolution',  # Or 'superclass' for higher-level grouping
    method_name='UMAP',
    title='UMAP Projection of LungHist700 Embeddings without Aggregation',
    palette_name='tab10'
)
plt.show()

In [None]:
# Compute UMAP reduction
print("Computing UMAP projection...")
reducer = umap.UMAP(
    n_neighbors=10,          # Size of local neighborhood (higher: more global structure)
    min_dist=0.2,            # Minimum distance between points (lower: tighter clusters)
    n_components=2,          # Output dimensions
    metric='euclidean',      # Distance metric
    random_state=42          # For reproducibility
)
umap_embedding = reducer.fit_transform(aggregated_embeddings)

In [None]:
# Plot with your visualization function
fig = plot_embeddings(
    reduced_data=umap_embedding,
    metadata=aggregated_metadata,
    color_by='class_name',  # Or 'superclass' for higher-level grouping
    method_name='UMAP',
    title='UMAP Projection of LungHist700 Embeddings with Mean aggregation',
    palette_name='tab10'
)
plt.show()

In [None]:
# Plot with your visualization function
fig = plot_embeddings(
    reduced_data=umap_embedding,
    metadata=aggregated_metadata,
    color_by='superclass',  # Or 'superclass' for higher-level grouping
    method_name='UMAP',
    title='UMAP Projection of LungHist700 Embeddings with Mean aggregation',
    palette_name='tab10'
)
plt.show()

In [None]:
# Plot with your visualization function
fig = plot_embeddings(
    reduced_data=umap_embedding,
    metadata=aggregated_metadata,
    color_by='subclass',  # Or 'superclass' for higher-level grouping
    method_name='UMAP',
    title='UMAP Projection of LungHist700 Embeddings with Mean aggregation',
    palette_name='tab10'
)
plt.show()

In [None]:
# Plot with your visualization function
fig = plot_embeddings(
    reduced_data=umap_embedding,
    metadata=aggregated_metadata,
    color_by='resolution',  # Or 'superclass' for higher-level grouping
    method_name='UMAP',
    title='UMAP Projection of LungHist700 Embeddings with Mean aggregation',
    palette_name='tab10'
)
plt.show()

In [None]:
# Compute PCA reduction
print("Computing PCA projection...")
pca = PCA(
    n_components=2,        # Output dimensions
    random_state=42        # For reproducibility
)
pca_embedding = pca.fit_transform(embeddings)

# Print explained variance
explained_variance = pca.explained_variance_ratio_
print(f"Explained variance ratio: {explained_variance[0]:.4f}, {explained_variance[1]:.4f}")
print(f"Total explained variance: {sum(explained_variance):.4f}")

In [None]:
# Plot with your visualization function
fig = plot_embeddings(
    reduced_data=pca_embedding,
    metadata=aligned_metadata,
    color_by='class_name',  # Or 'superclass' for higher-level grouping
    method_name='PCA',
    title='PCA Projection of LungHist700 Embeddings without Aggregation',
    palette_name='tab10'
)
plt.show()

In [None]:
# Plot with your visualization function
fig = plot_embeddings(
    reduced_data=pca_embedding,
    metadata=aligned_metadata,
    color_by='superclass',  # Or 'superclass' for higher-level grouping
    method_name='PCA',
    title='PCA Projection of LungHist700 Embeddings without Aggregation',
    palette_name='tab10'
)
plt.show()

In [None]:
# Plot with your visualization function
fig = plot_embeddings(
    reduced_data=pca_embedding,
    metadata=aligned_metadata,
    color_by='subclass',  # Or 'superclass' for higher-level grouping
    method_name='PCA',
    title='PCA Projection of LungHist700 Embeddings without Aggregation',
    palette_name='tab10'
)
plt.show()

In [None]:
# Plot with your visualization function
fig = plot_embeddings(
    reduced_data=pca_embedding,
    metadata=aligned_metadata,
    color_by='resolution',  # Or 'superclass' for higher-level grouping
    method_name='PCA',
    title='PCA Projection of LungHist700 Embeddings without Aggregation',
    palette_name='tab10'
)
plt.show()

In [None]:
# Compute PCA reduction
print("Computing PCA projection...")
pca = PCA(
    n_components=2,        # Output dimensions
    random_state=42        # For reproducibility
)
pca_embedding = pca.fit_transform(aggregated_embeddings)

# Print explained variance
explained_variance = pca.explained_variance_ratio_
print(f"Explained variance ratio: {explained_variance[0]:.4f}, {explained_variance[1]:.4f}")
print(f"Total explained variance: {sum(explained_variance):.4f}")

In [None]:
# Plot with your visualization function
fig = plot_embeddings(
    reduced_data=pca_embedding,
    metadata=aggregated_metadata,
    color_by='class_name',  # Or 'superclass' for higher-level grouping
    method_name='PCA',
    title='PCA Projection of LungHist700 Embeddings with Mean aggregation',
    palette_name='tab10'
)
plt.show()

In [None]:
# Plot with your visualization function
fig = plot_embeddings(
    reduced_data=pca_embedding,
    metadata=aggregated_metadata,
    color_by='superclass',  # Or 'superclass' for higher-level grouping
    method_name='PCA',
    title='PCA Projection of LungHist700 Embeddings with Mean aggregation',
    palette_name='tab10'
)
plt.show()

In [None]:
# Plot with your visualization function
fig = plot_embeddings(
    reduced_data=pca_embedding,
    metadata=aggregated_metadata,
    color_by='subclass',  # Or 'superclass' for higher-level grouping
    method_name='PCA',
    title='PCA Projection of LungHist700 Embeddings with Mean aggregation',
    palette_name='tab10'
)
plt.show()

In [None]:
# Plot with your visualization function
fig = plot_embeddings(
    reduced_data=pca_embedding,
    metadata=aggregated_metadata,
    color_by='resolution',  # Or 'superclass' for higher-level grouping
    method_name='PCA',
    title='PCA Projection of LungHist700 Embeddings with Mean aggregation',
    palette_name='tab10'
)
plt.show()