# Configs

In [23]:
config_checked = False

if not config_checked:
  raise RuntimeError("Configuration check failed! Please check the config and set config_checked to True before proceeding.")

RETRAIN = False
number_of_run = 5
DATA_DIRECTORY = 'ptbxl_data'
MODEL_SAVE_DIRECTORY = 'saved_models'
BATCH_SIZE = 32
NUM_EPOCHS = 30
LEARNING_RATE = 0.001
NUM_WORKERS = 2 # Adjust based on your system
PIN_MEMORY = True
DA = False
if RETRAIN:
  assert number_of_run < 6



# Download prerequisite libabaries

In [24]:
!wget https://github.com/thisistayeb/ECG-PTB-XL/blob/main/requirments.txt
!pip install -r ./requirments.txt -q

# Download Pre-trained wights

In [52]:
!rm -rf saved_models
!wget https://github.com/thisistayeb/ECG-PTB-XL/blob/main/saved_models.zip
!unzip -q saved_models.zip -d .
!rm saved_models.zip

In [25]:
import gdown
import os
import requests
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import wfdb
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import f1_score, hamming_loss, classification_report, precision_recall_fscore_support, multilabel_confusion_matrix
import wfdb
import glob
import pickle
from sklearn.metrics import f1_score
from tqdm.notebook import tqdm
from scipy import signal as sp_signal


# Download PTB-XL

In [26]:
gdrive_url = 'https://drive.google.com/file/d/1Rmjr43WqzOYv0EsWduHIJipLh61iQLPa/view?usp=drive_link'
fallback_url = 'https://physionet.org/static/published-projects/ptb-xl/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3.zip'
output_filename = 'ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3.zip' # Using specific name

if os.path.exists(output_filename):
    print(f"File '{output_filename}' already exists. Skipping download.")
else:
    try:
        print(f"Attempting download using gdown from: {gdrive_url}")
        gdown.download(url=gdrive_url, output=output_filename, quiet=False, fuzzy=True)
        print(f"\nFile downloaded successfully via gdown and saved as: {output_filename}")

    except Exception as gdown_exception:
        print(f"\nGoogle Drive download failed (link might be broken or inaccessible):")
        # --- Fallback to Direct Link Download ---
        print("\nAttempting fallback download from PhysioNet Server:")
        print(f"Direct link: {fallback_url}")

        try:
            response = requests.get(fallback_url, stream=True, timeout=60) # timeout in seconds
            response.raise_for_status() # Raises HTTPError for bad responses (4xx or 5xx)

            # Get total file size from headers for progress bar
            total_size = int(response.headers.get('content-length', 0))

            print(f"Saving to: {output_filename}")
            # Write content to file in chunks with progress bar
            with open(output_filename, 'wb') as f, tqdm(
                desc=output_filename,
                total=total_size,
                unit='B',
                unit_scale=True,
                unit_divisor=1024,
            ) as bar:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)
                    bar.update(len(chunk))

            print(f"\nFile '{output_filename}' downloaded successfully via fallback link.")

        except requests.exceptions.RequestException as fallback_request_exception:
            # This block runs if the requests download (fallback) fails
            print(f"\nFallback download using direct link also failed:")
            print(fallback_request_exception)
            print("\nPlease check the fallback URL and your internet connection.")
            # Clean up potentially partially downloaded file from the failed fallback attempt
            if os.path.exists(output_filename):
                 try:
                     os.remove(output_filename)
                     print(f"Removed potentially partially downloaded file: {output_filename}")
                 except OSError as remove_err:
                     print(f"Error removing partial file {output_filename}: {remove_err}")
        except Exception as general_exception:
            # Catch any other unexpected errors during fallback
            print(f"\nAn unexpected error occurred during fallback download: {general_exception}")
            # Clean up
            if os.path.exists(output_filename):
                 try:
                     os.remove(output_filename)
                     print(f"Removed potentially partially downloaded file: {output_filename}")
                 except OSError as remove_err:
                     print(f"Error removing partial file {output_filename}: {remove_err}")


# Final check to confirm if the file exists after all attempts
if os.path.exists(output_filename):
    print(f"\nFinal Status: File '{output_filename}' is available.")
else:
    print(f"\nFinal Status: Failed to download '{output_filename}' from either source.")

Attempting download using gdown from: https://drive.google.com/file/d/1Rmjr43WqzOYv0EsWduHIJipLh61iQLPa/view?usp=drive_link


Downloading...
From (original): https://drive.google.com/uc?id=1Rmjr43WqzOYv0EsWduHIJipLh61iQLPa
From (redirected): https://drive.google.com/uc?id=1Rmjr43WqzOYv0EsWduHIJipLh61iQLPa&confirm=t&uuid=f77ac624-183c-49a2-bd4c-d3a5155dffb9
To: /content/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3.zip
100%|██████████| 1.84G/1.84G [00:42<00:00, 42.9MB/s]


File downloaded successfully via gdown and saved as: ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3.zip

Final Status: File 'ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3.zip' is available.





In [27]:
!unzip -q ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3.zip -d .
!mv  ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3 ptbxl_data
!rm ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3.zip

# Generating Seeds

In [28]:
import numpy as np
import torch
import random

def set_seed(seed=None, seed_torch=True):

    """
    Function that controls randomness. NumPy and random modules must be imported.

    Args:
    seed : Integer
    A non-negative integer that defines the random state. Default is `None`.
    seed_torch : Boolean
    If `True` sets the random seed for pytorch tensors, so pytorch module
    must be imported. Default is `True`.

    Returns:
    Nothing.
    """
    if seed is None:

        seed = np.random.choice(2 ** 32)
        random.seed(seed)
    np.random.seed(seed)
    if seed_torch:
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    print(f'Random seed {seed} has been set.')


def generate_sub_seeds(initial_seed, num_seeds=5):
    """
    Generates a specified number of distinct pseudo-random integers
    based on an initial seed, suitable for use as further seeds.

    Using NumPy's default_rng ensures that the generation is isolated
    and repeatable given the same initial_seed.

    Args:
        initial_seed (int): The non-negative integer seed to initialize
                            the random number generator.
        num_seeds (int): The number of sub-seeds to generate. Default is 5.

    Returns:
        numpy.ndarray: An array containing 'num_seeds' integers derived
                       from the initial seed. These are typically within
                       the range [0, 2**32 - 1]. Collisions (non-unique
                       numbers) are extremely unlikely with this range
                       and small num_seeds.
    """

    rng = np.random.default_rng(initial_seed)
    max_val = 2**10
    sub_seeds = rng.integers(low=0, high=max_val, size=num_seeds, dtype=np.uint32)
    assert len(sub_seeds) == num_seeds
    return sub_seeds

# Define your main starting seed
master_seed = 2025

# Generate 5 sub-seeds based on the master seed
all_seeds = generate_sub_seeds(master_seed, num_seeds=number_of_run)
print(f"Initial Master Seed: {master_seed}")
print(f"Generated 5 Sub-Seeds: {all_seeds}")

Initial Master Seed: 2025
Generated 5 Sub-Seeds: [ 458 1018 1016  391  976]


#loading Data and Define the model

In [29]:
def get_diagnostic_classes(scp_codes_dict_str, scp_map, valid_classes):
    """ Safely parses SCP codes string and maps them to diagnostic classes. """
    try:
        scp_codes_dict = eval(scp_codes_dict_str)
        present_classes = set()
        for code, _ in scp_codes_dict.items():
            diag_class = scp_map.get(code) # Use the map derived from scp_statements
            if diag_class is not None and diag_class in valid_classes:
                present_classes.add(diag_class)
        return list(present_classes)
    except Exception as e:
        print(f"Warning: Error processing scp_codes '{scp_codes_dict_str}': {e}")
        return []

class ECGMultiLabelDataset(Dataset):
    def __init__(self, df, data_dir, mlb_instance, signal_length=1000,
                 filter_lowcut=0.5, filter_highcut=45.0, filter_order=4):
        self.df = df.copy()
        self.data_dir = data_dir
        self.mlb = mlb_instance
        self.signal_length = signal_length

        # --- Filter Parameters ---
        self.filter_lowcut = filter_lowcut
        self.filter_highcut = filter_highcut
        self.filter_order = filter_order
        self.filter_coeffs = {}

        # Ensure 'label_vector' exists and extract labels
        if 'label_vector' not in self.df.columns:
             raise ValueError("'label_vector' column not found in DataFrame for Dataset.")
        self.labels = np.array(self.df['label_vector'].tolist(), dtype=np.float32)

        # Ensure 'filename_lr' or 'filename_hr' exists
        if 'filename_lr' in self.df.columns:
            self.filepaths = self.df['filename_lr'].values
        elif 'filename_hr' in self.df.columns:
             print("Using high-resolution filenames ('filename_hr')")
             self.filepaths = self.df['filename_hr'].values
        else:
             raise ValueError("Filename column ('filename_lr' or 'filename_hr') not found.")

        # Verify number of labels matches number of samples
        if len(self.labels) != len(self.filepaths):
             raise ValueError(f"Mismatch between number of labels ({len(self.labels)}) and filepaths ({len(self.filepaths)}).")

    def __len__(self):
        return len(self.df)

    def _get_filter_coeffs(self, fs):
        """ Get or compute Butterworth filter coefficients for a given sampling frequency. """
        fs = int(fs)
        if fs not in self.filter_coeffs:
            nyquist = 0.5 * fs
            low = self.filter_lowcut / nyquist
            high = self.filter_highcut / nyquist
            # Ensure frequency bounds are valid
            low = max(low, 1e-6) # Avoid zero frequency
            high = min(high, 1.0 - 1e-6) # Avoid Nyquist exactly
            if low >= high:
                 print(f"Warning: Filter lowcut ({self.filter_lowcut} Hz) >= highcut ({self.filter_highcut} Hz) for fs={fs}. Skipping filter.")
                 self.filter_coeffs[fs] = (None, None) # Store None to indicate skipping
            else:
                try:
                    b, a = sp_signal.butter(self.filter_order, [low, high], btype='bandpass')
                    self.filter_coeffs[fs] = (b, a)
                except ValueError as e:
                    print(f"Error creating Butterworth filter for fs={fs}, Wn=[{low}, {high}]: {e}. Skipping filter.")
                    self.filter_coeffs[fs] = (None, None)
        return self.filter_coeffs[fs]


    def __getitem__(self, idx):
        record_filename = self.filepaths[idx]
        record_path = os.path.join(self.data_dir, record_filename)
        label = self.labels[idx]

        try:
            record = wfdb.rdrecord(os.path.splitext(record_path)[0])
            signal = record.p_signal.T
            fs = record.fs
            # --- Preprocessing Step 1: Butterworth Bandpass Filter ---
            b, a = self._get_filter_coeffs(fs)
            if b is not None and a is not None:
                # Apply zero-phase filter (filtfilt) to each lead (axis=1)
                # Handle potential issues with constant signals if necessary
                try:
                    signal_filtered = sp_signal.filtfilt(b, a, signal, axis=1)
                except ValueError as e:
                    print(f"Warning: filtfilt error on {record_filename} (perhaps constant lead?): {e}. Using unfiltered signal.")
                    signal_filtered = signal # Fallback to original signal
            else:
                # Filter coefficients couldn't be generated (e.g., invalid range)
                signal_filtered = signal # Use original signal if filter failed

            # --- Preprocessing Step 2: Z-score Normalization (per lead) ---
            # Apply to the filtered signal (or original if filtering failed)
            mean = np.mean(signal_filtered, axis=1, keepdims=True)
            std = np.std(signal_filtered, axis=1, keepdims=True)
            # Add epsilon to std to prevent division by zero for flat signals
            signal_normalized = (signal_filtered - mean) / (std + 1e-8)

            # --- Preprocessing Step 3: Ensure Correct Length (Pad or Truncate) ---
            # Apply padding/truncating AFTER filtering and normalization
            current_length = signal_normalized.shape[1]
            if current_length < self.signal_length:
                padding = self.signal_length - current_length
                # Pad with zeros (or another value like mean/edge if preferred)
                signal_final = np.pad(signal_normalized, ((0, 0), (0, padding)), 'constant', constant_values=0)
            elif current_length > self.signal_length:
                # Truncate from the end
                signal_final = signal_normalized[:, :self.signal_length]
            else:
                signal_final = signal_normalized

            signal_tensor = torch.tensor(signal_final, dtype=torch.float32)

        except FileNotFoundError:
             print(f"Error: Record file not found at {record_path}. Returning zeros.")
             # Assuming 12 leads
             signal_tensor = torch.zeros((12, self.signal_length), dtype=torch.float32)
             # label remains as loaded, or could be set to zeros if preferred for errors
        except Exception as e:
             print(f"Error loading or processing record: {record_path} - {e}")
             signal_tensor = torch.zeros((12, self.signal_length), dtype=torch.float32)
             # label remains as loaded

        label_tensor = torch.tensor(label, dtype=torch.float32) # Ensure label is tensor too

        return signal_tensor, label_tensor


class ECGCNN(nn.Module):
    def __init__(self, num_classes, input_channels=12):
        super(ECGCNN, self).__init__()
        self.conv1 = nn.Conv1d(input_channels, 32, kernel_size=5, stride=1, padding=2)
        self.bn1 = nn.BatchNorm1d(32)
        self.conv2 = nn.Conv1d(32, 64, kernel_size=5, stride=1, padding=2)
        self.bn2 = nn.BatchNorm1d(64)
        self.conv3 = nn.Conv1d(64, 128, kernel_size=5, stride=1, padding=2)
        self.bn3 = nn.BatchNorm1d(128)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(128, num_classes) #
        self.dropout = nn.Dropout(0.3)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.pool(x).squeeze(-1)
        x = self.dropout(x)
        x = self.fc(x)
        return x


