In [None]:
import argparse
import os
import json
import pandas as pd
import numpy as np
import pathlib
from types import SimpleNamespace

import torch

from lib.test import coherent_test_cos_rejection, test_model
from lib.config import modalities_list
from lib.read_data import read_data
from lib.get_models import get_diffusion_model
from lib.diffusion_models import GaussianDiffusion
from lib.metrics import calculate_PED_balanced, compute_prdc

# parameters
dim = 32    
test_repeats = 10

results_path = '../results'    
labels_dir = "../datasets_TCGA/downstream_labels/"
data_dir = '../datasets_TCGA/07_normalized/'

device = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")

# list the four modalities you want to run
modalities_to_run = modalities_list  

# read all data once
modalities_map = read_data(
    modalities=modalities_list,
    splits=['train','test'],
    data_dir=data_dir,
    dim=dim,
)

# store a summary of metrics if desired
summary = []

def compute_all_metrics(real_arr, gen_arr, test_repeats):
    """Compute all metrics for a given real/generated array pair."""
    metrics = {name: np.zeros(test_repeats) for name in 
               ['prec', 'rec', 'f1', 'g_ped', 'g_ed', 'l_ped', 'l_ed', 'b_ped', 'b_ed']}
    
    for i in range(test_repeats):
        # PRDC
        p, r = compute_prdc(real_arr[i], gen_arr[i], nearest_k=10, only_pr=True)
        metrics['prec'][i], metrics['rec'][i] = p, r
        metrics['f1'][i] = 2 * (p * r) / (p + r + 1e-8)
        
        # PED/ED metrics
        (metrics['g_ped'][i], metrics['g_ed'][i], 
         metrics['l_ped'][i], metrics['l_ed'][i],
         metrics['b_ped'][i], metrics['b_ed'][i]) = calculate_PED_balanced(
            real_arr[i], gen_arr[i], metric='l2')
    
    return {k: (v.mean(), v.std()) for k, v in metrics.items()}

def print_metrics(label, metrics):
    """Print metrics in a compact format."""
    m = {k: v[0] for k, v in metrics.items()}  # extract means
    print(f"{label:<8} Precision: {m['prec']:.2f}, Recall: {m['rec']:.2f}, F1: {m['f1']:.2f}"
          f"   ED_G: {m['g_ed']:.2f}, ED_L: {m['l_ed']:.2f}, ED_B: {m['b_ed']:.2f}"
          f"   PED_G: {m['g_ped']:.2f}, PED_L: {m['l_ped']:.2f}, PED_B: {m['b_ped']:.2f}")

def create_summary_entry(modality, uncond_metrics, cond_metrics, train_test_metrics, train_cond_metrics):
    """Create a summary dictionary entry."""
    entry = {'modality': modality}
    for condition, metrics in [
        ('uncond', uncond_metrics),
        ('cond', cond_metrics),
        ('train_vs_test', train_test_metrics),
        ('train_vs_cond', train_cond_metrics)
    ]:
        for metric, (mean, std) in metrics.items():
            entry[f'{metric}_{condition}'] = mean
            entry[f'{metric}_{condition}_std'] = std
    return entry

for modality in modalities_to_run:
    print(f"\n=== Running modality: {modality} ===")
    
    # Setup paths and load model
    test_real = modalities_map[modality]['test']
    train_real = modalities_map[modality]['train']
    train_real = train_real.dropna()
    n_samples, n_feats = test_real.shape
    base_dir = pathlib.Path(f"{results_path}/{dim}/{modality}_from_multi")
    ckpt_path = base_dir / 'train' / 'best_by_mse.pth'
    
    # Load and setup model
    ckpt = torch.load(ckpt_path, map_location='cpu')
    config = SimpleNamespace(**ckpt['config'])
    
    cond_datatypes = [m for m in modalities_map.keys() if m != modality]
    cond_dim_list = [modalities_map[c]['test'].shape[1] for c in cond_datatypes]
    
    diffusion = GaussianDiffusion(num_timesteps=1000).to(device)
    model = get_diffusion_model(config.architecture, diffusion, config, 
                               x_dim=n_feats, cond_dims=cond_dim_list).to(device)
    model.load_state_dict(ckpt['best_model_mse'])
    model.eval()
    
    # Generate unconditional samples
    cond_test_list = [pd.DataFrame(np.zeros_like(modalities_map[c]['test']), 
                                  columns=modalities_map[c]['test'].columns) 
                     for c in cond_datatypes]
    masks = [np.zeros(n_samples) for _ in cond_datatypes]
    
    _, uncond_generated_data = test_model(test_real, cond_test_list, model, diffusion,
                                         test_iterations=test_repeats, device=device, masks=masks)
    
    # Load conditional samples
    conditioning_string = '_'.join(cond_datatypes)
    synth_path = base_dir / 'test' / f'generated_samples_from_{conditioning_string}_best_mse.csv'
    cond_generated_data = pd.read_csv(synth_path)
    
    # Reshape arrays
    uncond_arr = uncond_generated_data.values.reshape(test_repeats, n_samples, n_feats)
    cond_arr = cond_generated_data.values.reshape(test_repeats, n_samples, n_feats)
    real_arr = np.tile(test_real.values[np.newaxis], (test_repeats, 1, 1))
    
    # Create random subsets of training data for each repeat
    train_real_arr = np.zeros((test_repeats, n_samples, n_feats))
    for i in range(test_repeats):
        idx = np.random.choice(len(train_real), size=n_samples, replace=False)
        train_real_arr[i] = train_real.values[idx]
    
    # Compute metrics
    uncond_metrics = compute_all_metrics(real_arr, uncond_arr, test_repeats)
    cond_metrics = compute_all_metrics(real_arr, cond_arr, test_repeats)
    train_test_metrics = compute_all_metrics(train_real_arr, uncond_arr, test_repeats)
    train_cond_metrics = compute_all_metrics(train_real_arr, cond_arr, test_repeats)
    
    # Print results
    print_metrics("Uncond", uncond_metrics)
    print_metrics("Cond", cond_metrics)
    print_metrics("Train vs Test", train_test_metrics)
    print_metrics("Train vs Cond", train_cond_metrics)
    
    # Add to summary
    summary.append(create_summary_entry(modality, uncond_metrics, cond_metrics, train_test_metrics, train_cond_metrics))

In [None]:
import argparse
import os
import json
import pandas as pd
import numpy as np
import pathlib
from types import SimpleNamespace

import torch

from lib.test import coherent_test_cos_rejection, test_model
from lib.config import modalities_list
from lib.read_data import read_data
from lib.get_models import get_diffusion_model
from lib.diffusion_models import GaussianDiffusion
from lib.metrics import calculate_PED_balanced, compute_prdc

# parameters
dim = 32    
test_repeats = 10

results_path = '../results'    
labels_dir = "../datasets_TCGA/downstream_labels/"
data_dir = '../datasets_TCGA/07_normalized/'

device = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")

# list the four modalities you want to run
modalities_to_run = modalities_list  

# read all data once
modalities_map = read_data(
    modalities=modalities_list,
    splits=['train','test'],
    data_dir=data_dir,
    dim=dim,
)

# store a summary of metrics if desired
summary = []

def compute_all_metrics(real_arr, gen_arr, test_repeats):
    """Compute all metrics for a given real/generated array pair."""
    metrics = {name: np.zeros(test_repeats) for name in 
               ['prec', 'rec', 'f1', 'g_ped', 'g_ed', 'l_ped', 'l_ed', 'b_ped', 'b_ed']}
    
    for i in range(test_repeats):
        # PRDC
        p, r = compute_prdc(real_arr[i], gen_arr[i], nearest_k=10, only_pr=True)
        metrics['prec'][i], metrics['rec'][i] = p, r
        metrics['f1'][i] = 2 * (p * r) / (p + r + 1e-8)
        
        # PED/ED metrics
        (metrics['g_ped'][i], metrics['g_ed'][i], 
         metrics['l_ped'][i], metrics['l_ed'][i],
         metrics['b_ped'][i], metrics['b_ed'][i]) = calculate_PED_balanced(
            real_arr[i], gen_arr[i], metric='l2')
    
    return {k: (v.mean(), v.std()) for k, v in metrics.items()}

def print_metrics(label, metrics):
    """Print metrics in a compact format."""
    m = {k: v[0] for k, v in metrics.items()}  # extract means
    print(f"{label:<8} Precision: {m['prec']:.2f}, Recall: {m['rec']:.2f}, F1: {m['f1']:.2f}"
          f"   ED_G: {m['g_ed']:.2f}, ED_L: {m['l_ed']:.2f}, ED_B: {m['b_ed']:.2f}"
          f"   PED_G: {m['g_ped']:.2f}, PED_L: {m['l_ped']:.2f}, PED_B: {m['b_ped']:.2f}")

def create_summary_entry(modality, uncond_metrics, cond_metrics, train_test_metrics, train_cond_metrics):
    """Create a summary dictionary entry."""
    entry = {'modality': modality}
    for condition, metrics in [
        ('uncond', uncond_metrics),
        ('cond', cond_metrics),
        ('train_vs_test', train_test_metrics),
        ('train_vs_cond', train_cond_metrics)
    ]:
        for metric, (mean, std) in metrics.items():
            entry[f'{metric}_{condition}'] = mean
            entry[f'{metric}_{condition}_std'] = std
    return entry

