# TODO

- Check that the column name to compute histograms and stuff are correct (i.e. 'image_id')

In [None]:
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import umap
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import normalize

from histopatseg.visualization.visualization import plot_embeddings
from histopatseg.evaluation.utils import aggregate_tile_embeddings, custom_balanced_group_kfold
from histopatseg.evaluation.prototype_classifier import PrototypeClassifier

In [None]:
project_dir = Path(".").resolve().parent
print(f"Project Directory: {project_dir}")

In [None]:
embedding_file = project_dir / "data/processed/embeddings/lunghist700_20x_UNI2_embeddings.npz"
metadata  = pd.read_csv(project_dir / "/home/valentin/workspaces/histopatseg/data/processed/LungHist700_tiled/LungHist700_20x/metadata.csv").set_index("tile_id")
metadata.head()

In [None]:
# Load the embeddings
data = np.load(embedding_file)
embeddings = data["embeddings"]
tile_ids = data["tile_ids"]
embedding_dim = data["embedding_dim"]

# Check if all embedding tile_ids are in the metadata index
missing_ids = [id for id in tile_ids if id not in metadata.index]
if missing_ids:
    print(f"Warning: {len(missing_ids)} tile_idembeddings = remove_image_pcs_for_normalized(embeddings, image_ids, n_components=8)s from embeddings are not in metadata")
    print(f"First few missing IDs: {missing_ids[:5]}")

metadata = metadata.reindex(tile_ids)
metadata["subclass"] = metadata.apply(
    lambda row: row["superclass"]
    if pd.isna(row["subclass"]) and row["superclass"] == "nor"
    else row["subclass"],
    axis=1,
)
 

# Print basic information
print(f"Loaded {len(embeddings)} embeddings with dimensionality {embeddings.shape[1]}")
print(f"Embedding dimension from model: {embedding_dim}")

In [None]:
embeddings = normalize(embeddings, norm="l2")

In [None]:
from sklearn.model_selection import cross_val_score
from sklearn.linear_model import LogisticRegression
from sklearn.base import BaseEstimator, TransformerMixin


class OrthogonalDeconfounding(BaseEstimator, TransformerMixin):
    """
    Remove group-specific information by projection to orthogonal space.
    
    This class can be used to remove batch effects, patient-specific,
    or image-specific information from embeddings by projecting them
    onto a space orthogonal to the directions that predict group membership.
    """
    
    def __init__(self, C=1.0):
        """
        Initialize the deconfounding transformer.
        
        Parameters:
        -----------
        C : float
            Regularization parameter for LogisticRegression (controls the strength
            of L2 regularization). Lower values encourage more regularization.
        """
        self.C = C
        self.is_fitted_ = False
    
    def fit(self, X, y=None, groups=None):
        """
        Fit the group predictor and compute orthogonal space.
        
        Parameters:
        -----------
        X : array-like of shape (n_samples, n_features)
            Embedding vectors
        y : array-like, default=None
            Not used, present for API consistency
        groups : array-like of shape (n_samples,)
            Group identifiers (e.g., patient_ids, image_ids)
            
        Returns:
        --------
        self : object
            Returns self
        """
        if groups is None:
            raise ValueError("groups must be provided")
            
        # Create matrix to hold directions we'll project out
        n_features = X.shape[1]
        self.group_directions = np.zeros((0, n_features))
        
        # For each group, find the direction in embedding space that best separates
        # that group from others (one-vs-rest approach)
        unique_groups = np.unique(groups)
        print(f"Identifying predictive directions for {len(unique_groups)} groups")
        
        for group in unique_groups:
            # Create a binary task: this group vs all others
            binary_task = (groups == group).astype(int)
            
            # Train a linear model for this task
            model = LogisticRegression(penalty='l2', C=self.C, solver='lbfgs', max_iter=1000)
            model.fit(X, binary_task)
            
            # Extract the coefficient vector and normalize
            direction = model.coef_[0]
            norm = np.linalg.norm(direction)
            
            # Skip if direction is close to zero (can happen with high regularization)
            if norm < 1e-10:
                continue
                
            direction = direction / norm
            
            # Add to our collection of directions
            self.group_directions = np.vstack([self.group_directions, direction])
        
        # Orthogonalize the directions using QR decomposition
        # This is more stable than individual projections
        q, r = np.linalg.qr(self.group_directions.T)
        self.orthogonal_basis = q  # Store the full orthogonal basis
        
        self.n_available_directions = min(len(unique_groups), n_features)
        print(f"Computed {self.n_available_directions} orthogonal directions (use transform with n_components to select)")
        
        self.is_fitted_ = True
        return self
    
    def transform(self, X, n_components=None):
        """
        Project embeddings to space orthogonal to group-predictive directions.
        
        Parameters:
        -----------
        X : array-like of shape (n_samples, n_features)
            Embedding vectors to deconfound
        n_components : int or None
            Number of components to use for deconfounding. If None, will use all available.
            
        Returns:
        --------
        X_projected : array-like of shape (n_samples, n_features)
            Deconfounded embeddings
        """
        if not self.is_fitted_:
            raise ValueError("This OrthogonalDeconfounding instance is not fitted yet. Call 'fit' first.")
            
        # Determine number of components to use
        if n_components is None:
            n_directions = self.n_available_directions
        else:
            n_directions = min(n_components, self.n_available_directions)
            
        print(f"Using {n_directions} orthogonal directions for deconfounding")
        
        # Get the selected number of orthogonal directions
        selected_directions = self.orthogonal_basis[:, :n_directions]
        
        # Create the projection matrix for the orthogonal complement
        projection_matrix = np.eye(X.shape[1]) - selected_directions @ selected_directions.T
        
        # Project the embeddings
        X_projected = X @ projection_matrix
        
        # Re-normalize to unit length
        X_projected = normalize(X_projected, norm='l2')
        
        return X_projected