In [30]:
def load_and_prepare_data(data_dir, batch_size=32, num_workers=2, pin_memory=True):
    """
    Loads PTB-XL metadata, performs preprocessing (age bins, multi-label encoding),
    splits data, and creates PyTorch DataLoaders.

    Args:
        data_dir (str): Path to the directory containing PTB-XL files
                        (ptbxl_database.csv, scp_statements.csv, and record data).
        batch_size (int): Batch size for DataLoaders.
        num_workers (int): Number of subprocesses for data loading.
        pin_memory (bool): If True, copies Tensors into CUDA pinned memory.

    Returns:
        tuple: Contains (train_loader, val_loader, test_loader, mlb, num_classes, target_classes)
               Returns None for loaders if data loading fails at any stage.
    """
    # --- Load Metadata ---
    database_path = os.path.join(data_dir, 'ptbxl_database.csv')
    scp_path = os.path.join(data_dir, 'scp_statements.csv')

    if not os.path.exists(database_path):
        print(f"Error: Database file not found at {database_path}")
        return None, None, None, None, None, None
    if not os.path.exists(scp_path):
        print(f"Error: SCP statements file not found at {scp_path}")
        return None, None, None, None, None, None

    try:
        ptbxl_df = pd.read_csv(database_path, index_col='ecg_id')
        # Load scp_statements, ensure the index is the SCP code string (important!)
        # If index is integer, convert it or handle mapping carefully. Assume it's the code string.
        scp_df = pd.read_csv(scp_path, index_col=0) # Assuming first col is SCP code string
    except Exception as e:
        print(f"Error reading CSV files: {e}")
        return None, None, None, None, None, None

    print("Metadata loaded.")

    # --- Add Age Bin ---
    bins = [0, 10, 20, 30, 40, 50, 60, 70, 80,90, np.inf]
    labels = ['0s', '10s', '20s', '30s', '40s', '50s', '60s', '70s', '80s', '90']
    if 'age' in ptbxl_df.columns:
        ptbxl_df['age'] = pd.to_numeric(ptbxl_df['age'], errors='coerce')
        ptbxl_df['age_bin'] = pd.cut(ptbxl_df['age'], bins=bins, labels=labels, right=False)
        print("Added 'age_bin' column.")
    else:
        print("Warning: 'age' column not found. Could not create 'age_bin'.")

    # --- Multi-Label Preprocessing ---
    # Check if 'scp_codes' column exists
    if 'scp_codes' not in ptbxl_df.columns:
        print("Error: 'scp_codes' column is required in ptbxl_database.csv for label processing.")
        return None, None, None, None, None, None

    # Determine target classes and mapping from SCP statements
    if 'diagnostic_class' in scp_df.columns:
        scp_df_filtered = scp_df.dropna(subset=['diagnostic_class'])
        # Ensure index is treated as string if needed for matching eval output
        scp_df_filtered.index = scp_df_filtered.index.astype(str)
        scp_to_class_map = scp_df_filtered['diagnostic_class'].to_dict()
        target_classes = sorted(scp_df_filtered['diagnostic_class'].unique().tolist())
        print(f"Using diagnostic superclasses: {target_classes}")
    else:
        # Fallback or error if 'diagnostic_class' is missing
        print("Warning: 'diagnostic_class' column not found in scp_statements.csv.")
        # Implement alternative logic if needed, e.g., using 'diagnostic' column
        # For now, raise an error if the primary column is missing.
        print("Error: Cannot determine target classes without 'diagnostic_class' column.")
        return None, None, None, None, None, None

    # Apply function to get labels for each ECG
    ptbxl_df['diagnostic_classes_list'] = ptbxl_df['scp_codes'].apply(
        lambda x: get_diagnostic_classes(x, scp_to_class_map, target_classes)
    )

    # Use MultiLabelBinarizer
    mlb = MultiLabelBinarizer(classes=target_classes)
    multi_hot_labels = mlb.fit_transform(ptbxl_df['diagnostic_classes_list'])
    num_classes = len(mlb.classes_)

    # Store multi-hot vector efficiently
    ptbxl_df['label_vector'] = [row for row in multi_hot_labels]
    print(f"Processed multi-label vectors for {num_classes} classes.")

    # --- Split dataset ---
    if 'strat_fold' not in ptbxl_df.columns:
        print("Error: 'strat_fold' column required for splitting.")
        return None, None, None, None, None, None

    train_df = ptbxl_df[ptbxl_df.strat_fold <= 8].copy()
    val_df = ptbxl_df[ptbxl_df.strat_fold == 9].copy()
    test_df = ptbxl_df[ptbxl_df.strat_fold == 10].copy()

    if len(train_df) == 0 or len(val_df) == 0 or len(test_df) == 0:
        print("Error: Data split resulted in empty training, validation, or test set.")
        return None, None, None, None, None, None

    print(f"Data split: Train={len(train_df)}, Val={len(val_df)}, Test={len(test_df)}")

    # --- Create Datasets and DataLoaders ---
    # Pass the fitted MultiLabelBinarizer instance to the dataset
    train_dataset = ECGMultiLabelDataset(train_df, data_dir, mlb)
    val_dataset = ECGMultiLabelDataset(val_df, data_dir, mlb)
    test_dataset = ECGMultiLabelDataset(test_df, data_dir, mlb)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

    print("--- DataLoaders created successfully ---")

    return train_loader, val_loader, test_loader, val_df, test_df, mlb, num_classes, target_classes


In [31]:
def train_ecg_model(train_loader, val_loader, num_classes, seed, device,
                    num_epochs=30, learning_rate=0.001, DA=False,
                    model_save_dir='saved_models'):
    """
    Trains the ECG CNN model, saves the BEST model based on validation loss
    (overwriting previous best for this run), and saves the FINAL model state.

    Args:
        train_loader (DataLoader): DataLoader for the training set.
        val_loader (DataLoader): DataLoader for the validation set.
        num_classes (int): Number of output classes for the model.
        seed (int): Random seed for reproducibility.
        device (torch.device): Device to train on ('cuda' or 'cpu').
        num_epochs (int): Number of training epochs.
        learning_rate (float): Learning rate for the optimizer.
        DA (bool): If True, data augmentation logic is placeholder-active
                     and '_DA' is included in saved model filenames.
        model_save_dir (str): Directory to save the trained model weights.

    Returns:
        tuple: (best_model_path, final_model_path)
               Paths to the saved best and final model weights for this run.
               Returns (None, None) if training failed or no epochs were run.
    """
    print(f"\n--- Starting Training ---")
    print(f"Seed: {seed}, Data Augmentation (DA) Flag: {DA}, Epochs: {num_epochs}, LR: {learning_rate}")
    print(f"Models will be saved in: {model_save_dir}")

    # --- Setup ---
    try:
        set_seed(seed)
    except NameError:
        print("Warning: 'set_seed' function not found. Reproducibility may not be guaranteed.")

    os.makedirs(model_save_dir, exist_ok=True) # Ensure save directory exists

    try:
        model = ECGCNN(num_classes).to(device) # Ensure ECGCNN is defined
    except NameError:
        print("ERROR: 'ECGCNN' class not found. Cannot initialize model.")
        return None, None # Cannot proceed

    criterion = nn.BCEWithLogitsLoss() # Suitable for multi-label
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    best_val_loss = float('inf')
    best_model_path = None # Path to the best model *for this specific run*
    final_model_path = None # Path for the model after the last epoch

    # --- Define Consistent Base Name ---
    # This base name identifies the specific run (seed + DA status)
    base_name_prefix = f"model_seed_{seed}{'_DA' if DA else ''}"

    # --- Training Loop ---
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for i, (signals, labels) in enumerate(train_loader):
            signals, labels = signals.to(device), labels.to(device)

            # --- Data Augmentation Placeholder ---
            if DA:
                # signals = apply_my_augmentations(signals) # Your augmentation here
                pass # No augmentation implemented in this example

            # Forward pass, loss calculation, backward pass, optimize
            outputs = model(signals)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)

        # --- Validation Loop ---
        model.eval()
        val_loss = 0.0
        val_preds_list, val_labels_list = [], []
        with torch.no_grad():
            for signals, labels in val_loader:
                signals, labels = signals.to(device), labels.to(device)
                outputs = model(signals)
                v_loss = criterion(outputs, labels)
                val_loss += v_loss.item()
                preds = (torch.sigmoid(outputs) > 0.5).float()
                val_preds_list.append(preds.cpu().numpy())
                val_labels_list.append(labels.cpu().numpy())

        avg_val_loss = val_loss / len(val_loader)
        val_preds = np.concatenate(val_preds_list, axis=0)
        val_labels = np.concatenate(val_labels_list, axis=0)
        val_hamming = hamming_loss(val_labels, val_preds)

        print(f"Epoch [{epoch+1}/{num_epochs}] | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Hamming: {val_hamming:.4f}")

        # --- Save Best Model (based on validation loss) ---
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            # Define the specific filename for the BEST model of this run
            best_model_filename = f"best_{base_name_prefix}.pth"
            current_best_path = os.path.join(model_save_dir, best_model_filename)

            # Remove the previous best model *for this run* if it exists
            # This check prevents accidentally deleting the current best if filenames somehow clash
            # and avoids errors if it's the first time saving.
            if best_model_path and os.path.exists(best_model_path) and best_model_path != current_best_path:
                try:
                    os.remove(best_model_path)
                    # print(f"  Removed previous best model: {os.path.basename(best_model_path)}")
                except OSError as e:
                    print(f"  Warning: Error removing previous best model {best_model_path}: {e}")

            # Save the new best model
            torch.save(model.state_dict(), current_best_path)
            best_model_path = current_best_path # Update the path variable
            print(f"  * New best val loss. Saved best model to: {os.path.basename(best_model_path)}")

    # --- End of Training Loop ---

    # --- Save Final Model ---
    if num_epochs > 0: # Ensure training actually ran
        # Define the specific filename for the FINAL model of this run
        final_model_filename = f"final_{base_name_prefix}_epoch_{num_epochs}.pth"
        final_model_path = os.path.join(model_save_dir, final_model_filename)
        torch.save(model.state_dict(), final_model_path)

        print(f"\n--- Training Complete for Seed {seed} ---")
        if best_model_path:
            print(f"Best model saved to : {os.path.basename(best_model_path)} (Val Loss: {best_val_loss:.4f})")
        else:
            print("Warning: No improvement in validation loss observed during training. 'Best' model not saved.")
        print(f"Final model saved to: {os.path.basename(final_model_path)}")
        return best_model_path, final_model_path # Return paths to both
    else:
        print("--- Training Warning ---")
        print("num_epochs was 0. No training performed, no models saved.")
        return None, None

# Train Loop

In [53]:
model_class_factory = ECGCNN
actual_num_classes = 5

# 3. Create the model config dictionary
model_config = {
    'num_classes': actual_num_classes
}




# --- Step 1: Load and Prepare Data ---
print("--- Loading and Preparing Data ---")
train_loader, val_loader, test_loader, val_df, test_df, mlb, num_classes, target_classes = \
    load_and_prepare_data(
        data_dir=DATA_DIRECTORY,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY
    )

# Check if data loading was successful (including val_df now)
if train_loader is None or val_loader is None or test_loader is None or val_df is None or test_df is None or mlb is None:
    print("Failed to load data completely (check loaders, val_df, test_df, mlb). Exiting.")
    exit()

print(f"Number of classes detected: {num_classes}")
print(f"Class names: {mlb.classes_}") # Access classes via mlb

if RETRAIN:
  # --- Step 2: Training Loop ---
  # Dictionary to store paths of the FINAL models after training
  final_model_paths = {} # Renamed for clarity

  print("\n--- Starting Training Phase ---")
  for current_seed in all_seeds:
      # Determine device for this training run
      DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
      print(f"\n-- Training SEED={current_seed} on device={DEVICE} | DA={DA} --")
      if not torch.cuda.is_available():
          print(f"Warning: CUDA not available, using CPU.")

      # Call the modified train_ecg_model function
      # MODIFICATION 2: Unpack BOTH returned paths
      best_path, final_path = train_ecg_model(
          train_loader=train_loader,
          val_loader=val_loader,
          num_classes=num_classes,
          seed=current_seed,
          device=DEVICE,
          num_epochs=NUM_EPOCHS,
          learning_rate=LEARNING_RATE,
          DA=DA,
          model_save_dir=MODEL_SAVE_DIRECTORY
      )

      # Store the path to the FINAL model (or BEST, or both, depending on needs)
      # We'll store the final path here for the summary printout.
      run_key = f'seed_{current_seed}_DA_{DA}' # Include DA status in key
      if final_path:
          final_model_paths[run_key] = final_path
          print(f"Stored final model path for {run_key}")
      else:
          print(f"Training might have failed for {run_key}, no final path returned.")

  # Use the correct dictionary name `final_model_paths`
  print(f"\n--- All DA={DA} Training Runs Complete ---")
  print("Summary of FINAL model paths saved:")
  if final_model_paths:
      for key, path in final_model_paths.items():
          # Use os.path.basename for cleaner output if path is not None
          print(f"- {key}: {os.path.basename(path) if path else 'Path not saved/returned'}")
  else:
      print("No final model paths were recorded.")

  print("\nTraining script finished.")
else:
  print("\nLoad pre-trained weights")

--- Loading and Preparing Data ---
Metadata loaded.
Added 'age_bin' column.
Using diagnostic superclasses: ['CD', 'HYP', 'MI', 'NORM', 'STTC']
Processed multi-label vectors for 5 classes.
Data split: Train=17418, Val=2183, Test=2198
--- DataLoaders created successfully ---
Number of classes detected: 5
Class names: ['CD' 'HYP' 'MI' 'NORM' 'STTC']

Load pre-trained weights


# Global optimal threshold for each class to maximize the F1 on Validation

In [54]:
import numpy as np
import torch
from sklearn.metrics import f1_score
import os

def find_optimal_thresholds_val(model_path, val_loader, num_classes, device, target_classes=None):
    """
    Finds the optimal probability threshold for each class independently
    by maximizing the F1-score on the validation set.

    Args:
        model_path (str): Path to the trained model state dictionary (.pth file).
        val_loader (DataLoader): DataLoader for the validation set.
        num_classes (int): Number of classes.
        device (torch.device): Device to run inference on ('cuda' or 'cpu').
        target_classes (list, optional): List of class names for printing. Defaults to None.

    Returns:
        np.ndarray: An array of shape (num_classes,) containing the optimal
                    threshold for each class. Returns array of 0.5 if error occurs.
    """
    print(f"\n--- Finding Optimal Thresholds using Validation Set ---")
    print(f"Loading model: {os.path.basename(model_path)}")

    # --- Load Model ---
    try:
        model = ECGCNN(num_classes=num_classes) # Ensure ECGCNN class is accessible
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.to(device)
        model.eval() # Set model to evaluation mode
        print("Model loaded successfully.")
    except NameError:
        print("ERROR: 'ECGCNN' class not defined. Cannot load model.")
        return np.full(num_classes, 0.5) # Return default thresholds
    except FileNotFoundError:
         print(f"ERROR: Model file not found at {model_path}")
         return np.full(num_classes, 0.5)
    except Exception as e:
        print(f"Error loading model state_dict: {e}")
        return np.full(num_classes, 0.5)

    # --- Get Probabilities and Labels from Validation Set ---
    all_val_probs_list = []
    all_val_labels_list = []
    print("Running inference on validation set to get probabilities...")
    with torch.no_grad(): # Disable gradient calculations
        for i, (signals, labels) in enumerate(val_loader):
            signals = signals.to(device)
            outputs = model(signals)
            # Apply sigmoid to get probabilities
            probs = torch.sigmoid(outputs)
            all_val_probs_list.append(probs.cpu().numpy())
            all_val_labels_list.append(labels.cpu().numpy())



    # Concatenate results from all batches
    try:
        y_true_val = np.concatenate(all_val_labels_list, axis=0)
        y_prob_val = np.concatenate(all_val_probs_list, axis=0)
        print(f"Inference complete. Found {y_true_val.shape[0]} validation samples.")
    except ValueError:
        print("Error: Validation loader might be empty or data format issue.")
        return np.full(num_classes, 0.5)

    # --- Find Best Threshold per Class ---
    optimal_thresholds = np.zeros(num_classes)
    print(f"Optimizing thresholds for {num_classes} classes...")

    for i in range(num_classes):
        best_threshold_class = 0.5 # Default threshold
        best_f1_class = -1.0       # Initialize with a value lower than any possible F1

        # Define the range of thresholds to test
        threshold_candidates = np.linspace(0.01, 0.99, 99) # Test 99 thresholds from 0.01 to 0.99

        # Get true labels and predicted probabilities for the current class
        true_labels_class = y_true_val[:, i]
        probs_class = y_prob_val[:, i]

        for thr in threshold_candidates:
            # Apply the candidate threshold
            pred_labels_class = (probs_class >= thr).astype(int)

            # Calculate F1 score for this class using this threshold
            # zero_division=0 handles cases where precision and recall are both 0
            f1 = f1_score(true_labels_class, pred_labels_class, zero_division=0)

            # Update if this threshold gives a better F1 score
            if f1 > best_f1_class:
                best_f1_class = f1
                best_threshold_class = thr

        # Store the best threshold found for this class
        optimal_thresholds[i] = best_threshold_class

        class_name = target_classes[i] if target_classes and i < len(target_classes) else f"Class {i}"
        if best_f1_class >= 0: # Check if F1 score was calculable
            print(f"  {class_name:<10}: Best Threshold = {best_threshold_class:.3f} (Validation F1 = {best_f1_class:.4f})")
        else:
             # This can happen if a class has no positive examples in validation set
             print(f"  {class_name:<10}: Could not optimize (F1 score remained {best_f1_class:.1f}). Using default threshold {optimal_thresholds[i]:.1f}")


    print("--- Optimal Threshold finding complete ---")
    return optimal_thresholds

