# Statistical Analysis of Latent Space and Segmentation Metrics

This notebook performs a full statistical analysis of the latent space and segmentation metrics produced by the model trained on "Normal" muscles.

It includes:

* Extraction of latent codes from log files
* PCA and MANOVA to evaluate group separation
* LDA visualization and Fisher scores
* Computation of segmentation metrics (DSC, ASD, HSD95, volume errors, etc.)
* Group comparison (healthy vs sarcopenia patients) using:
  * Shapiro–Wilk normality tests
  * Mann–Whitney U tests
  * Linear Mixed-Effects Models (LME)
* ROC curve analysis to study classification performance
* Threshold extraction for DSC to separate sarcopenia patients from healthy subjects
* Boxplots of metrics per fold


# ⚙️ How to use this notebook

1. Set the muscle you want to analyze, and specify the number of folds if doing cross-validation:
   - `muscle` corresponds to the folder name inside `casename_files/DIASEM/`.
   - `n_folds` can be set to 1 for a single dataset or higher for cross-validation.

2. Place your log files in the expected folder structure:
   - Logs should come directly from `eval.py` and be located in `./output/` (one log per test/train set).

3. Create subject list files:
   - One subject per line in text files.
   - Optional: assign groups manually, or the notebook will categorize automatically based on filename patterns.

4. Run all cells.

5. Outputs generated:
   - `fisher_scores_and_metrics_<muscle>.csv`
   - LDA plots
   - Boxplots for all metrics (PDF)
   - ROC curve figure
   - Computed DSC threshold for classifying sarcopenia


# Note for clinicians

This notebook provides a quantitative analysis of segmentation quality and group separation.

You can use it to:

* Visualize how well the model distinguishes healthy vs sarcopenic subjects
* Study segmentation accuracy for clinical quality assurance
* Extract thresholds (e.g., DSC) to support diagnosis

No coding knowledge is required — simply update the text files containing subject lists and run the notebook.

⚠️ Note: This analysis provides quantitative metrics and thresholds, but it should not be used as a standalone diagnostic tool.


In [1]:
# English comments inside code
import os
import re
import json
import warnings
from collections import Counter
from ast import literal_eval

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.metrics import roc_curve, auc, confusion_matrix, accuracy_score, recall_score, precision_score

from scipy.stats import shapiro, mannwhitneyu, combine_pvalues
from statsmodels.multivariate.manova import MANOVA
import statsmodels.api as sm
from statsmodels.formula.api import mixedlm

# Optional: adjust plotting style (clinician-friendly)
sns.set(style="whitegrid")
warnings.filterwarnings("ignore")


In [2]:
# ---------- CONFIG ----------
# Edit these variables for your run
muscle = 'RF'               # folder name under ../casename_files/DIASEM/
n_folds = 1                 # 1 = single fold (no CV), >1 = cross-validation
output_csv = f"fisher_scores_and_metrics_{muscle}.csv"
output_dir = f"plots_{muscle}"
os.makedirs(output_dir, exist_ok=True)

# Log file name patterns (adjust if your logs are named differently)
log_patterns = {
    'test': './DIASEM_{muscle}_set{fold}_NO_MISMATCH_registered_turned_1500_eval_ax2_x1_healthy/log.txt',
    'sarcopenia': './DIASEM_{muscle}_set{fold}_NO_MISMATCH_registered_turned_1500_eval_ax2_x1_sarcopenia/log.txt',
    'train': './DIASEM_{muscle}_set{fold}_NO_MISMATCH_registered_turned_1500_eval_ax2_x1_train/log.txt'
}
# Paths to subject lists
subject_lists_base = f'../casename_files/DIASEM/{muscle}'
path_sarcopenia = os.path.join(subject_lists_base, 'sarcopenia_subj_DIASEM.txt')
# train/test files are constructed per fold: train_cases_{fold}.txt, test_cases_{fold}.txt
# ----------------------------


In [3]:
# Utility functions: parsing, converting string tensors, categorization

def safe_readlines(path):
    """Return lines of file, or empty list if not found."""
    if not os.path.isfile(path):
        return []
    with open(path, 'r') as f:
        return f.readlines()

def load_subjects(file_path, group=None):
    """Load a subject text file (one per line). If group provided, return dataframe with Group."""
    if not os.path.isfile(file_path):
        return pd.DataFrame(columns=['Subject','Group']) if group else []
    with open(file_path, 'r') as f:
        subjects = [line.strip().lower() for line in f if line.strip()]
    if group:
        return pd.DataFrame({'Subject': subjects, 'Group': group})
    return subjects

def categorize_subjects(subjects):
    """Simple rule: 'sujet...' -> healthy_subjects, else patients_wo_sarcopenia."""
    categorized = []
    for subj in subjects:
        if subj.startswith('sujet'):
            categorized.append({'Subject': subj, 'Group': 'healthy_subjects'})
        else:
            categorized.append({'Subject': subj, 'Group': 'patients_wo_sarcopenia'})
    return pd.DataFrame(categorized)

