In [None]:
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.preprocessing import normalize
import umap

from histopatseg.visualization.visualization import plot_embeddings
from histopatseg.evaluation.utils import aggregate_tile_embeddings

In [None]:
project_dir = Path(".").resolve().parent
print(f"Project Directory: {project_dir}")

In [None]:
embedding_file = project_dir / "data/processed/embeddings/lunghist700_20x_UNI2_embeddings.npz"
metadata  = pd.read_csv(project_dir / "data/processed/LungHist700_tiled/LungHist700_20x/metadata.csv").set_index("tile_id")

In [None]:
# Load the embeddings
data = np.load(embedding_file)
embeddings = data["embeddings"]
tile_ids = data["tile_ids"]
embedding_dim = data["embedding_dim"]

# Print basic information
print(f"Loaded {len(embeddings)} embeddings with dimensionality {embeddings.shape[1]}")
print(f"Embedding dimension from model: {embedding_dim}")

In [None]:
# Check if all embedding tile_ids are in the metadata index
missing_ids = [id for id in tile_ids if id not in metadata.index]
if missing_ids:
    print(f"Warning: {len(missing_ids)} tile_ids from embeddings are not in metadata")
    print(f"First few missing IDs: {missing_ids[:5]}")
metadata = metadata.reindex(tile_ids)
metadata["image_id"] = metadata["image_id"].astype(str) + "__" + metadata["patient_id"].astype(str)
metadata.head()

In [None]:
embeddings = normalize(embeddings, norm="l2")

In [None]:
norms = np.linalg.norm(embeddings, axis=1)
print("Mean norm:", norms.mean())
print("Min norm:", norms.min(), "Max norm:", norms.max())

In [None]:
def visualize_embeddings(embeddings, metadata, method="t-SNE", aggregated=False):
    """Generate visualizations for embeddings using specified dimensionality reduction.
    
    Args:
        embeddings: The embedding vectors
        metadata: Associated metadata
        method: Dimensionality reduction method ("t-SNE", "UMAP", or "PCA")
        aggregated: Whether these are aggregated embeddings
    """
    suffix = "with Mean aggregation" if aggregated else "without Aggregation"
    
    # Perform dimensionality reduction
    if method == "t-SNE":
        reducer = TSNE(
            n_components=2,
            perplexity=15 if aggregated else 30,
            n_iter=1000,
            random_state=42,
            init='pca'
        )
    elif method == "UMAP":
        reducer = umap.UMAP(
            n_neighbors=10 if aggregated else 15,
            min_dist=0.2 if aggregated else 0.1,
            n_components=2,
            metric='euclidean',
            random_state=42
        )
    elif method == "PCA":
        reducer = PCA(n_components=2, random_state=42)
    
    reduced_data = reducer.fit_transform(embeddings)
    
    # Plot with different colorings
    for color_by in ['class_name', 'superclass', 'patient_id']:
        fig = plot_embeddings(
            reduced_data=reduced_data,
            metadata=metadata,
            color_by=color_by,
            method_name=method,
            title=f'{method} Projection of LungHist700 Embeddings {suffix}',
            palette_name='tab10'
        )
        plt.show()
    
    return reduced_data

In [None]:
tsne_embedding = visualize_embeddings(embeddings, metadata, "t-SNE")

In [None]:
from sklearn.model_selection import cross_val_score
from sklearn.linear_model import LogisticRegression
from sklearn.base import BaseEstimator, TransformerMixin