deconfounder = OrthogonalDeconfounding(C=0.01)
deconfounder.fit(embeddings, groups=metadata["patient_id"].values)
embeddings = deconfounder.transform(embeddings, n_components=None)

In [None]:
mask_aca = metadata["superclass"] == "aca"
metadata = metadata[mask_aca]
embeddings = embeddings[mask_aca]

In [None]:
y = metadata["class_name"].values
groups = metadata["patient_id"].values

In [None]:
norms = np.linalg.norm(embeddings, axis=1)
print("Mean norm:", norms.mean())
print("Min norm:", norms.min(), "Max norm:", norms.max())

In [None]:
cv = list(custom_balanced_group_kfold(
    embeddings,
    y,
    groups,
    n_splits=4,
))

In [None]:
np.unique(y)

In [None]:
train_idx, test_idx = cv[1]

In [None]:
def remove_image_pcs_for_normalized(embeddings, image_ids, n_components=2):
    """
    Remove principal components that capture image-level variation,
    specially adapted for L2-normalized embeddings.
    
    Parameters:
    -----------
    embeddings : numpy array of shape (n_samples, n_features)
        L2-normalized embedding vectors
    image_ids : numpy array of shape (n_samples,)
        Image ID for each tile
    n_components : int
        Number of principal components to remove
    
    Returns:
    --------
    corrected_embeddings : numpy array of shape (n_samples, n_features)
        Embeddings with image-level PCs removed
    """
    unique_image_ids = np.unique(image_ids)
    
    # Compute image means
    image_means = np.zeros((len(unique_image_ids), embeddings.shape[1]))
    for i, img_id in enumerate(unique_image_ids):
        mask = image_ids == img_id
        # For L2-normalized vectors, take the mean and re-normalize
        mean_vector = embeddings[mask].mean(axis=0)
        image_means[i] = mean_vector / np.linalg.norm(mean_vector)
    
    # Compute PCA on image means to identify image-specific directions
    pca = PCA(n_components=n_components)
    pca.fit(image_means)
    image_pcs = pca.components_
    
    # For each embedding, remove projection onto image PCs
    corrected_embeddings = embeddings.copy()
    for i in range(len(embeddings)):
        embedding = embeddings[i]
        
        # Remove projections onto image PCs
        for pc in image_pcs:
            # Calculate projection
            proj = np.dot(embedding, pc) * pc
            # Subtract projection
            embedding = embedding - proj
            
        # Renormalize to unit length
        corrected_embeddings[i] = embedding / np.linalg.norm(embedding)
    
    return corrected_embeddings