def convert_tensor_string_to_array(tensor_str):
    """Parse verbose tensor string saved in logs to numpy array.
       Many formats are possible; try common ones, else return None."""
    if not isinstance(tensor_str, str) or len(tensor_str.strip()) == 0:
        return None
    # remove common prefixes
    s = re.sub(r'Parameter containing:', '', tensor_str).strip()
    s = s.replace("tensor(", "").split(", device=")[0]
    # ensure parentheses balanced/convertible to Python literal
    try:
        arr = literal_eval(s)
        return np.array(arr)
    except Exception:
        # fallback: extract floats from string
        nums = re.findall(r"[-+]?\d*\.\d+|\d+", s)
        if len(nums) == 0:
            return None
        return np.array([float(x) for x in nums])

def parse_log_file(path):
    """Parse a log file produced by eval.py and return DataFrame with metrics and z-latents per subject.
       The parser is tolerant: if patterns differ slightly, it tries best-effort extraction.
    """
    lines = safe_readlines(path)
    if not lines:
        return pd.DataFrame()
    parsed = []
    i = 0
    batch_cases = []
    z_latents = ""
    stats = {}
    # regex to capture metrics line: adapt to your exact print format if necessary
    stats_pattern = re.compile(r"ASD: *([-\d.]+).*?HSD: *([-\d.]+).*?HSD95: *([-\d.]+).*?DSC: *([-\d.]+).*?err_vol_cm_3: *([-\d.]+).*?err_vol_percent: *([-\d.]+)", re.IGNORECASE)
    while i < len(lines):
        line = lines[i]
        if "Batch cases:" in line:
            batch_cases = [x.strip().lower() for x in line.split("Batch cases:")[-1].strip().split(',') if x.strip()]
        elif "z:" in line and batch_cases:
            # collect subsequent lines containing tensor data (best effort)
            tensor_parts = [line.split('z:')[-1].strip()]
            i += 1
            while i < len(lines) and (lines[i].strip().startswith('tensor') or 'Parameter containing' in lines[i] or re.search(r'[\[\]\d\.\-]', lines[i])):
                tensor_parts.append(lines[i].strip())
                i += 1
            z_latents = ' '.join(tensor_parts)
            continue
        elif batch_cases and ("Batch ASD:" in line or "ASD:" in line):
            m = stats_pattern.search(line)
            if m:
                stats = {
                    'ASD': float(m.group(1)),
                    'HSD': float(m.group(2)),
                    'HSD95': float(m.group(3)),
                    'DSC': float(m.group(4)),
                    'Err_Vol_CM3': float(m.group(5)),
                    'Err_Vol_Percent': float(m.group(6))
                }
        if batch_cases and stats:
            for case in batch_cases:
                parsed.append((case.strip().lower(), z_latents, stats.get('ASD'), stats.get('HSD'),
                               stats.get('HSD95'), stats.get('DSC'), stats.get('Err_Vol_CM3'), stats.get('Err_Vol_Percent')))
            batch_cases, z_latents, stats = [], "", {}
        i += 1
    cols = ['Subject','z_latents_batch','ASD','HSD','HSD95','DSC','Err_Vol_CM3','Err_Vol_Percent']
    return pd.DataFrame(parsed, columns=cols)


