## UMAPs ##

In [1]:
import numpy as np
import pandas as pd
import ast
import matplotlib.pyplot as plt
import seaborn as sns
import math
from matplotlib.colors import Normalize
import matplotlib.cm as cm
import os

In [2]:
# dimplot for leiden clusters
def dim_plot(merged_df, umap_1_col='umap_1', umap_2_col='umap_2', feature_col='leiden_2.0', save_dir='.'):
    """
    Plot UMAP with specified feature column and save the plot, with cluster numbers projected if feature_col is Leiden.
    """
    print(f"Plotting UMAP by {feature_col}...")
    plt.figure(figsize=(10, 8))
    ax = sns.scatterplot(
        x=umap_1_col, y=umap_2_col, hue=feature_col,
        data=merged_df, palette='tab20', legend=None, s=5, alpha=0.7
    )
    
    # annotate cluster numbers
    if feature_col.startswith('leiden'):
        centroids = merged_df.groupby(feature_col)[[umap_1_col, umap_2_col]].mean()
        for cluster, (x, y) in centroids.iterrows():
            ax.text(x, y, str(cluster), fontsize=9, weight='bold', ha='center', va='center', color='black')

    plt.title(f'UMAP Colored by {feature_col}')
    plt.savefig(f"{save_dir}/umap_by_{feature_col}.png", bbox_inches='tight')
    plt.close() 

# feature plot to highlight one feature at a time
def feature_plot(merged_df, umap_1_col='umap_1', umap_2_col='umap_2', feature_cols='leiden_2.0', save_dir='.'):
    print(f"Plotting UMAP by {feature_cols}...")
    # Loop over each feature column and plot
    for feature in feature_cols:
        unique_values = merged_df[feature].dropna().value_counts().index.tolist()
        n_vals = len(unique_values)
        n_cols = 2
        n_rows = math.ceil(n_vals / n_cols)
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows))
        axes = axes.flatten()

        for i, val in enumerate(unique_values):
            ax = axes[i]
            base = merged_df.copy()
            base['plot_group'] = 'Other'
            base.loc[base[feature] == val, 'plot_group'] = val
            sns.scatterplot(data=base, x=umap_1_col, y=umap_2_col, hue='plot_group', palette={val: 'red', 'Other': 'lightgray'}, ax=ax, s=5, alpha=0.7, legend=False)
            ax.set_title(f'{val}')
            ax.set_xlabel('')
            ax.set_ylabel('')
        for j in range(i + 1, len(axes)):
            axes[j].axis('off')
        fig.suptitle(f'UMAP by {feature}', fontsize=16)
        plt.tight_layout()
        plt.savefig(f"{save_dir}/umap_by_{feature}.png", bbox_inches='tight')
        plt.close()

In [4]:
metadata_file = "/gpfs/scratch/yb2612/dl4med_25/dl_project/scratch_data/hpl-clip/lung_subsample_clinical_clusters.csv"

epoch = 27

for model in ["BarlowTwins_3"]:
    for set in ["test"]:
        print(f"Processing {model}/{set} data...")
        embedding_file = f"/gpfs/data/pmedlab/Users/mottej02/dl_project/pipeline/Histomorphological-Phenotype-Learning/results/{model}/epoch_{epoch}/dataframes/{set}/image_embeddings.npy"
        filenames_file = f"/gpfs/data/pmedlab/Users/mottej02/dl_project/pipeline/Histomorphological-Phenotype-Learning/results/{model}/epoch_{epoch}/dataframes/{set}/image_filenames.npy"
        save_dir = f"/gpfs/data/pmedlab/Users/mottej02/dl_project/pipeline/Histomorphological-Phenotype-Learning/results/{model}/epoch_{epoch}/dataframes/{set}/leiden"
        
        merged_df = pd.read_csv(f"/gpfs/data/pmedlab/Users/mottej02/dl_project/pipeline/Histomorphological-Phenotype-Learning/results/{model}/epoch_{epoch}/dataframes/{set}/leiden/umap_leiden_results.csv")
        dim_plot(merged_df, feature_col='leiden_2.0', save_dir=save_dir)
        dim_plot(merged_df, feature_col='sampleID', save_dir=save_dir)
        feature_plot(merged_df, feature_cols=['_primary_disease'], save_dir=save_dir)

Processing BarlowTwins_3/test data...
Plotting UMAP by leiden_2.0...
Plotting UMAP by sampleID...
Plotting UMAP by ['_primary_disease']...
