In [None]:
import argparse
import os
import json
import pandas as pd
import numpy as np


from lib.test import coherent_test_cos_rejection, test_model
from lib.config import modalities_list
from lib.read_data import read_data
from lib.get_models import get_diffusion_model
from lib.diffusion_models import GaussianDiffusion
from lib.metrics import calculate_PED_balanced, compute_prdc

import pathlib
from types import SimpleNamespace
import torch



In [None]:
# parameters
dim = 32    

test_repeats = 10

results_path = '../results'    
labels_dir = "../datasets_TCGA/downstream_labels/"
data_dir = '../datasets_TCGA/07_normalized/'

device = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")


In [None]:
# pick mod
modality = 'rppa'

method = 'multi'

In [None]:
# read data real
modalities_map = read_data(
    modalities=modalities_list,
    splits=['train','test'],
    data_dir=data_dir,
    dim=dim,
)


In [None]:
test_real = modalities_map[modality]['test']

In [None]:
# load model

base_dir = pathlib.Path(f"{results_path}/{str(dim)}/{modality}_from_{method}") 
ckpt_path = base_dir / 'train' / f'best_by_mse.pth'


# Load the checkpoint dict
ckpt = torch.load(ckpt_path, map_location='cpu')
raw_cfg = ckpt['config']
config = SimpleNamespace(**raw_cfg)
state_dict = ckpt[f'best_model_mse']

x_dim = test_real.shape[1]

cond_datatypes = list(modalities_map.keys())
cond_datatypes.remove(modality)
cond_dim_list = [modalities_map[c]['test'].shape[1] for c in cond_datatypes]

diffusion = GaussianDiffusion(num_timesteps=1000).to(device)

# Load the model
model = get_diffusion_model(
    config.architecture,
    diffusion,
    config,
    x_dim=x_dim,
    cond_dims=cond_dim_list   # list of conditioning dims
).to(device)
model.load_state_dict(state_dict)
model.eval()

In [None]:
# generate data unconditionally with full masking

# Build cond_test_list with zero replacement
cond_test_list = []
for cond_name in cond_datatypes:
    shape = modalities_map[cond_name]['test'].shape
    cond_test_list.append(pd.DataFrame(np.zeros(shape), columns=modalities_map[cond_name]['test'].columns))


# Create masks for each conditioning: 1 if the condition is in the combo, 0 otherwise
masks = []
num_samples = test_real.shape[0]
for cond_name in cond_datatypes:
    masks.append(np.zeros(num_samples))



test_metrics, uncond_generated_data = test_model(
    test_real, cond_test_list, model, diffusion,
    test_iterations=test_repeats, device=device, masks=masks
)

In [None]:
# load generated data
conditioning_string = '_'.join(cond_datatypes)
synth_path = os.path.join(results_path, str(dim), f'{modality}_from_{method}/test/generated_samples_from_{conditioning_string}_best_mse.csv')
cond_generated_data = pd.read_csv(synth_path)
print(f"Loaded generated data from {synth_path}")

In [None]:
# print shapes
print(f"Test real shape: {test_real.shape}")
print(f"Uncond generated data shape: {uncond_generated_data.shape}")
print(f"Cond generated data shape: {cond_generated_data.shape}")

In [None]:
#measure the distance of conditional and unconditional generation from real data



In [None]:
precision_uncond, recall_uncond = compute_prdc(
    test_real.values,
    uncond_generated_data.iloc[:len(test_real)].values,
    nearest_k=10,
    only_pr=True
)

f1_uncond = 2 * (precision_uncond * recall_uncond) / (precision_uncond + recall_uncond)

print(f"Unconditional Generation - Precision: {precision_uncond:.4f}, Recall: {recall_uncond:.4f}, F1: {f1_uncond:.4f}")

In [None]:
precision_cond, recall_cond = compute_prdc(
    test_real.values,
    cond_generated_data.iloc[:len(test_real)].values,
    nearest_k=10,
    only_pr=True
)

f1_cond = 2 * (precision_cond * recall_cond) / (precision_cond + recall_cond)

print(f"Conditional Generation - Precision: {precision_cond:.4f}, Recall: {recall_cond:.4f}, F1: {f1_cond:.4f}")

In [None]:
g_ped_uncond, g_ed_uncond, l_ped_uncond, l_ed_uncond, b_ped_uncond, b_ed_uncond = calculate_PED_balanced(
    test_real.values,
    uncond_generated_data.iloc[:len(test_real)].values,
    metric='l2'
)
print(f"Unconditional Generation - Global PED: {g_ped_uncond:.4f}, Global ED: {g_ed_uncond:.4f}, "
      f"Local PED: {l_ped_uncond:.4f}, Local ED: {l_ed_uncond:.4f}, "
      f"Balanced PED: {b_ped_uncond:.4f}, Balanced ED: {b_ed_uncond:.4f}")

In [None]:
g_ped_cond, g_ed_cond, l_ped_cond, l_ed_cond, b_ped_cond, b_ed_cond = calculate_PED_balanced(
    test_real.values,
    cond_generated_data.iloc[:len(test_real)].values,
    metric='l2'
)

print(f"Conditional Generation - Global PED: {g_ped_cond:.4f}, Global ED: {g_ed_cond:.4f}, "
      f"Local PED: {l_ped_cond:.4f}, Local ED: {l_ed_cond:.4f}, "
      f"Balanced PED: {b_ped_cond:.4f}, Balanced ED: {b_ed_cond:.4f}")

In [None]:
# compare to real train

train_real = modalities_map[modality]['train']
train_real = train_real.dropna()
print(f"Train real shape after dropping NaNs: {train_real.shape}")

In [None]:
import argparse
import os
import json
import pandas as pd
import numpy as np
import pathlib
from types import SimpleNamespace

import torch

from lib.test import coherent_test_cos_rejection, test_model
from lib.config import modalities_list
from lib.read_data import read_data
from lib.get_models import get_diffusion_model
from lib.diffusion_models import GaussianDiffusion

# parameters
dim = 32    
test_repeats = 10

results_path = '../results'    
labels_dir = "../datasets_TCGA/downstream_labels/"
data_dir = '../datasets_TCGA/07_normalized/'

device = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")

# list the four modalities you want to run
modalities_to_run = modalities_list  

# read all data once
modalities_map = read_data(
    modalities=modalities_list,
    splits=['train','test'],
    data_dir=data_dir,
    dim=dim,
)

# store a summary of metrics if desired
summary = []

def compute_all_metrics(real_arr, gen_arr, test_repeats):
    """Compute all metrics for a given real/generated array pair."""
    metrics = {name: np.zeros(test_repeats) for name in 
               ['prec', 'rec', 'f1', 'g_ped', 'g_ed', 'l_ped', 'l_ed', 'b_ped', 'b_ed']}
    
    for i in range(test_repeats):
        # PRDC
        p, r = compute_prdc(real_arr[i], gen_arr[i], nearest_k=10, only_pr=True)
        metrics['prec'][i], metrics['rec'][i] = p, r
        metrics['f1'][i] = 2 * (p * r) / (p + r + 1e-8)
        
        # PED/ED metrics
        (metrics['g_ped'][i], metrics['g_ed'][i], 
         metrics['l_ped'][i], metrics['l_ed'][i],
         metrics['b_ped'][i], metrics['b_ed'][i]) = calculate_PED_balanced(
            real_arr[i], gen_arr[i], metric='l2')
    
    return {k: (v.mean(), v.std()) for k, v in metrics.items()}