In [6]:
def process_fold(fold, muscle, log_patterns, subject_lists_base):
    """
    Process a single fold: load subject lists, parse logs, compute PCA/LDA/Fisher/MANOVA,
    and return dicts/dataframes with results and metrics for boxplots/ROC.
    """
    print(f"Processing fold {fold} ...")
    # build paths
    path_train = os.path.join(subject_lists_base, f"train_cases_{fold}.txt")
    path_test = os.path.join(subject_lists_base, f"test_cases_{fold}.txt")
    path_sarc = os.path.join(subject_lists_base, 'sarcopenia_subj_DIASEM.txt')

    # load lists
    sarc_df = load_subjects(path_sarc, 'patient_with_sarcopenia')
    train_df = categorize_subjects(load_subjects(path_train))
    test_df = categorize_subjects(load_subjects(path_test))

    # parse logs
    logs = {}
    for key, pattern in log_patterns.items():
        log_path = pattern.format(muscle=muscle, fold=fold)
        logs[key] = parse_log_file(log_path)
        if not logs[key].empty:
            logs[key]['Subject'] = logs[key]['Subject'].str.strip("[]").str.replace("'", "")
            logs[key]['z_latents_batch'] = logs[key]['z_latents_batch'].apply(convert_tensor_string_to_array)

    # merge subject group info
    if not logs.get('test', pd.DataFrame()).empty:
        logs['test'] = logs['test'].merge(test_df, on='Subject', how='left').assign(Set='Test')
    if not logs.get('sarcopenia', pd.DataFrame()).empty:
        logs['sarcopenia'] = logs['sarcopenia'].merge(sarc_df, on='Subject', how='left').assign(Set='Test')
    if not logs.get('train', pd.DataFrame()).empty:
        logs['train'] = logs['train'].merge(train_df, on='Subject', how='left').assign(Set='Train')

    # combine for latent analysis
    combined_df = pd.concat([df for df in [logs.get('test', pd.DataFrame()), logs.get('sarcopenia', pd.DataFrame()), logs.get('train', pd.DataFrame())] if not df.empty], ignore_index=True)
    # ensure we have latent vectors
    if combined_df.empty:
        print(f"No data for fold {fold}. Skipping.")
        return None

    shapes = combined_df['z_latents_batch'].apply(lambda x: None if x is None else x.shape)
    if shapes.dropna().empty:
        print(f"No latent arrays parsed for fold {fold}. Skipping.")
        return None

    most_common_shape = Counter(shapes.dropna()).most_common(1)[0][0]
    df_filtered = combined_df[combined_df['z_latents_batch'].apply(lambda x: x is not None and x.shape == most_common_shape)]
    if df_filtered.empty:
        print(f"No filtered latent arrays for fold {fold}. Skipping.")
        return None

    # Build X and y for latent analysis
    X = np.stack(df_filtered['z_latents_batch'].values)
    if X.ndim == 3:
        X = X.reshape(X.shape[0], -1)
    y = df_filtered['Group'].values

    # PCA
    n_components = min(64, X.shape[1], X.shape[0])

    pca = PCA(n_components=n_components)
    X_pca = pca.fit_transform(X)
    pca_cols = [f"PC{i+1}" for i in range(X_pca.shape[1])]
    df_latents = pd.DataFrame(X_pca, columns=pca_cols)
    df_latents["Group"] = y

    # MANOVA (best-effort)
    try:
        formula = " + ".join(pca_cols) + " ~ Group"
        maov = MANOVA.from_formula(formula, data=df_latents)
        manova_res = maov.mv_test()
        p_value = manova_res.results['Group']['stat']['Pr > F']["Wilks' lambda"]
    except Exception as e:
        warnings.warn(f"MANOVA failed on fold {fold}: {e}")
        p_value = np.nan

    # LDA + Fisher
    try:
        lda = LDA(n_components=2)
        z_lda = lda.fit_transform(X, y)
        # compute multivariate fisher score on lda coordinates
        classes = np.unique(y)
        overall_mean = np.mean(z_lda, axis=0)
        S_B, S_W = np.zeros((2,2)), np.zeros((2,2))
        for cls in classes:
            X_c = z_lda[y == cls]
            if X_c.shape[0] <= 1:
                continue
            mean_c = np.mean(X_c, axis=0)
            n_c = X_c.shape[0]
            mean_diff = (mean_c - overall_mean).reshape(-1, 1)
            S_B += n_c * (mean_diff @ mean_diff.T)
            S_W += np.cov(X_c, rowvar=False) * (n_c - 1)
        fisher_score_multi = np.trace(S_B) / (np.trace(S_W) if np.trace(S_W) != 0 else np.nan)
    except Exception as e:
        warnings.warn(f"LDA/Fisher failed on fold {fold}: {e}")
        fisher_score_multi = np.nan
        z_lda = None

    # collect test metrics for boxplots/ROC
    df_metrics_list = []
    for name in ['test','sarcopenia']:
        if not logs.get(name, pd.DataFrame()).empty:
            dfm = logs[name].copy()
            # ensure DSC numeric
            for col in ['DSC','ASD','HSD','HSD95','Err_Vol_CM3','Err_Vol_Percent']:
                if col in dfm.columns:
                    dfm[col] = pd.to_numeric(dfm[col], errors='coerce')
            dfm['fold'] = fold
            dfm['source'] = f"Fold_{fold}_Test_Set_{'healthy' if name=='test' else 'sarcopenia'}"
            df_metrics_list.append(dfm)
    df_metrics = pd.concat(df_metrics_list, ignore_index=True) if df_metrics_list else pd.DataFrame()

    return {
        'fold': fold,
        'df_latents': df_latents,
        'p_value_manova': p_value,
        'fisher_score': fisher_score_multi,
        'z_lda': z_lda,
        'df_metrics': df_metrics,
        'df_filtered': df_filtered  # useful for plotting LDA points
    }


In [7]:
# Run processing for all folds
fold_results = []
df_metrics_all_list = []

for fold in range(1, n_folds + 1):
    res = process_fold(fold, muscle, log_patterns, subject_lists_base)
    if res:
        fold_results.append(res)
        if not res['df_metrics'].empty:
            df_metrics_all_list.append(res['df_metrics'])

# collect fisher + p-values summary
df_results = pd.DataFrame([{'fold': r['fold'], 'fisher_score': r['fisher_score'], 'p_value': r['p_value_manova']} for r in fold_results])
if df_metrics_all_list:
    df_combined = pd.concat(df_metrics_all_list, ignore_index=True)
else:
    df_combined = pd.DataFrame()

print("Processing done.")
display(df_results)


Processing fold 1 ...
Processing done.


Unnamed: 0,fold,fisher_score,p_value
0,1,0.478506,
