In [None]:
import os
import json
import torch
import pandas as pd

BASE_DIR = "../results/32"
rows = []

for root, dirs, files in os.walk(BASE_DIR):
    for fname in files:
        if not (fname.startswith("test_metrics") and fname.endswith(".json")):
            continue

        json_path = os.path.join(root, fname)
        # The “experiment” folder is the parent of “test”, e.g. "cna_from_rnaseq"
        experiment = os.path.basename(os.path.dirname(root))

        # 1) Read the JSON contents (metrics)
        with open(json_path, "r") as f:
            data = json.load(f)

        # 2) Determine sampling_type
        if "_from_coherent" in experiment:
            sampling_type = "coherent"
        elif "_from_multi" in experiment:
            sampling_type = "multi"
        else:
            sampling_type = "single"

        # 3) Parse target and source
        if sampling_type == "single":
            target, src_part = experiment.split("_from_", 1)
            source = src_part
        else:
            target = experiment.split("_from_", 1)[0]
            # In coherent/multi, “source” combo is encoded in the filename:
            # e.g. "test_metrics_from_rnaseq_rppa_best_mse.json"
            prefix = "test_metrics_from_"
            suffix = "_best"
            combo = fname[len(prefix) : fname.rfind(suffix)]
            source = combo

        # 4) Identify which reference metric wrote this file
        if "best_mse" in fname:
            metric_used = "mse"
        elif "best_cosine" in fname:
            metric_used = "cosine"
        elif "best_timestep" in fname:
            metric_used = "timestep"
        else:
            metric_used = None

        # 5) Locate the corresponding checkpoint under “train/”
        #    If this is single sampling, the checkpoint is at:
        #       <BASE_DIR>/<experiment>/train/best_by_<metric_used>.pth
        #    If coherent/multi, it’s still in “<experiment>/train/...”
        train_dir = os.path.join(BASE_DIR, experiment, "train")
        ckpt_filename = f"best_by_{metric_used}.pth"
        ckpt_path = os.path.join(train_dir, ckpt_filename)

        # 6) Load that checkpoint (CPU) and extract “architecture”
        if os.path.exists(ckpt_path):
            ckpt = torch.load(ckpt_path, map_location="cpu")
            raw_cfg = ckpt.get("config", {})
            architecture = raw_cfg.get("architecture", None)
        else:
            # In case something’s missing, mark as None
            architecture = None

        # 7) Build the row
        rows.append({
            "experiment":     experiment,
            "target":         target,
            "source":         source,
            "sampling_type":  sampling_type,
            "metric_used":    metric_used,
            "architecture":   architecture,
            **data       # This unpacks {mse_mean, mse_std, r2_mean, r2_std, cos_mean, cos_std}
        })

# 8) Finally, turn into a DataFrame
df_results = pd.DataFrame(rows)


In [None]:
df_results

In [None]:
# Count architectures per metric_used
counts = df_results.groupby("metric_used")["architecture"] \
                   .value_counts() \
                   .reset_index(name="count")

# Pivot so rows are metric_used, columns are architectures
counts_pivot = counts.pivot(index="metric_used", 
                            columns="architecture", 
                            values="count") \
                     .fillna(0) \
                     .astype(int)

print(counts_pivot)


In [None]:
print(df_results["source"].unique())

In [None]:
import os
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from itertools import combinations

# === 1) Gather all test_metrics JSONs into a single DataFrame ===

# Make sure BASE_DIR points at the folder that contains “cna_from_…”, “rnaseq_from_…”, etc.
BASE_DIR = "../results/32"
rows = []

for root, dirs, files in os.walk(BASE_DIR):
    for fname in files:
        if not (fname.startswith("test_metrics") and fname.endswith(".json")):
            continue

        fullpath = os.path.join(root, fname)
        exp_folder = os.path.basename(os.path.dirname(root))

        # Determine sampling_type
        if "_from_coherent" in exp_folder:
            sampling_type = "coherent"
        elif "_from_multi" in exp_folder:
            sampling_type = "multi"
        else:
            sampling_type = "single"

        # Parse target and source
        if sampling_type == "single":
            # e.g. “cna_from_rnaseq” → target="cna", source="rnaseq"
            if "_from_" not in exp_folder:
                continue
            target, src_part = exp_folder.split("_from_", 1)
            source = src_part
        else:
            # e.g. “cna_from_coherent” → target="cna", then combo in filename
            target = exp_folder.split("_from_", 1)[0]
            prefix = "test_metrics_from_"
            suffix = "_best"
            combo = fname[len(prefix): fname.rfind(suffix)]
            source = combo  # e.g. "rna_rppa" or "cna_rppa_wsi"

        # Which reference metric?
        if "best_mse" in fname:
            metric_used = "mse"
        elif "best_cosine" in fname:
            metric_used = "cosine"
        elif "best_timestep" in fname:
            metric_used = "timestep"
        else:
            metric_used = None

        with open(fullpath, "r") as f:
            data = json.load(f)

        rows.append({
            "experiment":    exp_folder,
            "target":        target,
            "source":        source,
            "sampling_type": sampling_type,
            "metric_used":   metric_used,
            **data
        })

df_results = pd.DataFrame(rows)
if df_results.empty:
    raise RuntimeError(f"No test_metrics JSON found under {BASE_DIR}.")

# Convert metric columns to numeric
for col in ["mse_mean","mse_std","r2_mean","r2_std","cos_mean","cos_std"]:
    if col in df_results:
        df_results[col] = pd.to_numeric(df_results[col], errors="coerce")

# === 2) For each target modality, plot a grouped bar chart of Test R² ===

metrics     = ["mse", "cosine", "timestep"]
all_targets = sorted(df_results["target"].unique())