def print_metrics(label, metrics):
    """Print metrics in a compact format."""
    m = {k: v[0] for k, v in metrics.items()}  # extract means
    print(f"{label:<8} Precision: {m['prec']:.2f}, Recall: {m['rec']:.2f}, F1: {m['f1']:.2f}"
          f"   ED_G: {m['g_ed']:.2f}, ED_L: {m['l_ed']:.2f}, ED_B: {m['b_ed']:.2f}"
          f"   PED_G: {m['g_ped']:.2f}, PED_L: {m['l_ped']:.2f}, PED_B: {m['b_ped']:.2f}")

def create_summary_entry(modality, uncond_metrics, cond_metrics):
    """Create a summary dictionary entry."""
    entry = {'modality': modality}
    for condition, metrics in [('uncond', uncond_metrics), ('cond', cond_metrics)]:
        for metric, (mean, std) in metrics.items():
            entry[f'{metric}_{condition}'] = mean
            entry[f'{metric}_{condition}_std'] = std
    return entry

for modality in modalities_to_run:
    print(f"\n=== Running modality: {modality} ===")
    
    # Setup paths and load model
    test_real = modalities_map[modality]['test']
    train_real = modalities_map[modality]['train']
    train_real = train_real.dropna()
    n_samples, n_feats = test_real.shape
    base_dir = pathlib.Path(f"{results_path}/{dim}/{modality}_from_multi")
    ckpt_path = base_dir / 'train' / 'best_by_mse.pth'
    
    # Load and setup model
    ckpt = torch.load(ckpt_path, map_location='cpu')
    config = SimpleNamespace(**ckpt['config'])
    
    cond_datatypes = [m for m in modalities_map.keys() if m != modality]
    cond_dim_list = [modalities_map[c]['test'].shape[1] for c in cond_datatypes]
    
    diffusion = GaussianDiffusion(num_timesteps=1000).to(device)
    model = get_diffusion_model(config.architecture, diffusion, config, 
                               x_dim=n_feats, cond_dims=cond_dim_list).to(device)
    model.load_state_dict(ckpt['best_model_mse'])
    model.eval()
    
    # Generate unconditional samples
    cond_test_list = [pd.DataFrame(np.zeros_like(modalities_map[c]['test']), 
                                  columns=modalities_map[c]['test'].columns) 
                     for c in cond_datatypes]
    masks = [np.zeros(n_samples) for _ in cond_datatypes]
    
    _, uncond_generated_data = test_model(test_real, cond_test_list, model, diffusion,
                                         test_iterations=test_repeats, device=device, masks=masks)
    
    # Load conditional samples
    conditioning_string = '_'.join(cond_datatypes)
    synth_path = base_dir / 'test' / f'generated_samples_from_{conditioning_string}_best_mse.csv'
    cond_generated_data = pd.read_csv(synth_path)
    
    # Reshape arrays
    uncond_arr = uncond_generated_data.values.reshape(test_repeats, n_samples, n_feats)
    cond_arr = cond_generated_data.values.reshape(test_repeats, n_samples, n_feats)
    real_arr = np.tile(test_real.values[np.newaxis], (test_repeats, 1, 1))
    
    # Create random subsets of training data for each repeat
    train_real_arr = np.zeros((test_repeats, n_samples, n_feats))
    for i in range(test_repeats):
        idx = np.random.choice(len(train_real), size=n_samples, replace=False)
        train_real_arr[i] = train_real.values[idx]
    
    # Compute metrics
    uncond_metrics = compute_all_metrics(real_arr, uncond_arr, test_repeats)
    cond_metrics = compute_all_metrics(real_arr, cond_arr, test_repeats)
    train_test_metrics = compute_all_metrics(train_real_arr, uncond_arr, test_repeats)
    train_cond_metrics = compute_all_metrics(train_real_arr, cond_arr, test_repeats)
    
    # Print results
    print_metrics("Uncond", uncond_metrics)
    print_metrics("Cond", cond_metrics)
    print_metrics("Train vs Test", train_test_metrics)
    print_metrics("Train vs Cond", train_cond_metrics)
    
    # Add to summary
    summary.append(create_summary_entry(modality, uncond_metrics, cond_metrics, train_test_metrics, train_cond_metrics))

In [None]:
# summay
df_summary = pd.DataFrame(summary)
display(df_summary)

In [None]:
import argparse
import os
import json
import pandas as pd
import numpy as np
import pathlib
from types import SimpleNamespace
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
import seaborn as sns

import torch
import umap

from lib.test import coherent_test_cos_rejection, test_model
from lib.config import modalities_list
from lib.read_data import read_data
from lib.get_models import get_diffusion_model
from lib.diffusion_models import GaussianDiffusion

# parameters
dim = 32    
test_repeats = 10

method = 'multi'

results_path = '../results'    
labels_dir = "../datasets_TCGA/downstream_labels/"
data_dir = '../datasets_TCGA/07_normalized/'

device = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")

# list the four modalities you want to run
modalities_to_run = modalities_list  

# read all data once
modalities_map = read_data(
    modalities=modalities_list,
    splits=['train','test'],
    data_dir=data_dir,
    dim=dim,
)

def load_labels(labels_dir, split):
    """Load labels for a given split."""
    labels_path = os.path.join(labels_dir, f"{split}_cancer_type.csv")
    if os.path.exists(labels_path):
        return pd.read_csv(labels_path, index_col=0)
    else:
        print(f"Warning: Labels file not found at {labels_path}")
        return None

