In [None]:
import os
import json
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from itertools import combinations
import umap  # Add this import
import torch
import glob


from lib.read_data import read_data
from lib.config import data_dir

def visualize_generated_samples(metric_used):
    """
    For each modality, for each experiment (single/coherent/multi) matching the chosen metric,
    load the generated samples CSV, perform UMAP on train+val for that modality,
    and for 10 random test indices, plot real vs. generated points in 2D UMAP space.
    """

    BASE_DIR = "../results/32"
    # Load modalities map: keys are modality names, values: dicts with 'train','val','test'
    modalities_map = read_data(
        modalities=['cna','rnaseq','rppa','wsi'],
        splits=['train','test'],
        data_dir=data_dir,
        dim='32'
    )  # adjust arguments as needed

    for modality in modalities_map.keys():
        # 1) Collect all generated_sample files for this modality + metric
        gen_files = []
        for mod_dir in glob.glob(os.path.join(BASE_DIR, f'{modality}_*')):
            for root, dirs, files in os.walk(mod_dir):
                for fname in files:
                    if fname.startswith("generated_samples") and f"best_{metric_used}.csv" in fname:
                        gen_files.append(os.path.join(root, fname))
        
        print(gen_files)
        
        if not gen_files:
            print(f"No generated samples found for modality '{modality}' with metric '{metric_used}'")
            continue

        # 2) Load train, val, test data for this modality
        x_train = modalities_map[modality]['train']
        x_test  = modalities_map[modality]['test']

        x_train = x_train.dropna()  # remove any rows with NaN

        # Fit UMAP on x_train only
        umap_reducer = umap.UMAP(
            n_components=2,
            random_state=42,
            n_neighbors=15,
            min_dist=0.1
        ).fit(x_train.values)
        
        X_train_val_umap = umap_reducer.transform(x_train.values)
        X_test_umap = umap_reducer.transform(x_test.values)
        N = x_test.shape[0]

        # Choose 10 random test indices
        rng = np.random.default_rng(seed=42)
        sample_indices = rng.choice(N, size=min(10, N), replace=False)

        # 3) For each experiment (method)
        for gen_path in gen_files:
            # Determine method name from folder structure
            # e.g. "./results/32/rnaseq_from_cna/test/generated_samples_best_mse.csv"
            parts = gen_path.split(os.sep)
            # exp_folder is something like "rnaseq_from_cna" or "wsi_from_coherent"
            exp_folder = parts[-3]
            method = exp_folder

            fname = os.path.basename(gen_path)
            if "generated_samples_from_" in fname:
                start = fname.find("generated_samples_from_") + len("generated_samples_from_")
                end   = fname.find(f"_best_{metric_used}.csv")
                combo = fname[start:end]  # e.g. "rnaseq_rppa_wsi"
            else:
                combo = None

            # Load generated samples: shape ~ (N * test_repeats, D)
            if method.endswith('multi'):
                x_gen = pd.read_csv(gen_path)
            else:
                x_gen = pd.read_csv(gen_path, index_col=0) # NOTE: For 'coherent' or 'simple', generated samples may be stored with index

            subtitle = f"Method = {method}"
            if combo:
                subtitle += f" (conds: {combo})"
            subtitle += f", Metric = {metric_used}"
                
            print(f'data from: {gen_path}')

            # 4) Create figure with 10 subplots
            fig, axes = plt.subplots(2, 5, figsize=(20, 8))
            fig.suptitle(f"UMAP Real vs Generated for '{modality}'\n{subtitle}",
                         fontsize=16)

            for idx, ax in zip(sample_indices, axes.flat):
                # Pick out generated points for this test index:
                # rows at idx, idx+N, idx+2N, ... until end
                indices = np.arange(idx, len(x_gen), N)
                single_gen_sample = x_gen.iloc[indices].values
                gen_sample_umap = umap_reducer.transform(single_gen_sample)

                # Plot distribution of combined train+val in grey
                ax.scatter(
                    X_train_val_umap[:, 0],
                    X_train_val_umap[:, 1],
                    color="grey",
                    alpha=0.2,
                    s=5,
                    label="Train Distribution"
                )
                # Plot the real test point
                ax.scatter(
                    X_test_umap[idx, 0],
                    X_test_umap[idx, 1],
                    color="blue",
                    marker="D",
                    s=50,
                    label="Real Test Point"
                )
                # Plot generated points
                ax.scatter(
                    gen_sample_umap[:, 0],
                    gen_sample_umap[:, 1],
                    color="darkorange",
                    alpha=0.5,
                    s=20,
                    label="Generated Samples"
                )

                ax.set_title(f"Test Index = {idx}", fontsize=10)
                ax.set_xlabel("UMAP Component 1", fontsize=8)
                ax.set_ylabel("UMAP Component 2", fontsize=8)
                ax.tick_params(labelsize=6)

                if idx == sample_indices[0]:
                    ax.legend(loc="upper left", fontsize=8)

            plt.tight_layout(rect=[0, 0.03, 1, 0.95])
            plt.show()