for modality in modalities_to_run:
    print(f"\n=== Running modality: {modality} ===")
    
    # Setup paths and load model
    test_real = modalities_map[modality]['test']
    train_real = modalities_map[modality]['train']
    train_real = train_real.dropna()
    n_samples, n_feats = test_real.shape
    base_dir = pathlib.Path(f"{results_path}/{dim}/{modality}_from_multi_masked")
    ckpt_path = base_dir / 'train' / 'best_by_mse.pth'
    
    # Load and setup model
    ckpt = torch.load(ckpt_path, map_location='cpu')
    config = SimpleNamespace(**ckpt['config'])
    
    cond_datatypes = [m for m in modalities_map.keys() if m != modality]
    cond_dim_list = [modalities_map[c]['test'].shape[1] for c in cond_datatypes]
    
    diffusion = GaussianDiffusion(num_timesteps=1000).to(device)
    model = get_diffusion_model(config.architecture, diffusion, config, 
                               x_dim=n_feats, cond_dims=cond_dim_list).to(device)
    model.load_state_dict(ckpt['best_model_mse'])
    model.eval()
    
    # Generate unconditional samples
    cond_test_list = [pd.DataFrame(np.zeros_like(modalities_map[c]['test']), 
                                  columns=modalities_map[c]['test'].columns) 
                     for c in cond_datatypes]
    masks = [np.zeros(n_samples) for _ in cond_datatypes]
    
    _, uncond_generated_data = test_model(test_real, cond_test_list, model, diffusion,
                                         test_iterations=test_repeats, device=device, masks=masks)
    
    # Load conditional samples
    conditioning_string = '_'.join(cond_datatypes)
    synth_path = base_dir / 'test' / f'generated_samples_from_{conditioning_string}_best_mse.csv'
    cond_generated_data = pd.read_csv(synth_path)
    
    # Reshape arrays
    uncond_arr = uncond_generated_data.values.reshape(test_repeats, n_samples, n_feats)
    cond_arr = cond_generated_data.values.reshape(test_repeats, n_samples, n_feats)
    real_arr = np.tile(test_real.values[np.newaxis], (test_repeats, 1, 1))
    
    # Create random subsets of training data for each repeat
    train_real_arr = np.zeros((test_repeats, n_samples, n_feats))
    for i in range(test_repeats):
        idx = np.random.choice(len(train_real), size=n_samples, replace=False)
        train_real_arr[i] = train_real.values[idx]
    
    # Compute metrics
    uncond_metrics = compute_all_metrics(real_arr, uncond_arr, test_repeats)
    cond_metrics = compute_all_metrics(real_arr, cond_arr, test_repeats)
    train_test_metrics = compute_all_metrics(train_real_arr, uncond_arr, test_repeats)
    train_cond_metrics = compute_all_metrics(train_real_arr, cond_arr, test_repeats)
    
    # Print results
    print_metrics("Uncond", uncond_metrics)
    print_metrics("Cond", cond_metrics)
    print_metrics("Train vs Test", train_test_metrics)
    print_metrics("Train vs Cond", train_cond_metrics)
    
    # Add to summary
    summary.append(create_summary_entry(modality, uncond_metrics, cond_metrics, train_test_metrics, train_cond_metrics))

In [None]:
import argparse
import os
import json
import pandas as pd
import numpy as np
import pathlib
from types import SimpleNamespace
import itertools

import torch

# Assuming lib functions are in the correct PYTHONPATH
from lib.test import coherent_test_cos_rejection, test_model
from lib.config import modalities_list
from lib.read_data import read_data
from lib.get_models import get_diffusion_model
from lib.diffusion_models import GaussianDiffusion
from lib.metrics import calculate_PED_balanced, compute_prdc

# --- Parameters ---
dim = '32'
test_repeats = 10
metric_to_use = 'mse' # The metric used to select the best model checkpoint

# Note: The 'folder' argument from your generation script is represented here as 'results_path'
results_path = '../results'
data_dir = '../datasets_TCGA/07_normalized/'

device = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")

# List the four modalities you want to run
modalities_to_run = modalities_list

# Read all data once
modalities_map = read_data(
    modalities=modalities_list,
    splits=['train','test'],
    data_dir=data_dir,
    dim=dim,
)

# Store a summary of metrics
summary = []

def compute_all_metrics(real_arr, gen_arr, test_repeats):
    """Compute all metrics for a given real/generated array pair."""
    valid_repeats = gen_arr.shape[0]
    metrics = {name: np.zeros(valid_repeats) for name in
               ['prec', 'rec', 'f1', 'g_ped', 'g_ed', 'l_ped', 'l_ed', 'b_ped', 'b_ed']}

    for i in range(valid_repeats):
        p, r = compute_prdc(real_arr[i], gen_arr[i], nearest_k=10, only_pr=True)
        metrics['prec'][i], metrics['rec'][i] = p, r
        metrics['f1'][i] = 2 * (p * r) / (p + r + 1e-8) if (p + r) > 1e-8 else 0.0
        (metrics['g_ped'][i], metrics['g_ed'][i],
         metrics['l_ped'][i], metrics['l_ed'][i],
         metrics['b_ped'][i], metrics['b_ed'][i]) = calculate_PED_balanced(
             real_arr[i], gen_arr[i], metric='l2')

    return {k: (v.mean(), v.std()) for k, v in metrics.items()}

def print_metrics(label, metrics):
    """Print metrics in a compact format."""
    m = {k: v[0] for k, v in metrics.items()}
    print(f"{label:<15} Precision: {m['prec']:.2f}, Recall: {m['rec']:.2f}, F1: {m['f1']:.2f}"
          f"  |  ED_G: {m['g_ed']:.2f}, ED_L: {m['l_ed']:.2f}, ED_B: {m['b_ed']:.2f}"
          f"  |  PED_G: {m['g_ped']:.2f}, PED_L: {m['l_ped']:.2f}, PED_B: {m['b_ped']:.2f}")

def create_summary_entry(modality, uncond_metrics, cond_metrics, train_test_metrics, train_cond_metrics):
    """Create a summary dictionary entry."""
    entry = {'modality': modality}
    for condition, metrics in [('uncond', uncond_metrics), ('cond', cond_metrics), ('train_vs_uncond', train_test_metrics), ('train_vs_cond', train_cond_metrics)]:
        if metrics:
            for metric, (mean, std) in metrics.items():
                entry[f'{metric}_{condition}'] = mean
                entry[f'{metric}_{condition}_std'] = std
    return entry

# --- Main Evaluation Loop ---
for modality in modalities_to_run:
    print(f"\n=== Running modality: {modality} ===")

    test_real = modalities_map[modality]['test']
    train_real = modalities_map[modality]['train'].dropna()
    n_samples, n_feats = test_real.shape

    cond_datatypes = [m for m in modalities_map.keys() if m != modality]
    diffusion = GaussianDiffusion(num_timesteps=1000).to(device)

    # --- Part 1: Generate Unconditional Samples using Coherent Ensemble ---
    print("Generating unconditional samples using coherent ensemble...")
    models_for_uncond = []
    weights_for_uncond = []
    zero_cond_list = []
    all_models_found = True

    for c in cond_datatypes:
        ckpt_path = pathlib.Path(f"{results_path}/{dim}/{modality}_from_{c}/train/best_by_{metric_to_use}.pth")
        if not ckpt_path.exists():
            print(f"WARNING: Checkpoint not found for model {modality}_from_{c}. Cannot generate unconditional samples.")
            print(f"Path: {ckpt_path}")
            all_models_found = False
            break

        ckpt = torch.load(ckpt_path, map_location='cpu')
        config_c = SimpleNamespace(**ckpt['config'])
        state_dict = ckpt[f'best_model_{metric_to_use}']
        weights_for_uncond.append(ckpt['best_loss'])

        cond_dim = modalities_map[c]['test'].shape[1]
        model_c = get_diffusion_model(
            config_c.architecture, diffusion, config_c,
            x_dim=n_feats, cond_dims=cond_dim
        ).to(device)
        model_c.load_state_dict(state_dict)
        model_c.eval()
        models_for_uncond.append(model_c)

        # Create a zeroed-out DataFrame for this condition
        shape = modalities_map[c]['test'].shape
        zero_cond_list.append(pd.DataFrame(np.zeros(shape), columns=modalities_map[c]['test'].columns))

    uncond_generated_data = None
    if all_models_found:
        _, uncond_generated_data, _ = coherent_test_cos_rejection(
            test_real,
            zero_cond_list,
            models_for_uncond,
            diffusion,
            test_iterations=test_repeats,
            max_retries= 10,
            device=device,
            weights_list=weights_for_uncond
        )
    else:
        print(f"Skipping unconditional evaluation for {modality}.")

    # --- Part 2: Load Pre-Generated Conditional Samples ---
    print("Loading conditional samples...")
    base_dir_coherent = pathlib.Path(f"{results_path}/{dim}/{modality}_from_coherent")
    conditioning_string = '_'.join(cond_datatypes)
    synth_path = base_dir_coherent / 'test' / f'generated_samples_from_{conditioning_string}_best_{metric_to_use}.csv'

    cond_generated_data = None
    if not synth_path.exists():
        print(f"WARNING: Coherent conditional file not found, skipping cond eval for {modality}.")
        print(f"Path: {synth_path}")
    else:
        cond_generated_data = pd.read_csv(synth_path)

    # --- Part 3: Reshape Arrays and Compute Metrics ---
    real_arr = np.tile(test_real.values[np.newaxis], (test_repeats, 1, 1))
    
    train_real_arr = np.zeros((test_repeats, n_samples, n_feats))
    for i in range(test_repeats):
        idx = np.random.choice(len(train_real), size=n_samples, replace=False)
        train_real_arr[i] = train_real.values[idx]

    uncond_metrics, cond_metrics, train_uncond_metrics, train_cond_metrics = {}, {}, {}, {}

    if uncond_generated_data is not None:
        uncond_arr = uncond_generated_data.values.reshape(-1, n_samples, n_feats)
        uncond_metrics = compute_all_metrics(real_arr, uncond_arr, test_repeats)
        train_uncond_metrics = compute_all_metrics(train_real_arr, uncond_arr, test_repeats)
        print_metrics("Uncond vs Real", uncond_metrics)
        print_metrics("Uncond vs Train", train_uncond_metrics)

    if cond_generated_data is not None:
        cond_arr = cond_generated_data.values.reshape(-1, n_samples, n_feats)
        cond_metrics = compute_all_metrics(real_arr, cond_arr, test_repeats)
        train_cond_metrics = compute_all_metrics(train_real_arr, cond_arr, test_repeats)
        print_metrics("Cond vs Real", cond_metrics)
        print_metrics("Cond vs Train", train_cond_metrics)

    summary.append(create_summary_entry(modality, uncond_metrics, cond_metrics, train_uncond_metrics, train_cond_metrics))


# --- Final Summary ---
if summary:
    summary_df = pd.DataFrame(summary)
    print("\n\n--- Evaluation Summary (Coherent Method) ---")
    print(summary_df.to_string())

In [None]:
import argparse
import os
import json
import pandas as pd
import numpy as np
import pathlib
from types import SimpleNamespace
from abc import ABC, abstractmethod

import torch

# Assuming lib functions are in the correct PYTHONPATH
from lib.test import coherent_test_cos_rejection, test_model
from lib.config import modalities_list
from lib.read_data import read_data
from lib.get_models import get_diffusion_model
from lib.diffusion_models import GaussianDiffusion
from lib.metrics import calculate_PED_balanced, compute_prdc

# --- Global Parameters & Setup ---
DEVICE = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")
METRIC_TO_USE = 'mse'

# =============================================================================
# 1. HELPER FUNCTIONS (UNCHANGED CORE LOGIC)
# =============================================================================