def create_umap_visualization(train_data, test_data, cond_data, train_labels, test_labels, 
                             modality, save_path=None):
    """
    Create UMAP visualization with train, test, and conditional generated data.
    
    Args:
        train_data: Training data (pandas DataFrame)
        test_data: Test data (pandas DataFrame) 
        cond_data: Conditional generated data (pandas DataFrame)
        train_labels: Training labels (pandas DataFrame)
        test_labels: Test labels (pandas DataFrame)
        modality: Name of the modality
        save_path: Path to save the plot (optional)
    """
    
    # Fit UMAP on training data
    print(f"Fitting UMAP on training data for {modality}...")
    umap_model = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)
    
    # Clean training data (remove NaN)
    train_clean = train_data.dropna()
    
    # Remove THCA samples from training data if labels are available
    if train_labels is not None:
        # Get the boolean mask for non-NaN rows
        non_nan_mask = ~train_data.isna().any(axis=1)
        train_labels_temp = train_labels[non_nan_mask]
        
        # Create mask to exclude THCA samples
        thca_mask = train_labels_temp.iloc[:, 0] != 'THCA'
        train_clean = train_clean[thca_mask]
        print(f"Removed THCA samples. Training data shape after filtering: {train_clean.shape}")
    
    train_umap = umap_model.fit_transform(train_clean.values)
    
    # Transform test and conditional data
    print(f"Transforming test and conditional data...")
    test_umap = umap_model.transform(test_data.values)
    
    # For conditional data, take only the first test_samples worth (matching test set size)
    n_test_samples = len(test_data)
    cond_subset = cond_data.iloc[:n_test_samples]  # Take first portion matching test size
    cond_umap = umap_model.transform(cond_subset.values)
    
    # Prepare labels
    if train_labels is not None and test_labels is not None:
        # Ensure we have the right number of labels after cleaning and filtering
        # Get the boolean mask for non-NaN rows and apply it to labels
        non_nan_mask = ~train_data.isna().any(axis=1)
        train_labels_temp = train_labels[non_nan_mask]
        
        # Apply THCA filter to labels as well
        thca_mask = train_labels_temp.iloc[:, 0] != 'THCA'
        train_labels_clean = train_labels_temp[thca_mask]
        
        # Get the first column as labels (assuming it contains the class labels)
        train_y = train_labels_clean.iloc[:, 0] if len(train_labels_clean.columns) > 0 else None
        test_y = test_labels.iloc[:, 0] if len(test_labels.columns) > 0 else None
        
        # Use same labels for conditional data as test data
        cond_y = test_y.copy() if test_y is not None else None
        
        # Encode labels to ensure consistent colors
        if train_y is not None and test_y is not None:
            le = LabelEncoder()
            all_labels = pd.concat([train_y, test_y])
            le.fit(all_labels)
            train_y_encoded = le.transform(train_y)
            test_y_encoded = le.transform(test_y)
            cond_y_encoded = le.transform(cond_y)
            unique_labels = le.classes_
        else:
            train_y_encoded = test_y_encoded = cond_y_encoded = None
            unique_labels = None
    else:
        train_y_encoded = test_y_encoded = cond_y_encoded = None
        unique_labels = None
    
    # Create the plot
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Set up color palette with 20 colors
    if unique_labels is not None:
        colors = plt.cm.tab20(np.linspace(0, 1, min(20, len(unique_labels))))
        color_map = dict(zip(range(len(unique_labels)), colors))
    else:
        color_map = None
    
    # Calculate common axis limits for all panels
    all_x = np.concatenate([train_umap[:, 0], test_umap[:, 0], cond_umap[:, 0]])
    all_y = np.concatenate([train_umap[:, 1], test_umap[:, 1], cond_umap[:, 1]])
    x_margin = (all_x.max() - all_x.min()) * 0.05
    y_margin = (all_y.max() - all_y.min()) * 0.05
    xlim = [all_x.min() - x_margin, all_x.max() + x_margin]
    ylim = [all_y.min() - y_margin, all_y.max() + y_margin]
    
    # Plot training data
    if train_y_encoded is not None:
        scatter = axes[0].scatter(train_umap[:, 0], train_umap[:, 1], 
                                c=[color_map[label] for label in train_y_encoded],
                                alpha=0.6, s=20)
    else:
        scatter = axes[0].scatter(train_umap[:, 0], train_umap[:, 1], 
                                alpha=0.6, s=20, c='blue')
    
    axes[0].set_title(f'Training Data - {modality}')
    axes[0].set_xlabel('UMAP 1')
    axes[0].set_ylabel('UMAP 2')
    axes[0].set_xlim(xlim)
    axes[0].set_ylim(ylim)
    
    # Plot test data
    if test_y_encoded is not None:
        axes[1].scatter(test_umap[:, 0], test_umap[:, 1], 
                       c=[color_map[label] for label in test_y_encoded],
                       alpha=0.6, s=20)
    else:
        axes[1].scatter(test_umap[:, 0], test_umap[:, 1], 
                       alpha=0.6, s=20, c='orange')
    
    axes[1].set_title(f'Test Data - {modality}')
    axes[1].set_xlabel('UMAP 1')
    axes[1].set_ylabel('UMAP 2')
    axes[1].set_xlim(xlim)
    axes[1].set_ylim(ylim)
    
    # Plot conditional generated data
    if cond_y_encoded is not None:
        axes[2].scatter(cond_umap[:, 0], cond_umap[:, 1], 
                       c=[color_map[label] for label in cond_y_encoded],
                       alpha=0.6, s=20)
    else:
        axes[2].scatter(cond_umap[:, 0], cond_umap[:, 1], 
                       alpha=0.6, s=20, c='red')
    
    axes[2].set_title(f'Generated Data - {modality}')
    axes[2].set_xlabel('UMAP 1')
    axes[2].set_ylabel('UMAP 2')
    axes[2].set_xlim(xlim)
    axes[2].set_ylim(ylim)
    
    # Add legend if we have labels
    if unique_labels is not None:
        legend_elements = [plt.scatter([], [], c=color_map[i], s=50, label=unique_labels[i]) 
                          for i in range(len(unique_labels))]
        fig.legend(handles=legend_elements, loc='center right', bbox_to_anchor=(1.05, 0.5))
    
    plt.tight_layout()
    
    # Save plot if path provided
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Plot saved to {save_path}")
    
    plt.show()
    
    return umap_model

# Load labels
train_labels = load_labels(labels_dir, 'train')
test_labels = load_labels(labels_dir, 'test')

for modality in modalities_to_run:
    print(f"\n=== Processing modality: {modality} ===")
    
    # Get data
    test_real = modalities_map[modality]['test']
    train_real = modalities_map[modality]['train']
    n_samples, n_feats = test_real.shape
    
    # Setup paths
    base_dir = pathlib.Path(f"{results_path}/{dim}/{modality}_from_{method}")
    
    # Load conditional samples
    cond_datatypes = [m for m in modalities_map.keys() if m != modality]
    conditioning_string = '_'.join(cond_datatypes)
    synth_path = base_dir / 'test' / f'generated_samples_from_{conditioning_string}_best_mse.csv'
    
    if synth_path.exists():
        cond_generated_data = pd.read_csv(synth_path)
        print(f"Loaded conditional data with shape: {cond_generated_data.shape}")
        
        # Create UMAP visualization
        save_path = f'../results/umap/umap_{modality}.png'
        umap_model = create_umap_visualization(
            train_data=train_real,
            test_data=test_real, 
            cond_data=cond_generated_data,
            train_labels=train_labels,
            test_labels=test_labels,
            modality=modality,
            save_path=save_path
        )
        
    else:
        print(f"Warning: Conditional data not found at {synth_path}")
        print("Skipping UMAP visualization for this modality.")

print("\n=== UMAP visualization complete ===")

In [None]:
import argparse
import os
import json
import pandas as pd
import numpy as np
import pathlib
from types import SimpleNamespace
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
import seaborn as sns

import torch
import umap

from lib.test import coherent_test_cos_rejection, test_model
from lib.config import modalities_list
from lib.read_data import read_data
from lib.get_models import get_diffusion_model
from lib.diffusion_models import GaussianDiffusion

# parameters
dim = 32    
test_repeats = 10

method = 'multi'

results_path = '../results'    
labels_dir = "../datasets_TCGA/downstream_labels/"
data_dir = '../datasets_TCGA/07_normalized/'

device = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")

# list the four modalities you want to run
modalities_to_run = modalities_list  

# read all data once
modalities_map = read_data(
    modalities=modalities_list,
    splits=['train','test'],
    data_dir=data_dir,
    dim=dim,
)

def load_labels(labels_dir, split):
    """Load labels for a given split."""
    labels_path = os.path.join(labels_dir, f"{split}_cancer_type.csv")
    if os.path.exists(labels_path):
        return pd.read_csv(labels_path, index_col=0)
    else:
        print(f"Warning: Labels file not found at {labels_path}")
        return None