for modality in all_targets:
    # 2a) Identify other modalities
    others = [m for m in all_targets if m != modality]

    # 2b) Build label lists, prefixing coherent/multi to distinguish
    single_labels   = others.copy()  # e.g. ["cna","rnaseq","rppa"]

    # e.g. ["coh_cna_rnaseq","coh_cna_rppa","coh_rnaseq_rppa","coh_cna_rnaseq_rppa"]
    coherent_labels = [f"coh_{'_'.join(combo)}"
                       for r in [2, 3]
                       for combo in combinations(others, r)]

    # e.g. ["mult_cna","mult_rnaseq","mult_rppa", ... up to size 3]
    multi_labels = [f"mult_{'_'.join(combo)}"
                    for r in [1, 2, 3]
                    for combo in combinations(others, r)]

    # All 14 labels in the exact order requested
    labels = single_labels + coherent_labels + multi_labels

    # 2c) Create data matrices filled with NaN
    data_mat = np.full((len(labels), len(metrics)), np.nan)
    err_mat  = np.full((len(labels), len(metrics)), np.nan)

    # 2d) Fill in Test R² values
    for i, lbl in enumerate(labels):
        if lbl in single_labels:
            sampling_type = "single"
            source        = lbl
        elif lbl in coherent_labels:
            sampling_type = "coherent"
            source        = lbl.replace("coh_", "", 1)
        else:
            sampling_type = "multi"
            source        = lbl.replace("mult_", "", 1)

        for j, m in enumerate(metrics):
            df_row = df_results[
                (df_results["target"] == modality) &
                (df_results["sampling_type"] == sampling_type) &
                (df_results["source"] == source) &
                (df_results["metric_used"] == m)
            ]
            if not df_row.empty:
                data_mat[i, j] = df_row.iloc[0]["r2_mean"]
                err_mat[i, j]  = df_row.iloc[0]["r2_std"]

    # 2e) Plot grouped bars
    x = np.arange(len(labels))
    width = 0.25

    fig, ax = plt.subplots(figsize=(14, 6))
    for j, m in enumerate(metrics):
        ax.bar(
            x + (j - 1) * width,
            data_mat[:, j],
            width,
            yerr=err_mat[:, j],
            capsize=5,
            label=m.capitalize()
        )

    ax.set_ylabel("Test R² (Mean ± Std)")
    ax.set_title(f"Target = '{modality}': R² by Source (bars grouped by reference metric)")
    ax.set_xticks(x)
    ax.set_xticklabels(labels, rotation=45, ha="right")
    ax.legend(title="Picked by")
    ax.grid(axis="y", linestyle="--", alpha=0.7)
    plt.tight_layout()
    plt.show()


In [None]:
import os
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from itertools import combinations

# === 1) Gather all test_metrics JSONs into a single DataFrame ===

# Make sure BASE_DIR points at the folder that contains "cna_from_…", "rnaseq_from_…", etc.
BASE_DIR = "../results/32"
rows = []

for root, dirs, files in os.walk(BASE_DIR):
    for fname in files:
        if not (fname.startswith("test_metrics") and fname.endswith(".json")):
            continue

        fullpath = os.path.join(root, fname)
        exp_folder = os.path.basename(os.path.dirname(root))

        # Determine sampling_type
        if "_from_coherent" in exp_folder:
            sampling_type = "coherent"
        elif "_from_multi" in exp_folder:
            sampling_type = "multi"
        else:
            sampling_type = "single"

        # Parse target and source
        if sampling_type == "single":
            # e.g. "cna_from_rnaseq" → target="cna", source="rnaseq"
            if "_from_" not in exp_folder:
                continue
            target, src_part = exp_folder.split("_from_", 1)
            source = src_part
        else:
            # e.g. "cna_from_coherent" → target="cna", then combo in filename
            target = exp_folder.split("_from_", 1)[0]
            prefix = "test_metrics_from_"
            suffix = "_best"
            combo = fname[len(prefix): fname.rfind(suffix)]
            source = combo  # e.g. "rna_rppa" or "cna_rppa_wsi"

        # Which reference metric?
        if "best_mse" in fname:
            metric_used = "mse"
        elif "best_cosine" in fname:
            metric_used = "cosine"
        elif "best_timestep" in fname:
            metric_used = "timestep"
        else:
            metric_used = None

        with open(fullpath, "r") as f:
            data = json.load(f)

        rows.append({
            "experiment":    exp_folder,
            "target":        target,
            "source":        source,
            "sampling_type": sampling_type,
            "metric_used":   metric_used,
            **data
        })

df_results = pd.DataFrame(rows)
if df_results.empty:
    raise RuntimeError(f"No test_metrics JSON found under {BASE_DIR}.")

# Convert metric columns to numeric
for col in ["mse_mean","mse_std","r2_mean","r2_std","cos_mean","cos_std"]:
    if col in df_results:
        df_results[col] = pd.to_numeric(df_results[col], errors="coerce")

# === 2) For each target modality, plot a bar chart of Test R² (MSE metric only) ===

all_targets = sorted(df_results["target"].unique())