# Optimize threshold for each seed

In [55]:
optimized_thresholds_per_seed = {}
OPTIMIZATION_RESULTS_DIR = 'optimization_results'
os.makedirs(OPTIMIZATION_RESULTS_DIR, exist_ok=True)
optimized_thresholds_filename = os.path.join(OPTIMIZATION_RESULTS_DIR, 'optimized_global_thresholds_per_seed.pkl')
DEVICE_OPT = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device for threshold optimization: {DEVICE_OPT}")

# Check if validation loader exists
if 'val_loader' not in locals() or val_loader is None:
     print("ERROR: 'val_loader' not found or is None. Cannot optimize thresholds.")
     # Decide how to handle: exit() or use default 0.5? Using default for now.
     for seed in all_seeds:
         optimized_thresholds_per_seed[seed] = np.full(num_classes, 0.5)
     print("WARNING: Using default threshold 0.5 for all seeds due to missing val_loader.")
else:
    # Iterate through seeds to find or compute thresholds
    seeds_to_compute = []
    for current_seed in all_seeds:
         if current_seed not in optimized_thresholds_per_seed:
             seeds_to_compute.append(current_seed)
         # Check model existence for seeds needing computation (important!)
         base_name_prefix = f"model_seed_{current_seed}{'_DA' if DA else ''}" # Assuming DA=False for this path
         final_model_filename = f"final_{base_name_prefix}_epoch_{NUM_EPOCHS}.pth"
         model_path = os.path.join(MODEL_SAVE_DIRECTORY, final_model_filename)
         if not os.path.exists(model_path) and current_seed in seeds_to_compute:
              print(f"WARNING: Model file not found for seed {current_seed}: {model_path}. Cannot optimize. Using default 0.5.")
              optimized_thresholds_per_seed[current_seed] = np.full(num_classes, 0.5)
              seeds_to_compute.remove(current_seed) # Remove from computation list

    if seeds_to_compute:
        print(f"Computing optimal thresholds for seeds: {seeds_to_compute}")
        for current_seed in seeds_to_compute:
            print(f"\n-- Optimizing for SEED={current_seed} --")
            # Construct model path (assuming DA=False for this example)
            current_da_flag_opt = False # Explicitly set DA flag assumption here
            base_name_prefix = f"model_seed_{current_seed}{'_DA' if current_da_flag_opt else ''}"
            final_model_filename = f"final_{base_name_prefix}_epoch_{NUM_EPOCHS}.pth"
            model_path = os.path.join(MODEL_SAVE_DIRECTORY, final_model_filename)

            # Call the optimization function
            optimal_thresholds = find_optimal_thresholds_val(
                model_path=model_path,
                val_loader=val_loader,
                num_classes=num_classes,
                device=DEVICE_OPT,
                target_classes=target_classes # Pass class names for printout
            )
            optimized_thresholds_per_seed[current_seed] = optimal_thresholds
            print(f"Optimization complete for seed {current_seed}.")
    else:
         print("All required thresholds were loaded or previously computed.")


# Optional: Save the computed thresholds
if seeds_to_compute: # Save only if new values were computed
     print(f"Saving computed thresholds to: {optimized_thresholds_filename}")
     try:
         with open(optimized_thresholds_filename, 'wb') as f:
             pickle.dump(optimized_thresholds_per_seed, f)
         print("Optimized thresholds saved.")
     except Exception as e:
         print(f"Error saving optimized thresholds: {e}")

# --- Verify thresholds are available for all seeds needed for evaluation ---
missing_thresholds = False
for seed in all_seeds:
    if seed not in optimized_thresholds_per_seed:
        print(f"ERROR: Thresholds for seed {seed} are missing after optimization/loading phase.")
        missing_thresholds = True
if missing_thresholds:
    print("Exiting due to missing optimized thresholds.")
    exit()

print(f"\n{'='*20} Threshold Optimization Phase Complete {'='*20}")



Using device for threshold optimization: cuda
Computing optimal thresholds for seeds: [np.uint32(458), np.uint32(1018), np.uint32(1016), np.uint32(391), np.uint32(976)]

-- Optimizing for SEED=458 --

--- Finding Optimal Thresholds using Validation Set ---
Loading model: final_model_seed_458_epoch_30.pth
Model loaded successfully.
Running inference on validation set to get probabilities...
Inference complete. Found 2183 validation samples.
Optimizing thresholds for 5 classes...
  CD        : Best Threshold = 0.430 (Validation F1 = 0.7441)
  HYP       : Best Threshold = 0.200 (Validation F1 = 0.4665)
  MI        : Best Threshold = 0.540 (Validation F1 = 0.7452)
  NORM      : Best Threshold = 0.360 (Validation F1 = 0.8556)
  STTC      : Best Threshold = 0.340 (Validation F1 = 0.7609)
--- Optimal Threshold finding complete ---
Optimization complete for seed 458.

-- Optimizing for SEED=1018 --

--- Finding Optimal Thresholds using Validation Set ---
Loading model: final_model_seed_1018_ep

# Implement Subgroup Optimization

In [56]:
# --- Helper Function: Calculate Subgroup Metrics
def calculate_subgroup_metrics(group_labels, group_preds, target_classes):
    num_samples = len(group_labels)
    metrics = {}
    num_classes_calc = len(target_classes)

    if num_samples == 0:
        metrics['num_samples'] = 0; metrics['exact_match_ratio'] = np.nan; metrics['mean_true_labels'] = np.nan
        metrics['precision_macro'] = np.nan; metrics['recall_macro'] = np.nan; metrics['f1_macro'] = np.nan
        metrics['precision_micro'] = np.nan; metrics['recall_micro'] = np.nan; metrics['f1_micro'] = np.nan
        metrics['precision_weighted'] = np.nan; metrics['recall_weighted'] = np.nan; metrics['f1_weighted'] = np.nan
        metrics['precision_per_class'] = {cls: np.nan for cls in target_classes}; metrics['recall_per_class'] = {cls: np.nan for cls in target_classes}
        metrics['f1_per_class'] = {cls: np.nan for cls in target_classes}; metrics['support_per_class'] = {cls: 0 for cls in target_classes}
        return metrics

    metrics['num_samples'] = num_samples
    if num_samples == 1: group_labels = group_labels.reshape(1, -1); group_preds = group_preds.reshape(1, -1)

    metrics['exact_match_ratio'] = np.sum(np.all(group_labels == group_preds, axis=1)) / num_samples
    metrics['mean_true_labels'] = np.mean(np.sum(group_labels, axis=1))

    precision_mac, recall_mac, f1_mac, _ = precision_recall_fscore_support(group_labels, group_preds, average='macro', zero_division=0)
    precision_mic, recall_mic, f1_mic, _ = precision_recall_fscore_support(group_labels, group_preds, average='micro', zero_division=0)
    precision_wei, recall_wei, f1_wei, _ = precision_recall_fscore_support(group_labels, group_preds, average='weighted', zero_division=0)
    precision_pc, recall_pc, f1_pc, support_pc = precision_recall_fscore_support(group_labels, group_preds, average=None, zero_division=0, labels=list(range(num_classes_calc)))

    metrics.update({'precision_macro': precision_mac, 'recall_macro': recall_mac, 'f1_macro': f1_mac, 'precision_micro': precision_mic, 'recall_micro': recall_mic, 'f1_micro': f1_mic,'precision_weighted': precision_wei, 'recall_weighted': recall_wei, 'f1_weighted': f1_wei,'precision_per_class': dict(zip(target_classes, precision_pc)),'recall_per_class': dict(zip(target_classes, recall_pc)),'f1_per_class': dict(zip(target_classes, f1_pc)),'support_per_class': dict(zip(target_classes, support_pc))})
    return metrics


def find_optimal_thresholds_subgroup(
    model_path,
    model_class_factory, # Factory function for your model (e.g., ECGCNN or xresnet)
    model_config,        # Dictionary with model config, MUST include 'num_classes'
    val_loader,          # DataLoader for the validation set (MUST NOT SHUFFLE)
    val_df,              # DataFrame with metadata, ordered consistently with val_loader
    subgroup_col,        # Column name in val_df (e.g., 'sex', 'age_bin', 'device')
    global_optimal_thresholds, # Numpy array (shape: num_classes) of pre-calculated global thresholds
    device,              # Torch device ('cuda' or 'cpu')
    target_classes=None, # Optional: List of class names for printing
    threshold_step=0.01,  # Granularity for threshold search
    log=False,
):
    """
    Finds optimal probability thresholds for each class, optimized *within* each
    subgroup defined by the unique values in `subgroup_col` of `val_df`.
    Uses pre-calculated global thresholds as fallback when a class has no positive
    samples within a subgroup.

    Args:
        model_path (str): Path to the trained model state dictionary (.pth file).
        model_class_factory (function): Factory function to create the model instance.
        model_config (dict): Configuration dictionary for the model factory,
                               must include 'num_classes'.
        val_loader (DataLoader): DataLoader for the validation set (MUST NOT SHUFFLE).
        val_df (pd.DataFrame): DataFrame containing metadata for the validation set,
                               ordered consistently with val_loader.
        subgroup_col (str): The column name in val_df to group by (e.g., 'sex').
        global_optimal_thresholds (np.ndarray): Array (shape: num_classes) with
                                                pre-computed optimal thresholds
                                                from the entire validation set.
        device (torch.device): Device to run inference on ('cuda' or 'cpu').
        target_classes (list, optional): List of class names for printing.
        threshold_step (float): Step size for searching thresholds (e.g., 0.01).

    Returns:
        dict: A dictionary where keys are the unique subgroup values (as strings),
              and values are numpy arrays (shape: num_classes) containing the
              optimal thresholds for that specific subgroup. Returns None on critical error.
              Example: {'male': array([0.45, 0.33,...]), 'female': array([0.51, 0.29,...])}
    """
    if log:
      print(f"\n--- Finding Optimal Thresholds per Subgroup: '{subgroup_col}' ---")
      print(f"Using Model Factory: {model_class_factory.__name__}")
      print(f"Loading model state: {os.path.basename(model_path)}")

    # --- Validate Inputs ---
    if subgroup_col not in val_df.columns:
        print(f"ERROR: Subgroup column '{subgroup_col}' not found in val_df.")
        return None
    if len(val_loader.dataset) != len(val_df):
        print(f"ERROR: Mismatch between val_loader dataset size ({len(val_loader.dataset)}) "
              f"and val_df length ({len(val_df)}). Ensure order consistency.")
        return None
    if not isinstance(global_optimal_thresholds, np.ndarray):
         print(f"ERROR: global_optimal_thresholds must be a NumPy array.")
         return None

    # --- Extract num_classes from config ---
    try:
        num_classes = model_config['num_classes']
        if global_optimal_thresholds.shape != (num_classes,):
             print(f"ERROR: Shape mismatch. global_optimal_thresholds shape {global_optimal_thresholds.shape} "
                   f"does not match num_classes ({num_classes}) from model_config.")
             return None
    except KeyError:
        print("ERROR: 'num_classes' not found in model_config.")
        return None
    except TypeError:
         print("ERROR: model_config is not a dictionary or is None.")
         return None

    # --- Ensure subgroup column is string type for consistent keys ---
    try:
        val_df_internal = val_df.copy()
        val_df_internal[subgroup_col] = val_df_internal[subgroup_col].astype(str)
        subgroup_values_array = val_df_internal[subgroup_col].values
    except Exception as e:
        print(f"Warning: Could not convert column '{subgroup_col}' to string: {e}")
        subgroup_values_array = val_df[subgroup_col].values

    # --- Load Model ---
    try:
        model = model_class_factory(**model_config)
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.to(device)
        model.eval()
        if log:
          print("Model loaded successfully.")
    except FileNotFoundError:
        print(f"ERROR: Model file not found at {model_path}")
        return None
    except Exception as e:
        print(f"Error loading model or state_dict for {model_class_factory.__name__}: {e}")
        import traceback
        traceback.print_exc()
        return None

    # --- Get ALL Probabilities and Labels from Validation Set (ONCE) ---
    all_val_probs_list = []
    all_val_labels_list = []
    if log:
      print("Running inference on entire validation set...")
    with torch.no_grad():
        for signals, labels in val_loader:
            signals = signals.to(device)
            outputs = model(signals)
            probs = torch.sigmoid(outputs)
            all_val_probs_list.append(probs.cpu().numpy())
            all_val_labels_list.append(labels.cpu().numpy())

    # Concatenate results
    try:
        if not all_val_labels_list or not all_val_probs_list:
             print("Error: No data collected from validation loader. Is it empty?")
             return None
        all_y_true_val = np.concatenate(all_val_labels_list, axis=0).astype(int)
        all_y_prob_val = np.concatenate(all_val_probs_list, axis=0)
        if all_y_true_val.shape[0] == 0:
             print("Error: Concatenated arrays are empty.")
             return None
        if all_y_true_val.shape[1] != num_classes or all_y_prob_val.shape[1] != num_classes:
             print(f"Error: Shape mismatch after concatenation. Expected {num_classes} classes.")
             return None
        print(f"Inference complete. Found {all_y_true_val.shape[0]} validation samples.")
    except ValueError as e:
        print(f"Error during concatenation: {e}. Check data shapes in validation loader.")
        return None
    except Exception as e:
        print(f"An unexpected error occurred during result processing: {e}")
        return None

    # --- Optimize Thresholds per Subgroup Value ---
    optimal_thresholds_per_subgroup = {}
    unique_subgroup_values = np.unique(subgroup_values_array)
    print(f"Found unique subgroup values in '{subgroup_col}': {list(unique_subgroup_values)}")

    threshold_candidates = np.arange(threshold_step, 1.0, threshold_step)
    if len(threshold_candidates) == 0:
        print(f"ERROR: threshold_step ({threshold_step}) is too large, resulted in zero candidates.")
        return None
    print(f"Optimizing thresholds using {len(threshold_candidates)} steps (step={threshold_step:.3f})...")

    for value_str in unique_subgroup_values:
        print(f"\n-- Optimizing for Subgroup: {subgroup_col} = {value_str} --")
        mask = (subgroup_values_array == value_str)
        num_samples_subgroup = np.sum(mask)

        if num_samples_subgroup == 0:
            print("  No samples found for this subgroup. Using global thresholds as fallback.")
            optimal_thresholds_per_subgroup[value_str] = global_optimal_thresholds.copy()
            continue

        y_true_subgroup = all_y_true_val[mask]
        y_prob_subgroup = all_y_prob_val[mask]
        print(f"  Optimizing using {num_samples_subgroup} samples.")

        optimal_thresholds_subgroup = np.zeros(num_classes)
        for i in range(num_classes): # Loop through classes
            true_labels_class = y_true_subgroup[:, i]
            probs_class = y_prob_subgroup[:, i]

            # Check if positive samples exist for this class in this subgroup
            if np.sum(true_labels_class) == 0:
                # Path 1: No positive samples exist
                class_name = target_classes[i] if target_classes and i < len(target_classes) else f"Class {i}"
                # Define fallback_threshold locally just for this path's logic/printout
                fallback_threshold = global_optimal_thresholds[i]
                print(f"  {class_name:<15}: No positive samples in subgroup. Using global threshold {fallback_threshold:.3f}")
                optimal_thresholds_subgroup[i] = fallback_threshold # Assign the fallback
                continue # Skip optimization for this class

            # Path 2: Positive samples DO exist
            # Initialize best_threshold with the global one before searching
            best_threshold_class = global_optimal_thresholds[i] # <<< THE FIX IS HERE
            best_f1_class = -1.0

            # Search for a better threshold within the subgroup
            for thr in threshold_candidates:
                pred_labels_class = (probs_class >= thr).astype(int)
                f1 = f1_score(true_labels_class, pred_labels_class, zero_division=0)
                if f1 > best_f1_class:
                    best_f1_class = f1
                    best_threshold_class = thr

            # Store the best threshold found (or the initialized global one if F1 never improved)
            optimal_thresholds_subgroup[i] = best_threshold_class
            class_name = target_classes[i] if target_classes and i < len(target_classes) else f"Class {i}"
            if best_f1_class >= 0:
                 print(f"  {class_name:<15}: Best Threshold = {best_threshold_class:.3f} (Subgroup Val F1 = {best_f1_class:.4f})")
            else:
                 print(f"  {class_name:<15}: Optimization found no improvement? Using initial fallback {optimal_thresholds_subgroup[i]:.3f}")

        # Store the thresholds found for this specific subgroup value
        optimal_thresholds_per_subgroup[value_str] = optimal_thresholds_subgroup

    print(f"\n--- Subgroup Threshold Optimization for '{subgroup_col}' Complete ---")
    return optimal_thresholds_per_subgroup