def create_umap_visualization(train_data, test_data, cond_multi_data, cond_coherent_data, 
                               train_labels, test_labels, modality, save_path=None):
    """
    Create UMAP visualization with train, test, multi-generated, and coherent-generated data.
    """
    print(f"Fitting UMAP on training data for {modality}...")
    umap_model = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)
    train_clean = train_data.dropna()

    if train_labels is not None:
        non_nan_mask = ~train_data.isna().any(axis=1)
        train_labels_temp = train_labels[non_nan_mask]
        thca_mask = train_labels_temp.iloc[:, 0] != 'THCA'
        train_clean = train_clean[thca_mask]
        print(f"Removed THCA samples. Training data shape after filtering: {train_clean.shape}")

    train_umap = umap_model.fit_transform(train_clean.values)
    test_umap = umap_model.transform(test_data.values)
    cond_multi_umap = umap_model.transform(cond_multi_data.iloc[:len(test_data)].values)
    cond_coherent_umap = umap_model.transform(cond_coherent_data.iloc[:len(test_data)].values)

    if train_labels is not None and test_labels is not None:
        non_nan_mask = ~train_data.isna().any(axis=1)
        train_labels_temp = train_labels[non_nan_mask]
        thca_mask = train_labels_temp.iloc[:, 0] != 'THCA'
        train_labels_clean = train_labels_temp[thca_mask]

        train_y = train_labels_clean.iloc[:, 0]
        test_y = test_labels.iloc[:, 0]
        cond_y = test_y.copy()

        le = LabelEncoder()
        le.fit(pd.concat([train_y, test_y]))
        train_y_encoded = le.transform(train_y)
        test_y_encoded = le.transform(test_y)
        cond_y_encoded = le.transform(cond_y)
        unique_labels = le.classes_
    else:
        train_y_encoded = test_y_encoded = cond_y_encoded = None
        unique_labels = None

    fig, axes = plt.subplots(1, 4, figsize=(24, 6))

    if unique_labels is not None:
        colors = plt.cm.tab20(np.linspace(0, 1, min(20, len(unique_labels))))
        color_map = dict(zip(range(len(unique_labels)), colors))
    else:
        color_map = None

    all_x = np.concatenate([
        train_umap[:, 0], test_umap[:, 0], 
        cond_multi_umap[:, 0], cond_coherent_umap[:, 0]
    ])
    all_y = np.concatenate([
        train_umap[:, 1], test_umap[:, 1], 
        cond_multi_umap[:, 1], cond_coherent_umap[:, 1]
    ])
    x_margin = (all_x.max() - all_x.min()) * 0.05
    y_margin = (all_y.max() - all_y.min()) * 0.05
    xlim = [all_x.min() - x_margin, all_x.max() + x_margin]
    ylim = [all_y.min() - y_margin, all_y.max() + y_margin]

    panels = [
        (train_umap, train_y_encoded, 'Training Data'),
        (test_umap, test_y_encoded, 'Test Data'),
        (cond_multi_umap, cond_y_encoded, 'Generated - Multi'),
        (cond_coherent_umap, cond_y_encoded, 'Generated - Coherent'),
    ]

    for ax, (data, labels, title) in zip(axes, panels):
        if labels is not None:
            ax.scatter(data[:, 0], data[:, 1], 
                       c=[color_map[l] for l in labels], alpha=0.6, s=20)
        else:
            ax.scatter(data[:, 0], data[:, 1], alpha=0.6, s=20)
        ax.set_title(f'{title} - {modality}')
        ax.set_xlabel('UMAP 1')
        ax.set_ylabel('UMAP 2')
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)

    if unique_labels is not None:
        legend_elements = [plt.scatter([], [], c=color_map[i], s=50, label=unique_labels[i]) 
                          for i in range(len(unique_labels))]
        fig.legend(handles=legend_elements, loc='center right', bbox_to_anchor=(1.05, 0.5))

    plt.tight_layout()
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Plot saved to {save_path}")

    plt.show()
    return umap_model

# Load labels
train_labels = load_labels(labels_dir, 'train')
test_labels = load_labels(labels_dir, 'test')

for modality in modalities_to_run:
    print(f"\n=== Processing modality: {modality} ===")
    
    test_real = modalities_map[modality]['test']
    train_real = modalities_map[modality]['train']
    
    cond_datatypes = [m for m in modalities_map.keys() if m != modality]
    conditioning_string = '_'.join(cond_datatypes)
    
    def load_generated(method_name):
        path = pathlib.Path(f"{results_path}/{dim}/{modality}_from_{method_name}") / 'test' / \
               f'generated_samples_from_{conditioning_string}_best_mse.csv'
        if not path.exists():
            print(f"Warning: {method_name} data not found at {path}")
            return None
        return pd.read_csv(path)

    cond_multi_data = load_generated('multi')
    cond_coherent_data = load_generated('coherent')

    if cond_multi_data is not None and cond_coherent_data is not None:
        save_path = f'../results/umap/umap_{modality}.png'
        create_umap_visualization(
            train_data=train_real,
            test_data=test_real,
            cond_multi_data=cond_multi_data,
            cond_coherent_data=cond_coherent_data,
            train_labels=train_labels,
            test_labels=test_labels,
            modality=modality,
            save_path=save_path
        )
    else:
        print("Skipping modality due to missing data.")

print("\n=== UMAP visualization complete ===")

In [None]:
import os
import pandas as pd
import numpy as np
import pathlib
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
import torch
import umap

from lib.config import modalities_list
from lib.read_data import read_data

# parameters
Dim = 32
results_path = '../results'
labels_dir = '../datasets_TCGA/downstream_labels/'
data_dir = '../datasets_TCGA/07_normalized/'
device = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")

# list of modalities to run
modalities_to_run = modalities_list

# utility: load labels
def load_labels(labels_dir, split):
    path = os.path.join(labels_dir, f"{split}_cancer_type.csv")
    return pd.read_csv(path, index_col=0) if os.path.exists(path) else None

# utility: load generated samples for a given modality and method
def load_generated(modality, method_name):
    conds = [m for m in modalities_to_run if m != modality]
    cond_str = '_'.join(conds)
    path = pathlib.Path(results_path)/str(Dim)/f"{modality}_from_{method_name}"/'test'/ \
           f"generated_samples_from_{cond_str}_best_mse.csv"
    return pd.read_csv(path) if path.exists() else None

# read all real data
modalities_map = read_data(
    modalities=modalities_to_run,
    splits=['train','test'],
    data_dir=data_dir,
    dim=Dim,
)

# load labels
train_labels = load_labels(labels_dir, 'train')
test_labels = load_labels(labels_dir, 'test')

# collect per-data-type lists
t_list, te_list, m_list, c_list = [], [], [], []
for mod in modalities_to_run:
    t_list.append(modalities_map[mod]['train'])
    te_list.append(modalities_map[mod]['test'])
    gm = load_generated(mod, 'multi')
    gc = load_generated(mod, 'coherent')
    if gm is not None: m_list.append(gm)
    if gc is not None: c_list.append(gc)

# concatenate across modalities
train_all = pd.concat(t_list, axis=1)
test_all = pd.concat(te_list, axis=1)
multi_all = pd.concat(m_list, axis=1)
coherent_all = pd.concat(c_list, axis=1)

print(f"Training data shape: {train_all.shape}")
print(f"Test data shape: {test_all.shape}")
print(f"Multi-generated data shape: {multi_all.shape}")
print(f"Coherent-generated data shape: {coherent_all.shape}")