def compute_all_metrics(real_arr, gen_arr):
    """Compute all metrics for a given real/generated array pair."""
    valid_repeats = gen_arr.shape[0]
    metrics = {name: np.zeros(valid_repeats) for name in
               ['prec', 'rec', 'f1', 'g_ped', 'g_ed', 'l_ped', 'l_ed', 'b_ped', 'b_ed']}

    for i in range(valid_repeats):
        p, r = compute_prdc(real_arr[i], gen_arr[i], nearest_k=10, only_pr=True)
        metrics['prec'][i], metrics['rec'][i] = p, r
        metrics['f1'][i] = 2 * (p * r) / (p + r + 1e-8) if (p + r) > 1e-8 else 0.0
        (metrics['g_ped'][i], metrics['g_ed'][i],
         metrics['l_ped'][i], metrics['l_ed'][i],
         metrics['b_ped'][i], metrics['b_ed'][i]) = calculate_PED_balanced(
             real_arr[i], gen_arr[i], metric='l2')

    return {k: (v.mean(), v.std()) for k, v in metrics.items()}

# =============================================================================
# 2. STRATEGY CLASSES (CORE OF THE REFACTOR)
# =============================================================================

class EvaluationStrategy(ABC):
    """Abstract base class for an evaluation strategy."""
    def __init__(self, method_name, results_path, dim):
        self.method_name = method_name
        self.results_path = results_path
        self.dim = dim

    def get_base_dir(self, modality):
        return pathlib.Path(f"{self.results_path}/{self.dim}/{modality}_from_{self.method_name}")

    @abstractmethod
    def generate_unconditional(self, modality, modalities_map, n_feats, test_repeats):
        """Generates unconditional samples for a given modality."""
        pass

class MultiModelStrategy(EvaluationStrategy):
    """Strategy for 'multi' and 'multi_masked' models."""
    def generate_unconditional(self, modality, modalities_map, n_feats, test_repeats):
        print(f"  Generating unconditional samples using '{self.method_name}' strategy...")
        base_dir = self.get_base_dir(modality)
        ckpt_path = base_dir / 'train' / f'best_by_{METRIC_TO_USE}.pth'
        if not ckpt_path.exists():
            print(f"  WARNING: Checkpoint not found at {ckpt_path}. Skipping.")
            return None

        ckpt = torch.load(ckpt_path, map_location='cpu')
        config = SimpleNamespace(**ckpt['config'])
        cond_datatypes = [m for m in modalities_map.keys() if m != modality]
        cond_dim_list = [modalities_map[c]['test'].shape[1] for c in cond_datatypes]
        
        diffusion = GaussianDiffusion(num_timesteps=1000).to(DEVICE)
        model = get_diffusion_model(config.architecture, diffusion, config,
                                    x_dim=n_feats, cond_dims=cond_dim_list).to(DEVICE)
        model.load_state_dict(ckpt[f'best_model_{METRIC_TO_USE}'])
        model.eval()

        n_samples = len(modalities_map[modality]['test'])
        zero_conds = [pd.DataFrame(np.zeros_like(modalities_map[c]['test'])) for c in cond_datatypes]
        masks = [np.zeros(n_samples) for _ in cond_datatypes]
        
        _, uncond_data = test_model(
            modalities_map[modality]['test'], zero_conds, model, diffusion,
            test_iterations=test_repeats, device=DEVICE, masks=masks
        )
        return uncond_data

class CoherentStrategy(EvaluationStrategy):
    """Strategy for the 'coherent' ensemble model."""
    def __init__(self, method_name, results_path, dim):
        super().__init__(method_name, results_path, dim)

    def get_base_dir(self, modality):
        # The coherent results are stored in a differently named folder
        return pathlib.Path(f"{self.results_path}/{self.dim}/{modality}_from_coherent")

    def generate_unconditional(self, modality, modalities_map, n_feats, test_repeats):
        print("  Generating unconditional samples using 'coherent' strategy...")
        cond_datatypes = [m for m in modalities_map.keys() if m != modality]
        diffusion = GaussianDiffusion(num_timesteps=1000).to(DEVICE)

        models, weights, zero_conds = [], [], []
        for c in cond_datatypes:
            # Coherent method relies on single-pair models
            ckpt_path = pathlib.Path(f"{self.results_path}/{self.dim}/{modality}_from_{c}/train/best_by_{METRIC_TO_USE}.pth")
            if not ckpt_path.exists():
                print(f"  WARNING: Coherent dependency not found at {ckpt_path}. Skipping.")
                return None

            ckpt = torch.load(ckpt_path, map_location='cpu')
            config_c = SimpleNamespace(**ckpt['config'])
            model_c = get_diffusion_model(
                config_c.architecture, diffusion, config_c,
                x_dim=n_feats, cond_dims=modalities_map[c]['test'].shape[1]
            ).to(DEVICE)
            model_c.load_state_dict(ckpt[f'best_model_{METRIC_TO_USE}'])
            model_c.eval()

            models.append(model_c)
            weights.append(ckpt['best_loss'])
            zero_conds.append(pd.DataFrame(np.zeros_like(modalities_map[c]['test'])))

        _, uncond_data, _ = coherent_test_cos_rejection(
            modalities_map[modality]['test'], zero_conds, models, diffusion,
            test_iterations=test_repeats, max_retries=10, device=DEVICE, weights_list=weights
        )
        return uncond_data

# =============================================================================
# 3. MAIN EVALUATION SCRIPT
# =============================================================================

def run_evaluation(args):
    """Main function to run the evaluation for specified methods."""
    
    # Load data once
    modalities_map = read_data(
        modalities=modalities_list,
        splits=['train', 'test'],
        data_dir=args.data_dir,
        dim=args.dim,
    )

    # Instantiate all available strategies
    all_strategies = {
        'multi': MultiModelStrategy('multi', args.results_path, args.dim),
        'multi_masked': MultiModelStrategy('multi_masked', args.results_path, args.dim),
        'coherent': CoherentStrategy('coherent', args.results_path, args.dim),
    }

    # Filter strategies based on user input
    methods_to_run = args.methods if 'all' not in args.methods else all_strategies.keys()
    strategies_to_run = {name: strat for name, strat in all_strategies.items() if name in methods_to_run}

    if not strategies_to_run:
        print("No valid methods selected. Available methods are: 'multi', 'multi_masked', 'coherent', 'all'")
        return

    summary_rows = []
    for method_name, strategy in strategies_to_run.items():
        print(f"\n{'='*20} Evaluating Method: {method_name.upper()} {'='*20}")
        
        for modality in modalities_list:
            print(f"\n--- Running modality: {modality} ---")
            
            # --- a) Prepare data ---
            test_real_df = modalities_map[modality]['test']
            train_real_df = modalities_map[modality]['train'].dropna()
            n_samples, n_feats = test_real_df.shape
            
            # --- b) Generate/Load Samples ---
            uncond_gen_df = strategy.generate_unconditional(modality, modalities_map, n_feats, args.test_repeats)
            
            cond_datatypes = [m for m in modalities_map.keys() if m != modality]
            cond_string = '_'.join(cond_datatypes)
            cond_path = strategy.get_base_dir(modality) / 'test' / f'generated_samples_from_{cond_string}_best_{METRIC_TO_USE}.csv'
            
            cond_gen_df = None
            if cond_path.exists():
                cond_gen_df = pd.read_csv(cond_path)
            else:
                print(f"  WARNING: Conditional data not found at {cond_path}. Skipping.")

            # --- c) Compute Metrics ---
            real_arr = np.tile(test_real_df.values[np.newaxis], (args.test_repeats, 1, 1))
            train_arr = np.array([train_real_df.sample(n=n_samples).values for _ in range(args.test_repeats)])
            
            # Process and print metrics for this modality and method
            modality_results = {'method': method_name, 'modality': modality}

            if uncond_gen_df is not None:
                uncond_arr = uncond_gen_df.values.reshape(args.test_repeats, n_samples, n_feats)
                modality_results['uncond_vs_test'] = compute_all_metrics(real_arr, uncond_arr)
                modality_results['uncond_vs_train'] = compute_all_metrics(train_arr, uncond_arr)
                print("\n  -- Unconditional Metrics --")
                print_metrics("vs Test", modality_results['uncond_vs_test'])
                print_metrics("vs Train", modality_results['uncond_vs_train'])
                
            if cond_gen_df is not None:
                cond_arr = cond_gen_df.values.reshape(args.test_repeats, n_samples, n_feats)
                modality_results['cond_vs_test'] = compute_all_metrics(real_arr, cond_arr)
                modality_results['cond_vs_train'] = compute_all_metrics(train_arr, cond_arr)
                print("\n  -- Conditional Metrics --")
                print_metrics("vs Test", modality_results['cond_vs_test'])
                print_metrics("vs Train", modality_results['cond_vs_train'])

            # --- d) Format for Summary ---
            flat_summary_row = {'method': method_name, 'modality': modality}
            for eval_type, metrics_dict in modality_results.items():
                if isinstance(metrics_dict, dict):
                    for metric_name, (mean, std) in metrics_dict.items():
                        flat_summary_row[f'{eval_type}_{metric_name}_mean'] = mean
                        flat_summary_row[f'{eval_type}_{metric_name}_std'] = std
            summary_rows.append(flat_summary_row)

    # --- Final Summary ---
    if summary_rows:
        summary_df = pd.DataFrame(summary_rows)
        output_path = f'{args.results_path}/{args.dim}/distributions_coverage.csv'
        summary_df.to_csv(output_path, index=False)
        print(f"\n\n{'='*25} FULL SUMMARY {'='*25}")
        print(summary_df.to_string())
        print(f"\nComplete summary saved to: {output_path}")

from types import SimpleNamespace

if __name__ == '__main__':
    # --- Hardcoded Parameters ---
    # Create a 'SimpleNamespace' object to mimic the 'args' object created by argparse.
    # This allows you to pass the parameters to the run_evaluation function without changing its internal logic.
    args = SimpleNamespace(
        dim=32,
        test_repeats=10,
        results_path='../results',
        data_dir='../datasets_TCGA/07_normalized/',
        methods=['all']  # This will run 'multi', 'multi_masked', and 'coherent'
    )
    
    # Call the main evaluation function with the hardcoded parameters
    run_evaluation(args)

In [None]:
import os
import pandas as pd
import numpy as np
import pathlib
from types import SimpleNamespace
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import umap

import torch



# --- Global Parameters & Setup ---
DEVICE = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")
METRIC_TO_USE = 'mse'