for modality in all_targets:
    # 2a) Identify other modalities
    others = [m for m in all_targets if m != modality]

    # 2b) Build label lists, prefixing coherent/multi to distinguish
    single_labels   = others.copy()  # e.g. ["cna","rnaseq","rppa"]

    # e.g. ["coh_cna_rnaseq","coh_cna_rppa","coh_rnaseq_rppa","coh_cna_rnaseq_rppa"]
    coherent_labels = [f"coh_{'_'.join(combo)}"
                       for r in [2, 3]
                       for combo in combinations(others, r)]

    # e.g. ["mult_cna","mult_rnaseq","mult_rppa", ... up to size 3]
    multi_labels = [f"mult_{'_'.join(combo)}"
                    for r in [1, 2, 3]
                    for combo in combinations(others, r)]

    # All labels in the exact order requested
    labels = single_labels + coherent_labels + multi_labels

    # 2c) Create data arrays for MSE metric only
    data_values = np.full(len(labels), np.nan)
    err_values = np.full(len(labels), np.nan)
    
    # Track which bars should be highlighted (all conditionings)
    highlight_bars = []

    # 2d) Fill in Test R² values for MSE metric only
    for i, lbl in enumerate(labels):
        if lbl in single_labels:
            sampling_type = "single"
            source        = lbl
        elif lbl in coherent_labels:
            sampling_type = "coherent"
            source        = lbl.replace("coh_", "", 1)
            # Check if this is the "all conditionings" case for coherent
            if len(source.split('_')) == len(others):
                highlight_bars.append(i)
        else:
            sampling_type = "multi"
            source        = lbl.replace("mult_", "", 1)
            # Check if this is the "all conditionings" case for multi
            if len(source.split('_')) == len(others):
                highlight_bars.append(i)

        # Only look for MSE metric
        df_row = df_results[
            (df_results["target"] == modality) &
            (df_results["sampling_type"] == sampling_type) &
            (df_results["source"] == source) &
            (df_results["metric_used"] == "mse")
        ]
        if not df_row.empty:
            data_values[i] = df_row.iloc[0]["r2_mean"]
            err_values[i] = df_row.iloc[0]["r2_std"]

    # 2e) Plot bars with highlighting
    x = np.arange(len(labels))
    
    fig, ax = plt.subplots(figsize=(14, 6))
    
    # Create bars with uniform color
    bars = ax.bar(x, data_values, yerr=err_values, capsize=5, color='steelblue')
    
    # Make highlighted bars more prominent with thicker border
    for i in highlight_bars:
        bars[i].set_edgecolor('black')
        bars[i].set_linewidth(3)

    ax.set_ylabel("Test R² (Mean ± Std)")
    ax.set_title(f"Target = '{modality}': R² by Source (best models by MSE)")
    ax.set_xticks(x)
    
    # Create labels with bold formatting for highlighted bars
    label_weights = ['bold' if i in highlight_bars else 'normal' for i in range(len(labels))]
    ax.set_xticklabels(labels, rotation=45, ha="right")
    
    # Apply bold formatting to specific tick labels
    for i, (tick, weight) in enumerate(zip(ax.get_xticklabels(), label_weights)):
        tick.set_weight(weight)
    
    ax.grid(axis="y", linestyle="--", alpha=0.7)
    plt.tight_layout()
    plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from itertools import combinations

# === assume df_results is already built ===
all_targets = sorted(df_results["target"].unique())
n = len(all_targets)

# 1) One figure, n rows
fig, axes = plt.subplots(nrows=n, ncols=1, figsize=(14, 6 * n), sharex=False)
if n == 1:
    axes = [axes]

for ax, modality in zip(axes, all_targets):
    # 2a) Identify others
    others = [m for m in all_targets if m != modality]

    # 2b) Build labels
    single_labels   = others.copy()
    coherent_labels = [f"coh_{'_'.join(combo)}"
                       for r in [2, 3]
                       for combo in combinations(others, r)]
    multi_labels    = [f"mult_{'_'.join(combo)}"
                       for r in [1, 2, 3]
                       for combo in combinations(others, r)]
    labels = single_labels + coherent_labels + multi_labels

    # 2c) Prepare data arrays for R²
    data_values = np.full(len(labels), np.nan)
    err_values  = np.full(len(labels), np.nan)
    highlight_bars = []

    # 2d) Fill in R² mean & std
    for i, lbl in enumerate(labels):
        if lbl in single_labels:
            stype, src = "single", lbl
        elif lbl in coherent_labels:
            stype, src = "coherent", lbl.replace("coh_", "", 1)
            if len(src.split('_')) == len(others):
                highlight_bars.append(i)
        else:
            stype, src = "multi", lbl.replace("mult_", "", 1)
            if len(src.split('_')) == len(others):
                highlight_bars.append(i)

        sel = df_results[
            (df_results["target"]        == modality) &
            (df_results["sampling_type"] == stype)     &
            (df_results["source"]        == src)       &
            (df_results["metric_used"]   == "mse")     # still best-MSE models
        ]
        if not sel.empty:
            data_values[i] = sel.iloc[0]["r2_mean"]
            err_values[i]  = sel.iloc[0]["r2_std"]

    # 2e) Plot on this Ax
    x = np.arange(len(labels))
    bars = ax.bar(x, data_values, yerr=err_values, capsize=5, color='steelblue')
    for i in highlight_bars:
        bars[i].set_edgecolor('black')
        bars[i].set_linewidth(3)

    ax.set_ylabel("Test R² (Mean ± Std)")
    ax.set_title(f"Target = '{modality}'")
    ax.set_xticks(x)
    ax.set_xticklabels(labels, rotation=45, ha="right")
    for i, tick in enumerate(ax.get_xticklabels()):
        tick.set_weight('bold' if i in highlight_bars else 'normal')
    ax.grid(axis="y", linestyle="--", alpha=0.7)

plt.tight_layout(rect=[0, 0, 1, 0.95])     # leave room at top for the main title
fig.suptitle("Overall Comparison of Models: R² Across All Targets",
             fontsize=18, weight='bold')

# 3) Save all subplots in one image
fig.savefig("../results/images/all_modalities_r2.png", dpi=300)


In [None]:
import os
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from itertools import combinations

# === 1) Gather all test_metrics JSONs into a single DataFrame ===

# Make sure BASE_DIR points at the folder that contains "cna_from_…", "rnaseq_from_…", etc.
BASE_DIR = "../results/32"
rows = []

for root, dirs, files in os.walk(BASE_DIR):
    for fname in files:
        if not (fname.startswith("test_metrics") and fname.endswith(".json")):
            continue

        fullpath = os.path.join(root, fname)
        exp_folder = os.path.basename(os.path.dirname(root))

        # Determine sampling_type
        if "_from_coherent" in exp_folder:
            sampling_type = "coherent"
        elif "_from_multi" in exp_folder:
            sampling_type = "multi"
        else:
            sampling_type = "single"

        # Parse target and source
        if sampling_type == "single":
            if "_from_" not in exp_folder:
                continue
            target, src_part = exp_folder.split("_from_", 1)
            source = src_part
        else:
            target = exp_folder.split("_from_", 1)[0]
            prefix = "test_metrics_from_"
            suffix = "_best"
            combo = fname[len(prefix): fname.rfind(suffix)]
            source = combo

        # Which reference metric?
        if "best_mse" in fname:
            metric_used = "mse"
        elif "best_cosine" in fname:
            metric_used = "cosine"
        elif "best_timestep" in fname:
            metric_used = "timestep"
        else:
            metric_used = None

        with open(fullpath, "r") as f:
            data = json.load(f)

        rows.append({
            "experiment":    exp_folder,
            "target":        target,
            "source":        source,
            "sampling_type": sampling_type,
            "metric_used":   metric_used,
            **data
        })