# create a single UMAP visualization
def create_umap_concat(train_df, test_df, multi_df, coherent_df,
                       train_lbls, test_lbls, save_path=None):
    # initialize UMAP
    mapper = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)

    # align and filter labels (drop only explicit THCA)
    if train_lbls is not None:
        train_lbls = train_lbls.reindex(train_df.index)
        mask = ~(train_lbls.iloc[:, 0] == 'THCA')
        mask = mask.fillna(True)
        train_df = train_df.loc[mask]
        train_lbls = train_lbls.loc[mask]

    # impute missing values using train mean
    train_mean = train_df.mean(axis=0)
    train_imp = train_df.fillna(train_mean)
    test_imp = test_df.fillna(train_mean)
    multi_imp = multi_df.fillna(train_mean)
    coherent_imp = coherent_df.fillna(train_mean)

    # fit and transform
    train_emb = mapper.fit_transform(train_imp.values)
    test_emb = mapper.transform(test_imp.values)
    # ensure consistent lengths
    n_test = test_imp.shape[0]
    multi_emb = mapper.transform(multi_imp.iloc[:n_test].values)
    coherent_emb = mapper.transform(coherent_imp.iloc[:n_test].values)

    # encode labels if available
    if train_lbls is not None and test_lbls is not None:
        le = LabelEncoder().fit(pd.concat([train_lbls.iloc[:, 0], test_lbls.iloc[:, 0]]))
        ty_train = le.transform(train_lbls.iloc[:, 0])
        ty_test = le.transform(test_lbls.iloc[:, 0])
        ty_cond = ty_test.copy()
        classes = le.classes_
    else:
        ty_train = ty_test = ty_cond = None
        classes = None

    # plot panels
    fig, axs = plt.subplots(1, 4, figsize=(24, 6))
    groups = [
        (train_emb, ty_train, 'Training (All)'),
        (test_emb, ty_test, 'Test (All)'),
        (multi_emb, ty_cond, 'Generated - Multi'),
        (coherent_emb, ty_cond, 'Generated - Coherent'),
    ]
    # unified axis limits
    all_x = np.hstack([e[:, 0] for e, *_ in groups])
    all_y = np.hstack([e[:, 1] for e, *_ in groups])
    x_margin, y_margin = (all_x.max() - all_x.min()) * 0.05, (all_y.max() - all_y.min()) * 0.05
    xlims = (all_x.min() - x_margin, all_x.max() + x_margin)
    ylims = (all_y.min() - y_margin, all_y.max() + y_margin)

    for ax, (emb, labels, title) in zip(axs, groups):
        if labels is not None:
            colors = plt.cm.tab20(np.linspace(0, 1, min(20, len(classes))))
            cmap = {i: colors[i] for i in range(len(classes))}
            ax.scatter(emb[:, 0], emb[:, 1], c=[cmap[l] for l in labels], alpha=0.6, s=20)
        else:
            ax.scatter(emb[:, 0], emb[:, 1], color='gray', alpha=0.6, s=20)
        ax.set_title(title)
        ax.set_xlim(*xlims)
        ax.set_ylim(*ylims)
        ax.set_xlabel('UMAP 1')
        ax.set_ylabel('UMAP 2')

    # legend
    if classes is not None:
        legend_elems = [plt.Line2D([0],[0], marker='o', color='w', markerfacecolor=plt.cm.tab20(i/len(classes)), label=lbl)
                        for i, lbl in enumerate(classes)]
        fig.legend(handles=legend_elems, loc='center right', bbox_to_anchor=(1.05, 0.5))

    plt.tight_layout()
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

# execute
save_path = os.path.join(results_path, 'umap', 'umap_all_modalities.png')
create_umap_concat(train_all, test_all, multi_all, coherent_all, train_labels, test_labels, save_path)
print('UMAP concatenation complete.')


In [None]:
import os
import pandas as pd
import numpy as np
import pathlib
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
import torch
import umap

from lib.config import modalities_list
from lib.read_data import read_data

# parameters
Dim = 32
results_path = '../results'
labels_dir = '../datasets_TCGA/downstream_labels/'
data_dir = '../datasets_TCGA/07_normalized/'
device = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")

# list of modalities to run
modalities_to_run = modalities_list

# utility: load labels
def load_labels(labels_dir, split):
    path = os.path.join(labels_dir, f"{split}_cancer_type.csv")
    return pd.read_csv(path, index_col=0) if os.path.exists(path) else None

# utility: load generated samples for a given modality and method
def load_generated(modality, method_name):
    conds = [m for m in modalities_to_run if m != modality]
    cond_str = '_'.join(conds)
    path = pathlib.Path(results_path)/str(Dim)/f"{modality}_from_{method_name}"/'test'/ \
           f"generated_samples_from_{cond_str}_best_mse.csv"
    return pd.read_csv(path) if path.exists() else None

# read all real data
modalities_map = read_data(
    modalities=modalities_to_run,
    splits=['train','test'],
    data_dir=data_dir,
    dim=Dim,
)

# load labels
test_labels = load_labels(labels_dir, 'test')

# collect per-data-type lists
te_list, m_list, c_list = [], [], []
for mod in modalities_to_run:
    te_list.append(modalities_map[mod]['test'])
    gm = load_generated(mod, 'multi')
    gc = load_generated(mod, 'coherent')
    if gm is not None: m_list.append(gm)
    if gc is not None: c_list.append(gc)

# concatenate across modalities
test_all = pd.concat(te_list, axis=1)
multi_all = pd.concat(m_list, axis=1)
coherent_all = pd.concat(c_list, axis=1)


print(f"Test data shape: {test_all.shape}")
print(f"Multi-generated data shape: {multi_all.shape}")
print(f"Coherent-generated data shape: {coherent_all.shape}")

# create a single UMAP visualization
def create_umap_concat(test_df, multi_df, coherent_df,
                       test_lbls, save_path=None):
    # initialize UMAP
    mapper = umap.UMAP(n_neighbors=15, min_dist=0.8, spread=2, n_components=2, random_state=23)



    # fit and transform
    test_emb = mapper.fit_transform(test_df.values)
    # ensure consistent lengths
    n_test = test_df.shape[0]
    multi_emb = mapper.transform(multi_df.iloc[:n_test].values)
    coherent_emb = mapper.transform(coherent_df.iloc[:n_test].values)

    # encode labels if available
    if  test_lbls is not None:
        le = LabelEncoder().fit(test_lbls.iloc[:, 0])
        ty_test = le.transform(test_lbls.iloc[:, 0])
        ty_cond = ty_test.copy()
        classes = le.classes_
    else:
        ty_test = ty_cond = None
        classes = None

    # plot panels
    fig, axs = plt.subplots(1, 3, figsize=(24, 6))
    groups = [
        (test_emb, ty_test, 'Test Set'),
        (multi_emb, ty_cond, 'Generated - Multi-Condition'),
        (coherent_emb, ty_cond, 'Generated - Coherent Denoising'),
    ]
    # unified axis limits
    all_x = np.hstack([e[:, 0] for e, *_ in groups])
    all_y = np.hstack([e[:, 1] for e, *_ in groups])
    x_margin, y_margin = (all_x.max() - all_x.min()) * 0.05, (all_y.max() - all_y.min()) * 0.05
    xlims = (all_x.min() - x_margin, all_x.max() + x_margin)
    ylims = (all_y.min() - y_margin, all_y.max() + y_margin)

    for ax, (emb, labels, title) in zip(axs, groups):
        if labels is not None:
            colors = plt.cm.tab20(np.linspace(0, 1, min(20, len(classes))))
            cmap = {i: colors[i] for i in range(len(classes))}
            ax.scatter(emb[:, 0], emb[:, 1], c=[cmap[l] for l in labels], alpha=0.8, s=10)
        else:
            ax.scatter(emb[:, 0], emb[:, 1], color='gray', alpha=0.6, s=20)
        ax.set_title(title)
        ax.set_xlim(*xlims)
        ax.set_ylim(*ylims)
        ax.set_xlabel('UMAP 1')
        ax.set_ylabel('UMAP 2')

    # legend
    if classes is not None:
        legend_elems = [plt.Line2D([0],[0], marker='o', color='w', markerfacecolor=plt.cm.tab20(i/len(classes)), label=lbl)
                        for i, lbl in enumerate(classes)]
        fig.legend(handles=legend_elems, loc='center right', bbox_to_anchor=(1.05, 0.5), title='Cancer Types', markerscale=2)

    plt.tight_layout()
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