def plot_distributions(real_df, uncond_df, cond_df, method_name, modality):
    """
    Creates and displays a comprehensive 2D embedding plot (PCA & UMAP) comparing
    real, unconditional, and conditional data distributions.
    """
    print(f"  Generating plots for {modality} using {method_name}...")

    datasets = {'Real': real_df, 'Unconditional': uncond_df}
    if cond_df is not None:
        datasets['Conditional'] = cond_df

    # Combine all data for a single, consistent embedding
    all_data = pd.concat(datasets.values(), ignore_index=True)
    labels = np.concatenate([np.full(len(df), name) for name, df in datasets.items()])

    # --- Fit Embeddings ONCE on all data ---
    pca = PCA(n_components=2, random_state=42)
    embedding_pca = pca.fit_transform(all_data)

    reducer = umap.UMAP(n_components=2, n_neighbors=30, min_dist=0.1, random_state=42)
    embedding_umap = reducer.fit_transform(all_data)

    # --- Create the Plot ---
    fig, axes = plt.subplots(1, 2, figsize=(18, 8))
    fig.suptitle(f'Distribution Embeddings: {modality.upper()} (Method: {method_name.capitalize()})', fontsize=20)

    plot_styles = {
        'Real':          {'color': "#405d72", 'alpha': 0.6, 's': 5, 'zorder': 1},
        'Unconditional': {'color': "#ff0efb", 'alpha': 0.4, 's': 5, 'zorder': 2},
        'Conditional':   {'color': "#2dd42d", 'alpha': 0.4, 's': 5, 'zorder': 3}
    }

    # Plot PCA
    axes[0].set_title('PCA Projection', fontsize=14)
    for label, style in plot_styles.items():
        if label in datasets:
            idx = labels == label
            axes[0].scatter(embedding_pca[idx, 0], embedding_pca[idx, 1], label=label, **style)
    axes[0].grid(True, linestyle='--', alpha=0.6)

    # Plot UMAP
    axes[1].set_title('UMAP Projection', fontsize=14)
    for label, style in plot_styles.items():
        if label in datasets:
            idx = labels == label
            axes[1].scatter(embedding_umap[idx, 0], embedding_umap[idx, 1], label=label, **style)
    axes[1].grid(True, linestyle='--', alpha=0.6)

    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper right', bbox_to_anchor=(0.99, 0.95), fontsize=12)
    plt.tight_layout(rect=[0, 0, 1, 0.96])

    # --- Show the plot interactively ---
    # This will pause the script until you close the plot window.
    plt.show()


def run_plotting(args):
    """Main function to run the plotting for specified methods."""

    modalities_map = read_data(
        modalities=modalities_list,
        splits=['train', 'test'],
        data_dir=args.data_dir,
        dim=args.dim,
    )

    all_strategies = {
        'multi': MultiModelStrategy('multi', args.results_path, args.dim),
        'multi_masked': MultiModelStrategy('multi_masked', args.results_path, args.dim),
        'coherent': CoherentStrategy('coherent', args.results_path, args.dim),
    }

    methods_to_run = args.methods if 'all' not in args.methods else all_strategies.keys()
    strategies_to_run = {name: strat for name, strat in all_strategies.items() if name in methods_to_run}

    if not strategies_to_run:
        print("No valid methods selected.")
        return

    for method_name, strategy in strategies_to_run.items():
        print(f"\n{'='*20} Plotting for Method: {method_name.upper()} {'='*20}")

        for modality in modalities_list:
            print(f"\n--- Processing modality: {modality} ---")

            test_real_df = modalities_map[modality]['test']
            n_samples, n_feats = test_real_df.shape

            uncond_gen_df = strategy.generate_unconditional(modality, modalities_map, n_feats, 1)

            cond_datatypes = [m for m in modalities_map.keys() if m != modality]
            cond_string = '_'.join(cond_datatypes)
            cond_path = strategy.get_base_dir(modality) / 'test' / f'generated_samples_from_{cond_string}_best_{METRIC_TO_USE}.csv'

            cond_gen_df = pd.read_csv(cond_path) if cond_path.exists() else None
            if cond_gen_df is not None:
                cond_gen_df = cond_gen_df.head(n_samples)

            if uncond_gen_df is None:
                print(f"  Skipping plots for {modality} as unconditional data could not be generated.")
                continue

            plot_distributions(
                real_df=test_real_df,
                uncond_df=uncond_gen_df,
                cond_df=cond_gen_df,
                method_name=method_name,
                modality=modality
            )

if __name__ == '__main__':
    # Create a SimpleNamespace object with hardcoded parameters to run everything.
    # This replaces the command-line parser.
    args = SimpleNamespace(
        dim=32,
        results_path='../results',
        data_dir='../datasets_TCGA/07_normalized/',
        methods=['all']  # Instructs the script to run all available strategies
    )

    run_plotting(args)

In [None]:
import os
import pandas as pd
import numpy as np
import pathlib
from types import SimpleNamespace
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import umap

import torch




# --- Global Parameters & Setup ---
DEVICE = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")
METRIC_TO_USE = 'mse'


def plot_combined_distributions(real_df, uncond_df, cond_df, method_name):
    """
    Creates and displays a 2x2 embedding plot, fitting PCA/UMAP on real data only
    and then transforming all data into that space.
    """
    print(f"  Generating combined plot for method '{method_name}'...")

    # --- Data Preparation ---
    datasets = {'Real': real_df, 'Unconditional': uncond_df}
    has_conditional = cond_df is not None
    if has_conditional:
        datasets['Conditional'] = cond_df

    # --- MODIFIED: Fit Embeddings ONLY on REAL data, then transform all ---
    print("    Fitting PCA and UMAP on REAL data only...")

    # PCA
    pca = PCA(n_components=2, random_state=42)
    pca.fit(real_df)  # Fit only on the real data
    # Transform each dataset into the learned space
    pca_embeddings = {name: pca.transform(df) for name, df in datasets.items()}

    # UMAP
    reducer = umap.UMAP(n_components=2, n_neighbors=30, min_dist=0.1, random_state=42)
    reducer.fit(real_df) # Fit only on the real data
    # Transform each dataset into the learned space
    umap_embeddings = {name: reducer.transform(df) for name, df in datasets.items()}
    print("    Transformation complete.")


    # --- Create the 2x2 Plot ---
    fig, axes = plt.subplots(2, 2, figsize=(20, 18))
    fig.suptitle(f'Combined Distribution Embeddings (Method: {method_name.capitalize()})', fontsize=22)

    plot_styles = {
        'Real':          {'color': "#3a7cac", 'alpha': 0.9, 's': 6, 'zorder': 1, 'label': 'Real'},
        'Unconditional': {'color': "#e01122", 'alpha': 0.6, 's': 6, 'zorder': 2, 'label': 'Unconditional'},
        'Conditional':   {'color': "#f2991b", 'alpha': 0.6, 's': 6, 'zorder': 2, 'label': 'Conditional'}
    }

    # --- Helper function for plotting on a subplot ---
    def plot_on_ax(ax, embedding_dict, labels_to_plot):
        for label in labels_to_plot:
            if label in embedding_dict:
                embedding = embedding_dict[label]
                ax.scatter(embedding[:, 0], embedding[:, 1], **plot_styles[label])
        ax.grid(True, linestyle='--', alpha=0.6)

    # --- Populate the subplots ---
    # Top-Left: PCA Real vs Unconditional
    axes[0, 0].set_title('PCA: Real vs. Unconditional', fontsize=16)
    plot_on_ax(axes[0, 0], pca_embeddings, ['Real', 'Unconditional'])

    # Top-Right: PCA Real vs Conditional
    axes[0, 1].set_title('PCA: Real vs. Conditional', fontsize=16)
    if has_conditional:
        plot_on_ax(axes[0, 1], pca_embeddings, ['Real', 'Conditional'])
    else:
        axes[0, 1].text(0.5, 0.5, 'Conditional Data\nNot Available', ha='center', va='center', transform=axes[0, 1].transAxes, alpha=0.5, fontsize=14)
        axes[0, 1].tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)


    # Bottom-Left: UMAP Real vs Unconditional
    axes[1, 0].set_title('UMAP: Real vs. Unconditional', fontsize=16)
    plot_on_ax(axes[1, 0], umap_embeddings, ['Real', 'Unconditional'])

    # Bottom-Right: UMAP Real vs Conditional
    axes[1, 1].set_title('UMAP: Real vs. Conditional', fontsize=16)
    if has_conditional:
        plot_on_ax(axes[1, 1], umap_embeddings, ['Real', 'Conditional'])
    else:
        axes[1, 1].text(0.5, 0.5, 'Conditional Data\nNot Available', ha='center', va='center', transform=axes[1, 1].transAxes, alpha=0.5, fontsize=14)
        axes[1, 1].tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)

    # --- Create a single, unified legend for the figure ---
    handles = [plt.Line2D([0], [0], marker='o', color='w', label=style['label'],
                          markerfacecolor=style['color'], markersize=12, alpha=0.8)
               for label, style in plot_styles.items() if label in datasets]
    fig.legend(handles=handles, loc='upper right', bbox_to_anchor=(0.98, 0.96), fontsize=14)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])

    plt.show()


def run_plotting(args):
    """Main function to run the plotting for specified methods."""

    modalities_map = read_data(
        modalities=modalities_list,
        splits=['train', 'test'],
        data_dir=args.data_dir,
        dim=args.dim,
    )

    all_strategies = {
        'multi': MultiModelStrategy('multi', args.results_path, args.dim),
        'multi_masked': MultiModelStrategy('multi_masked', args.results_path, args.dim),
        'coherent': CoherentStrategy('coherent', args.results_path, args.dim),
    }

    methods_to_run = args.methods if 'all' not in args.methods else all_strategies.keys()
    strategies_to_run = {name: strat for name, strat in all_strategies.items() if name in methods_to_run}

    if not strategies_to_run:
        print("No valid methods selected.")
        return

    for method_name, strategy in strategies_to_run.items():
        print(f"\n{'='*20} Processing Method: {method_name.upper()} {'='*20}")

        real_dfs, uncond_dfs, cond_dfs = [], [], []
        
        for modality in modalities_list:
            print(f"--- Generating data for modality: {modality} ---")

            test_real_df = modalities_map[modality]['test']
            n_samples, n_feats = test_real_df.shape

            uncond_gen_df = strategy.generate_unconditional(modality, modalities_map, n_feats, 1)

            if uncond_gen_df is None:
                print(f"  Skipping modality {modality} for method {method_name} due to generation failure.")
                continue

            cond_datatypes = [m for m in modalities_map.keys() if m != modality]
            cond_string = '_'.join(cond_datatypes)
            cond_path = strategy.get_base_dir(modality) / 'test' / f'generated_samples_from_{cond_string}_best_{METRIC_TO_USE}.csv'

            cond_gen_df = pd.read_csv(cond_path) if cond_path.exists() else None
            if cond_gen_df is not None:
                cond_gen_df = cond_gen_df.head(n_samples)
                cond_dfs.append(cond_gen_df)

            real_dfs.append(test_real_df)
            uncond_dfs.append(uncond_gen_df)

        if not real_dfs:
            print(f"No data was successfully generated for method '{method_name}'. Skipping plot.")
            continue

        final_real_df = pd.concat(real_dfs, axis=1)
        final_uncond_df = pd.concat(uncond_dfs, axis=1)
        final_cond_df = pd.concat(cond_dfs, axis=1) if len(cond_dfs) == len(real_dfs) else None

        plot_combined_distributions(
            real_df=final_real_df,
            uncond_df=final_uncond_df,
            cond_df=final_cond_df,
            method_name=method_name
        )