metric = 'mse' # choices=["mse","cosine","timestep"]

visualize_generated_samples(metric)

In [None]:
import os
import json
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from itertools import combinations
import umap  # Add this import
import torch
import glob


from lib.read_data import read_data
from lib.config import data_dir

def visualize_generated_samples(metric_used):
    """
    For each modality, for each experiment (single/coherent/multi) matching the chosen metric,
    load the generated samples CSV, perform UMAP on train+val for that modality,
    and for 10 random test indices, plot real vs. generated points in 2D UMAP space.
    """

    BASE_DIR = "../results/32"
    # Load modalities map: keys are modality names, values: dicts with 'train','val','test'
    modalities_map = read_data(
        modalities=['cna','rnaseq','rppa','wsi'],
        splits=['train','test'],
        data_dir=data_dir,
        dim='32'
    )  # adjust arguments as needed

    for modality in modalities_map.keys():
        # 1) Collect all generated_sample files for this modality + metric
        gen_files = []
        for mod_dir in glob.glob(os.path.join(BASE_DIR, f'{modality}_*')):
            for root, dirs, files in os.walk(mod_dir):
                for fname in files:
                    if fname.startswith("generated_samples") and f"best_{metric_used}.csv" in fname:
                        gen_files.append(os.path.join(root, fname))
        
        print(gen_files)
        
        if not gen_files:
            print(f"No generated samples found for modality '{modality}' with metric '{metric_used}'")
            continue

        # 2) Load train, val, test data for this modality
        x_train = modalities_map[modality]['train']
        x_test  = modalities_map[modality]['test']

        x_train = x_train.dropna()  # remove any rows with NaN

        # Fit UMAP on x_train only
        umap_reducer = umap.UMAP(
            n_components=2,
            random_state=42,
            n_neighbors=15,
            min_dist=0.1
        ).fit(x_train.values)
        
        X_train_val_umap = umap_reducer.transform(x_train.values)
        X_test_umap = umap_reducer.transform(x_test.values)
        N = x_test.shape[0]

        # Choose 10 random test indices
        rng = np.random.default_rng(seed=42)
        sample_indices = rng.choice(N, size=min(10, N), replace=False)

        # 3) For each experiment (method)
        for gen_path in gen_files:
            # Determine method name from folder structure
            # e.g. "./results/32/rnaseq_from_cna/test/generated_samples_best_mse.csv"
            parts = gen_path.split(os.sep)
            # exp_folder is something like "rnaseq_from_cna" or "wsi_from_coherent"
            exp_folder = parts[-3]
            method = exp_folder

            fname = os.path.basename(gen_path)
            if "generated_samples_from_" in fname:
                start = fname.find("generated_samples_from_") + len("generated_samples_from_")
                end   = fname.find(f"_best_{metric_used}.csv")
                combo = fname[start:end]  # e.g. "rnaseq_rppa_wsi"
            else:
                combo = None

            # Load generated samples: shape ~ (N * test_repeats, D)
            if method.endswith('multi'):
                x_gen = pd.read_csv(gen_path)
            else:
                x_gen = pd.read_csv(gen_path, index_col=0) # NOTE: For 'coherent' or 'simple', generated samples may be stored with index

            subtitle = f"Method = {method}"
            if combo:
                subtitle += f" (conds: {combo})"
            subtitle += f", Metric = {metric_used}"
                
            print(f'data from: {gen_path}')

            # 4) Create figure with 10 subplots
            fig, axes = plt.subplots(2, 5, figsize=(20, 8))
            fig.suptitle(f"UMAP Real vs Generated for '{modality}'\n{subtitle}",
                         fontsize=16)

            for idx, ax in zip(sample_indices, axes.flat):
                # Pick out generated points for this test index:
                # rows at idx, idx+N, idx+2N, ... until end
                indices = np.arange(idx, len(x_gen), N)
                single_gen_sample = x_gen.iloc[indices].values
                gen_sample_umap = umap_reducer.transform(single_gen_sample)

                # Plot distribution of combined train+val in grey
                ax.scatter(
                    X_train_val_umap[:, 0],
                    X_train_val_umap[:, 1],
                    color="grey",
                    alpha=0.2,
                    s=5,
                    label="Train Distribution"
                )
                # Plot the real test point
                ax.scatter(
                    X_test_umap[idx, 0],
                    X_test_umap[idx, 1],
                    color="blue",
                    marker="D",
                    s=50,
                    label="Real Test Point"
                )
                # Plot generated points
                ax.scatter(
                    gen_sample_umap[:, 0],
                    gen_sample_umap[:, 1],
                    color="darkorange",
                    alpha=0.5,
                    s=20,
                    label="Generated Samples"
                )

                ax.set_title(f"Test Index = {idx}", fontsize=10)
                ax.set_xlabel("UMAP Component 1", fontsize=8)
                ax.set_ylabel("UMAP Component 2", fontsize=8)
                ax.tick_params(labelsize=6)

                if idx == sample_indices[0]:
                    ax.legend(loc="upper left", fontsize=8)

            plt.tight_layout(rect=[0, 0.03, 1, 0.95])
            plt.show()