In [57]:
OPTIMIZATION_RESULTS_DIR = 'optimization_results'
os.makedirs(OPTIMIZATION_RESULTS_DIR, exist_ok=True)

optimized_subgroup_thresholds_per_seed_result = {}

# --- Loop through seeds ---
for current_seed in all_seeds:
    print(f"\n{'='*20} Optimizing Subgroup Thresholds for SEED={current_seed} {'='*20}")
    optimized_subgroup_thresholds_per_seed_result[current_seed] = {}

    current_global_thresholds = optimized_thresholds_per_seed.get(current_seed)
    if current_global_thresholds is None:
        print(f"WARNING: Global thresholds for seed {current_seed} not found. Skipping subgroup optimization.")
        # Assign None placeholders
        optimized_subgroup_thresholds_per_seed_result[current_seed]['sex'] = None
        optimized_subgroup_thresholds_per_seed_result[current_seed]['age_bin'] = None
        optimized_subgroup_thresholds_per_seed_result[current_seed]['device'] = None
        continue

    base_name_prefix = f"model_seed_{current_seed}{'_DA' if DA else ''}"
    final_model_filename = f"final_{base_name_prefix}_epoch_{NUM_EPOCHS}.pth"
    model_path = os.path.join(MODEL_SAVE_DIRECTORY, final_model_filename)

    if not os.path.exists(model_path):
        print(f"WARNING: Model file not found for seed {current_seed}: {model_path}. Skipping.")
        optimized_subgroup_thresholds_per_seed_result[current_seed]['sex'] = None
        optimized_subgroup_thresholds_per_seed_result[current_seed]['age_bin'] = None
        optimized_subgroup_thresholds_per_seed_result[current_seed]['device'] = None
        continue

    # --- Call find_optimal_thresholds_subgroup (calls now work correctly) ---

    # Optimize for 'sex'
    thresholds_sex = find_optimal_thresholds_subgroup(
        model_path=model_path,
        model_class_factory=model_class_factory,
        model_config=model_config,
        val_loader=val_loader,
        val_df=val_df,
        subgroup_col='sex',
        global_optimal_thresholds=current_global_thresholds,
        device=DEVICE_OPT,
        target_classes=target_classes
    )
    optimized_subgroup_thresholds_per_seed_result[current_seed]['sex'] = thresholds_sex

    # Optimize for 'age_bin'
    thresholds_age = find_optimal_thresholds_subgroup(
        model_path=model_path,
        model_class_factory=model_class_factory,
        model_config=model_config,
        val_loader=val_loader,
        val_df=val_df,
        subgroup_col='age_bin',
        global_optimal_thresholds=current_global_thresholds,
        device=DEVICE_OPT,
        target_classes=target_classes
    )
    optimized_subgroup_thresholds_per_seed_result[current_seed]['age_bin'] = thresholds_age

    # Optimize for 'device'
    thresholds_device = find_optimal_thresholds_subgroup(
        model_path=model_path,
        model_class_factory=model_class_factory,
        model_config=model_config,
        val_loader=val_loader,
        val_df=val_df,
        subgroup_col='device',
        global_optimal_thresholds=current_global_thresholds,
        device=DEVICE_OPT,
        target_classes=target_classes
    )
    optimized_subgroup_thresholds_per_seed_result[current_seed]['device'] = thresholds_device

# --- Save the results ---
subgroup_thresholds_filename = os.path.join(OPTIMIZATION_RESULTS_DIR, f'optimized_subgroup_thresholds_per_seed_DA_{DA}.pkl')
print(f"\nSaving computed subgroup thresholds per seed to: {subgroup_thresholds_filename}")
try:
    with open(subgroup_thresholds_filename, 'wb') as f:
        pickle.dump(optimized_subgroup_thresholds_per_seed_result, f)
    print("Optimized subgroup thresholds per seed saved.")
except Exception as e:
    print(f"Error saving optimized subgroup thresholds per seed: {e}")


Inference complete. Found 2183 validation samples.
Found unique subgroup values in 'sex': ['0', '1']
Optimizing thresholds using 99 steps (step=0.010)...

-- Optimizing for Subgroup: sex = 0 --
  Optimizing using 1133 samples.
  CD             : Best Threshold = 0.430 (Subgroup Val F1 = 0.7559)
  HYP            : Best Threshold = 0.160 (Subgroup Val F1 = 0.5116)
  MI             : Best Threshold = 0.540 (Subgroup Val F1 = 0.7700)
  NORM           : Best Threshold = 0.450 (Subgroup Val F1 = 0.8496)
  STTC           : Best Threshold = 0.340 (Subgroup Val F1 = 0.7361)

-- Optimizing for Subgroup: sex = 1 --
  Optimizing using 1050 samples.
  CD             : Best Threshold = 0.450 (Subgroup Val F1 = 0.7320)
  HYP            : Best Threshold = 0.340 (Subgroup Val F1 = 0.4490)
  MI             : Best Threshold = 0.540 (Subgroup Val F1 = 0.7112)
  NORM           : Best Threshold = 0.330 (Subgroup Val F1 = 0.8661)
  STTC           : Best Threshold = 0.340 (Subgroup Val F1 = 0.7836)

--- Subg

In [58]:
# ==============================================================================
# ===== START OF EVALUATION AND AGGREGATION CODE ===============================
# ==============================================================================
# --- Configuration for Evaluation ---
MODEL_SAVE_DIRECTORY = 'saved_models' # Directory where models are saved
NUM_EPOCHS = 30 # Number of epochs models were trained for (used in filename)
RESULTS_DIR = 'evaluation_results' # Directory to save results files
os.makedirs(RESULTS_DIR, exist_ok=True)

# --- Validate Test Data (Ensure variables from loading step exist) ---
if 'test_df' not in locals() or test_df is None: print("Error: 'test_df' not available."); exit()
if 'mlb' not in locals() or mlb is None: print("Error: 'mlb' not available."); exit()
if 'test_loader' not in locals() or test_loader is None: print("Error: 'test_loader' not available."); exit()

required_cols = ['sex', 'age_bin', 'device']
if not all(col in test_df.columns for col in required_cols):
    print(f"Error: test_df missing columns: {required_cols}"); exit()
# Ensure consistent data types (strings recommended for categorical keys)
test_df['age_bin'] = test_df['age_bin'].astype(str)
test_df['device'] = test_df['device'].astype(str)
test_df['sex'] = test_df['sex'].astype(str) # Use '0', '1'

target_classes = list(mlb.classes_)
num_classes = len(target_classes)
print(f"Number of classes for evaluation: {num_classes}")
print(f"Target classes: {target_classes}")

EXPERIMENTAL_CONDITIONS = {
    # Condition Name: { Parameters }
    'Optimized_Global': { # Experiment 1
        'DA_flag': False,
        'threshold_mode': 'global_optimized',
        'subgroup_focus': None # Not applicable, uses global array
    },
    'Optimized_Subgroup_Combined': { # Experiment 2 (Sex + Age + Device)
        'DA_flag': False,
        'threshold_mode': 'subgroup',
        'subgroup_focus': 'combined' # Special keyword for combined effect
    },
    'Optimized_Subgroup_Sex': { # Experiment 3 (Sex Only)
        'DA_flag': False,
        'threshold_mode': 'subgroup',
        'subgroup_focus': 'sex' # Focus only on sex thresholds
    },
    'Optimized_Subgroup_Age': { # Experiment 4 (Age Only)
        'DA_flag': False,
        'threshold_mode': 'subgroup',
        'subgroup_focus': 'age_bin' # Focus only on age thresholds
    },
    'Optimized_Subgroup_Device': { # Experiment 5 (Device Only)
        'DA_flag': False,
        'threshold_mode': 'subgroup',
        'subgroup_focus': 'device' # Focus only on device thresholds
    },
}

# --- Flexible Detailed Evaluation Function ---
def evaluate_model_detailed(model_path, test_loader, test_df, mlb, device, thresholds_arg):
    """
    Evaluates model with detailed subgroup analysis. Handles global (float/array)
    or subgroup-specific (dict) thresholds passed via thresholds_arg.
    """
    print(f"\n--- Starting Detailed Evaluation for Model: {os.path.basename(model_path)} ---")
    # Basic Validation
    if not os.path.exists(model_path): print(f"Error: Model not found: {model_path}"); return None
    if len(test_loader.dataset) != len(test_df): print("Error: Mismatch test_loader/test_df"); return None

    target_classes_eval = list(mlb.classes_)
    num_classes_eval = len(target_classes_eval)
    class_to_index = {cls_name: idx for idx, cls_name in enumerate(target_classes_eval)}

    # Load Model
    try:
        model = ECGCNN(num_classes=num_classes_eval)
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.to(device)
        model.eval()
        print("Model loaded successfully.")
    except Exception as e: print(f"Error loading model: {e}"); return None

    # Run Inference
    all_probs_list, all_labels_list = [], []
    print("Running inference...")
    with torch.no_grad():
        for signals, labels in test_loader:
            signals = signals.to(device); outputs = model(signals); probs = torch.sigmoid(outputs)
            all_probs_list.append(probs.cpu().numpy()); all_labels_list.append(labels.cpu().numpy())
    all_probs = np.concatenate(all_probs_list, axis=0)
    all_labels = np.concatenate(all_labels_list, axis=0).astype(int)
    print("Inference complete. Calculating metrics...")

    # --- Determine Thresholding Mode and Set Fallback ---
    thresholds_dict = None # Will hold the dict if mode is subgroup
    if isinstance(thresholds_arg, dict):
        print("Using Subgroup-Specific Threshold Dictionary.")
        threshold_mode = 'subgroup'
        thresholds_dict = thresholds_arg
        if 'all' not in thresholds_dict: print("Error: Threshold dict must contain 'all' key."); return None
        fallback_thresholds = thresholds_dict['all']
        if not isinstance(fallback_thresholds, np.ndarray) or fallback_thresholds.shape != (num_classes_eval,): print(f"Error: thresholds_dict['all'] invalid."); return None
    elif isinstance(thresholds_arg, (np.ndarray, list)):
         print("Using Global Per-Class Threshold Array.")
         threshold_mode = 'global_array'
         fallback_thresholds = np.array(thresholds_arg)
         if fallback_thresholds.shape != (num_classes_eval,): print(f"Error: Global threshold array shape invalid."); return None
    elif isinstance(thresholds_arg, (float, int)):
         print(f"Using Single Global Threshold Value: {thresholds_arg}")
         threshold_mode = 'global_single'
         fallback_thresholds = np.array([float(thresholds_arg)] * num_classes_eval)
    else:
         print("Error: Invalid thresholds argument type."); return None

    # Prepare Results Structure
    results = {'all': {}, 'sex': {}, 'age_bin': {}, 'device': {}, 'diagnostic_class': {}, 'overall_metrics': {}}

    # Calculate 'all' metrics using fallback/global thresholds
    print("\n--- Overall Metrics (All Samples) ---")
    all_preds_overall = (all_probs >= fallback_thresholds).astype(int)
    results['all'] = calculate_subgroup_metrics(all_labels, all_preds_overall, target_classes_eval)
    for metric, value in results['all'].items():
         if not isinstance(value, dict): print(f"  {metric}: {value:.4f}" if isinstance(value, float) else f"  {metric}: {value}")

    # Calculate Metrics for Subgroups
    subgroup_types = ['sex', 'age_bin', 'device', 'diagnostic_class']
    for sg_type in subgroup_types:
        print(f"\n--- Metrics by Subgroup Type: {sg_type.upper()} ---")
        unique_values = target_classes_eval if sg_type == 'diagnostic_class' else test_df[sg_type].unique()

        for value in unique_values:
            value_str = str(value); print(f" -- Subgroup: {value_str} --")
            # Create mask
            if sg_type == 'diagnostic_class':
                 class_index = class_to_index.get(value);
                 if class_index is None: continue
                 mask = (all_labels[:, class_index] == 1)
            else: mask = (test_df[sg_type].values == value_str)

            if np.sum(mask) == 0:
                 print("   (N=0) Skipping."); results[sg_type][value_str] = calculate_subgroup_metrics(np.array([]), np.array([]), target_classes_eval); continue

            group_labels_sub = all_labels[mask]
            group_probs_sub = all_probs[mask]

            # Determine and Apply Thresholds for THIS subgroup
            applied_threshold_source = "Fallback/Global"
            if threshold_mode == 'subgroup':
                current_thresholds = thresholds_dict.get(sg_type, {}).get(value_str, fallback_thresholds)
                if not isinstance(current_thresholds, np.ndarray) or current_thresholds.shape != (num_classes_eval,):
                    print(f"Warning: Invalid stored threshold for {sg_type}/{value_str}. Using fallback."); current_thresholds = fallback_thresholds
                elif value_str in thresholds_dict.get(sg_type, {}): applied_threshold_source = "Specific" # Mark if specific threshold was found and used
            else: current_thresholds = fallback_thresholds # Use global thresholds

            group_preds_sub = (group_probs_sub >= current_thresholds).astype(int)

            # Calculate metrics
            subgroup_metrics = calculate_subgroup_metrics(group_labels_sub, group_preds_sub, target_classes_eval)
            results[sg_type][value_str] = subgroup_metrics

            # Print summary
            print(f"   (N={subgroup_metrics['num_samples']}) Applied thresholds: {applied_threshold_source}")
            if subgroup_metrics['num_samples'] > 0: print(f"   Macro F1: {subgroup_metrics['f1_macro']:.4f}; Exact Match: {subgroup_metrics['exact_match_ratio']:.4f}")

    # Calculate Specific Cross-Group Metrics (Avg Male/Female)
    male_key, female_key = '1', '0'
    results['overall_metrics']['avg_macro_recall_male_female'] = np.nan
    results['overall_metrics']['avg_macro_precision_male_female'] = np.nan
    results['overall_metrics']['avg_macro_f1_male_female'] = np.nan
    if male_key in results['sex'] and female_key in results['sex'] and results['sex'][male_key].get('num_samples', 0) > 0 and results['sex'][female_key].get('num_samples', 0) > 0:
        male_metrics = results['sex'][male_key]; female_metrics = results['sex'][female_key]
        if not np.isnan(male_metrics.get('recall_macro', np.nan)) and not np.isnan(female_metrics.get('recall_macro', np.nan)): results['overall_metrics']['avg_macro_recall_male_female'] = (male_metrics['recall_macro'] + female_metrics['recall_macro']) / 2
        if not np.isnan(male_metrics.get('precision_macro', np.nan)) and not np.isnan(female_metrics.get('precision_macro', np.nan)): results['overall_metrics']['avg_macro_precision_male_female'] = (male_metrics['precision_macro'] + female_metrics['precision_macro']) / 2
        if not np.isnan(male_metrics.get('f1_macro', np.nan)) and not np.isnan(female_metrics.get('f1_macro', np.nan)): results['overall_metrics']['avg_macro_f1_male_female'] = (male_metrics['f1_macro'] + female_metrics['f1_macro']) / 2
        print("\n--- Cross-Group Metrics ---"); print(f"  Avg Macro F1 (M/F): {results['overall_metrics']['avg_macro_f1_male_female']:.4f}")
    else: print("\nWarning: Cannot calculate average male/female metrics.")

    print("\n--- Detailed Evaluation Complete ---")
    return results