# execute
save_path = os.path.join(results_path, 'umap', 'umap_all_modalities_test_set_only.png')
create_umap_concat(test_all, multi_all, coherent_all, test_labels, save_path)
print('UMAP concatenation complete.')


In [None]:
import os
import pandas as pd
import numpy as np
import pathlib
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
import torch
import umap

from lib.config import modalities_list
from lib.read_data import read_data

# parameters
Dim = 32
results_path = '../results'
labels_dir = '../datasets_TCGA/downstream_labels/'
data_dir = '../datasets_TCGA/07_normalized/'
device = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")

# list of modalities to run
modalities_to_run = modalities_list

# utility: load labels
def load_labels(labels_dir, split):
    path = os.path.join(labels_dir, f"{split}_cancer_type.csv")
    return pd.read_csv(path, index_col=0) if os.path.exists(path) else None

# utility: load generated samples for a given modality and method
def load_generated(modality, method_name):
    conds = [m for m in modalities_to_run if m != modality]
    cond_str = '_'.join(conds)
    path = pathlib.Path(results_path)/str(Dim)/f"{modality}_from_{method_name}"/'test'/ \
           f"generated_samples_from_{cond_str}_best_mse.csv"
    return pd.read_csv(path) if path.exists() else None

# read all real data
modalities_map = read_data(
    modalities=modalities_to_run,
    splits=['train','test'],
    data_dir=data_dir,
    dim=Dim,
)

# load labels
test_labels = load_labels(labels_dir, 'test')

# collect per-data-type lists
te_list, m_list, c_list = [], [], []
for mod in modalities_to_run:
    te_list.append(modalities_map[mod]['test'])
    gm = load_generated(mod, 'multi')
    gc = load_generated(mod, 'coherent')
    if gm is not None: m_list.append(gm)
    if gc is not None: c_list.append(gc)

# concatenate across modalities
test_all = pd.concat(te_list, axis=1)
multi_all = pd.concat(m_list, axis=1)
coherent_all = pd.concat(c_list, axis=1)


print(f"Test data shape: {test_all.shape}")
print(f"Multi-generated data shape: {multi_all.shape}")
print(f"Coherent-generated data shape: {coherent_all.shape}")


def create_polished_umap(test_df, multi_df, coherent_df,
                         test_lbls,
                         main_title="UMAP Projection of Real vs. Generated Data",
                         save_path=None):
    """
    Generates and saves a 3-panel UMAP visualization with a final
    polished, paper-ready aesthetic. Guarantees square plot panels.
    """
    # --- 1. AESTHETIC SETUP ---
    plt.style.use('seaborn-v0_8-white')

    # --- 2. UMAP TRANSFORMATION (Identical Logic) ---
    mapper = umap.UMAP(n_neighbors=15, min_dist=0.8, spread=2, n_components=2, random_state=23)

    test_emb = mapper.fit_transform(test_df.values)
    n_test = test_df.shape[0]
    multi_emb = mapper.transform(multi_df.iloc[:n_test].values)
    coherent_emb = mapper.transform(coherent_df.iloc[:n_test].values)

    if test_lbls is not None:
        le = LabelEncoder().fit(test_lbls.iloc[:, 0])
        ty_test = le.transform(test_lbls.iloc[:, 0])
        ty_cond = ty_test.copy()
        classes = le.classes_
    else:
        ty_test = ty_cond = None
        classes = None

    # --- 3. PLOTTING SETUP ---
    fig, axs = plt.subplots(1, 3, figsize=(22, 8))
    fig.suptitle(main_title, fontsize=26, weight='bold', y=0.96)

    groups = [
        (test_emb, ty_test, 'Real Test Data'),
        (multi_emb, ty_cond, 'Generated (Multi-Condition)'),
        (coherent_emb, ty_cond, 'Generated (Coherent Denoising)'),
    ]

    # --- NEW: Logic to enforce square axis limits ---
    all_x = np.vstack([e[:, 0] for e, *_ in groups])
    all_y = np.vstack([e[:, 1] for e, *_ in groups])
    
    # Original limits
    xlims_orig = (all_x.min(), all_x.max())
    ylims_orig = (all_y.min(), all_y.max())

    # Calculate ranges and find the maximum range
    x_range = xlims_orig[1] - xlims_orig[0]
    y_range = ylims_orig[1] - ylims_orig[0]
    max_range = max(x_range, y_range)
    
    # Add a 5% margin
    margin = max_range * 0.05
    max_range += margin * 2

    # Center the plot within the new, larger square range
    x_center = (xlims_orig[0] + xlims_orig[1]) / 2
    y_center = (ylims_orig[0] + ylims_orig[1]) / 2
    
    xlims = (x_center - max_range / 2, x_center + max_range / 2)
    ylims = (y_center - max_range / 2, y_center + max_range / 2)
    # --- End of new logic ---

    # --- 4. DRAW AND REFINE PLOTS ---
    for ax, (emb, labels, title) in zip(axs, groups):
        if labels is not None:
            colors = plt.cm.get_cmap('tab20', len(classes))
            ax.scatter(emb[:, 0], emb[:, 1], c=labels, cmap=colors, alpha=0.9, s=10, zorder=10)
        else:
            ax.scatter(emb[:, 0], emb[:, 1], color='gray', alpha=0.6, s=20, zorder=10)

        ax.set_title(title, fontsize=20, weight='bold', pad=20)
        
        ax.set_xlim(*xlims)
        ax.set_ylim(*ylims)
        
        # Set aspect to equal AFTER setting limits to ensure it's visually square
        ax.set_aspect('equal', adjustable='box')

        ax.set_xlabel('UMAP 1', fontsize=10, labelpad=5, color='#303030')
        ax.set_ylabel('UMAP 2', fontsize=10, labelpad=5, color='#303030')
        ax.tick_params(axis='x', colors='#505050')
        ax.tick_params(axis='y', colors='#505050')

        for spine in ax.spines.values():
            spine.set_color('#B0B0B0')

    # --- 5. LEGEND ---
    if classes is not None:
        colors = plt.cm.get_cmap('tab20', len(classes))
        legend_elems = [plt.Line2D([0], [0], marker='o', color='w',
                                 markerfacecolor=colors(i), markersize=16,
                                 label=lbl) for i, lbl in enumerate(classes)]

        fig.legend(handles=legend_elems,
                   loc='center left',
                   bbox_to_anchor=(0.93, 0.5),
                   frameon=False,
                   title='Cancer Types',
                   fontsize=16,
                   title_fontsize=18)

    # --- 6. FINAL LAYOUT & SAVE ---
    plt.subplots_adjust(wspace=0.4)
    fig.tight_layout(rect=[0, 0.02, 0.92, 0.92])

    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()



results_path = '../results' # Make sure this is defined
#
save_path = os.path.join(results_path, 'umap', 'umap_all_3_panel.png')
create_polished_umap(test_all, multi_all, coherent_all, test_labels, main_title="UMAP Projection of Real vs. Generated Data", save_path=save_path)
print('Polished UMAP visualization complete.')