class OrthogonalDeconfounding(BaseEstimator, TransformerMixin):
    """
    Remove group-specific information by projection to orthogonal space.
    
    This class can be used to remove batch effects, patient-specific,
    or image-specific information from embeddings by projecting them
    onto a space orthogonal to the directions that predict group membership.
    """
    
    def __init__(self, C=1.0):
        """
        Initialize the deconfounding transformer.
        
        Parameters:
        -----------
        C : float
            Regularization parameter for LogisticRegression (controls the strength
            of L2 regularization). Lower values encourage more regularization.
        """
        self.C = C
        self.is_fitted_ = False
    
    def fit(self, X, y=None, groups=None):
        """
        Fit the group predictor and compute orthogonal space.
        
        Parameters:
        -----------
        X : array-like of shape (n_samples, n_features)
            Embedding vectors
        y : array-like, default=None
            Not used, present for API consistency
        groups : array-like of shape (n_samples,)
            Group identifiers (e.g., patient_ids, image_ids)
            
        Returns:
        --------
        self : object
            Returns self
        """
        if groups is None:
            raise ValueError("groups must be provided")
            
        # Create matrix to hold directions we'll project out
        n_features = X.shape[1]
        self.group_directions = np.zeros((0, n_features))
        
        # For each group, find the direction in embedding space that best separates
        # that group from others (one-vs-rest approach)
        unique_groups = np.unique(groups)
        print(f"Identifying predictive directions for {len(unique_groups)} groups")
        
        for group in unique_groups:
            # Create a binary task: this group vs all others
            binary_task = (groups == group).astype(int)
            
            # Train a linear model for this task
            model = LogisticRegression(penalty='l2', C=self.C, solver='lbfgs', max_iter=1000)
            model.fit(X, binary_task)
            
            # Extract the coefficient vector and normalize
            direction = model.coef_[0]
            norm = np.linalg.norm(direction)
            
            # Skip if direction is close to zero (can happen with high regularization)
            if norm < 1e-10:
                continue
                
            direction = direction / norm
            
            # Add to our collection of directions
            self.group_directions = np.vstack([self.group_directions, direction])
        
        # Orthogonalize the directions using QR decomposition
        # This is more stable than individual projections
        q, r = np.linalg.qr(self.group_directions.T)
        self.orthogonal_basis = q  # Store the full orthogonal basis
        
        self.n_available_directions = min(len(unique_groups), n_features)
        print(f"Computed {self.n_available_directions} orthogonal directions (use transform with n_components to select)")
        
        self.is_fitted_ = True
        return self
    
    def transform(self, X, n_components=None):
        """
        Project embeddings to space orthogonal to group-predictive directions.
        
        Parameters:
        -----------
        X : array-like of shape (n_samples, n_features)
            Embedding vectors to deconfound
        n_components : int or None
            Number of components to use for deconfounding. If None, will use all available.
            
        Returns:
        --------
        X_projected : array-like of shape (n_samples, n_features)
            Deconfounded embeddings
        """
        if not self.is_fitted_:
            raise ValueError("This OrthogonalDeconfounding instance is not fitted yet. Call 'fit' first.")
            
        # Determine number of components to use
        if n_components is None:
            n_directions = self.n_available_directions
        else:
            n_directions = min(n_components, self.n_available_directions)
            
        print(f"Using {n_directions} orthogonal directions for deconfounding")
        
        # Get the selected number of orthogonal directions
        selected_directions = self.orthogonal_basis[:, :n_directions]
        
        # Create the projection matrix for the orthogonal complement
        projection_matrix = np.eye(X.shape[1]) - selected_directions @ selected_directions.T
        
        # Project the embeddings
        X_projected = X @ projection_matrix
        
        # Re-normalize to unit length
        X_projected = normalize(X_projected, norm='l2')
        
        return X_projected


In [None]:
def remove_image_pcs_for_normalized(embeddings, groups, n_components=2):
    """
    Remove principal components that capture image-level variation,
    specially adapted for L2-normalized embeddings.
    
    Parameters:
    -----------
    embeddings : numpy array of shape (n_samples, n_features)
        L2-normalized embedding vectors
    groups : numpy array of shape (n_samples,)
        groups ID for each tile to remove the effect of
    n_components : int
        Number of principal components to remove
    
    Returns:
    --------
    corrected_embeddings : numpy array of shape (n_samples, n_features)
        Embeddings with image-level PCs removed
    """
    unique_group_ids = np.unique(groups)
    
    # Compute image means
    group_means = np.zeros((len(unique_group_ids), embeddings.shape[1]))
    for i, group_id in enumerate(unique_group_ids):
        mask = groups == group_id
        # For L2-normalized vectors, take the mean and re-normalize
        mean_vector = embeddings[mask].mean(axis=0)
        group_means[i] = mean_vector / np.linalg.norm(mean_vector)
    
    # Compute PCA on image means to identify image-specific directions
    pca = PCA(n_components=n_components)
    pca.fit(group_means)
    image_pcs = pca.components_
    
    # For each embedding, remove projection onto image PCs
    corrected_embeddings = embeddings.copy()
    for i in range(len(embeddings)):
        embedding = embeddings[i]
        
        # Remove projections onto image PCs
        for pc in image_pcs:
            # Calculate projection
            proj = np.dot(embedding, pc) * pc
            # Subtract projection
            embedding = embedding - proj
            
        # Renormalize to unit length
        corrected_embeddings[i] = embedding / np.linalg.norm(embedding)
    
    return corrected_embeddings