# --- Main Evaluation Loop ---
# Structure: {condition_name: {seed: evaluation_dict}}
all_evaluation_results = {name: {} for name in EXPERIMENTAL_CONDITIONS.keys()}

# Assuming `all_seeds` list is defined earlier and `optimized_thresholds_per_seed` is populated
if 'all_seeds' not in locals(): print("Error: `all_seeds` not defined."); exit()
if 'optimized_thresholds_per_seed' not in locals(): print("Error: `optimized_thresholds_per_seed` not defined."); exit()

for condition_name, condition_params in EXPERIMENTAL_CONDITIONS.items():
    print(f"\n{'='*20} Evaluating Condition: {condition_name} {'='*20}")
    current_da_flag = condition_params['DA_flag']
    threshold_mode = condition_params['threshold_mode']
    subgroup_focus = condition_params.get('subgroup_focus') # Get the focus for subgroup modes
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {DEVICE}, DA Flag: {current_da_flag}, Thresh Mode: {threshold_mode}, Focus: {subgroup_focus}")

    for current_seed in all_seeds:
        print(f"\n-- Evaluating SEED={current_seed} --")
        base_name_prefix = f"model_seed_{current_seed}{'_DA' if current_da_flag else ''}"
        final_model_filename = f"final_{base_name_prefix}_epoch_{NUM_EPOCHS}.pth"
        model_path = os.path.join(MODEL_SAVE_DIRECTORY, final_model_filename)

        if not os.path.exists(model_path):
            print(f"Warning: Model file not found: {model_path}. Skipping."); all_evaluation_results[condition_name][current_seed] = None; continue

        # --- *** MODIFIED THRESHOLD FETCHING *** ---
        current_thresholds_arg = None
        try:
            # --- Handle GLOBAL mode (Experiment 1) ---
            if threshold_mode == 'global_optimized':
                if current_seed in optimized_thresholds_per_seed:
                    current_thresholds_arg = optimized_thresholds_per_seed[current_seed]
                    print(f"Using optimized global thresholds for seed {current_seed}.")
                    if not isinstance(current_thresholds_arg, np.ndarray) or current_thresholds_arg.shape[0] != num_classes:
                         raise ValueError("Invalid global threshold array found.")
                else:
                    raise KeyError(f"Optimized global thresholds for seed {current_seed} not found!")

            # --- Handle SUBGROUP modes (Experiments 2, 3, 4, 5) ---
            elif threshold_mode == 'subgroup':
                if current_seed not in optimized_thresholds_per_seed or current_seed not in optimized_subgroup_thresholds_per_seed_result:
                     raise KeyError(f"Required optimized thresholds (global or subgroup) missing for seed {current_seed}")

                # Always get the optimized global thresholds for fallback
                global_fallback = optimized_thresholds_per_seed[current_seed]
                if not isinstance(global_fallback, np.ndarray) or global_fallback.shape[0] != num_classes:
                     raise ValueError("Invalid global threshold array found for fallback.")

                # Get the full dictionary of all optimized subgroup thresholds for this seed
                all_subgroup_thresholds_this_seed = optimized_subgroup_thresholds_per_seed_result[current_seed]

                # Build the argument dictionary based on subgroup_focus
                current_thresholds_arg = {'all': global_fallback} # Start with the fallback

                if subgroup_focus == 'combined': # Experiment 2
                    print(f"Using COMBINED optimized subgroup thresholds (Sex+Age+Device) for seed {current_seed}.")
                    current_thresholds_arg['sex'] = all_subgroup_thresholds_this_seed.get('sex', {})
                    current_thresholds_arg['age_bin'] = all_subgroup_thresholds_this_seed.get('age_bin', {})
                    current_thresholds_arg['device'] = all_subgroup_thresholds_this_seed.get('device', {})
                    # Add 'diagnostic_class' if optimized:
                    # current_thresholds_arg['diagnostic_class'] = all_subgroup_thresholds_this_seed.get('diagnostic_class', {})

                elif subgroup_focus == 'sex': # Experiment 3
                    print(f"Using SEX optimized subgroup thresholds ONLY for seed {current_seed}.")
                    current_thresholds_arg['sex'] = all_subgroup_thresholds_this_seed.get('sex', {})
                    # Age/Device keys omitted -> evaluate_model_detailed will use 'all' fallback

                elif subgroup_focus == 'age_bin': # Experiment 4
                    print(f"Using AGE optimized subgroup thresholds ONLY for seed {current_seed}.")
                    current_thresholds_arg['age_bin'] = all_subgroup_thresholds_this_seed.get('age_bin', {})
                     # Sex/Device keys omitted -> evaluate_model_detailed will use 'all' fallback

                elif subgroup_focus == 'device': # Experiment 5
                    print(f"Using DEVICE optimized subgroup thresholds ONLY for seed {current_seed}.")
                    current_thresholds_arg['device'] = all_subgroup_thresholds_this_seed.get('device', {})
                     # Sex/Age keys omitted -> evaluate_model_detailed will use 'all' fallback

                else:
                    raise ValueError(f"Invalid subgroup_focus '{subgroup_focus}' for condition '{condition_name}'")

            # --- Handle other modes if needed ---
            else:
                raise ValueError(f"Unknown threshold_mode '{threshold_mode}'.")

        except (KeyError, ValueError) as e:
             print(f"ERROR fetching/validating thresholds for seed {current_seed}, condition {condition_name}: {e}. Using default 0.5.")
             # For subgroup mode, create a minimal fallback dict
             if threshold_mode == 'subgroup':
                 current_thresholds_arg = {'all': np.full(num_classes, 0.5)}
             else: # Global mode fallback
                 current_thresholds_arg = np.full(num_classes, 0.5)

        # Ensure thresholds_arg is assigned
        if current_thresholds_arg is None:
            print(f"CRITICAL ERROR: Failed to determine thresholds_arg for seed {current_seed}, condition {condition_name}. Using fallback.")
            if threshold_mode == 'subgroup':
                 current_thresholds_arg = {'all': np.full(num_classes, 0.5)}
            else:
                 current_thresholds_arg = np.full(num_classes, 0.5)
        # --- *** END THRESHOLD FETCHING *** ---

        # Call evaluation (No change needed here)
        results_this_seed = evaluate_model_detailed(
            model_path=model_path, test_loader=test_loader, test_df=test_df, mlb=mlb, device=DEVICE,
            thresholds_arg=current_thresholds_arg # Pass the dynamically built thresholds dict/array
        )

        all_evaluation_results[condition_name][current_seed] = results_this_seed
        if results_this_seed: print(f"Evaluation complete for seed {current_seed}.")
        else: print(f"Evaluation failed for seed {current_seed}.")


# --- Save Raw Evaluation Results ---
results_save_filename = 'all_evaluation_results_raw.pkl'
results_save_path = os.path.join(RESULTS_DIR, results_save_filename)
print(f"\nSaving detailed raw evaluation results to: {results_save_path}")
try:
    with open(results_save_path, 'wb') as f: pickle.dump(all_evaluation_results, f)
    print("Raw results saved successfully.")
except Exception as e: print(f"Error saving raw results: {e}")



# --- Step 5: Aggregation and Reporting ---
print(f"\n{'='*20} Aggregating Evaluation Results Across Seeds {'='*20}")

# Ensure the raw results dictionary from the evaluation loop is available
if 'all_evaluation_results' not in locals() or not all_evaluation_results:
    print("Error: 'all_evaluation_results' dictionary not found or is empty. Cannot aggregate.")
else:
    print("Using existing 'all_evaluation_results' dictionary.")


# Structure for aggregated stats: {condition_name: {top_level_key: {subgroup_name_or_metric: {metric_or_stats_dict}}}}
# Initialize the dictionary to store aggregated results
aggregated_stats = {cond_name: {} for cond_name in all_evaluation_results.keys()}

# Loop through each experimental condition IN THE RAW RESULTS
for condition_name, seed_results_dict in all_evaluation_results.items():
    print(f"\n--- Aggregating for Condition: {condition_name} ---")

    # Filter out seeds where evaluation might have failed (result is None)
    valid_seed_results = [res for res in seed_results_dict.values() if isinstance(res, dict)]

    if not valid_seed_results:
        print("  No valid evaluation results found for this condition. Skipping aggregation.")
        aggregated_stats[condition_name] = {} # Store empty dict for this condition
        continue # Skip to the next condition

    print(f"  Aggregating from {len(valid_seed_results)} valid run(s).")
    aggregated_stats[condition_name] = {} # Initialize condition entry

    # Use the structure of the first valid result to find keys
    first_result = valid_seed_results[0]
    top_level_keys = list(first_result.keys()) # Should be ['all', 'sex', 'age_bin', ...]

    for top_key in top_level_keys: # Iterate through 'all', 'sex', 'age_bin', etc.
        print(f"\n    Processing Top Key: {top_key.upper()}")
        aggregated_stats[condition_name][top_key] = {} # Initialize entry for this top key

        first_level_data = first_result.get(top_key) # e.g., data under 'sex' or 'all'

        # Check if this level contains subgroups (like 'sex': {'0':{...}, '1':{...}})
        # or direct metrics (like 'all': {'f1_macro':...})
        # or overall cross-group metrics ('overall_metrics': {'avg_macro...': ...})
        if not isinstance(first_level_data, dict):
            print(f"      Skipping {top_key}: Expected dict, found {type(first_level_data)}.")
            continue

        second_level_keys = list(first_level_data.keys())
        if not second_level_keys:
            print(f"      Skipping {top_key}: Contains no data.")
            continue

        # Determine if the second level represents subgroups (like '0', '1' for sex)
        # or metrics (like 'f1_macro' for 'all' or 'overall_metrics')
        # A simple check: does the first item at the second level contain 'f1_macro'? (adjust if needed)
        level_contains_subgroups = isinstance(first_level_data.get(second_level_keys[0]), dict) and \
                                   'f1_macro' in first_level_data.get(second_level_keys[0], {})

        if level_contains_subgroups: # Processing subgroups like 'sex', 'age_bin', 'device'
            subgroup_names = second_level_keys # e.g., '0', '1', or '50s', '60s', or device names
            for sg_name in subgroup_names:
                print(f"      Subgroup: {sg_name}")
                aggregated_stats[condition_name][top_key][sg_name] = {} # Initialize entry

                # Get the structure of metrics from the first result for this subgroup
                first_subgroup_data = first_level_data.get(sg_name)
                if not isinstance(first_subgroup_data, dict): continue # Skip if subgroup data invalid

                metric_names = list(first_subgroup_data.keys()) # e.g., 'f1_macro', 'precision_per_class'

                for metric in metric_names:
                    # Check if the metric itself contains per-class breakdown
                    is_per_class = isinstance(first_subgroup_data.get(metric), dict)

                    if is_per_class: # Handle per-class metrics (e.g., 'f1_per_class')
                        aggregated_stats[condition_name][top_key][sg_name][metric] = {}
                        class_names = list(first_subgroup_data[metric].keys())
                        for cls_name in class_names:
                            # Collect values for this specific class across seeds
                            values = [seed_res.get(top_key, {}).get(sg_name, {}).get(metric, {}).get(cls_name)
                                      for seed_res in valid_seed_results]
                            valid_values = [v for v in values if isinstance(v, (int, float)) and not np.isnan(v)]
                            mean_val, std_val = (np.mean(valid_values), np.std(valid_values)) if valid_values else (np.nan, np.nan)
                            aggregated_stats[condition_name][top_key][sg_name][metric][cls_name] = {'mean': mean_val, 'std': std_val}
                    else: # Handle direct metrics (e.g., 'f1_macro', 'num_samples')
                        # Collect values for this metric across seeds
                        values = [seed_res.get(top_key, {}).get(sg_name, {}).get(metric)
                                  for seed_res in valid_seed_results]
                        valid_values = [v for v in values if isinstance(v, (int, float)) and not np.isnan(v)]
                        mean_val, std_val = (np.mean(valid_values), np.std(valid_values)) if valid_values else (np.nan, np.nan)
                        aggregated_stats[condition_name][top_key][sg_name][metric] = {'mean': mean_val, 'std': std_val}
                        # Optional print for key metrics
                        if metric in ['f1_macro', 'recall_macro', 'precision_macro', 'exact_match_ratio', 'mean_true_labels', 'num_samples']:
                             print(f"        {metric}: Mean={mean_val:.4f}, Std={std_val:.4f}" if not isinstance(mean_val, (int)) and metric != 'num_samples' else f"        {metric}: Mean={mean_val:.1f}, Std={std_val:.2f}")

        else: # Processing direct metrics under 'all' or 'overall_metrics'
            metric_names = second_level_keys # e.g., 'f1_macro', 'avg_macro_f1_male_female'
            print(f"      Metrics for: {top_key.upper()}")
            for metric in metric_names:
                # Check if metric contains per-class breakdown (relevant for 'all')
                is_per_class = isinstance(first_level_data.get(metric), dict)

                if is_per_class and top_key == 'all': # Handle per-class metrics under 'all'
                    aggregated_stats[condition_name][top_key][metric] = {}
                    class_names = list(first_level_data[metric].keys())
                    for cls_name in class_names:
                        values = [seed_res.get(top_key, {}).get(metric, {}).get(cls_name)
                                  for seed_res in valid_seed_results]
                        valid_values = [v for v in values if isinstance(v, (int, float)) and not np.isnan(v)]
                        mean_val, std_val = (np.mean(valid_values), np.std(valid_values)) if valid_values else (np.nan, np.nan)
                        aggregated_stats[condition_name][top_key][metric][cls_name] = {'mean': mean_val, 'std': std_val}
                elif not is_per_class: # Handle direct metrics under 'all' or 'overall_metrics'
                    values = [seed_res.get(top_key, {}).get(metric)
                              for seed_res in valid_seed_results]
                    valid_values = [v for v in values if isinstance(v, (int, float)) and not np.isnan(v)]
                    mean_val, std_val = (np.mean(valid_values), np.std(valid_values)) if valid_values else (np.nan, np.nan)
                    # **** THIS IS WHERE THE DATA FOR 'all' GETS STORED ****
                    aggregated_stats[condition_name][top_key][metric] = {'mean': mean_val, 'std': std_val}
                    # *******************************************************
                    print(f"        {metric}: Mean={mean_val:.4f}, Std={std_val:.4f}" if metric != 'num_samples' else f"        {metric}: Mean={mean_val:.1f}, Std={std_val:.2f}")