if __name__ == '__main__':
    args = SimpleNamespace(
        dim=32,
        results_path='../results',
        data_dir='../datasets_TCGA/07_normalized/',
        methods=['all']
    )
    run_plotting(args)

In [None]:
import os
import pandas as pd
import numpy as np
import pathlib
from types import SimpleNamespace
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import umap

import torch



# --- Global Parameters & Setup ---
DEVICE = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")
METRIC_TO_USE = 'mse'


def plot_combined_distributions(real_df, uncond_df, cond_df, method_name):
    """
    Creates and displays a 2x2 embedding plot, fitting PCA/UMAP on real data only
    and then transforming all data into that space.
    """
    print(f"  Generating combined plot for method '{method_name}'...")

    # --- Data Preparation ---
    # The 'Real' label will now correspond to the Training Set sample
    datasets = {'Training Set': real_df, 'Unconditional': uncond_df}
    has_conditional = cond_df is not None
    if has_conditional:
        datasets['Conditional'] = cond_df

    # --- Fit Embeddings ONLY on REAL (Training) data, then transform all ---
    print("    Fitting PCA and UMAP on REAL (Training) data only...")

    # PCA
    pca = PCA(n_components=2, random_state=42)
    pca.fit(real_df)
    pca_embeddings = {name: pca.transform(df) for name, df in datasets.items()}

    # UMAP
    # reducer = umap.UMAP(n_components=2, n_neighbors=30, min_dist=0.1, random_state=42)
    reducer = umap.UMAP(n_components=2, n_neighbors=15, min_dist=0.5, random_state=42)
    reducer.fit(real_df)
    umap_embeddings = {name: reducer.transform(df) for name, df in datasets.items()}
    print("    Transformation complete.")

    # --- Create the 2x2 Plot ---
    fig, axes = plt.subplots(2, 2, figsize=(20, 18))
    fig.suptitle(f'Generated vs. Training Set Distributions (Method: {method_name.capitalize()})', fontsize=22)

    plot_styles = {
        'Training Set':  {'color': "#3a7cac", 'alpha': 0.9, 's': 6, 'zorder': 1},
        'Unconditional': {'color': "#e01122", 'alpha': 0.6, 's': 6, 'zorder': 2},
        'Conditional':   {'color': "#f2991b", 'alpha': 0.6, 's': 6, 'zorder': 2}
    }

    # Helper function for plotting on a subplot
    def plot_on_ax(ax, embedding_dict, labels_to_plot):
        for label in labels_to_plot:
            if label in embedding_dict:
                embedding = embedding_dict[label]
                ax.scatter(embedding[:, 0], embedding[:, 1], label=label, **plot_styles[label])
        ax.grid(True, linestyle='--', alpha=0.6)

    # --- Populate the subplots ---
    axes[0, 0].set_title('PCA: Training Set vs. Unconditional', fontsize=16)
    plot_on_ax(axes[0, 0], pca_embeddings, ['Training Set', 'Unconditional'])

    axes[0, 1].set_title('PCA: Training Set vs. Conditional', fontsize=16)
    if has_conditional:
        plot_on_ax(axes[0, 1], pca_embeddings, ['Training Set', 'Conditional'])
    else:
        axes[0, 1].text(0.5, 0.5, 'Conditional Data\nNot Available', ha='center', va='center', transform=axes[0, 1].transAxes, alpha=0.5, fontsize=14)
        axes[0, 1].tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)

    axes[1, 0].set_title('UMAP: Training Set vs. Unconditional', fontsize=16)
    plot_on_ax(axes[1, 0], umap_embeddings, ['Training Set', 'Unconditional'])

    axes[1, 1].set_title('UMAP: Training Set vs. Conditional', fontsize=16)
    if has_conditional:
        plot_on_ax(axes[1, 1], umap_embeddings, ['Training Set', 'Conditional'])
    else:
        axes[1, 1].text(0.5, 0.5, 'Conditional Data\nNot Available', ha='center', va='center', transform=axes[1, 1].transAxes, alpha=0.5, fontsize=14)
        axes[1, 1].tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)

    # --- Create a single, unified legend for the figure ---
    handles = [plt.Line2D([0], [0], marker='o', color='w', label=label,
                          markerfacecolor=style['color'], markersize=12, alpha=0.8)
               for label, style in plot_styles.items() if label in datasets]
    fig.legend(handles=handles, loc='upper right', bbox_to_anchor=(0.98, 0.96), fontsize=14)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])

    plt.show()

def run_plotting(args):
    """Main function to run the plotting for specified methods."""

    modalities_map = read_data(
        modalities=modalities_list,
        splits=['train', 'test'],
        data_dir=args.data_dir,
        dim=args.dim,
    )

    # --- NEW: Pre-process training data to find common, complete samples ---
    print("Finding common samples with no missing values across all training modalities...")
    all_train_dfs = [modalities_map[modality]['train'] for modality in modalities_list]
    combined_train_df = pd.concat(all_train_dfs, axis=1)
    
    # Drop any row that has a NaN in ANY of the modalities
    complete_train_df = combined_train_df.dropna()
    print(f"Found {len(complete_train_df)} complete samples common to all modalities.")

    if complete_train_df.empty:
        print("Error: No common samples found across all modalities after dropping NaNs. Cannot proceed.")
        return

    # Create a new map of the clean, common training data for each modality
    complete_train_data_map = {}
    start_col = 0
    for modality in modalities_list:
        num_feats = modalities_map[modality]['train'].shape[1]
        modality_df = complete_train_df.iloc[:, start_col : start_col + num_feats]
        modality_df.columns = modalities_map[modality]['train'].columns # Restore original column names
        complete_train_data_map[modality] = modality_df
        start_col += num_feats
    # ------------------------------------------------------------------------

    all_strategies = {
        'multi': MultiModelStrategy('multi', args.results_path, args.dim),
        'multi_masked': MultiModelStrategy('multi_masked', args.results_path, args.dim),
        'coherent': CoherentStrategy('coherent', args.results_path, args.dim),
    }

    methods_to_run = args.methods if 'all' not in args.methods else all_strategies.keys()
    strategies_to_run = {name: strat for name, strat in all_strategies.items() if name in methods_to_run}

    if not strategies_to_run:
        print("No valid methods selected.")
        return

    for method_name, strategy in strategies_to_run.items():
        print(f"\n{'='*20} Processing Method: {method_name.upper()} {'='*20}")

        real_data_for_plotting_dfs, uncond_dfs, cond_dfs = [], [], []

        for modality in modalities_list:
            print(f"--- Generating data for modality: {modality} ---")

            test_real_df = modalities_map[modality]['test']
            # MODIFIED: Use the pre-cleaned training data
            train_real_df = complete_train_data_map[modality]
            
            n_samples_from_test = len(test_real_df)
            
            # Ensure we don't try to sample more than we have
            n_samples = min(n_samples_from_test, len(train_real_df))
            if n_samples < n_samples_from_test:
                 print(f"  Warning: Only {len(train_real_df)} complete training samples available. Using this smaller size for comparison.")


            _, n_feats = test_real_df.shape
            uncond_gen_df = strategy.generate_unconditional(modality, modalities_map, n_feats, 1)

            if uncond_gen_df is None:
                print(f"  Skipping modality {modality} for method {method_name} due to generation failure.")
                continue

            cond_datatypes = [m for m in modalities_map.keys() if m != modality]
            cond_string = '_'.join(cond_datatypes)
            cond_path = strategy.get_base_dir(modality) / 'test' / f'generated_samples_from_{cond_string}_best_{METRIC_TO_USE}.csv'

            cond_gen_df = pd.read_csv(cond_path) if cond_path.exists() else None
            if cond_gen_df is not None:
                cond_gen_df = cond_gen_df.head(n_samples)
                cond_dfs.append(cond_gen_df)

            train_sample_df = train_real_df.sample(n=n_samples, random_state=42, replace=False)
            real_data_for_plotting_dfs.append(train_sample_df)
            uncond_dfs.append(uncond_gen_df.head(n_samples))

        if not real_data_for_plotting_dfs:
            print(f"No data was successfully generated for method '{method_name}'. Skipping plot.")
            continue

        final_real_df_for_plot = pd.concat(real_data_for_plotting_dfs, axis=1)
        final_uncond_df = pd.concat(uncond_dfs, axis=1)
        final_cond_df = pd.concat(cond_dfs, axis=1) if len(cond_dfs) == len(real_data_for_plotting_dfs) else None

        plot_combined_distributions(
            real_df=final_real_df_for_plot,
            uncond_df=final_uncond_df,
            cond_df=final_cond_df,
            method_name=method_name
        )


if __name__ == '__main__':
    args = SimpleNamespace(
        dim=32,
        results_path='../results',
        data_dir='../datasets_TCGA/07_normalized/',
        methods=['all']
    )
    run_plotting(args)

In [None]:
import os
import pandas as pd
import numpy as np
import pathlib
from types import SimpleNamespace
import matplotlib.pyplot as plt
import umap
import torch
from abc import ABC, abstractmethod

# Assuming lib functions are in your PYTHONPATH
from lib.test import coherent_test_cos_rejection, test_model
from lib.config import modalities_list
from lib.read_data import read_data
from lib.get_models import get_diffusion_model
from lib.diffusion_models import GaussianDiffusion

# =============================================================================
# 1. STRATEGY CLASSES (WITH CORRECTED SIGNATURES)
# =============================================================================

class EvaluationStrategy(ABC):
    """Abstract base class for an evaluation strategy."""
    def __init__(self, method_name, results_path, dim):
        self.method_name = method_name
        self.results_path = results_path
        self.dim = dim

    def get_base_dir(self, modality):
        return pathlib.Path(f"{self.results_path}/{self.dim}/{modality}_from_{self.method_name}")

    @abstractmethod
    def generate_unconditional(self, modality, modalities_map, n_feats, test_repeats):
        """Generates unconditional samples for a given modality."""
        pass