# Apply to your normalized embeddings
image_ids = metadata["image_id"].values
embeddings = remove_image_pcs_for_normalized(embeddings, image_ids, n_components=8)

In [None]:
def plot_elbow_method(embeddings, max_clusters=20):
    """Plot the elbow method to find optimal number of clusters."""
    inertia = []
    k_range = range(1, max_clusters + 1)
    
    for k in k_range:
        kmeans = KMeans(n_clusters=k, random_state=42, n_init="auto")
        kmeans.fit(embeddings)
        inertia.append(kmeans.inertia_)
    
    # Plot the elbow
    plt.figure(figsize=(10, 6))
    plt.plot(k_range, inertia, 'o-', markersize=8)
    plt.xlabel('Number of Clusters (k)')
    plt.ylabel('Inertia (Within-Cluster Sum of Squares)')
    plt.title('Elbow Method for Optimal k')
    plt.grid(True)
    
    # Calculate the angle/second derivative to find the elbow point
    angles = []
    for i in range(1, len(inertia)-1):
        x1, y1 = k_range[i-1], inertia[i-1]
        x2, y2 = k_range[i], inertia[i]
        x3, y3 = k_range[i+1], inertia[i+1]
        
        # Calculate the angle between the two line segments
        angle = np.abs(np.arctan2(y3-y2, x3-x2) - np.arctan2(y2-y1, x2-x1))
        angles.append((k_range[i], angle))
    
    # Find the point with maximum angle
    elbow_point = max(angles, key=lambda x: x[1])[0]
    plt.axvline(x=elbow_point, color='r', linestyle='--', label=f'Elbow point: k={elbow_point}')
    plt.legend()
    plt.show()
    
    return elbow_point

# Normalize embeddings before clustering
embeddings_norm = normalize(embeddings, norm="l2")
elbow_k = plot_elbow_method(embeddings_norm)

In [None]:
# Fit k-means
n_clusters = 8  # you can tune this (e.g., 5-8 for LUAD patterns)
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init="auto")
cluster_labels = kmeans.fit_predict(embeddings)

# Store cluster assignments in metadata
metadata["cluster"] = cluster_labels

In [None]:
metadata.head()

In [None]:
# %%
# Get LUAD metadata for training and test
metadata_train = metadata.iloc[train_idx]
metadata_test = metadata.iloc[test_idx]


# Helper function to compute histograms per WSI
def compute_wsi_histograms(meta, n_clusters):
    wsi_histograms = {}
    wsi_labels = {}

    for image_id, group in meta.groupby("image_id"):
        cluster_counts = np.bincount(group["cluster"], minlength=n_clusters)
        histogram = cluster_counts / cluster_counts.sum()  # normalize
        wsi_histograms[image_id] = histogram
        wsi_labels[image_id] = group["class_name"].iloc[0]  # assuming consistent label

    return wsi_histograms, wsi_labels

# Compute histograms
train_histograms, train_labels = compute_wsi_histograms(metadata_train, n_clusters)
test_histograms, test_labels = compute_wsi_histograms(metadata_test, n_clusters)


In [None]:
# %%
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix

# Convert dicts to arrays
X_train = np.stack(list(train_histograms.values()))
y_train = np.array(list(train_labels.values()))

X_test = np.stack(list(test_histograms.values()))
y_test = np.array(list(test_labels.values()))

# Train logistic regression
clf = LogisticRegression(multi_class="multinomial", max_iter=1000)
clf.fit(X_train, y_train)

# Predict and evaluate
y_pred = clf.predict(X_test)

print("Classification report:")
print(classification_report(y_test, y_pred))

print("Confusion matrix:")
print(confusion_matrix(y_test, y_pred))

In [None]:
# Convert dicts to arrays
X_train = embeddings[train_idx]
y_train = y[train_idx]

X_test = embeddings[test_idx]
y_test = y[test_idx]

# Train logistic regression
clf = LogisticRegression(max_iter=1000, penalty="l1", C=10, solver="saga")
clf.fit(X_train, y_train)

# Predict and evaluate
y_pred = clf.predict(X_test)

print("Classification report:")
print(classification_report(y_test, y_pred))

print("Confusion matrix:")
print(confusion_matrix(y_test, y_pred))