# --- Save Aggregated Stats ---
agg_save_filename = 'aggregated_evaluation_stats.pkl'
agg_save_path = os.path.join(RESULTS_DIR, agg_save_filename)
print(f"\nSaving aggregated evaluation statistics to: {agg_save_path}")
try:
    with open(agg_save_path, 'wb') as f:
        pickle.dump(aggregated_stats, f)
    print("Aggregated stats saved successfully.")
except Exception as e:
    print(f"Error saving aggregated stats: {e}")

print("\n--- Aggregation Script Finished ---")

Number of classes for evaluation: 5
Target classes: ['CD', 'HYP', 'MI', 'NORM', 'STTC']

Using device: cuda, DA Flag: False, Thresh Mode: global_optimized, Focus: None

-- Evaluating SEED=458 --
Using optimized global thresholds for seed 458.

--- Starting Detailed Evaluation for Model: final_model_seed_458_epoch_30.pth ---
Model loaded successfully.
Running inference...
Inference complete. Calculating metrics...
Using Global Per-Class Threshold Array.

--- Overall Metrics (All Samples) ---
  num_samples: 2198
  exact_match_ratio: 0.5682
  mean_true_labels: 1.2702
  precision_macro: 0.6741
  recall_macro: 0.7455
  f1_macro: 0.7062
  precision_micro: 0.7109
  recall_micro: 0.7855
  f1_micro: 0.7463
  precision_weighted: 0.7180
  recall_weighted: 0.7855
  f1_weighted: 0.7490

--- Metrics by Subgroup Type: SEX ---
 -- Subgroup: 0 --
   (N=1132) Applied thresholds: Fallback/Global
   Macro F1: 0.7226; Exact Match: 0.5936
 -- Subgroup: 1 --
   (N=1066) Applied thresholds: Fallback/Global
  

#Generate LaTex Table:

In [59]:
import numpy as np
import os
import pickle

# --- Optional Imports (Needed for PDF generation and Notebook display) ---
try:
    from pylatex import Document, Section, Subsection, Tabular, MultiColumn, Command
    from pylatex.utils import NoEscape # To use LaTeX commands like \pm
    PYLATEX_AVAILABLE = True
    print("pylatex library loaded successfully.")
except ImportError:
    print("Warning: pylatex not found. PDF generation will be skipped.")
    PYLATEX_AVAILABLE = False

try:
    # Needed only if you want to display the PDF inside a Jupyter Notebook/Lab
    from IPython.display import display, IFrame
    IPYTHON_AVAILABLE = True
except ImportError:
    IPYTHON_AVAILABLE = False
    print("Warning: IPython.display not found. PDF preview in notebook will be skipped.")

# --- Configuration Variables ---

# Directory where results were saved and PDFs will be generated
# Ensure this matches the directory used in your evaluation script
RESULTS_DIR = 'evaluation_results'
# Make sure the directory exists, create if not
os.makedirs(RESULTS_DIR, exist_ok=True)

# Full path to the file containing your aggregated statistics
# Ensure this file exists and was generated by the previous aggregation step
AGGREGATED_STATS_PATH = os.path.join(RESULTS_DIR, 'aggregated_evaluation_stats.pkl')

# --- Helper Function Definitions ---

def get_stat(agg_stats, condition, sg_type, sg_name, metric, stat_type, precision=3, default='N/A'):
    """
    Retrieves an aggregated statistic (mean or std) from the nested dictionary.
    Handles the 'all' case where metrics are stored directly under agg_stats[condition]['all'].
    """
    try:
        # Special handling for the 'all' group ('Total' row)
        if sg_type == 'all':
            # Access metrics directly under the 'all' key for the condition
            value = agg_stats[condition]['all'][metric][stat_type]
        else:
            # Access metrics for specific subgroups
            value = agg_stats[condition][sg_type][sg_name][metric][stat_type]

        # Check for NaN and format
        if isinstance(value, (int, float)) and np.isnan(value):
            return default
        if precision == 0:
             if isinstance(value, (int, float)) and not np.isnan(value):
                 return f"{int(value)}"
             else:
                 return default
        else:
             if isinstance(value, (int, float)):
                return f"{value:.{precision}f}"
             else:
                 return default

    except (KeyError, TypeError, IndexError):
        # Return default if path doesn't exist (e.g., subgroup missing for a condition)
        return default
    except Exception as e:
        print(f"Unexpected error in get_stat({condition}, {sg_type}, {sg_name}, {metric}, {stat_type}): {e}")
        return default

def _format_metric_cell(mean_val, std_val, precision=3, default='N/A'):
    """Formats a table cell as 'mean ± std' handling N/A."""
    # Ensure mean/std are floats for isnan check, handle potential non-numeric types gracefully
    try:
        mean_val_f = float(mean_val) if mean_val is not None else np.nan
    except (ValueError, TypeError):
        mean_val_f = np.nan
    try:
        std_val_f = float(std_val) if std_val is not None else np.nan
    except (ValueError, TypeError):
        std_val_f = np.nan

    # Format based on availability of valid mean and std
    if not np.isnan(mean_val_f) and not np.isnan(std_val_f):
        mean_str = f"{mean_val_f:.{precision}f}"
        std_str = f"{std_val_f:.{precision}f}"
        # Use NoEscape only if pylatex is available
        if PYLATEX_AVAILABLE:
             return NoEscape(f"{mean_str} $\\pm$ {std_str}")
        else:
             return f"{mean_str} ± {std_str}" # Fallback for non-latex output
    elif not np.isnan(mean_val_f):
         return f"{mean_val_f:.{precision}f}"
    else:
        return default


# --- Helper Function to Format Mean ± Std  ---
def _format_metric_cell(mean_val, std_val, precision=3, default='N/A'):
    """Formats a table cell as 'mean ± std' handling N/A."""
    if isinstance(mean_val, (int, float)) and not np.isnan(mean_val) and \
       isinstance(std_val, (int, float)) and not np.isnan(std_val):
        mean_str = f"{mean_val:.{precision}f}"
        std_str = f"{std_val:.{precision}f}"
        return NoEscape(f"{mean_str} $\\pm$ {std_str}")
    elif isinstance(mean_val, (int, float)) and not np.isnan(mean_val):
         return f"{mean_val:.{precision}f}"
    else:
        return default



# --- Placeholder Definitions (Customize these before running table generation) ---

# **IMPORTANT**: You need to define which conditions and metrics to include for EACH table
# Use the exact keys from your 'EXPERIMENTAL_CONDITIONS' dictionary defined earlier

# Example for Sex Comparison Table
CONDITIONS_TO_COMPARE_SEX = ['Optimized_Global', 'Optimized_Subgroup_Sex']
METRICS_TO_INCLUDE_SEX = ['f1_macro', 'recall_macro', 'precision_macro', 'exact_match_ratio']
OUTPUT_PDF_PATH_SEX = os.path.join(RESULTS_DIR, 'comparison_table_global_vs_sex.pdf')

# Example for Age Comparison Table
CONDITIONS_TO_COMPARE_AGE = ['Optimized_Global', 'Optimized_Subgroup_Age']
METRICS_TO_INCLUDE_AGE = ['f1_macro', 'recall_macro', 'precision_macro', 'exact_match_ratio']
OUTPUT_PDF_PATH_AGE = os.path.join(RESULTS_DIR, 'comparison_table_global_vs_age.pdf')

# Example for Device Comparison Table
CONDITIONS_TO_COMPARE_DEVICE = ['Optimized_Global', 'Optimized_Subgroup_Device']
METRICS_TO_INCLUDE_DEVICE = ['f1_macro', 'recall_macro', 'precision_macro'] # Fewer metrics might be better
OUTPUT_PDF_PATH_DEVICE = os.path.join(RESULTS_DIR, 'comparison_table_global_vs_device.pdf')
# Optional: Specify which devices to show if you have many
# DEVICES_TO_FILTER = ['AT-6 C 5.5', 'CS-12', 'CS-12   E'] # Example subset

# You might also want to compare Global vs Combined (Exp 1 vs Exp 2)
CONDITIONS_TO_COMPARE_COMBINED = ['Optimized_Global', 'Optimized_Subgroup_Combined']
# You could reuse METRICS_TO_INCLUDE_SEX or define specific ones
METRICS_TO_INCLUDE_COMBINED = ['f1_macro', 'recall_macro', 'precision_macro', 'exact_match_ratio', 'avg_macro_f1_male_female'] # Example including the avg M/F metric
OUTPUT_PDF_PATH_COMBINED = os.path.join(RESULTS_DIR, 'comparison_table_global_vs_combined.pdf')
# Note: You would likely use create_sex_comparison_table for this, as it shows M/F breakdown.

print(f"Results Directory: {RESULTS_DIR}")
print(f"Aggregated Stats Path: {AGGREGATED_STATS_PATH}")
print(f"PyLaTeX Available: {PYLATEX_AVAILABLE}")
print(f"IPython Available: {IPYTHON_AVAILABLE}")


pylatex library loaded successfully.
Results Directory: evaluation_results
Aggregated Stats Path: evaluation_results/aggregated_evaluation_stats.pkl
PyLaTeX Available: True
IPython Available: True


In [60]:

def create_sex_comparison_table(agg_stats, conditions, metrics, filename_base, metric_precision=3):
    """
    Creates a LaTeX document with a comparison table (in landscape)
    and generates a PDF.
    The 'Total - Overall Metrics' row displays metrics calculated for the 'all' group.
    """
    if not PYLATEX_AVAILABLE:
        print("Skipping PDF generation as pylatex is not available.")
        return

    # --- Input Validation ---
    if not conditions: print("Error: No conditions provided."); return
    if conditions[0] not in agg_stats: print(f"Error: Condition '{conditions[0]}' not found in agg_stats."); return
    if 'all' not in agg_stats[conditions[0]]: print(f"Error: 'all' group data not found for condition '{conditions[0]}'."); return
    if 'sex' not in agg_stats[conditions[0]]: print(f"Error: 'sex' group data not found for condition '{conditions[0]}'."); return


    num_conditions = len(conditions)
    num_metrics = len(metrics)

    # --- Document Setup ---
    geometry_options = {
        "tmargin": "0.75in", "bmargin": "0.75in",
        "lmargin": "1in",   "rmargin": "1in",
        "includeheadfoot": True,
        "landscape": True
    }
    doc = Document(geometry_options=geometry_options)
    doc.preamble.append(Command('usepackage', 'booktabs'))

    doc.append(NoEscape(r'\section*{Comparison of Metrics by Sex Subgroup}'))
    doc.append(NoEscape(r'\vspace{0.5em}'))

    # --- Table Setup ---
    col_format = 'r' * num_metrics # One data column per metric
    table_format = 'l' + ('|' + col_format) * num_conditions
    # Remove leading '|' if only one condition exists
    if num_conditions > 0 and table_format.startswith('l|'):
        table_format = 'l' + table_format[2:]

    # Use smaller font for potentially wide table
    doc.append(NoEscape(r'\footnotesize')) # or \scriptsize
    table = Tabular(table_format, booktabs=True)

    # --- Header Rows ---
    header1 = [""]
    for i, cond_name in enumerate(conditions):
        display_cond_name = cond_name.replace('_', ' ')
        align_str = 'c|' if i < num_conditions - 1 else 'c' # Separator line logic
        header1.append(MultiColumn(num_metrics, align=align_str, data=display_cond_name))
    table.add_row(header1)
    table.add_hline() # Add a horizontal line between header rows

    header2 = ["Subgroup (# Samples)"] # Updated label
    for _ in conditions:
        display_metric_names = [m.replace('_', ' ').title() for m in metrics] # Format metric names
        header2.extend(display_metric_names)
    table.add_row(header2)
    table.add_hline() # \midrule equivalent

    # --- Data Rows ---
    subgroups_to_show = [
        # sg_type, sg_name, display_label
        ('all', 'all', 'Total - Overall Metrics'), # Use 'all' sg_type to trigger correct fetching
        ('sex', '0', 'Female'),
        ('sex', '1', 'Male')
    ]

    for item_index, item in enumerate(subgroups_to_show):
        row_data = []
        sg_type = item[0]
        sg_name = item[1] # This will be 'all', '0', or '1'
        row_label_base = item[2]

        # Get sample count N from the specific group for the first condition
        # For the 'Total' row, this correctly fetches N from the 'all' group.
        n_samples_str = get_stat(agg_stats, conditions[0], sg_type, sg_name, 'num_samples', 'mean', precision=0)
        row_label = f"{row_label_base} " + (f"(N={n_samples_str})" if n_samples_str != 'N/A' else "(N=?)")
        row_data.append(row_label)

        # Loop through conditions and metrics to fill the row
        for cond_name in conditions:
            for metric in metrics:
                # --- MODIFIED LOGIC ---
                # Fetch metrics based on whether it's the 'Total' row or a specific subgroup
                # get_stat handles the sg_type='all' case internally now.
                mean_str = get_stat(agg_stats, cond_name, sg_type, sg_name, metric, 'mean', precision=metric_precision)
                std_str = get_stat(agg_stats, cond_name, sg_type, sg_name, metric, 'std', precision=metric_precision)
                # --- END MODIFIED LOGIC ---

                # Format the cell using the helper function
                mean_val = float(mean_str) if mean_str != 'N/A' else np.nan
                std_val = float(std_str) if std_str != 'N/A' else np.nan
                cell_value = _format_metric_cell(mean_val, std_val, precision=metric_precision)

                row_data.append(cell_value)

        table.add_row(row_data)
        # Optional: Add subtle lines between subgroup rows
        if item_index < len(subgroups_to_show) - 1:
             table.add_hline(start=2, end=len(row_data)) # Line only under data columns

    # --- Add Table to Document ---
    doc.append(table)
    doc.append(NoEscape(r'\normalsize')) # Revert font size
    doc.append(NoEscape(r'\vspace{1em}'))


    # --- Generate PDF ---
    try:
        # Ensure filename_base doesn't have extension for generate_pdf
        pdf_path_base = os.path.splitext(filename_base)[0]
        doc.generate_pdf(pdf_path_base, clean_tex=False) # Set clean_tex=True to remove .tex files after
        print(f"PDF generated successfully: {pdf_path_base}.pdf")
    except Exception as e:
        print(f"Error generating PDF '{filename_base}': {e}")
        print("Attempting to save the .tex file for inspection.")
        try:
            tex_path = pdf_path_base + ".tex"
            doc.generate_tex(pdf_path_base)
            print(f"LaTeX source saved to: {tex_path}")
        except Exception as tex_e: print(f"Could not save .tex file: {tex_e}")