class MultiModelStrategy(EvaluationStrategy):
    """Strategy for 'multi' and 'multi_masked' models."""
    def generate_unconditional(self, modality, modalities_map, n_feats, test_repeats):
        print(f"  Generating unconditional samples using '{self.method_name}' for '{modality}'...")
        base_dir = self.get_base_dir(modality)
        ckpt_path = base_dir / 'train' / 'best_by_mse.pth'
        if not ckpt_path.exists():
            print(f"  WARNING: Checkpoint not found at {ckpt_path}. Skipping.")
            return None

        ckpt = torch.load(ckpt_path, map_location='cpu')
        config = SimpleNamespace(**ckpt['config'])
        
        cond_datatypes = [m for m in modalities_map.keys() if m != modality]
        cond_dim_list = [modalities_map[c]['test'].shape[1] for c in cond_datatypes]

        diffusion = GaussianDiffusion(num_timesteps=1000).to(DEVICE)
        
        # --- THE FIX: Use n_feats passed from the main loop ---
        model = get_diffusion_model(
            config.architecture, diffusion, config,
            x_dim=n_feats, 
            cond_dims=cond_dim_list
        ).to(DEVICE)
        
        model.load_state_dict(ckpt['best_model_mse'])
        model.eval()

        n_samples = len(modalities_map[modality]['test'])
        zero_conds = [pd.DataFrame(np.zeros_like(modalities_map[c]['test'])) for c in cond_datatypes]
        masks = [np.zeros(n_samples) for _ in cond_datatypes]
        
        _, uncond_data = test_model(
            modalities_map[modality]['test'], zero_conds, model, diffusion,
            test_iterations=test_repeats, device=DEVICE, masks=masks
        )
        return uncond_data.iloc[:n_samples]

class CoherentStrategy(EvaluationStrategy):
    """Strategy for the 'coherent' ensemble model."""
    def generate_unconditional(self, modality, modalities_map, n_feats, test_repeats):
        print(f"  Generating unconditional samples using 'coherent' for '{modality}'...")
        cond_datatypes = [m for m in modalities_map.keys() if m != modality]
        diffusion = GaussianDiffusion(num_timesteps=1000).to(DEVICE)

        models, weights, zero_conds = [], [], []
        for c in cond_datatypes:
            ckpt_path = pathlib.Path(f"{self.results_path}/{self.dim}/{modality}_from_{c}/train/best_by_mse.pth")
            if not ckpt_path.exists():
                print(f"  WARNING: Coherent dependency not found at {ckpt_path}. Skipping.")
                return None

            ckpt = torch.load(ckpt_path, map_location='cpu')
            config_c = SimpleNamespace(**ckpt['config'])
            
            # --- THE FIX: Use n_feats passed from the main loop ---
            model_c = get_diffusion_model(
                config_c.architecture, diffusion, config_c,
                x_dim=n_feats, 
                cond_dims=modalities_map[c]['test'].shape[1]
            ).to(DEVICE)
            
            model_c.load_state_dict(ckpt['best_model_mse'])
            model_c.eval()
            models.append(model_c)
            weights.append(ckpt['best_loss'])
            zero_conds.append(pd.DataFrame(np.zeros_like(modalities_map[c]['test'])))

        _, uncond_data, _ = coherent_test_cos_rejection(
            modalities_map[modality]['test'], zero_conds, models, diffusion,
            test_iterations=test_repeats, max_retries=10, device=DEVICE, weights_list=weights
        )
        n_samples = len(modalities_map[modality]['test'])
        return uncond_data.iloc[:n_samples]

# =============================================================================
# 2. PLOTTING FUNCTION (No changes needed here)
# =============================================================================
def plot_unconditional_comparison(train_df, uncond_coherent_df, uncond_multi_df, save_path=None):
    # This function is correct and remains the same as the last working version
    print("\n>>> Generating final 2-panel UMAP plot...")
    plt.style.use('seaborn-v0_8-white')
    # ... (rest of the plotting function is omitted for brevity but is unchanged) ...
    plot_styles = {
        'Training Set': {'color': '#0072b2', 'alpha': 0.7, 's': 10, 'label': 'Training Set'},
        'Generated':    {'color': '#e66000', 'alpha': 0.7, 's': 10, 'label': 'Generated Unconditional'}
    }
    print("  Fitting UMAP on training data...")
    reducer = umap.UMAP(n_components=2, n_neighbors=15, min_dist=0.8, spread=2.0, random_state=42)
    train_emb = reducer.fit_transform(train_df)
    coherent_emb = reducer.transform(uncond_coherent_df)
    multi_emb = reducer.transform(uncond_multi_df)
    print("  Transformation complete.")
    fig, axes = plt.subplots(1, 2, figsize=(18, 8))
    fig.suptitle('UMAP: Training Set vs. Unconditional Generation', fontsize=22, weight='bold')
    all_x = np.concatenate([train_emb[:, 0], coherent_emb[:, 0], multi_emb[:, 0]])
    all_y = np.concatenate([train_emb[:, 1], coherent_emb[:, 1], multi_emb[:, 1]])
    x_range = all_x.max() - all_x.min()
    y_range = all_y.max() - all_y.min()
    max_range = max(x_range, y_range) * 1.1
    x_center = (all_x.max() + all_x.min()) / 2
    y_center = (all_y.max() + all_y.min()) / 2
    xlims = (x_center - max_range / 2, x_center + max_range / 2)
    ylims = (y_center - max_range / 2, y_center + max_range / 2)
    axes[0].scatter(train_emb[:, 0], train_emb[:, 1], **plot_styles['Training Set'])
    axes[0].scatter(coherent_emb[:, 0], coherent_emb[:, 1], **plot_styles['Generated'])
    axes[0].set_title("Coherent Denoising", fontsize=16, pad=15)
    axes[1].scatter(train_emb[:, 0], train_emb[:, 1], **plot_styles['Training Set'])
    axes[1].scatter(multi_emb[:, 0], multi_emb[:, 1], **plot_styles['Generated'])
    axes[1].set_title("Multi-Condition", fontsize=16, pad=15)
    for ax in axes:
        ax.set_aspect('equal', adjustable='box')
        ax.set_xlim(*xlims)
        ax.set_ylim(*ylims)
        ax.set_xlabel('UMAP 1', fontsize=12)
        ax.set_ylabel('UMAP 2', fontsize=12)
        ax.grid(True, linestyle='--', alpha=0.6)
        for spine in ax.spines.values():
            spine.set_color('#B0B0B0')
        ax.tick_params(axis='both', colors='#505050')
    handles = [plt.Line2D([0], [0], marker='o', color='w', label=style['label'],
                          markerfacecolor=style['color'], markersize=12, alpha=0.8)
               for _, style in plot_styles.items()]
    fig.legend(handles=handles, loc='center left', bbox_to_anchor=(0.98, 0.5), fontsize=14)
    fig.tight_layout(rect=[0, 0, 0.90, 0.95])
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"   -> Plot saved to {save_path}")
    plt.show()

# =============================================================================
# 3. MAIN SCRIPT LOGIC (WITH CORRECTED PARAMETER PASSING)
# =============================================================================

def run_analysis_and_plotting(args):
    """Main function to gather data and run the unconditional comparison plotting."""

    modalities_map = read_data(
        modalities=modalities_list,
        splits=['train', 'test'],
        data_dir=args.data_dir,
        dim=args.dim,
    )

    print("Finding common samples with no missing values across all training modalities...")
    all_train_dfs = [modalities_map[modality]['train'] for modality in modalities_list]
    combined_train_df = pd.concat(all_train_dfs, axis=1)
    complete_train_df = combined_train_df.dropna()
    
    n_samples = len(complete_train_df)
    print(f"Found {n_samples} complete samples common to all modalities.")

    if complete_train_df.empty:
        print("Error: No common samples found. Cannot proceed.")
        return

    uncond_coherent_dfs, uncond_multi_dfs = [], []
    
    coherent_strategy = CoherentStrategy('coherent', args.results_path, args.dim)
    multi_strategy = MultiModelStrategy('multi', args.results_path, args.dim)

    for modality in modalities_list:
        print(f"\n--- Processing modality: {modality} ---")
        
        # --- THE FIX: Calculate n_feats for the current modality ---
        n_feats = modalities_map[modality]['train'].shape[1]
        
        # --- THE FIX: Pass n_feats to the generation method ---
        coherent_df = coherent_strategy.generate_unconditional(modality, modalities_map, n_feats, 1)
        if coherent_df is None: 
            print(f"Fatal: Could not generate Coherent data for {modality}. Aborting."); return
        uncond_coherent_dfs.append(coherent_df.head(n_samples))
        
        # --- THE FIX: Pass n_feats to the generation method ---
        multi_df = multi_strategy.generate_unconditional(modality, modalities_map, n_feats, 1)
        if multi_df is None: 
            print(f"Fatal: Could not generate Multi data for {modality}. Aborting."); return
        uncond_multi_dfs.append(multi_df.head(n_samples))

    # --- Concatenate all modalities for the final plot dataframes ---
    final_train_df = complete_train_df
    final_uncond_coherent_df = pd.concat(uncond_coherent_dfs, axis=1)
    final_uncond_multi_df = pd.concat(uncond_multi_dfs, axis=1)

    plot_unconditional_comparison(
        train_df=final_train_df,
        uncond_coherent_df=final_uncond_coherent_df,
        uncond_multi_df=final_uncond_multi_df,
        save_path=os.path.join(args.results_path, 'images/unconditional_umap_comparison.png')
    )


if __name__ == '__main__':
    DEVICE = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")
    METRIC_TO_USE = 'mse'

    args = SimpleNamespace(
        dim=32,
        results_path='../results',
        data_dir='../datasets_TCGA/07_normalized/',
    )
    run_analysis_and_plotting(args)

In [None]:
import os
import pandas as pd
import numpy as np
import pathlib
from types import SimpleNamespace
import matplotlib.pyplot as plt
import umap
import torch
from abc import ABC, abstractmethod

# Assuming these are available in your PYTHONPATH
from lib.test import coherent_test_cos_rejection, test_model
from lib.config import modalities_list
from lib.read_data import read_data
from lib.get_models import get_diffusion_model
from lib.diffusion_models import GaussianDiffusion

# =============================================================================
# 1. STRATEGY CLASSES (FROM YOUR WORKING SCRIPT)
# =============================================================================

class EvaluationStrategy(ABC):
    """Abstract base class for an evaluation strategy."""
    def __init__(self, method_name, results_path, dim):
        self.method_name = method_name
        self.results_path = results_path
        self.dim = dim

    def get_base_dir(self, modality):
        return pathlib.Path(f"{self.results_path}/{self.dim}/{modality}_from_{self.method_name}")

    @abstractmethod
    def generate_unconditional(self, modality, modalities_map, n_feats, test_repeats):
        """Generates unconditional samples for a given modality."""
        pass