df_results = pd.DataFrame(rows)
if df_results.empty:
    raise RuntimeError(f"No test_metrics JSON found under {BASE_DIR}.")

# Convert metric columns to numeric
for col in ["mse_mean", "mse_std", "r2_mean", "r2_std", "cos_mean", "cos_std"]:
    if col in df_results:
        df_results[col] = pd.to_numeric(df_results[col], errors="coerce")

# === 2) Pre-computation for synchronized Y-axis ===

all_targets = sorted(df_results["target"].unique())
conditions_to_plot = []

# Identify all data points that will be part of the plot
for modality in all_targets:
    others = sorted([m for m in all_targets if m != modality])
    single_sources = others
    full_multi_source = '_'.join(others)
    full_coherent_source = '_'.join(others)

    for s in single_sources:
        conditions_to_plot.append((modality, s, 'single'))
    conditions_to_plot.append((modality, full_coherent_source, 'coherent'))
    conditions_to_plot.append((modality, full_multi_source, 'multi'))

# Create a boolean mask to filter the DataFrame
mask = pd.Series(False, index=df_results.index)
for target, source, sampling in conditions_to_plot:
    mask |= (
        (df_results["target"] == target) &
        (df_results["source"] == source) &
        (df_results["sampling_type"] == sampling) &
        (df_results["metric_used"] == "mse")
    )
df_plot_data = df_results[mask].copy()

# Calculate the global minimum y-value from the relevant data
min_y_val = (df_plot_data['r2_mean'] - df_plot_data['r2_std']).min()
# Define a padded bottom limit
y_bottom_limit = min_y_val - 0.05


# === 3) Create a single 2x2 plot for all target modalities ===

plt.style.use('seaborn-v0_8-white')
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
axes = axes.flatten()

color_map = {
    'single':          "#8a969e",
    'full_coherent':   '#56b4e9',
    'full_multimodal': '#0072b2',
}
legend_labels = {
    'single': 'Single-conditioning',
    'full_coherent': 'Coherent Denoising',
    'full_multimodal': 'Multi-conditioning'
}

for i, modality in enumerate(all_targets):
    if i >= len(axes):
        break

    ax = axes[i]
    others = sorted([m for m in all_targets if m != modality])

    single_sources = others
    full_multi_source = '_'.join(others)
    full_coherent_source = '_'.join(others)

    labels = single_sources + ["Coherent", "Multi"]
    data_sources = [
        (s, 'single', 'single') for s in single_sources
    ]
    data_sources.append((full_coherent_source, 'coherent', 'full_coherent'))
    data_sources.append((full_multi_source, 'multi', 'full_multimodal'))

    data_values = np.full(len(labels), np.nan)
    err_values = np.full(len(labels), np.nan)
    bar_colors = []

    for j, (source, sampling, bar_type) in enumerate(data_sources):
        df_row = df_results[
            (df_results["target"] == modality) &
            (df_results["sampling_type"] == sampling) &
            (df_results["source"] == source) &
            (df_results["metric_used"] == "mse")
        ]
        if not df_row.empty:
            data_values[j] = df_row.iloc[0]["r2_mean"]
            err_values[j] = df_row.iloc[0]["r2_std"]
        bar_colors.append(color_map[bar_type])

    x = np.arange(len(labels))
    ax.bar(x, data_values, yerr=err_values, capsize=0, color=bar_colors, width=0.6, zorder=10)

    ax.set_title(f"Target: {modality.upper()}", fontsize=16, weight='bold')
    ax.set_ylabel("Test R² (Mean ± Std)", fontsize=12)
    ax.set_xticks(x)
    ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=11)

    # --- UPDATED: Apply synchronized y-limits ---
    ax.set_ylim(bottom=y_bottom_limit, top=1.0)

    ax.axhline(0, color='black', linewidth=1.2, zorder=5)
    ax.yaxis.grid(True, linestyle='--', linewidth=0.7, color="#CDCCCC", zorder=0)
    ax.xaxis.grid(False)

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_color('#B0B0B0')
    #ax.spines['bottom'].set_color('#B0B0B0')

# === 4) Final figure-level adjustments ===

handles = [plt.Rectangle((0,0),1,1, color=color_map[key]) for key in legend_labels]
fig.legend(handles, legend_labels.values(),
           title='Generative Method',
           loc='center left',
           bbox_to_anchor=(0.91, 0.5),
           frameon=False,
           fontsize=12,
           title_fontsize=14)

fig.suptitle("Overall Reconstruction Performance Across Target Modalities", fontsize=22, weight='bold')

# --- UPDATED: Increased vertical space (hspace) ---
fig.subplots_adjust(
    left=0.08,
    right=0.88,
    bottom=0.15,
    top=0.9,
    wspace=0.3,   # Horizontal space between plots
    hspace=0.5    # Vertical space between plots (increased from 0.4)
)

# === 5) Save the final image ===
save_path = "../results/images/r2_compact.png"
plt.savefig(save_path, dpi=300, bbox_inches='tight')

plt.show()

print(f"Plot saved to {save_path}")

In [None]:
import os
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from itertools import combinations

# === 1) Gather all test_metrics JSONs into a single DataFrame ===

# Make sure BASE_DIR points at the folder that contains "cna_from_…", "rnaseq_from_…", etc.
BASE_DIR = "../results/32"
rows = []