In [None]:
import os
import pandas as pd
import numpy as np
import pathlib
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
import torch
import umap

# Assuming lib.config and lib.read_data are available
from lib.config import modalities_list
from lib.read_data import read_data

# --- Parameters ---
dim = 32
results_path = '../results'
labels_dir = "../datasets_TCGA/downstream_labels/"
data_dir = '../datasets_TCGA/07_normalized/'
device = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")
modalities_to_run = modalities_list

# --- Data Loading Utilities ---
def load_labels(labels_dir, split):
    """Load labels for a given split."""
    labels_path = os.path.join(labels_dir, f"{split}_cancer_type.csv")
    if os.path.exists(labels_path):
        return pd.read_csv(labels_path, index_col=0)
    else:
        print(f"Warning: Labels file not found at {labels_path}")
        return None

# Read all data once
modalities_map = read_data(
    modalities=modalities_list,
    splits=['train', 'test'],
    data_dir=data_dir,
    dim=dim,
)

# --- REFACTORED PLOTTING FUNCTION (ADAPTED FROM YOUR FINAL VERSION) ---
def create_polished_umap(test_df, multi_df, coherent_df,
                         test_lbls,
                         main_title="UMAP Projection of Real vs. Generated Data",
                         save_path=None):
    """
    Generates and saves a 3-panel UMAP visualization with a final
    polished, paper-ready aesthetic. Guarantees square plot panels.
    """
    # --- 1. AESTHETIC SETUP ---
    plt.style.use('seaborn-v0_8-white')

    # --- 2. UMAP TRANSFORMATION (Fitted on Test Set) ---
    mapper = umap.UMAP(n_neighbors=15, min_dist=0.8, spread=2, n_components=2, random_state=23)

    test_emb = mapper.fit_transform(test_df.values)
    n_test = test_df.shape[0]
    multi_emb = mapper.transform(multi_df.iloc[:n_test].values)
    coherent_emb = mapper.transform(coherent_df.iloc[:n_test].values)

    if test_lbls is not None:
        le = LabelEncoder().fit(test_lbls.iloc[:, 0])
        ty_test = le.transform(test_lbls.iloc[:, 0])
        ty_cond = ty_test.copy()
        classes = le.classes_
    else:
        ty_test = ty_cond = None
        classes = None

    # --- 3. PLOTTING SETUP ---
    fig, axs = plt.subplots(1, 3, figsize=(22, 8))
    fig.suptitle(main_title, fontsize=26, weight='bold', y=0.96)

    groups = [
        (test_emb, ty_test, 'Real Test Data'),
        (multi_emb, ty_cond, 'Generated (Multi-Condition)'),
        (coherent_emb, ty_cond, 'Generated (Coherent Denoising)'),
    ]

    # --- NEW: Logic to enforce square axis limits ---
    all_x = np.vstack([e[:, 0] for e, *_ in groups])
    all_y = np.vstack([e[:, 1] for e, *_ in groups])
    
    # Original limits
    xlims_orig = (all_x.min(), all_x.max())
    ylims_orig = (all_y.min(), all_y.max())

    # Calculate ranges and find the maximum range
    x_range = xlims_orig[1] - xlims_orig[0]
    y_range = ylims_orig[1] - ylims_orig[0]
    max_range = max(x_range, y_range)
    
    # Add a 5% margin
    margin = max_range * 0.05
    max_range += margin * 2

    # Center the plot within the new, larger square range
    x_center = (xlims_orig[0] + xlims_orig[1]) / 2
    y_center = (ylims_orig[0] + ylims_orig[1]) / 2
    
    xlims = (x_center - max_range / 2, x_center + max_range / 2)
    ylims = (y_center - max_range / 2, y_center + max_range / 2)
    # --- End of new logic ---

    # --- 4. DRAW AND REFINE PLOTS ---
    for ax, (emb, labels, title) in zip(axs, groups):
        if labels is not None:
            colors = plt.cm.get_cmap('tab20', len(classes))
            ax.scatter(emb[:, 0], emb[:, 1], c=labels, cmap=colors, alpha=0.9, s=10, zorder=10)
        else:
            ax.scatter(emb[:, 0], emb[:, 1], color='gray', alpha=0.6, s=20, zorder=10)

        ax.set_title(title, fontsize=20, weight='bold', pad=20)
        
        # REMOVED: ax.set_aspect('equal', adjustable='box') is no longer needed
        
        ax.set_xlim(*xlims)
        ax.set_ylim(*ylims)
        
        # Set aspect to equal AFTER setting limits to ensure it's visually square
        ax.set_aspect('equal', adjustable='box')

        ax.set_xlabel('UMAP 1', fontsize=10, labelpad=5, color='#303030')
        ax.set_ylabel('UMAP 2', fontsize=10, labelpad=5, color='#303030')
        ax.tick_params(axis='x', colors='#505050')
        ax.tick_params(axis='y', colors='#505050')

        for spine in ax.spines.values():
            spine.set_color('#B0B0B0')

    # --- 5. LEGEND ---
    if classes is not None:
        colors = plt.cm.get_cmap('tab20', len(classes))
        legend_elems = [plt.Line2D([0], [0], marker='o', color='w',
                                 markerfacecolor=colors(i), markersize=16,
                                 label=lbl) for i, lbl in enumerate(classes)]

        fig.legend(handles=legend_elems,
                   loc='center left',
                   bbox_to_anchor=(0.93, 0.5),
                   frameon=False,
                   title='Cancer Types',
                   fontsize=16,
                   title_fontsize=18)

    # --- 6. FINAL LAYOUT & SAVE ---
    plt.subplots_adjust(wspace=0.4)
    fig.tight_layout(rect=[0, 0.02, 0.92, 0.92])

    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Plot saved to {save_path}")

    plt.show()

# --- MAIN EXECUTION LOOP ---
print("\n=== Starting UMAP visualization process ===")

# Load test labels (train labels are not needed for this plot)
test_labels = load_labels(labels_dir, 'test')

for modality in modalities_to_run:
    print(f"\n--- Processing modality: {modality.upper()} ---")

    test_real = modalities_map[modality]['test']
    
    # Define and load the corresponding generated data
    cond_datatypes = [m for m in modalities_map.keys() if m != modality]
    conditioning_string = '_'.join(cond_datatypes)

    def load_generated(method_name):
        path = pathlib.Path(f"{results_path}/{dim}/{modality}_from_{method_name}") / 'test' / \
               f'generated_samples_from_{conditioning_string}_best_mse.csv'
        if not path.exists():
            print(f"Warning: {method_name} data not found at {path}")
            return None
        return pd.read_csv(path)

    cond_multi_data = load_generated('multi')
    cond_coherent_data = load_generated('coherent')

    # Create the plot if all required data is present
    if all(df is not None for df in [test_real, cond_multi_data, cond_coherent_data]):
        save_path = f'../results/umap/umap_{modality}.png'
        main_title = f"UMAP Comparison for {modality.upper()} Modality"
        
        create_polished_umap(
            test_df=test_real,
            multi_df=cond_multi_data,
            coherent_df=cond_coherent_data,
            test_lbls=test_labels,
            main_title=main_title,
            save_path=save_path
        )
    else:
        print(f"Skipping modality {modality.upper()} due to missing data.")

print("\n=== UMAP visualization complete ===")

In [None]:
import os
import pandas as pd
import numpy as np
import pathlib
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
import torch
import umap