class MultiModelStrategy(EvaluationStrategy):
    """Strategy for 'multi' and 'multi_masked' models."""
    def generate_unconditional(self, modality, modalities_map, n_feats, test_repeats):
        print(f"  Generating unconditional samples using '{self.method_name}' for '{modality}'...")
        base_dir = self.get_base_dir(modality)
        ckpt_path = base_dir / 'train' / 'best_by_mse.pth'
        if not ckpt_path.exists():
            print(f"  ERROR: Checkpoint not found at {ckpt_path}. Cannot generate data.")
            return None

        ckpt = torch.load(ckpt_path, map_location='cpu')
        config = SimpleNamespace(**ckpt['config'])
        
        cond_datatypes = [m for m in modalities_map.keys() if m != modality]
        cond_dim_list = [modalities_map[c]['test'].shape[1] for c in cond_datatypes]

        diffusion = GaussianDiffusion(num_timesteps=1000).to(DEVICE)
        
        # This uses the feature dimensions calculated from the data, as you clarified
        model = get_diffusion_model(
            config.architecture, diffusion, config,
            x_dim=n_feats, cond_dims=cond_dim_list
        ).to(DEVICE)
        
        model.load_state_dict(ckpt['best_model_mse'])
        model.eval()

        n_samples = len(modalities_map[modality]['test'])
        zero_conds = [pd.DataFrame(np.zeros_like(modalities_map[c]['test'])) for c in cond_datatypes]
        masks = [np.zeros(n_samples) for _ in cond_datatypes]
        
        _, uncond_data = test_model(
            modalities_map[modality]['test'], zero_conds, model, diffusion,
            test_iterations=test_repeats, device=DEVICE, masks=masks
        )
        return uncond_data.iloc[:n_samples]

class CoherentStrategy(EvaluationStrategy):
    """Strategy for the 'coherent' ensemble model."""
    def generate_unconditional(self, modality, modalities_map, n_feats, test_repeats):
        print(f"  Generating unconditional samples using 'coherent' for '{modality}'...")
        cond_datatypes = [m for m in modalities_map.keys() if m != modality]
        diffusion = GaussianDiffusion(num_timesteps=1000).to(DEVICE)

        models, weights, zero_conds = [], [], []
        for c in cond_datatypes:
            ckpt_path = pathlib.Path(f"{self.results_path}/{self.dim}/{modality}_from_{c}/train/best_by_mse.pth")
            if not ckpt_path.exists():
                print(f"  ERROR: Coherent dependency not found at {ckpt_path}. Cannot generate data.")
                return None

            ckpt = torch.load(ckpt_path, map_location='cpu')
            config_c = SimpleNamespace(**ckpt['config'])
            
            model_c = get_diffusion_model(
                config_c.architecture, diffusion, config_c,
                x_dim=n_feats, 
                cond_dims=modalities_map[c]['test'].shape[1]
            ).to(DEVICE)
            
            model_c.load_state_dict(ckpt['best_model_mse'])
            model_c.eval()
            models.append(model_c)
            weights.append(ckpt['best_loss'])
            zero_conds.append(pd.DataFrame(np.zeros_like(modalities_map[c]['test'])))

        _, uncond_data, _ = coherent_test_cos_rejection(
            modalities_map[modality]['test'], zero_conds, models, diffusion,
            test_iterations=test_repeats, max_retries=10, device=DEVICE, weights_list=weights
        )
        n_samples = len(modalities_map[modality]['test'])
        return uncond_data.iloc[:n_samples]

# =============================================================================
# 2. PLOTTING FUNCTION (Unchanged)
# =============================================================================

def plot_unconditional_comparison(train_df, coherent_df, multi_df, save_path=None):
    """
    Generates a 2-panel UMAP plot comparing train vs. coherent and train vs. multi,
    with a polished, paper-ready aesthetic.
    """
    print("\n>>> Generating final 2-panel UMAP plot...")
    plt.style.use('seaborn-v0_8-white')
    plot_styles = {
        'Training Set':      {'color': "#c59b7d", 'alpha': 0.6, 's': 10, 'label': 'Training Set'},
        'Coherent Denoising': {'color': '#56b4e9', 'alpha': 0.9, 's': 10, 'label': 'Generated (Coherent Denoising)'},
        'Multi-Condition':   {'color': '#0072b2', 'alpha': 0.9, 's': 10, 'label': 'Generated (Multi-condition)'}
    }

    print("  Fitting UMAP on training data and transforming all sets...")
    mapper = umap.UMAP(n_neighbors=15, min_dist=0.8, spread=2, n_components=2, random_state=23)
    train_emb = mapper.fit_transform(train_df.values)
    
    n_train = train_df.shape[0]
    coherent_emb = mapper.transform(coherent_df.iloc[:n_train].values)
    multi_emb = mapper.transform(multi_df.iloc[:n_train].values)

    fig, axs = plt.subplots(1, 2, figsize=(20, 10))
    fig.suptitle("UMAP Comparison: Training Data vs. Unconditional Generation", fontsize=28, weight='bold', y=0.98)

    all_x = np.concatenate([train_emb[:, 0], coherent_emb[:, 0], multi_emb[:, 0]])
    all_y = np.concatenate([train_emb[:, 1], coherent_emb[:, 1], multi_emb[:, 1]])
    
    x_range = all_x.max() - all_x.min()
    y_range = all_y.max() - all_y.min()
    max_range = max(x_range, y_range) * 1.05
    
    x_center = (all_x.max() + all_x.min()) / 2
    y_center = (all_y.max() + all_y.min()) / 2
    
    xlims = (x_center - max_range / 2, x_center + max_range / 2)
    ylims = (y_center - max_range / 2, y_center + max_range / 2)
    
    # Panel 1: Train vs Coherent
    axs[0].scatter(train_emb[:, 0], train_emb[:, 1], **plot_styles['Training Set'])
    axs[0].scatter(coherent_emb[:, 0], coherent_emb[:, 1], **plot_styles['Coherent Denoising'])
    axs[0].set_title('Unconditional Coherent Denoising Generation', fontsize=22, weight='bold', pad=20)

    # Panel 2: Train vs Multi
    axs[1].scatter(train_emb[:, 0], train_emb[:, 1], **plot_styles['Training Set'])
    axs[1].scatter(multi_emb[:, 0], multi_emb[:, 1], **plot_styles['Multi-Condition'])
    axs[1].set_title('Unconditional Multi-Condition Generation', fontsize=22, weight='bold', pad=20)

    for ax in axs:
        ax.set_xlim(*xlims)
        ax.set_ylim(*ylims)
        ax.set_aspect('equal', adjustable='box')
        ax.set_xlabel('UMAP 1', fontsize=12, labelpad=5, color='#303030')
        ax.set_ylabel('UMAP 2', fontsize=12, labelpad=5, color='#303030')
        ax.tick_params(axis='both', colors='#505050')
        for spine in ax.spines.values():
            spine.set_color('#B0B0B0')
        #ax.grid(True, linestyle='--', alpha=0.6)

    legend_elems = [plt.Line2D([0], [0], marker='o', color='w',
                               markerfacecolor=style['color'], markersize=16,
                               label=style['label']) for style in plot_styles.values()]

    fig.legend(handles=legend_elems, loc='center left', bbox_to_anchor=(0.96, 0.5),
               frameon=False, title='Data Origin', fontsize=16, title_fontsize=18)

    fig.subplots_adjust(left=0.05, right=0.98, bottom=0.05, top=0.85, wspace=0.15)
    

    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Plot saved to {save_path}")
    
    plt.show()

# =============================================================================
# 3. MAIN SCRIPT LOGIC
# =============================================================================

def run_generation_and_plotting(args):
    """Main function to generate data and then create the comparison plot."""

    modalities_map = read_data(
        modalities=modalities_list,
        splits=['train', 'test'],
        data_dir=args.data_dir,
        dim=args.dim,
    )

    print(">>> Finding common samples with no missing values across all training modalities...")
    all_train_dfs = [modalities_map[modality]['train'] for modality in modalities_list]
    train_all_raw = pd.concat(all_train_dfs, axis=1)
    train_all_complete = train_all_raw.dropna()
    n_samples_to_use = len(train_all_complete)
    print(f"Found {n_samples_to_use} complete training samples.")

    if train_all_complete.empty:
        print("Error: No common samples found. Cannot proceed.")
        return

    # --- Generate data from both Coherent and Multi methods ---
    print("\n>>> Starting data generation for all modalities...")
    uncond_coherent_dfs, uncond_multi_dfs = [], []
    
    coherent_strategy = CoherentStrategy('coherent', args.results_path, args.dim)
    multi_strategy = MultiModelStrategy('multi', args.results_path, args.dim)

    for modality in modalities_list:
        n_feats = modalities_map[modality]['train'].shape[1]
        
        coherent_df = coherent_strategy.generate_unconditional(modality, modalities_map, n_feats, 1)
        if coherent_df is None: 
            print(f"FATAL: Could not generate Coherent data for {modality}. Aborting."); return
        uncond_coherent_dfs.append(coherent_df.head(n_samples_to_use))
        
        multi_df = multi_strategy.generate_unconditional(modality, modalities_map, n_feats, 1)
        if multi_df is None: 
            print(f"FATAL: Could not generate Multi data for {modality}. Aborting."); return
        uncond_multi_dfs.append(multi_df.head(n_samples_to_use))

    # --- Concatenate all modalities for the final plot dataframes ---
    final_uncond_coherent_df = pd.concat(uncond_coherent_dfs, axis=1)
    final_uncond_multi_df = pd.concat(uncond_multi_dfs, axis=1)

    # --- Call the plotting function ---
    plot_unconditional_comparison(
        train_df=train_all_complete,
        coherent_df=final_uncond_coherent_df,
        multi_df=final_uncond_multi_df,
        save_path=os.path.join(args.results_path, 'images', 'umap_train_vs_uncond_comparison.png')
    )


if __name__ == "__main__":
    # --- Global Parameters & Setup ---
    DEVICE = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")
    METRIC_TO_USE = 'mse'

    args = SimpleNamespace(
        dim=32,
        results_path='../results',
        data_dir='../datasets_TCGA/07_normalized/',
    )
    run_generation_and_plotting(args)

In [None]:
import os
import pandas as pd
import numpy as np
import pathlib
from types import SimpleNamespace
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA  # <-- IMPORT PCA
import torch
from abc import ABC, abstractmethod

# Assuming these are available in your PYTHONPATH
from lib.test import coherent_test_cos_rejection, test_model
from lib.config import modalities_list
from lib.read_data import read_data
from lib.get_models import get_diffusion_model
from lib.diffusion_models import GaussianDiffusion