for root, dirs, files in os.walk(BASE_DIR):
    for fname in files:
        if not (fname.startswith("test_metrics") and fname.endswith(".json")):
            continue

        fullpath = os.path.join(root, fname)
        exp_folder = os.path.basename(os.path.dirname(root))

        # Determine sampling_type
        if "_from_coherent" in exp_folder:
            sampling_type = "coherent"
        elif "_from_multi" in exp_folder:
            sampling_type = "multi"
        else:
            sampling_type = "single"

        # Parse target and source
        if sampling_type == "single":
            # e.g. "cna_from_rnaseq" → target="cna", source="rnaseq"
            if "_from_" not in exp_folder:
                continue
            target, src_part = exp_folder.split("_from_", 1)
            source = src_part
        else:
            # e.g. "cna_from_coherent" → target="cna", then combo in filename
            target = exp_folder.split("_from_", 1)[0]
            prefix = "test_metrics_from_"
            suffix = "_best"
            combo = fname[len(prefix): fname.rfind(suffix)]
            source = combo  # e.g. "rna_rppa" or "cna_rppa_wsi"

        # Which reference metric?
        if "best_mse" in fname:
            metric_used = "mse"
        elif "best_cosine" in fname:
            metric_used = "cosine"
        elif "best_timestep" in fname:
            metric_used = "timestep"
        else:
            metric_used = None

        with open(fullpath, "r") as f:
            data = json.load(f)

        rows.append({
            "experiment":    exp_folder,
            "target":        target,
            "source":        source,
            "sampling_type": sampling_type,
            "metric_used":   metric_used,
            **data
        })

df_results = pd.DataFrame(rows)
if df_results.empty:
    raise RuntimeError(f"No test_metrics JSON found under {BASE_DIR}.")

# Convert metric columns to numeric
for col in ["mse_mean","mse_std","r2_mean","r2_std","cos_mean","cos_std"]:
    if col in df_results:
        df_results[col] = pd.to_numeric(df_results[col], errors="coerce")

# === 2) For each target modality, plot a bar chart of Test R² (MSE metric only) ===

all_targets = sorted(df_results["target"].unique())

for modality in all_targets:
    # 2a) Identify other modalities
    others = [m for m in all_targets if m != modality]

    # 2b) Build label lists, prefixing coherent/multi to distinguish
    single_labels   = others.copy()  # e.g. ["cna","rnaseq","rppa"]

    # e.g. ["coh_cna_rnaseq","coh_cna_rppa","coh_rnaseq_rppa","coh_cna_rnaseq_rppa"]
    coherent_labels = [f"coh_{'_'.join(combo)}"
                       for r in [2, 3]
                       for combo in combinations(others, r)]

    # e.g. ["mult_cna","mult_rnaseq","mult_rppa", ... up to size 3]
    multi_labels = [f"mult_{'_'.join(combo)}"
                    for r in [1, 2, 3]
                    for combo in combinations(others, r)]

    # All labels in the exact order requested
    labels = single_labels + coherent_labels + multi_labels

    # 2c) Create data arrays for MSE metric only
    data_values = np.full(len(labels), np.nan)
    err_values = np.full(len(labels), np.nan)
    
    # Track which bars should be highlighted (all conditionings)
    highlight_bars = []

    # 2d) Fill in Test R² values for MSE metric only
    for i, lbl in enumerate(labels):
        if lbl in single_labels:
            sampling_type = "single"
            source        = lbl
        elif lbl in coherent_labels:
            sampling_type = "coherent"
            source        = lbl.replace("coh_", "", 1)
            # Check if this is the "all conditionings" case for coherent
            if len(source.split('_')) == len(others):
                highlight_bars.append(i)
        else:
            sampling_type = "multi"
            source        = lbl.replace("mult_", "", 1)
            # Check if this is the "all conditionings" case for multi
            if len(source.split('_')) == len(others):
                highlight_bars.append(i)

        # Only look for MSE metric
        df_row = df_results[
            (df_results["target"] == modality) &
            (df_results["sampling_type"] == sampling_type) &
            (df_results["source"] == source) &
            (df_results["metric_used"] == "mse")
        ]
        if not df_row.empty:
            data_values[i] = df_row.iloc[0]["mse_mean"]
            err_values[i] = df_row.iloc[0]["mse_std"]

    # 2e) Plot bars with highlighting
    x = np.arange(len(labels))
    
    fig, ax = plt.subplots(figsize=(14, 6))
    
    # Create bars with uniform color
    bars = ax.bar(x, data_values, yerr=err_values, capsize=5, color='steelblue')
    
    # Make highlighted bars more prominent with thicker border
    for i in highlight_bars:
        bars[i].set_edgecolor('black')
        bars[i].set_linewidth(3)

    ax.set_ylabel("Test ME (Mean ± Std)")
    ax.set_title(f"Target = '{modality}': MSE by Source (best models by MSE)")
    ax.set_xticks(x)
    
    # Create labels with bold formatting for highlighted bars
    label_weights = ['bold' if i in highlight_bars else 'normal' for i in range(len(labels))]
    ax.set_xticklabels(labels, rotation=45, ha="right")
    
    # Apply bold formatting to specific tick labels
    for i, (tick, weight) in enumerate(zip(ax.get_xticklabels(), label_weights)):
        tick.set_weight(weight)
    
    ax.grid(axis="y", linestyle="--", alpha=0.7)
    plt.tight_layout()
    plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from itertools import combinations

# === assume df_results is already built ===
all_targets = sorted(df_results["target"].unique())
n = len(all_targets)

# 1) Make a figure with one row per target
fig, axes = plt.subplots(nrows=n, ncols=1, figsize=(14, 6 * n), sharex=False)

# If there's only one target, axes might not be a list
if n == 1:
    axes = [axes]