In [61]:
# Use the exact names from your EXPERIMENTAL_CONDITIONS dictionary
CONDITIONS_TO_COMPARE_SEX = ['Optimized_Global_Threshold', 'Optimized_Subgroup_Sex']
CONDITIONS_TO_COMPARE_SEX = ['Optimized_Global', 'Optimized_Subgroup_Sex']

# Define which metrics you want columns for
METRICS_TO_INCLUDE_SEX = ['f1_macro', 'recall_macro', 'precision_macro', 'exact_match_ratio']

# Define path to your saved aggregated stats
# Make sure this path is correct and points to the .pkl file saved previously
AGGREGATED_STATS_PATH = os.path.join(RESULTS_DIR, 'aggregated_evaluation_stats.pkl')
OUTPUT_PDF_PATH_SEX = os.path.join(RESULTS_DIR, 'comparison_table_global_vs_sex.pdf')

# --- Main Execution Logic (Modified for Sex Comparison) ---
# 1. Load Aggregated Statistics
print(f"\nLoading aggregated statistics from: {AGGREGATED_STATS_PATH}")
try:
    with open(AGGREGATED_STATS_PATH, 'rb') as f:
        aggregated_stats = pickle.load(f)
    print("Aggregated stats loaded successfully.")
    # === ADD THIS LINE ===
    print("Keys found in aggregated_stats:", aggregated_stats.keys())
    # ====================
except FileNotFoundError: print(f"Error: Aggregated stats file not found: {AGGREGATED_STATS_PATH}."); exit()
except Exception as e: print(f"Error loading aggregated stats file: {e}"); exit()

# 2. Validate Conditions
valid_conditions = [cond for cond in CONDITIONS_TO_COMPARE_SEX if cond in aggregated_stats]
if not valid_conditions:
    print(f"Error: None of the conditions {CONDITIONS_TO_COMPARE_SEX} found in stats."); exit()
if len(valid_conditions) < len(CONDITIONS_TO_COMPARE_SEX):
    print(f"Warning: Only found conditions: {valid_conditions}")

# 3. Create and Generate Table PDF if pylatex is available
if PYLATEX_AVAILABLE: # Check if pylatex was imported successfully earlier
    print(f"\nGenerating Sex comparison table PDF ({OUTPUT_PDF_PATH_SEX})...")
    create_sex_comparison_table( # Use the existing function
        agg_stats=aggregated_stats,
        conditions=valid_conditions,
        metrics=METRICS_TO_INCLUDE_SEX,
        filename_base=OUTPUT_PDF_PATH_SEX,
        metric_precision=3 # Or your desired precision
    )
    # Optional: Display in Notebook
    # display(IFrame(f"{OUTPUT_PDF_PATH_SEX}", width=900, height=600))
else:
    print("\nSkipping PDF generation because pylatex is not installed.")

print("\nSex table generation process finished.")


Loading aggregated statistics from: evaluation_results/aggregated_evaluation_stats.pkl
Aggregated stats loaded successfully.
Keys found in aggregated_stats: dict_keys(['Optimized_Global', 'Optimized_Subgroup_Combined', 'Optimized_Subgroup_Sex', 'Optimized_Subgroup_Age', 'Optimized_Subgroup_Device'])

Generating Sex comparison table PDF (evaluation_results/comparison_table_global_vs_sex.pdf)...
Error generating PDF 'evaluation_results/comparison_table_global_vs_sex.pdf': No LaTex compiler was found
Either specify a LaTex compiler or make sure you have latexmk or pdfLaTex installed.
Attempting to save the .tex file for inspection.
LaTeX source saved to: evaluation_results/comparison_table_global_vs_sex.tex

Sex table generation process finished.


In [62]:
def create_age_comparison_table(agg_stats, conditions, metrics, filename_base, metric_precision=3):
    """
    Creates a LaTeX document comparing metrics by Age Bin subgroups.
    """
    if not PYLATEX_AVAILABLE:
        print("Skipping PDF generation as pylatex is not available.")
        return

    # --- Input Validation ---
    if not conditions: print("Error: No conditions provided."); return
    if conditions[0] not in agg_stats: print(f"Error: Condition '{conditions[0]}' not found in agg_stats."); return
    if 'all' not in agg_stats[conditions[0]]: print(f"Error: 'all' group data not found for condition '{conditions[0]}'."); return
    if 'age_bin' not in agg_stats[conditions[0]]: print(f"Error: 'age_bin' group data not found for condition '{conditions[0]}'."); return

    num_conditions = len(conditions)
    num_metrics = len(metrics)

    # --- Document Setup ---
    geometry_options = {
        "tmargin": "0.75in", "bmargin": "0.75in",
        "lmargin": "1in",   "rmargin": "1in",
        "includeheadfoot": True,
        "landscape": True
    }
    doc = Document(geometry_options=geometry_options)
    doc.preamble.append(Command('usepackage', 'booktabs'))

    doc.append(NoEscape(r'\section*{Comparison of Metrics by Age Bin Subgroup}')) # <-- Changed Title
    doc.append(NoEscape(r'\vspace{0.5em}'))

    # --- Table Setup (same as before) ---
    col_format = 'r' * num_metrics
    table_format = 'l' + ('|' + col_format) * num_conditions
    if num_conditions > 0 and table_format.startswith('l|'):
        table_format = 'l' + table_format[2:]
    doc.append(NoEscape(r'\footnotesize'))
    table = Tabular(table_format, booktabs=True)

    # --- Header Rows (same as before) ---
    header1 = [""]
    for i, cond_name in enumerate(conditions):
        display_cond_name = cond_name.replace('_', ' ')
        align_str = 'c|' if i < num_conditions - 1 else 'c'
        header1.append(MultiColumn(num_metrics, align=align_str, data=display_cond_name))
    table.add_row(header1)
    table.add_hline()
    header2 = ["Subgroup (# Samples)"]
    for _ in conditions:
        display_metric_names = [m.replace('_', ' ').title() for m in metrics]
        header2.extend(display_metric_names)
    table.add_row(header2)
    table.add_hline()

    # --- Data Rows ---
    # Define which subgroups to show (Total + all found age bins)
    subgroups_to_show = [('all', 'all', 'Total - Overall Metrics')] # Start with Total row
    try:
        # Get age bins dynamically from the first condition's stats
        age_bins = sorted(agg_stats[conditions[0]]['age_bin'].keys())
        # Add tuples for each age bin
        subgroups_to_show.extend([('age_bin', bin_name, f"Age Bin {bin_name}") for bin_name in age_bins])
    except KeyError:
        print(f"Warning: Could not dynamically determine age bins for condition '{conditions[0]}'. Age rows might be missing.")

    for item_index, item in enumerate(subgroups_to_show):
        row_data = []
        sg_type = item[0]
        sg_name = item[1]
        row_label_base = item[2]

        # Get sample count N (same logic using get_stat)
        n_samples_str = get_stat(agg_stats, conditions[0], sg_type, sg_name, 'num_samples', 'mean', precision=0)
        row_label = f"{row_label_base} " + (f"(N={n_samples_str})" if n_samples_str != 'N/A' else "(N=?)")
        row_data.append(row_label)

        # Fill row data (same logic using get_stat)
        for cond_name in conditions:
            for metric in metrics:
                mean_str = get_stat(agg_stats, cond_name, sg_type, sg_name, metric, 'mean', precision=metric_precision)
                std_str = get_stat(agg_stats, cond_name, sg_type, sg_name, metric, 'std', precision=metric_precision)
                mean_val = float(mean_str) if mean_str != 'N/A' else np.nan
                std_val = float(std_str) if std_str != 'N/A' else np.nan
                cell_value = _format_metric_cell(mean_val, std_val, precision=metric_precision)
                row_data.append(cell_value)

        table.add_row(row_data)
        if item_index < len(subgroups_to_show) - 1:
             table.add_hline(start=2, end=len(row_data))

    # --- Add Table and Generate PDF (same as before) ---
    doc.append(table)
    doc.append(NoEscape(r'\normalsize'))
    doc.append(NoEscape(r'\vspace{1em}'))
    try:
        pdf_path_base = os.path.splitext(filename_base)[0]
        doc.generate_pdf(pdf_path_base, clean_tex=False)
        print(f"PDF generated successfully: {pdf_path_base}.pdf")
    except Exception as e:
        print(f"Error generating PDF '{filename_base}': {e}")
        print("Attempting to save the .tex file for inspection.")
        try:
            tex_path = pdf_path_base + ".tex"
            doc.generate_tex(pdf_path_base)
            print(f"LaTeX source saved to: {tex_path}")
        except Exception as tex_e: print(f"Could not save .tex file: {tex_e}")

# --- Call the Age Comparison Function ---
CONDITIONS_TO_COMPARE_AGE = ['Optimized_Global', 'Optimized_Subgroup_Age']
METRICS_TO_INCLUDE_AGE = ['f1_macro', 'recall_macro', 'precision_macro', 'exact_match_ratio']
OUTPUT_PDF_PATH_AGE = os.path.join(RESULTS_DIR, 'comparison_table_global_vs_age.pdf')

if PYLATEX_AVAILABLE:
    print(f"\nGenerating Age comparison table PDF ({OUTPUT_PDF_PATH_AGE})...")
    # Reload stats if necessary, or ensure 'aggregated_stats' is still available
    # Load Aggregated Statistics (if not already loaded)
    if 'aggregated_stats' not in locals():
         print(f"Reloading aggregated statistics from: {AGGREGATED_STATS_PATH}")
         # Add loading logic here again if needed
         # ...

    valid_conditions_age = [cond for cond in CONDITIONS_TO_COMPARE_AGE if cond in aggregated_stats]
    if valid_conditions_age:
        create_age_comparison_table( # Call the NEW function
            agg_stats=aggregated_stats,
            conditions=valid_conditions_age,
            metrics=METRICS_TO_INCLUDE_AGE,
            filename_base=OUTPUT_PDF_PATH_AGE,
            metric_precision=3
        )
        # Optional: display(IFrame(f"{OUTPUT_PDF_PATH_AGE}", width=900, height=600))
    else:
         print(f"Error: None of the conditions {CONDITIONS_TO_COMPARE_AGE} found in stats.")

else:
    print("\nSkipping Age PDF generation because pylatex is not installed.")

print("\nAge table generation process finished.")


Generating Age comparison table PDF (evaluation_results/comparison_table_global_vs_age.pdf)...
Error generating PDF 'evaluation_results/comparison_table_global_vs_age.pdf': No LaTex compiler was found
Either specify a LaTex compiler or make sure you have latexmk or pdfLaTex installed.
Attempting to save the .tex file for inspection.
LaTeX source saved to: evaluation_results/comparison_table_global_vs_age.tex

Age table generation process finished.