# =============================================================================
# 1. STRATEGY CLASSES (Unchanged)
# =============================================================================

class EvaluationStrategy(ABC):
    """Abstract base class for an evaluation strategy."""
    def __init__(self, method_name, results_path, dim):
        self.method_name = method_name
        self.results_path = results_path
        self.dim = dim

    def get_base_dir(self, modality):
        return pathlib.Path(f"{self.results_path}/{self.dim}/{modality}_from_{self.method_name}")

    @abstractmethod
    def generate_unconditional(self, modality, modalities_map, n_feats, test_repeats):
        pass

class MultiModelStrategy(EvaluationStrategy):
    """Strategy for 'multi' and 'multi_masked' models."""
    def generate_unconditional(self, modality, modalities_map, n_feats, test_repeats):
        print(f"  Generating unconditional samples using '{self.method_name}' for '{modality}'...")
        base_dir = self.get_base_dir(modality)
        ckpt_path = base_dir / 'train' / 'best_by_mse.pth'
        if not ckpt_path.exists():
            print(f"  ERROR: Checkpoint not found at {ckpt_path}. Cannot generate data.")
            return None

        ckpt = torch.load(ckpt_path, map_location='cpu')
        config = SimpleNamespace(**ckpt['config'])
        
        cond_datatypes = [m for m in modalities_map.keys() if m != modality]
        cond_dim_list = [modalities_map[c]['test'].shape[1] for c in cond_datatypes]

        diffusion = GaussianDiffusion(num_timesteps=1000).to(DEVICE)
        
        model = get_diffusion_model(
            config.architecture, diffusion, config,
            x_dim=n_feats, cond_dims=cond_dim_list
        ).to(DEVICE)
        
        model.load_state_dict(ckpt['best_model_mse'])
        model.eval()

        n_samples = len(modalities_map[modality]['test'])
        zero_conds = [pd.DataFrame(np.zeros_like(modalities_map[c]['test'])) for c in cond_datatypes]
        masks = [np.zeros(n_samples) for _ in cond_datatypes]
        
        _, uncond_data = test_model(
            modalities_map[modality]['test'], zero_conds, model, diffusion,
            test_iterations=test_repeats, device=DEVICE, masks=masks
        )
        return uncond_data.iloc[:n_samples]

class CoherentStrategy(EvaluationStrategy):
    """Strategy for the 'coherent' ensemble model."""
    def generate_unconditional(self, modality, modalities_map, n_feats, test_repeats):
        print(f"  Generating unconditional samples using 'coherent' for '{modality}'...")
        cond_datatypes = [m for m in modalities_map.keys() if m != modality]
        diffusion = GaussianDiffusion(num_timesteps=1000).to(DEVICE)

        models, weights, zero_conds = [], [], []
        for c in cond_datatypes:
            ckpt_path = pathlib.Path(f"{self.results_path}/{self.dim}/{modality}_from_{c}/train/best_by_mse.pth")
            if not ckpt_path.exists():
                print(f"  ERROR: Coherent dependency not found at {ckpt_path}. Cannot generate data.")
                return None

            ckpt = torch.load(ckpt_path, map_location='cpu')
            config_c = SimpleNamespace(**ckpt['config'])
            
            model_c = get_diffusion_model(
                config_c.architecture, diffusion, config_c,
                x_dim=n_feats, 
                cond_dims=modalities_map[c]['test'].shape[1]
            ).to(DEVICE)
            
            model_c.load_state_dict(ckpt['best_model_mse'])
            model_c.eval()
            models.append(model_c)
            weights.append(ckpt['best_loss'])
            zero_conds.append(pd.DataFrame(np.zeros_like(modalities_map[c]['test'])))

        _, uncond_data, _ = coherent_test_cos_rejection(
            modalities_map[modality]['test'], zero_conds, models, diffusion,
            test_iterations=test_repeats, max_retries=10, device=DEVICE, weights_list=weights
        )
        n_samples = len(modalities_map[modality]['test'])
        return uncond_data.iloc[:n_samples]


# =============================================================================
# 2. NEW PLOTTING FUNCTION FOR PCA
# =============================================================================

def plot_unconditional_pca_comparison(train_df, coherent_df, multi_df, save_path=None):
    """
    Generates a 2-panel PCA plot comparing train vs. coherent and train vs. multi,
    with a polished, paper-ready aesthetic.
    """
    print("\n>>> Generating final 2-panel PCA plot...")
    plt.style.use('seaborn-v0_8-white')
    plot_styles = {
        'Training Set':       {'color': "#c59b7d", 'alpha': 0.6, 's': 10, 'label': 'Training Set'},
        'Coherent Denoising': {'color': '#56b4e9', 'alpha': 0.9, 's': 10, 'label': 'Generated (Coherent Denoising)'},
        'Multi-Condition':    {'color': '#0072b2', 'alpha': 0.9, 's': 10, 'label': 'Generated (Multi-condition)'}
    }

    # --- MODIFIED: Use PCA instead of UMAP ---
    print("  Fitting PCA on training data and transforming all sets...")
    reducer = PCA(n_components=2, random_state=42)
    train_emb = reducer.fit_transform(train_df.values)
    
    n_train = train_df.shape[0]
    coherent_emb = reducer.transform(coherent_df.iloc[:n_train].values)
    multi_emb = reducer.transform(multi_df.iloc[:n_train].values)

    fig, axs = plt.subplots(1, 2, figsize=(20, 10))
    # MODIFIED: Title
    fig.suptitle("PCA Comparison: Training Data vs. Unconditional Generation", fontsize=28, weight='bold', y=0.98)

    all_x = np.concatenate([train_emb[:, 0], coherent_emb[:, 0], multi_emb[:, 0]])
    all_y = np.concatenate([train_emb[:, 1], coherent_emb[:, 1], multi_emb[:, 1]])
    
    x_range = all_x.max() - all_x.min()
    y_range = all_y.max() - all_y.min()
    max_range = max(x_range, y_range) * 1.05
    
    x_center = (all_x.max() + all_x.min()) / 2
    y_center = (all_y.max() + all_y.min()) / 2
    
    xlims = (x_center - max_range / 2, x_center + max_range / 2)
    ylims = (y_center - max_range / 2, y_center + max_range / 2)
    
    # Panel 1: Train vs Coherent
    axs[0].scatter(train_emb[:, 0], train_emb[:, 1], **plot_styles['Training Set'])
    axs[0].scatter(coherent_emb[:, 0], coherent_emb[:, 1], **plot_styles['Coherent Denoising'])
    axs[0].set_title('Unconditional Coherent Denoising Generation', fontsize=22, weight='bold', pad=20)

    # Panel 2: Train vs Multi
    axs[1].scatter(train_emb[:, 0], train_emb[:, 1], **plot_styles['Training Set'])
    axs[1].scatter(multi_emb[:, 0], multi_emb[:, 1], **plot_styles['Multi-Condition'])
    axs[1].set_title('Unconditional Multi-Condition Generation', fontsize=22, weight='bold', pad=20)

    for ax in axs:
        ax.set_xlim(*xlims)
        ax.set_ylim(*ylims)
        ax.set_aspect('equal', adjustable='box')
        # MODIFIED: Axis labels
        ax.set_xlabel('Principal Component 1', fontsize=12, labelpad=5, color='#303030')
        ax.set_ylabel('Principal Component 2', fontsize=12, labelpad=5, color='#303030')
        ax.tick_params(axis='both', colors='#505050')
        for spine in ax.spines.values():
            spine.set_color('#B0B0B0')

    legend_elems = [plt.Line2D([0], [0], marker='o', color='w',
                               markerfacecolor=style['color'], markersize=16,
                               label=style['label']) for style in plot_styles.values()]

    fig.legend(handles=legend_elems, loc='center left', bbox_to_anchor=(0.96, 0.5),
               frameon=False, title='Data Origin', fontsize=16, title_fontsize=18)

    fig.subplots_adjust(left=0.05, right=0.98, bottom=0.05, top=0.85, wspace=0.15)
    
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Plot saved to {save_path}")
    
    plt.show()

# =============================================================================
# 3. MAIN SCRIPT LOGIC
# =============================================================================

def run_generation_and_plotting(args):
    """Main function to generate data and then create the comparison plot."""

    modalities_map = read_data(
        modalities=modalities_list,
        splits=['train', 'test'],
        data_dir=args.data_dir,
        dim=args.dim,
    )

    print(">>> Finding common samples with no missing values across all training modalities...")
    all_train_dfs = [modalities_map[modality]['train'] for modality in modalities_list]
    train_all_raw = pd.concat(all_train_dfs, axis=1)
    train_all_complete = train_all_raw.dropna()
    n_samples_to_use = len(train_all_complete)
    print(f"Found {n_samples_to_use} complete training samples.")

    if train_all_complete.empty:
        print("Error: No common samples found. Cannot proceed.")
        return

    print("\n>>> Starting data generation for all modalities...")
    uncond_coherent_dfs, uncond_multi_dfs = [], []
    
    coherent_strategy = CoherentStrategy('coherent', args.results_path, args.dim)
    multi_strategy = MultiModelStrategy('multi', args.results_path, args.dim)

    for modality in modalities_list:
        n_feats = modalities_map[modality]['train'].shape[1]
        
        coherent_df = coherent_strategy.generate_unconditional(modality, modalities_map, n_feats, 1)
        if coherent_df is None: 
            print(f"FATAL: Could not generate Coherent data for {modality}. Aborting."); return
        uncond_coherent_dfs.append(coherent_df.head(n_samples_to_use))
        
        multi_df = multi_strategy.generate_unconditional(modality, modalities_map, n_feats, 1)
        if multi_df is None: 
            print(f"FATAL: Could not generate Multi data for {modality}. Aborting."); return
        uncond_multi_dfs.append(multi_df.head(n_samples_to_use))

    final_uncond_coherent_df = pd.concat(uncond_coherent_dfs, axis=1)
    final_uncond_multi_df = pd.concat(uncond_multi_dfs, axis=1)

    # --- MODIFIED: Call the new PCA plotting function ---
    plot_unconditional_pca_comparison(
        train_df=train_all_complete,
        coherent_df=final_uncond_coherent_df,
        multi_df=final_uncond_multi_df,
        # MODIFIED: Updated save path for the new plot
        save_path=os.path.join(args.results_path, 'images', 'pca_train_vs_uncond_comparison.png')
    )


if __name__ == "__main__":
    # --- Global Parameters & Setup ---
    DEVICE = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")
    METRIC_TO_USE = 'mse'

    args = SimpleNamespace(
        dim=32,
        results_path='../results',
        data_dir='../datasets_TCGA/07_normalized/',
    )
    run_generation_and_plotting(args)