for ax, modality in zip(axes, all_targets):
    # Identify other modalities
    others = [m for m in all_targets if m != modality]

    # Build labels
    single_labels   = others.copy()
    coherent_labels = [f"coh_{'_'.join(combo)}"
                       for r in [2, 3]
                       for combo in combinations(others, r)]
    multi_labels    = [f"mult_{'_'.join(combo)}"
                       for r in [1, 2, 3]
                       for combo in combinations(others, r)]
    labels = single_labels + coherent_labels + multi_labels

    # Prepare data arrays
    data_values = np.full(len(labels), np.nan)
    err_values  = np.full(len(labels), np.nan)
    highlight_bars = []

    # Fill in values
    for i, lbl in enumerate(labels):
        if lbl in single_labels:
            stype, src = "single", lbl
        elif lbl in coherent_labels:
            stype, src = "coherent", lbl.replace("coh_", "", 1)
            if len(src.split('_')) == len(others):
                highlight_bars.append(i)
        else:
            stype, src = "multi", lbl.replace("mult_", "", 1)
            if len(src.split('_')) == len(others):
                highlight_bars.append(i)

        sel = df_results[
            (df_results["target"] == modality) &
            (df_results["sampling_type"] == stype) &
            (df_results["source"] == src) &
            (df_results["metric_used"] == "mse")
        ]
        if not sel.empty:
            data_values[i] = sel.iloc[0]["mse_mean"]
            err_values[i]  = sel.iloc[0]["mse_std"]

    # 2) Plot onto this subplot
    x = np.arange(len(labels))
    bars = ax.bar(x, data_values, yerr=err_values, capsize=5, color='steelblue')
    for i in highlight_bars:
        bars[i].set_edgecolor('black')
        bars[i].set_linewidth(3)

    ax.set_ylabel("Test MSE (Mean ± Std)")
    ax.set_title(f"Target = '{modality}'")
    ax.set_xticks(x)
    ax.set_xticklabels(labels, rotation=45, ha="right")
    for i, tick in enumerate(ax.get_xticklabels()):
        tick.set_weight('bold' if i in highlight_bars else 'normal')
    ax.grid(axis="y", linestyle="--", alpha=0.7)

plt.tight_layout(rect=[0, 0, 1, 0.95])     # leave room at top for the main title
fig.suptitle("Overall Comparison of Models: MSE Across All Targets",
             fontsize=18, weight='bold')

# 3) Save the whole figure once
fig.savefig("../results/images/all_modalities_mse.png", dpi=300)


In [None]:
import os
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from itertools import combinations

# === 1) Gather all test_metrics JSONs into a single DataFrame ===

# Make sure BASE_DIR points at the folder that contains "cna_from_…", "rnaseq_from_…", etc.
BASE_DIR = "../results/32"
rows = []

for root, dirs, files in os.walk(BASE_DIR):
    for fname in files:
        if not (fname.startswith("test_metrics") and fname.endswith(".json")):
            continue

        fullpath = os.path.join(root, fname)
        exp_folder = os.path.basename(os.path.dirname(root))

        # Determine sampling_type
        if "_from_coherent" in exp_folder:
            sampling_type = "coherent"
        # CHANGED: Added 'multi_masked' and placed it before 'multi'
        elif "_from_multi_masked" in exp_folder:
            sampling_type = "multi_masked"
        elif "_from_multi" in exp_folder:
            sampling_type = "multi"
        else:
            sampling_type = "single"

        # Parse target and source
        if sampling_type == "single":
            # e.g. "cna_from_rnaseq" → target="cna", source="rnaseq"
            if "_from_" not in exp_folder:
                continue
            target, src_part = exp_folder.split("_from_", 1)
            source = src_part
        else:
            # e.g. "cna_from_coherent" → target="cna", then combo in filename
            target = exp_folder.split("_from_", 1)[0]
            prefix = "test_metrics_from_"
            suffix = "_best"
            combo = fname[len(prefix): fname.rfind(suffix)]
            source = combo  # e.g. "rna_rppa" or "cna_rppa_wsi"

        # Which reference metric?
        if "best_mse" in fname:
            metric_used = "mse"
        elif "best_cosine" in fname:
            metric_used = "cosine"
        elif "best_timestep" in fname:
            metric_used = "timestep"
        else:
            metric_used = None

        with open(fullpath, "r") as f:
            data = json.load(f)

        rows.append({
            "experiment":    exp_folder,
            "target":        target,
            "source":        source,
            "sampling_type": sampling_type,
            "metric_used":   metric_used,
            **data
        })

df_results = pd.DataFrame(rows)
if df_results.empty:
    raise RuntimeError(f"No test_metrics JSON found under {BASE_DIR}.")

# Convert metric columns to numeric
for col in ["mse_mean","mse_std","r2_mean","r2_std","cos_mean","cos_std"]:
    if col in df_results:
        df_results[col] = pd.to_numeric(df_results[col], errors="coerce")

# === 2) For each target modality, plot a bar chart of Test R² (MSE metric only) ===

all_targets = sorted(df_results["target"].unique())