In [63]:
def create_device_comparison_table(agg_stats, conditions, metrics, filename_base, metric_precision=3, devices_to_include=None):
    """
    Creates a LaTeX document comparing metrics by Device subgroups.
    Optionally filters devices to include.
    """
    if not PYLATEX_AVAILABLE:
        print("Skipping PDF generation as pylatex is not available.")
        return

    # --- Input Validation ---
    if not conditions: print("Error: No conditions provided."); return
    if conditions[0] not in agg_stats: print(f"Error: Condition '{conditions[0]}' not found in agg_stats."); return
    if 'all' not in agg_stats[conditions[0]]: print(f"Error: 'all' group data not found for condition '{conditions[0]}'."); return
    if 'device' not in agg_stats[conditions[0]]: print(f"Error: 'device' group data not found for condition '{conditions[0]}'."); return

    num_conditions = len(conditions)
    num_metrics = len(metrics)

    # --- Document Setup (same)---
    geometry_options = {
        "tmargin": "0.75in", "bmargin": "0.75in",
        "lmargin": "0.75in",   "rmargin": "0.75in", # Reduced side margins for potentially wide table
        "includeheadfoot": True,
        "landscape": True
    }
    doc = Document(geometry_options=geometry_options)
    doc.preamble.append(Command('usepackage', 'booktabs'))
    # Consider rotating long device names if needed (requires rotating package)
    # doc.preamble.append(Command('usepackage', 'rotating'))

    doc.append(NoEscape(r'\section*{Comparison of Metrics by Device Subgroup}')) # <-- Changed Title
    doc.append(NoEscape(r'\vspace{0.5em}'))

    # --- Table Setup (same) ---
    col_format = 'r' * num_metrics
    table_format = 'l' + ('|' + col_format) * num_conditions
    if num_conditions > 0 and table_format.startswith('l|'):
        table_format = 'l' + table_format[2:]
    doc.append(NoEscape(r'\tiny')) # Use even smaller font if many devices/metrics
    table = Tabular(table_format, booktabs=True)

    # --- Header Rows (same) ---
    header1 = [""]
    for i, cond_name in enumerate(conditions):
        display_cond_name = cond_name.replace('_', ' ')
        align_str = 'c|' if i < num_conditions - 1 else 'c'
        header1.append(MultiColumn(num_metrics, align=align_str, data=display_cond_name))
    table.add_row(header1)
    table.add_hline()
    header2 = ["Subgroup (# Samples)"]
    for _ in conditions:
        display_metric_names = [m.replace('_', ' ').title() for m in metrics]
        header2.extend(display_metric_names)
    table.add_row(header2)
    table.add_hline()

    # --- Data Rows ---
    subgroups_to_show = [('all', 'all', 'Total - Overall Metrics')]
    try:
        all_devices = sorted(agg_stats[conditions[0]]['device'].keys())
        if devices_to_include: # Filter devices if a list is provided
             devices_to_iterate = [d for d in all_devices if d in devices_to_include]
             print(f"Including devices: {devices_to_iterate}")
        else:
             devices_to_iterate = all_devices # Include all devices
             print(f"Including all {len(devices_to_iterate)} devices.")

        subgroups_to_show.extend([('device', dev_name, f"{dev_name}") for dev_name in devices_to_iterate]) # Use device name as label
    except KeyError:
        print(f"Warning: Could not dynamically determine devices for condition '{conditions[0]}'. Device rows might be missing.")

    for item_index, item in enumerate(subgroups_to_show):
        row_data = []
        sg_type = item[0]
        sg_name = item[1]
        row_label_base = item[2]

        n_samples_str = get_stat(agg_stats, conditions[0], sg_type, sg_name, 'num_samples', 'mean', precision=0)
        row_label = f"{row_label_base} " + (f"(N={n_samples_str})" if n_samples_str != 'N/A' else "(N=?)")
        # Could use NoEscape + \rotatebox{90}{...} for long device names if needed
        row_data.append(row_label)

        for cond_name in conditions:
            for metric in metrics:
                mean_str = get_stat(agg_stats, cond_name, sg_type, sg_name, metric, 'mean', precision=metric_precision)
                std_str = get_stat(agg_stats, cond_name, sg_type, sg_name, metric, 'std', precision=metric_precision)
                mean_val = float(mean_str) if mean_str != 'N/A' else np.nan
                std_val = float(std_str) if std_str != 'N/A' else np.nan
                cell_value = _format_metric_cell(mean_val, std_val, precision=metric_precision)
                row_data.append(cell_value)

        table.add_row(row_data)
        if item_index < len(subgroups_to_show) - 1:
             table.add_hline(start=2, end=len(row_data))

    # --- Add Table and Generate PDF (same) ---
    doc.append(table)
    doc.append(NoEscape(r'\normalsize'))
    doc.append(NoEscape(r'\vspace{1em}'))
    try:
        pdf_path_base = os.path.splitext(filename_base)[0]
        doc.generate_pdf(pdf_path_base, clean_tex=False)
        print(f"PDF generated successfully: {pdf_path_base}.pdf")
    except Exception as e:
        print(f"Error generating PDF '{filename_base}': {e}")
        print("Attempting to save the .tex file for inspection.")
        try:
            tex_path = pdf_path_base + ".tex"
            doc.generate_tex(pdf_path_base)
            print(f"LaTeX source saved to: {tex_path}")
        except Exception as tex_e: print(f"Could not save .tex file: {tex_e}")


# --- Call the Device Comparison Function ---
CONDITIONS_TO_COMPARE_DEVICE = ['Optimized_Global', 'Optimized_Subgroup_Device']
METRICS_TO_INCLUDE_DEVICE = ['f1_macro', 'recall_macro', 'precision_macro'] # Maybe fewer metrics if table is wide
OUTPUT_PDF_PATH_DEVICE = os.path.join(RESULTS_DIR, 'comparison_table_global_vs_device.pdf')
# Optional: Select only specific devices if the table becomes too large
# DEVICES_TO_FILTER = ['AT-6 C 5.5', 'CS-12', 'CS-12   E'] # Example filter

if PYLATEX_AVAILABLE:
    print(f"\nGenerating Device comparison table PDF ({OUTPUT_PDF_PATH_DEVICE})...")
    # Reload stats if necessary
    if 'aggregated_stats' not in locals():
         print(f"Reloading aggregated statistics from: {AGGREGATED_STATS_PATH}")
         # Add loading logic here again if needed
         # ...

    valid_conditions_device = [cond for cond in CONDITIONS_TO_COMPARE_DEVICE if cond in aggregated_stats]
    if valid_conditions_device:
        create_device_comparison_table( # Call the NEW function
            agg_stats=aggregated_stats,
            conditions=valid_conditions_device,
            metrics=METRICS_TO_INCLUDE_DEVICE,
            filename_base=OUTPUT_PDF_PATH_DEVICE,
            metric_precision=3,
            # devices_to_include=DEVICES_TO_FILTER # Uncomment to filter devices
        )
        # display(IFrame(f"{OUTPUT_PDF_PATH_DEVICE}", width=900, height=600))
    else:
         print(f"Error: None of the conditions {CONDITIONS_TO_COMPARE_DEVICE} found in stats.")

else:
    print("\nSkipping Device PDF generation because pylatex is not installed.")

print("\nDevice table generation process finished.")


Generating Device comparison table PDF (evaluation_results/comparison_table_global_vs_device.pdf)...
Including all 11 devices.
Error generating PDF 'evaluation_results/comparison_table_global_vs_device.pdf': No LaTex compiler was found
Either specify a LaTex compiler or make sure you have latexmk or pdfLaTex installed.
Attempting to save the .tex file for inspection.
LaTeX source saved to: evaluation_results/comparison_table_global_vs_device.tex

Device table generation process finished.


In [64]:
def create_metric_class_summary_table(agg_stats, conditions, filename_base, metric_precision=3):
    """
    Creates a LaTeX table showing key overall metrics and then per-class
    F1, Precision, Recall across different conditions.
    """
    if not PYLATEX_AVAILABLE: print("Skipping PDF: pylatex not available."); return

    # --- Define Metrics to Display ---
    overall_metrics_to_show = {
        'f1_macro': 'Total - F1 Macro',
        'precision_macro': 'Total - Precision Macro',
        'recall_macro': 'Total - Recall Macro',
        'exact_match_ratio': 'Total - Exact Match'
    }
    # Map display name to the per-class key in agg_stats
    per_class_metrics_to_show = {
        'F1': 'f1_per_class',
        'Precision': 'precision_per_class',
        'Recall': 'recall_per_class'
    }

    # --- Input Validation and Class Name Extraction ---
    if not conditions: print("Error: No conditions provided."); return
    first_cond = conditions[0]
    if first_cond not in agg_stats: print(f"Error: Cond '{first_cond}' not found."); return
    if 'all' not in agg_stats[first_cond]: print(f"Error: 'all' data missing for {first_cond}."); return

    # Find the first valid per-class metric key to extract class names
    first_per_class_key = None
    for pc_key in per_class_metrics_to_show.values():
        if pc_key in agg_stats[first_cond]['all']:
            first_per_class_key = pc_key
            break
    if not first_per_class_key:
         print(f"Error: None of the per-class keys found in 'all' data: {list(per_class_metrics_to_show.values())}"); return

    try:
        class_names = sorted(list(agg_stats[first_cond]['all'][first_per_class_key].keys()))
        if not class_names: print("Warning: No class names found."); return
    except Exception as e: print(f"Error extracting class names: {e}"); return

    num_conditions = len(conditions)
    num_data_cols_per_cond = 1 # Each condition is one column of 'mean ± std'

    # --- Document Setup (Landscape) ---
    geometry_options = {"tmargin": "0.75in", "bmargin": "0.75in", "lmargin": "1in", "rmargin": "1in", "includeheadfoot": True, "landscape": True}
    doc = Document(geometry_options=geometry_options)
    doc.preamble.append(Command('usepackage', 'booktabs'))
    doc.append(NoEscape(r'\section*{Metric Summary by Condition and Class}')) # Updated Title
    doc.append(NoEscape(r'\vspace{0.5em}'))

    # --- Table Setup ---
    col_format = 'r' * num_data_cols_per_cond
    table_format = 'l' + ('|' + col_format) * num_conditions
    if num_conditions > 0 and table_format.startswith('l|'): table_format = 'l' + table_format[2:]

    doc.append(NoEscape(r'\footnotesize'))
    table = Tabular(table_format, booktabs=True)

    # --- Header Row ---
    header1 = ["Metric / Class"]
    for i, cond_name in enumerate(conditions):
        display_cond_name = cond_name.replace('_', ' ')
        # Use align='c' - let table_format handle separators
        header1.append(MultiColumn(num_data_cols_per_cond, align='c', data=display_cond_name))
    table.add_row(header1)
    table.add_hline() # \midrule

    # --- Overall Metric Rows ---
    print("Fetching Overall Metrics...")
    for metric_key, display_label in overall_metrics_to_show.items():
        row_data = [display_label]
        for cond_name in conditions:
            # Use get_stat for overall metrics (sg_type='all')
            mean_str = get_stat(agg_stats, cond_name, 'all', 'all', metric_key, 'mean', precision=metric_precision)
            std_str = get_stat(agg_stats, cond_name, 'all', 'all', metric_key, 'std', precision=metric_precision)
            mean_val = float(mean_str) if mean_str != 'N/A' else np.nan
            std_val = float(std_str) if std_str != 'N/A' else np.nan
            cell_value = _format_metric_cell(mean_val, std_val, precision=metric_precision)
            row_data.append(cell_value)
        table.add_row(row_data)

    # --- Separator Line ---
    table.add_hline(start=1, end=1+num_conditions*num_data_cols_per_cond) # Full width line

    # --- Per-Class Metric Rows ---
    print("Fetching Per-Class Metrics...")
    for class_name in class_names:
        print(f"  Class: {class_name}")
        for metric_display_name, per_class_key in per_class_metrics_to_show.items():
            row_label = f"{class_name} - {metric_display_name}"
            row_data = [row_label]

            for cond_name in conditions:
                mean_val, std_val = np.nan, np.nan # Default
                try:
                    # Direct access for per-class data
                    class_metric_dict = agg_stats[cond_name]['all'][per_class_key][class_name]
                    mean_val = class_metric_dict.get('mean', np.nan)
                    std_val = class_metric_dict.get('std', np.nan)
                except (KeyError, TypeError):
                     print(f"Warn: Data missing for {cond_name}/all/{per_class_key}/{class_name}")

                cell_value = _format_metric_cell(mean_val, std_val, precision=metric_precision)
                row_data.append(cell_value)
            table.add_row(row_data)

        # Add subtle line between classes if desired
        if class_name != class_names[-1]: # Don't add after last class
             table.add_hline(start=2, end=1+num_conditions*num_data_cols_per_cond) # Line only under data columns

    # --- Add Table to Document ---
    doc.append(table)
    doc.append(NoEscape(r'\normalsize'))
    doc.append(NoEscape(r'\vspace{1em}'))

    # --- Generate PDF ---
    try:
        pdf_path_base = os.path.splitext(filename_base)[0]
        doc.generate_pdf(pdf_path_base, clean_tex=False)
        print(f"PDF generated successfully: {pdf_path_base}.pdf")
    except Exception as e:
        print(f"Error generating PDF: {e}")
        print("Attempting to save the .tex file for inspection.")
        try:
            tex_path = pdf_path_base + ".tex"; doc.generate_tex(pdf_path_base)
            print(f"LaTeX source saved to: {tex_path}")
        except Exception as tex_e: print(f"Could not save .tex file: {tex_e}")


# Define paths, condition (singular), and metrics
RESULTS_DIR = 'evaluation_results'
AGGREGATED_STATS_FILENAME = 'aggregated_evaluation_stats.pkl'
# Give a specific, descriptive name for this output PDF
SUMMARY_TABLE_PDF_FILENAME = 'overall_metrics_table_global_thresh.pdf'

AGGREGATED_STATS_PATH = os.path.join(RESULTS_DIR, AGGREGATED_STATS_FILENAME)
SUMMARY_OUTPUT_PDF_PATH = os.path.join(RESULTS_DIR, SUMMARY_TABLE_PDF_FILENAME)

# === KEY CHANGE: Specify ONLY the global threshold condition ===
# Make sure 'Optimized_Global' is the exact key used in your aggregated_stats dictionary
CONDITIONS_TO_COMPARE = ['Optimized_Global']
# ===============================================================

METRIC_PRECISION = 3 # Decimal places for formatting

# 1. Load Aggregated Statistics (ensure this happens before this block)
print(f"\nLoading aggregated statistics from: {AGGREGATED_STATS_PATH}")
if not os.path.exists(AGGREGATED_STATS_PATH):
    print(f"Error: Aggregated stats file not found: {AGGREGATED_STATS_PATH}.")
    # Decide how to handle: exit() or skip table generation
    exit() # Example: exit if file not found
try:
    with open(AGGREGATED_STATS_PATH, 'rb') as f:
        aggregated_stats = pickle.load(f)
    print("Aggregated stats loaded successfully.")
    # Optional but recommended: Verify the specific condition exists
    if CONDITIONS_TO_COMPARE[0] not in aggregated_stats:
         print(f"Error: Condition '{CONDITIONS_TO_COMPARE[0]}' not found in loaded stats.")
         print(f"Available keys: {list(aggregated_stats.keys())}")
         exit() # Exit if the required condition is missing

except Exception as e:
    print(f"Error loading aggregated stats file: {e}")
    exit()

# 2. Validate Conditions (will just contain the single condition now)
valid_conditions = [cond for cond in CONDITIONS_TO_COMPARE if cond in aggregated_stats]
if not valid_conditions:
    print("Error: The specified condition was not found in the loaded stats.")
    exit() # Exit if somehow the condition is invalid despite check above

# 3. Create and Generate Metric x Class Summary Table PDF
if PYLATEX_AVAILABLE: # Assumes PYLATEX_AVAILABLE was set earlier
    print(f"\nGenerating Overall Metrics Table for '{CONDITIONS_TO_COMPARE[0]}' (landscape)...")
    create_metric_class_summary_table( # Call the existing function
        agg_stats=aggregated_stats,
        conditions=valid_conditions, # Pass the list containing only 'Optimized_Global'
        filename_base=SUMMARY_OUTPUT_PDF_PATH, # Use the specific output path
        metric_precision=METRIC_PRECISION
    )
    # Optional: Display in Jupyter
    try:
        # Ensure these were imported if using notebooks
        from IPython.display import display, IFrame
        if IPYTHON_AVAILABLE: # Check flag if you defined it
            display(IFrame(f"{SUMMARY_OUTPUT_PDF_PATH}", width=900, height=600))
    except ImportError:
        pass # Ignore if IPython not available
else:
    print("\nSkipping PDF generation because pylatex is not installed.")

print(f"\nOverall Metrics Table generation process finished ({SUMMARY_TABLE_PDF_FILENAME}).")


Loading aggregated statistics from: evaluation_results/aggregated_evaluation_stats.pkl
Aggregated stats loaded successfully.

Generating Overall Metrics Table for 'Optimized_Global' (landscape)...
Fetching Overall Metrics...
Fetching Per-Class Metrics...
  Class: CD
  Class: HYP
  Class: MI
  Class: NORM
  Class: STTC
Error generating PDF: No LaTex compiler was found
Either specify a LaTex compiler or make sure you have latexmk or pdfLaTex installed.
Attempting to save the .tex file for inspection.
LaTeX source saved to: evaluation_results/overall_metrics_table_global_thresh.tex



Overall Metrics Table generation process finished (overall_metrics_table_global_thresh.pdf).