# Assuming lib.config and lib.read_data are available
from lib.config import modalities_list
from lib.read_data import read_data

# --- Parameters ---
dim = 32
results_path = '../results'
labels_dir = "../datasets_TCGA/downstream_labels/"
data_dir = '../datasets_TCGA/07_normalized/'
device = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")

# --- You can select which modalities to plot here ---
modalities_to_run = modalities_list # e.g., ['cna', 'rnaseq', 'rppa', 'wsi']

# --- Data Loading Utilities ---
def load_labels(labels_dir, split):
    """Load labels for a given split."""
    labels_path = os.path.join(labels_dir, f"{split}_cancer_type.csv")
    if os.path.exists(labels_path):
        return pd.read_csv(labels_path, index_col=0)
    else:
        print(f"Warning: Labels file not found at {labels_path}")
        return None

# Read all data once
modalities_map = read_data(
    modalities=modalities_list,
    splits=['train', 'test'],
    data_dir=data_dir,
    dim=dim,
)

# --- FINAL, CORRECTED Master plotting function for the grid ---
def create_aggregate_umap_grid(all_modalities_data, le, save_path=None):
    """
    Generates a single grid image with all UMAP plots, keeping the original
    styling for each and using a common legend.
    """
    n_modalities = len(all_modalities_data)
    n_cols = 3 # Real, Multi-gen, Coherent-gen
    
    plt.style.use('seaborn-v0_8-white')
    # Adjust figure height based on the number of modalities
    fig, axs = plt.subplots(n_modalities, n_cols, figsize=(22, 8 * n_modalities))
    
    if n_modalities == 1:
        axs = np.array([axs])

    # --- Loop through each modality's pre-computed data ---
    for row_idx, (modality, data) in enumerate(all_modalities_data.items()):
        
        groups = data['groups']
        xlims = data['xlims']
        ylims = data['ylims']
        
        # --- CORRECTED: Add the Main Title for the ROW using ax.text ---
        # This prevents it from being overwritten by the subplot titles.
        row_title_ax = axs[row_idx, 1] # Anchor the title to the middle axis
        row_title = f"UMAP Comparison for {modality.upper()} Modality"
        row_title_ax.text(0.5, 1.2, row_title,
                          transform=row_title_ax.transAxes,
                          ha='center', va='center',
                          fontsize=26, weight='bold')

        # --- Draw the 3 panels for the modality ---
        for col_idx, (emb, labels, subplot_title) in enumerate(groups):
            ax = axs[row_idx, col_idx]
            
            if labels is not None:
                colors = plt.cm.get_cmap('tab20', len(le.classes_))
                ax.scatter(emb[:, 0], emb[:, 1], c=labels, cmap=colors, alpha=0.9, s=10, zorder=10)
            else:
                ax.scatter(emb[:, 0], emb[:, 1], color='gray', alpha=0.6, s=20, zorder=10)
            
            # Set the individual subplot title for EACH panel
            ax.set_title(subplot_title, fontsize=20, weight='bold', pad=20)
            
            ax.set_xlim(*xlims)
            ax.set_ylim(*ylims)
            ax.set_aspect('equal', adjustable='box')
            
            ax.set_xlabel('UMAP 1', fontsize=10, labelpad=5, color='#303030')
            ax.set_ylabel('UMAP 2', fontsize=10, labelpad=5, color='#303030')
            ax.tick_params(axis='x', colors='#505050')
            ax.tick_params(axis='y', colors='#505050')

            for spine in ax.spines.values():
                spine.set_color('#B0B0B0')

    # --- COMMON LEGEND ---
    classes = le.classes_
    colors = plt.cm.get_cmap('tab20', len(classes))
    legend_elems = [plt.Line2D([0], [0], marker='o', color='w',
                             markerfacecolor=colors(i), markersize=16,
                             label=lbl) for i, lbl in enumerate(classes)]

    fig.legend(handles=legend_elems,
               loc='center right',
               frameon=False,
               title='Cancer Types',
               fontsize=16,
               title_fontsize=18)

    # --- FINAL LAYOUT & SAVE ---
    # Adjust layout to make room for legend and titles.
    # The `top` parameter is slightly reduced to prevent titles from being cut off.
    fig.tight_layout(rect=[0.02, 0.02, 0.88, 0.96], h_pad=4) # h_pad adds vertical spacing

    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        fig.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Plot saved to {save_path}")

    plt.show()

# --- MAIN EXECUTION SCRIPT (No changes needed here) ---
print("\n=== Starting UMAP data processing for all modalities ===")

test_labels = load_labels(labels_dir, 'test')
if test_labels is None:
    raise ValueError("Test labels are required but not found.")

le = LabelEncoder().fit(test_labels.iloc[:, 0])

all_modalities_data = {}

for modality in modalities_to_run:
    print(f"\n--- Processing modality: {modality.upper()} ---")

    test_real = modalities_map[modality]['test']
    
    cond_datatypes = [m for m in modalities_map.keys() if m != modality]
    conditioning_string = '_'.join(cond_datatypes)

    def load_generated(method_name):
        path = pathlib.Path(f"{results_path}/{dim}/{modality}_from_{method_name}") / 'test' / \
               f'generated_samples_from_{conditioning_string}_best_mse.csv'
        if not path.exists():
            return None
        return pd.read_csv(path)

    cond_multi_data = load_generated('multi')
    cond_coherent_data = load_generated('coherent')

    if not all(df is not None for df in [test_real, cond_multi_data, cond_coherent_data]):
        print(f"Skipping modality {modality.upper()} due to missing data.")
        continue

    mapper = umap.UMAP(n_neighbors=15, min_dist=0.8, spread=2, n_components=2, random_state=23)
    test_emb = mapper.fit_transform(test_real.values)
    multi_emb = mapper.transform(cond_multi_data.iloc[:len(test_real)].values)
    coherent_emb = mapper.transform(cond_coherent_data.iloc[:len(test_real)].values)

    ty_test = le.transform(test_labels.iloc[:, 0])
    ty_cond = ty_test.copy()

    groups = [
        (test_emb, ty_test, 'Real Test Data'),
        (multi_emb, ty_cond, 'Generated (Multi-Condition)'),
        (coherent_emb, ty_cond, 'Generated (Coherent Denoising)'),
    ]

    all_x = np.vstack([e[:, 0] for e, *_ in groups])
    all_y = np.vstack([e[:, 1] for e, *_ in groups])
    
    xlims_orig = (all_x.min(), all_x.max())
    ylims_orig = (all_y.min(), all_y.max())

    x_range = xlims_orig[1] - xlims_orig[0]
    y_range = ylims_orig[1] - ylims_orig[0]
    max_range = max(x_range, y_range)
    
    margin = max_range * 0.05
    max_range += margin * 2

    x_center = (xlims_orig[0] + xlims_orig[1]) / 2
    y_center = (ylims_orig[0] + ylims_orig[1]) / 2
    
    all_modalities_data[modality] = {
        'groups': groups,
        'xlims': (x_center - max_range / 2, x_center + max_range / 2),
        'ylims': (y_center - max_range / 2, y_center + max_range / 2),
    }

if all_modalities_data:
    print("\n=== Generating aggregated UMAP grid plot ===")
    save_path = f'../results/umap/umap_single_modalities.png'
    create_aggregate_umap_grid(all_modalities_data, le, save_path)
else:
    print("\nNo data was processed. Skipping plot generation.")

print("\n=== UMAP visualization complete ===")