for modality in all_targets:
    # 2a) Identify other modalities
    others = [m for m in all_targets if m != modality]

    # 2b) Build label lists, prefixing coherent/multi/multi_masked to distinguish
    single_labels   = others.copy()  # e.g. ["cna","rnaseq","rppa"]

    # e.g. ["coh_cna_rnaseq","coh_cna_rppa","coh_rnaseq_rppa","coh_cna_rnaseq_rppa"]
    coherent_labels = [f"coh_{'_'.join(combo)}"
                       for r in [2, 3]
                       for combo in combinations(others, r)]

    # e.g. ["mult_cna","mult_rnaseq","mult_rppa", ... up to size 3]
    multi_labels = [f"mult_{'_'.join(combo)}"
                    for r in [1, 2, 3]
                    for combo in combinations(others, r)]

    # ADDED: Create labels for the new 'multi_masked' method
    multi_masked_labels = [f"mult_masked_{'_'.join(combo)}"
                           for r in [1, 2, 3]
                           for combo in combinations(others, r)]

    # All labels in the exact order requested
    # CHANGED: Added multi_masked_labels to the list
    labels = single_labels + coherent_labels + multi_labels + multi_masked_labels

    # 2c) Create data arrays for MSE metric only
    data_values = np.full(len(labels), np.nan)
    err_values = np.full(len(labels), np.nan)
    
    # Track which bars should be highlighted (all conditionings)
    highlight_bars = []

    # 2d) Fill in Test R² values for MSE metric only
    for i, lbl in enumerate(labels):
        # CHANGED: Added a new condition for 'multi_masked_labels'
        if lbl in single_labels:
            sampling_type = "single"
            source        = lbl
        elif lbl in coherent_labels:
            sampling_type = "coherent"
            source        = lbl.replace("coh_", "", 1)
            # Check if this is the "all conditionings" case for coherent
            if len(source.split('_')) == len(others):
                highlight_bars.append(i)
        elif lbl in multi_masked_labels: # ADDED: Handling for multi_masked
            sampling_type = "multi_masked"
            source        = lbl.replace("mult_masked_", "", 1)
            # Check if this is the "all conditionings" case for multi_masked
            if len(source.split('_')) == len(others):
                highlight_bars.append(i)
        elif lbl in multi_labels: # CHANGED: Changed 'else' to 'elif' for clarity
            sampling_type = "multi"
            source        = lbl.replace("mult_", "", 1)
            # Check if this is the "all conditionings" case for multi
            if len(source.split('_')) == len(others):
                highlight_bars.append(i)
        else:
            continue # Should not happen with current label setup

        # Only look for MSE metric
        df_row = df_results[
            (df_results["target"] == modality) &
            (df_results["sampling_type"] == sampling_type) &
            (df_results["source"] == source) &
            (df_results["metric_used"] == "mse")
        ]
        if not df_row.empty:
            data_values[i] = df_row.iloc[0]["r2_mean"]
            err_values[i] = df_row.iloc[0]["r2_std"]

    # 2e) Plot bars with highlighting
    x = np.arange(len(labels))
    
    fig, ax = plt.subplots(figsize=(16, 7)) # Increased figure size for more labels
    
    # Create bars with uniform color
    bars = ax.bar(x, data_values, yerr=err_values, capsize=5, color='steelblue')
    
    # Make highlighted bars more prominent with thicker border
    for i in highlight_bars:
        if i < len(bars): # Ensure index is valid
            bars[i].set_edgecolor('black')
            bars[i].set_linewidth(3)

    ax.set_ylabel("Test R² (Mean ± Std)")
    ax.set_title(f"Target = '{modality}': R² by Source (best models by MSE)")
    ax.set_xticks(x)
    
    # Create labels with bold formatting for highlighted bars
    label_weights = ['bold' if i in highlight_bars else 'normal' for i in range(len(labels))]
    ax.set_xticklabels(labels, rotation=60, ha="right") # Increased rotation for readability
    
    # Apply bold formatting to specific tick labels
    for i, (tick, weight) in enumerate(zip(ax.get_xticklabels(), label_weights)):
        tick.set_weight(weight)
    
    ax.grid(axis="y", linestyle="--", alpha=0.7)
    # Add a horizontal line at y=0 for reference
    ax.axhline(0, color='black', linewidth=0.8)
    plt.tight_layout()
    plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from itertools import combinations

# === assume df_results is already built ===
# For testing, you might need to create a dummy df_results if running this standalone.
# Example:
# import pandas as pd
# data = {'target': ['cna', 'cna', 'rnaseq'], 'sampling_type': ['single', 'multi', 'multi_masked'], 
#         'source': ['rnaseq', 'rnaseq', 'cna'], 'metric_used': ['mse', 'mse', 'mse'], 
#         'r2_mean': [0.8, 0.9, 0.85], 'r2_std': [0.05, 0.04, 0.06]}
# df_results = pd.DataFrame(data)


all_targets = sorted(df_results["target"].unique())
n = len(all_targets)

# 1) One figure, n rows
# CHANGED: Increased figure width to accommodate more bars
fig, axes = plt.subplots(nrows=n, ncols=1, figsize=(16, 6 * n), sharex=False)
if n == 1:
    axes = [axes]

for ax, modality in zip(axes, all_targets):
    # 2a) Identify others
    others = [m for m in all_targets if m != modality]

    # 2b) Build labels
    single_labels   = others.copy()
    coherent_labels = [f"coh_{'_'.join(combo)}"
                       for r in [2, 3]
                       for combo in combinations(others, r)]
    multi_labels    = [f"mult_{'_'.join(combo)}"
                       for r in [1, 2, 3]
                       for combo in combinations(others, r)]
    
    # ADDED: Create labels for the new 'multi_masked' method
    multi_masked_labels = [f"mult_masked_{'_'.join(combo)}"
                           for r in [1, 2, 3]
                           for combo in combinations(others, r)]
    
    # CHANGED: Added multi_masked_labels to the list
    labels = single_labels + coherent_labels + multi_labels + multi_masked_labels

    # 2c) Prepare data arrays for R²
    data_values = np.full(len(labels), np.nan)
    err_values  = np.full(len(labels), np.nan)
    highlight_bars = []

    # 2d) Fill in R² mean & std
    # CHANGED: Added a new condition for 'multi_masked_labels'
    for i, lbl in enumerate(labels):
        if lbl in single_labels:
            stype, src = "single", lbl
        elif lbl in coherent_labels:
            stype, src = "coherent", lbl.replace("coh_", "", 1)
            if len(src.split('_')) == len(others):
                highlight_bars.append(i)
        elif lbl in multi_masked_labels: # ADDED: Handling for multi_masked
            stype, src = "multi_masked", lbl.replace("mult_masked_", "", 1)
            if len(src.split('_')) == len(others):
                highlight_bars.append(i)
        elif lbl in multi_labels: # CHANGED: Changed 'else' to 'elif' for clarity
            stype, src = "multi", lbl.replace("mult_", "", 1)
            if len(src.split('_')) == len(others):
                highlight_bars.append(i)
        else:
            continue

        sel = df_results[
            (df_results["target"]        == modality) &
            (df_results["sampling_type"] == stype)    &
            (df_results["source"]        == src)      &
            (df_results["metric_used"]   == "mse")    # still best-MSE models
        ]
        if not sel.empty:
            data_values[i] = sel.iloc[0]["r2_mean"]
            err_values[i]  = sel.iloc[0]["r2_std"]

    # 2e) Plot on this Ax
    x = np.arange(len(labels))
    bars = ax.bar(x, data_values, yerr=err_values, capsize=5, color='steelblue')
    for i in highlight_bars:
        if i < len(bars): # Ensure index is valid
            bars[i].set_edgecolor('black')
            bars[i].set_linewidth(3)

    ax.set_ylabel("Test R² (Mean ± Std)")
    ax.set_title(f"Target = '{modality}'")
    ax.set_xticks(x)
    # CHANGED: Increased rotation for better label visibility
    ax.set_xticklabels(labels, rotation=60, ha="right")
    for i, tick in enumerate(ax.get_xticklabels()):
        tick.set_weight('bold' if i in highlight_bars else 'normal')
    ax.grid(axis="y", linestyle="--", alpha=0.7)
    # Add a horizontal line at y=0 for reference
    ax.axhline(0, color='black', linewidth=0.8)