metric = 'timestep' # choices=["mse","cosine","timestep"]

visualize_generated_samples(metric)

In [None]:
import os
import json
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from itertools import combinations
import umap  # Add this import
import torch
import glob


from lib.read_data import read_data
from lib.config import data_dir

def visualize_generated_samples(metric_used):
    """
    For each modality, for each experiment (single/coherent/multi) matching the chosen metric,
    load the generated samples CSV, perform UMAP on train+val for that modality,
    and for 10 random test indices, plot real vs. generated points in 2D UMAP space.
    """

    BASE_DIR = "../results/32"
    # Load modalities map: keys are modality names, values: dicts with 'train','val','test'
    modalities_map = read_data(
        modalities=['cna','rnaseq','rppa','wsi'],
        splits=['train','test'],
        data_dir=data_dir,
        dim='32'
    )  # adjust arguments as needed

    for modality in modalities_map.keys():
        # 1) Collect all generated_sample files for this modality + metric
        gen_files = []
        for mod_dir in glob.glob(os.path.join(BASE_DIR, f'{modality}_*')):
            for root, dirs, files in os.walk(mod_dir):
                for fname in files:
                    if fname.startswith("generated_samples") and f"best_{metric_used}.csv" in fname:
                        gen_files.append(os.path.join(root, fname))
        
        print(gen_files)
        
        if not gen_files:
            print(f"No generated samples found for modality '{modality}' with metric '{metric_used}'")
            continue

        # 2) Load train, val, test data for this modality
        x_train = modalities_map[modality]['train']
        x_test  = modalities_map[modality]['test']

        x_train = x_train.dropna()  # remove any rows with NaN

        # Fit UMAP on x_train only
        umap_reducer = umap.UMAP(
            n_components=2,
            random_state=42,
            n_neighbors=15,
            min_dist=0.1
        ).fit(x_train.values)
        
        X_train_val_umap = umap_reducer.transform(x_train.values)
        X_test_umap = umap_reducer.transform(x_test.values)
        N = x_test.shape[0]

        # Choose 10 random test indices
        rng = np.random.default_rng(seed=42)
        sample_indices = rng.choice(N, size=min(10, N), replace=False)

        # 3) Filter for coherent and multi methods with all conditioning
        # Define all modalities for "all conditioning" check
        all_modalities = {'cna', 'rnaseq', 'rppa', 'wsi'}
        target_modality_set = all_modalities - {modality}  # All except current target modality
        target_combo = '_'.join(sorted(target_modality_set))  # e.g., "cna_rnaseq_rppa" for wsi
        
        filtered_gen_files = []
        
        for gen_path in gen_files:
            # Determine method name from folder structure
            parts = gen_path.split(os.sep)
            exp_folder = parts[-3]
            method = exp_folder
            
            fname = os.path.basename(gen_path)
            if "generated_samples_from_" in fname:
                start = fname.find("generated_samples_from_") + len("generated_samples_from_")
                end   = fname.find(f"_best_{metric_used}.csv")
                combo = fname[start:end]  # e.g. "rnaseq_rppa_wsi"
            else:
                combo = None
            
            # Filter for coherent and multi methods with all conditioning
            if method.endswith('coherent') or method.endswith('multi'):
                if combo and set(combo.split('_')) == target_modality_set:
                    filtered_gen_files.append(gen_path)
        
        if not filtered_gen_files:
            print(f"No coherent/multi methods with all conditioning found for modality '{modality}'")
            continue
            
        # 4) For each filtered experiment (method)
        for gen_path in filtered_gen_files:
            # Determine method name from folder structure
            parts = gen_path.split(os.sep)
            exp_folder = parts[-3]
            method = exp_folder

            fname = os.path.basename(gen_path)
            if "generated_samples_from_" in fname:
                start = fname.find("generated_samples_from_") + len("generated_samples_from_")
                end   = fname.find(f"_best_{metric_used}.csv")
                combo = fname[start:end]  # e.g. "rnaseq_rppa_wsi"
            else:
                combo = None

            # Load generated samples: shape ~ (N * test_repeats, D)
            if method.endswith('multi'):
                x_gen = pd.read_csv(gen_path)
            else:
                x_gen = pd.read_csv(gen_path, index_col=0) # NOTE: For 'coherent' or 'simple', generated samples may be stored with index

            subtitle = f"Method = {method}"
            if combo:
                subtitle += f" (conds: {combo})"
            subtitle += f", Metric = {metric_used}"
                
            print(f'data from: {gen_path}')

            # 5) Create figure with 10 subplots
            fig, axes = plt.subplots(2, 5, figsize=(20, 8))
            fig.suptitle(f"UMAP Real vs Generated for '{modality}'\n{subtitle}",
                         fontsize=16)

            for idx, ax in zip(sample_indices, axes.flat):
                # Pick out generated points for this test index:
                # rows at idx, idx+N, idx+2N, ... until end
                indices = np.arange(idx, len(x_gen), N)
                single_gen_sample = x_gen.iloc[indices].values
                gen_sample_umap = umap_reducer.transform(single_gen_sample)

                # Plot distribution of combined train+val in grey
                ax.scatter(
                    X_train_val_umap[:, 0],
                    X_train_val_umap[:, 1],
                    color="grey",
                    alpha=0.2,
                    s=5,
                    label="Train Distribution"
                )
                # Plot the real test point
                ax.scatter(
                    X_test_umap[idx, 0],
                    X_test_umap[idx, 1],
                    color="blue",
                    marker="D",
                    s=50,
                    label="Real Test Point"
                )
                # Plot generated points
                ax.scatter(
                    gen_sample_umap[:, 0],
                    gen_sample_umap[:, 1],
                    color="darkorange",
                    alpha=0.5,
                    s=20,
                    label="Generated Samples"
                )

                ax.set_title(f"Test Index = {idx}", fontsize=10)
                ax.set_xlabel("UMAP Component 1", fontsize=8)
                ax.set_ylabel("UMAP Component 2", fontsize=8)
                ax.tick_params(labelsize=6)

                if idx == sample_indices[0]:
                    ax.legend(loc="upper left", fontsize=8)

            plt.tight_layout(rect=[0, 0.03, 1, 0.95])
            plt.show()


metric = 'mse' # choices=["mse","cosine","timestep"]

visualize_generated_samples(metric)