pca_embeddings = remove_image_pcs_for_normalized(embeddings, metadata['patient_id'].values, n_components=3)

In [None]:
def analyze_deconfounding_parameters(embeddings, metadata, deconf_type='patient'):
    """
    Analyze effect of different parameters for orthogonal deconfounding.
    
    Parameters:
    -----------
    embeddings : numpy array
        Original L2-normalized embeddings
    metadata : DataFrame
        Metadata containing class_name, patient_id and image_id
    deconf_type : str
        Whether to analyze 'patient' or 'image' deconfounding
    """
    from sklearn.preprocessing import LabelEncoder
    from sklearn.model_selection import cross_val_score
    from sklearn.linear_model import LogisticRegression
    
    # Set up parameter ranges
    n_comp_range = [5, 10, 20, 30, 40, 45, 60, 120, 240, 360, 481]  # None means use all components
    c_values = [0.01, 0.1, 1.0, 10.0]  # Regularization strength
    
    # Set up group column
    if deconf_type == 'patient':
        group_col = 'patient_id'
        other_col = 'image_id'
    else:  # image
        group_col = 'image_id'
        other_col = 'patient_id'
    
    # Prepare data for evaluation
    class_labels = metadata['class_name'].values
    group_labels = metadata[group_col].values
    other_labels = metadata[other_col].values
    
    # Set up classifier
    class_clf = LogisticRegression(solver='lbfgs', max_iter=2000)
    group_clf = LogisticRegression(solver='lbfgs', max_iter=2000)
    other_clf = LogisticRegression(solver='lbfgs', max_iter=2000)
    
    # Results storage
    results = []
    
    # Original baseline
    orig_class = cross_val_score(class_clf, embeddings, class_labels, cv=5).mean()
    orig_group = cross_val_score(group_clf, embeddings, group_labels, cv=5).mean()
    orig_other = cross_val_score(other_clf, embeddings, other_labels, cv=5).mean()
    
    print(f"Original embeddings:")
    print(f"  Class prediction: {orig_class:.4f}")
    print(f"  {group_col} prediction: {orig_group:.4f}")
    print(f"  {other_col} prediction: {orig_other:.4f}")
    print(f"  Class-to-{group_col} ratio: {orig_class/orig_group:.4f}")
    
    # Evaluate different parameter combinations
    for c_val in c_values:
        deconfounder = OrthogonalDeconfounding(C=c_val)
        deconfounded = deconfounder.fit(embeddings, groups=group_labels)
        for n_comp in n_comp_range:
            # Skip invalid combinations
            if n_comp is not None and n_comp > len(np.unique(group_labels)):
                continue
                
            print(f"\nEvaluating: n_components={n_comp}, C={c_val}")
            
            # Create deconfounder
            
            try:
                # Apply deconfounding
                deconfounded = deconfounder.transform(embeddings,  n_components=n_comp)
                
                # Evaluate
                class_score = cross_val_score(class_clf, deconfounded, class_labels, cv=5).mean()
                group_score = cross_val_score(group_clf, deconfounded, group_labels, cv=5).mean()
                other_score = cross_val_score(other_clf, deconfounded, other_labels, cv=5).mean()
                
                print(f"  Class prediction: {class_score:.4f}")
                print(f"  {group_col} prediction: {group_score:.4f}")
                print(f"  {other_col} prediction: {other_score:.4f}")
                print(f"  Class-to-{group_col} ratio: {class_score/group_score:.4f}")
                
                # Store results
                results.append({
                    'n_components': n_comp,
                    'C': c_val,
                    'class_score': class_score,
                    'group_score': group_score,
                    'other_score': other_score,
                    'class_group_ratio': class_score / group_score,
                    'class_drop': orig_class - class_score,
                    'group_drop': orig_group - group_score
                })
                
            except Exception as e:
                print(f"Error with parameters n_components={n_comp}, C={c_val}: {e}")
    
    # Convert to dataframe
    results_df = pd.DataFrame(results)
    
    # Plot results
    plt.figure(figsize=(15, 10))
    
    plt.subplot(2, 2, 1)
    sns.lineplot(data=results_df, x='n_components', y='class_score', hue='C', marker='o')
    plt.axhline(y=orig_class, linestyle='--', color='black', label='Original')
    plt.title('Class Prediction Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Number of Components')
    plt.legend(title='C value')
    
    plt.subplot(2, 2, 2)
    sns.lineplot(data=results_df, x='n_components', y='group_score', hue='C', marker='o')
    plt.axhline(y=orig_group, linestyle='--', color='black', label='Original')
    plt.axhline(y=1.0/len(np.unique(group_labels)), linestyle=':', color='red', 
                label=f'Random ({1.0/len(np.unique(group_labels)):.4f})')
    plt.title(f'{group_col.capitalize()} Prediction Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Number of Components')
    plt.legend(title='C value')
    
    plt.subplot(2, 2, 3)
    sns.lineplot(data=results_df, x='n_components', y='class_group_ratio', hue='C', marker='o')
    plt.axhline(y=orig_class/orig_group, linestyle='--', color='black', label='Original')
    plt.title('Class-to-Group Ratio (higher is better)')
    plt.ylabel('Ratio')
    plt.xlabel('Number of Components')
    plt.legend(title='C value')
    
    plt.subplot(2, 2, 4)
    # Create a scatterplot showing the tradeoff
    for c_val in c_values:
        subset = results_df[results_df['C'] == c_val]
        plt.scatter(subset['group_drop'], subset['class_drop'], 
                   label=f'C={c_val}', s=80, alpha=0.7)
        # Add n_components as annotations
        for _, row in subset.iterrows():
            plt.annotate(f"{row['n_components']}", 
                        (row['group_drop'], row['class_drop']), 
                        fontsize=9)
    
    # Add lines to divide the plot into quadrants
    plt.axhline(y=0, color='black', linestyle='--', alpha=0.3)
    plt.axvline(x=0, color='black', linestyle='--', alpha=0.3)
    
    # Add a diagonal line representing equal drops
    max_drop = max(results_df['group_drop'].max(), results_df['class_drop'].max())
    plt.plot([0, max_drop], [0, max_drop], 'k--', alpha=0.3)
    
    # Add better labels
    plt.xlabel(f'Drop in {group_col} prediction accuracy')
    plt.ylabel('Drop in class prediction accuracy')
    plt.title('Tradeoff between group removal and class preservation')
    
    # Add a legend
    plt.legend(title='C value')
    
    # Optimal area shading - we want high group drop but low class drop
    plt.fill_between([0, max_drop], [0, 0], [0, max_drop], 
                    color='green', alpha=0.1)
    
    plt.tight_layout()
    plt.show()
    
    # Find best parameters based on highest ratio
    best_row = results_df.loc[results_df['class_group_ratio'].idxmax()]
    print("\nBest parameters based on highest class-to-group ratio:")
    print(f"n_components: {best_row['n_components']}, C: {best_row['C']}")
    print(f"Class accuracy: {best_row['class_score']:.4f}")
    print(f"{group_col} accuracy: {best_row['group_score']:.4f}")
    print(f"Class-to-{group_col} ratio: {best_row['class_group_ratio']:.4f}")
    
    # Find best parameters based on group removal with least class drop
    # Normalize both drops to [0,1] range
    results_df['norm_group_drop'] = results_df['group_drop'] / results_df['group_drop'].max()
    results_df['norm_class_drop'] = results_df['class_drop'] / results_df['class_drop'].max()
    
    # Create a combined score: maximize group drop, minimize class drop
    results_df['removal_score'] = results_df['norm_group_drop'] - results_df['norm_class_drop']
    
    best_removal_row = results_df.loc[results_df['removal_score'].idxmax()]
    print("\nBest parameters based on optimal removal:")
    print(f"n_components: {best_removal_row['n_components']}, C: {best_removal_row['C']}")
    print(f"Class accuracy: {best_removal_row['class_score']:.4f}")
    print(f"{group_col} accuracy: {best_removal_row['group_score']:.4f}")
    print(f"Class-to-{group_col} ratio: {best_removal_row['class_group_ratio']:.4f}")
    
    return results_df


In [None]:

# Run analysis for patient deconfounding
patient_results = analyze_deconfounding_parameters(embeddings, metadata, deconf_type='patient')


In [None]:
# Run analysis for image deconfounding
image_results = analyze_deconfounding_parameters(embeddings, metadata, deconf_type='image')