plt.tight_layout(rect=[0, 0, 1, 0.96]) # leave room at top for the main title
fig.suptitle("Overall Comparison of Models: R² Across All Targets",
             fontsize=18, weight='bold')

# 3) Save all subplots in one image
# This line will now save the updated figure including the 'multi_masked' results.
fig.savefig("../results/images/all_modalities_r2_masked.png", dpi=300, bbox_inches='tight')

In [None]:
import os
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from itertools import combinations

# === 1) Gather all test_metrics JSONs into a single DataFrame ===

# Make sure BASE_DIR points at the folder that contains "cna_from_…", "rnaseq_from_…", etc.
BASE_DIR = "../results/32"
rows = []

for root, dirs, files in os.walk(BASE_DIR):
    for fname in files:
        if not (fname.startswith("test_metrics") and fname.endswith(".json")):
            continue

        fullpath = os.path.join(root, fname)
        exp_folder = os.path.basename(os.path.dirname(root))

        # Determine sampling_type
        if "_from_coherent" in exp_folder:
            sampling_type = "coherent"
        elif "_from_multi" in exp_folder:
            sampling_type = "multi"
        else:
            sampling_type = "single"

        # Parse target and source
        if sampling_type == "single":
            # e.g. "cna_from_rnaseq" → target="cna", source="rnaseq"
            if "_from_" not in exp_folder:
                continue
            target, src_part = exp_folder.split("_from_", 1)
            source = src_part
        else:
            # e.g. "cna_from_coherent" → target="cna", then combo in filename
            target = exp_folder.split("_from_", 1)[0]
            prefix = "test_metrics_from_"
            suffix = "_best"
            combo = fname[len(prefix): fname.rfind(suffix)]
            source = combo  # e.g. "rna_rppa" or "cna_rppa_wsi"

        # Which reference metric?
        if "best_mse" in fname:
            metric_used = "mse"
        elif "best_cosine" in fname:
            metric_used = "cosine"
        elif "best_timestep" in fname:
            metric_used = "timestep"
        else:
            metric_used = None

        with open(fullpath, "r") as f:
            data = json.load(f)

        rows.append({
            "experiment":    exp_folder,
            "target":        target,
            "source":        source,
            "sampling_type": sampling_type,
            "metric_used":   metric_used,
            **data
        })

df_results = pd.DataFrame(rows)
if df_results.empty:
    raise RuntimeError(f"No test_metrics JSON found under {BASE_DIR}.")

# Convert metric columns to numeric
for col in ["mse_mean","mse_std","r2_mean","r2_std","cos_mean","cos_std"]:
    if col in df_results:
        df_results[col] = pd.to_numeric(df_results[col], errors="coerce")


# === 2) For each target modality, create a results table (MSE metric only) ===

all_targets = sorted(df_results["target"].unique())

for modality in all_targets:
    # 2a) Identify other modalities
    others = [m for m in all_targets if m != modality]

    # 2b) Build label lists, prefixing coherent/multi to distinguish (same logic as the plot)
    single_labels   = others.copy()
    coherent_labels = [f"coh_{'_'.join(combo)}" for r in [2, 3] for combo in combinations(others, r)]
    multi_labels    = [f"mult_{'_'.join(combo)}" for r in [1, 2, 3] for combo in combinations(others, r)]
    labels = single_labels + coherent_labels + multi_labels

    # 2c) Instead of creating plot arrays, create a list to hold table rows
    table_rows = []

    # 2d) Fill the table rows with Test R² values for MSE metric only
    for lbl in labels:
        is_highlighted = False
        note = ""

        if lbl in single_labels:
            sampling_type = "single"
            source        = lbl
        elif lbl in coherent_labels:
            sampling_type = "coherent"
            source        = lbl.replace("coh_", "", 1)
            # Check if this is the "all conditionings" case (equivalent to a highlighted bar)
            if len(source.split('_')) == len(others):
                is_highlighted = True
                note = "(All Sources)"
        else: # multi_labels
            sampling_type = "multi"
            source        = lbl.replace("mult_", "", 1)
            # Check if this is the "all conditionings" case
            if len(source.split('_')) == len(others):
                is_highlighted = True
                note = "(All Sources)"

        # Find the corresponding result row in the main DataFrame (for MSE metric only)
        df_row = df_results[
            (df_results["target"] == modality) &
            (df_results["sampling_type"] == sampling_type) &
            (df_results["source"] == source) &
            (df_results["metric_used"] == "mse")
        ]

        row_data = {
            "Source Combination": lbl,
            "R² Mean": np.nan,
            "R² Std": np.nan,
            "Note": note
        }

        if not df_row.empty:
            row_data["R² Mean"] = df_row.iloc[0]["r2_mean"]
            row_data["R² Std"]  = df_row.iloc[0]["r2_std"]
        
        table_rows.append(row_data)

    # 2e) Create and display the final table for the current modality
    df_table = pd.DataFrame(table_rows)

    # --- Displaying the output ---
    print("=" * 60)
    print(f"    Target = '{modality}': R² by Source")
    print("=" * 60)
    
    # Format the table for better readability
    # We round the numbers and replace NaN with a more descriptive string
    df_table['R² Mean'] = df_table['R² Mean'].round(4)
    df_table['R² Std'] = df_table['R² Std'].round(4)
    
    print(df_table.to_string(index=False, na_rep="N/A"))
    print("\n")