# **Ablation Studies for BWAF Model Validation**

**Objective:**
This notebook implements and evaluates several ablated versions of our BWAF model to systematically dissect the sources of its performance and validate our architectural choices. The goal is to quantify the contribution of each data modality (Sequence, Network, Priors) and the novel BWAF fusion mechanism itself.

**All models will be trained and evaluated on the identical data splits** derived from the main experimental pipeline to ensure a fair and direct comparison.

**Ablation Models to be Tested:**
1.  **Ablation 1 (A1): Transformer + Priors Only** (No GAT branch)
2.  **Ablation 2 (A2): GAT + Priors Only** (No Transformer branch)
3.  **Ablation 3 (A3): Simple Concatenation Fusion** (Transformer + GAT + Priors, no BWAF weighting)
4.  **Ablation 4 (A4): Transformer Only** (No GAT or Prior data)
5.  **Ablation 5 (A5): Priors Only** (Simple MLP on motif counts)

---
## **0. Setup and Imports**
This cell imports all required libraries.

In [1]:
# %% 0. Setup and Imports
# ============================================================================
# Standard library imports
import os
import re
import glob
import gzip
import time
import argparse
import datetime
import sys
import warnings
import random
import traceback

# Third-party imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import train_test_split
from sklearn.metrics import (accuracy_score, precision_score, recall_score,
                           f1_score, roc_auc_score, confusion_matrix)
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# PyTorch Geometric imports (needed for GAT-related ablations)
try:
    from torch_geometric.nn import GATConv
    from torch_geometric.utils import from_scipy_sparse_matrix
except ImportError:
    print("PyTorch Geometric might not be fully installed. GAT-based ablations will fail.")

# Suppress specific warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)

print("Imports successful.")
print(f"PyTorch Version: {torch.__version__}")

Imports successful.
PyTorch Version: 2.7.0+cu126


## **1. Configuration / Constants**
Defines paths and hyperparameters. These settings are kept consistent with the main BWAF experiment for fair comparison.

In [2]:
# %% 1. Configuration / Constants
# ============================================================================
# --- Data Paths (Must be the same as the main BWAF experiment) ---
BASE_DATA_DIR = './data/'
SEQ_DATA_DIR = os.path.join(BASE_DATA_DIR, 'raw/human_genome_annotation')
PRIOR_DATA_DIR = os.path.join(BASE_DATA_DIR, 'raw/human_genome_annotation')
GAT_PREPROCESSED_DIR = os.path.join(BASE_DATA_DIR, 'preprocessed/gat_normalized') # Pre-aligned and normalized GAT data

PROMOTER_SEQ_FILE = os.path.join(SEQ_DATA_DIR, 'updated_promoter_features_clean.csv')
NON_PROMOTER_SEQ_FILE = os.path.join(SEQ_DATA_DIR, 'updated_non_promoter_sequences.csv')
PRIOR_FILE = os.path.join(PRIOR_DATA_DIR, 'biological_prior_for_transformer_branch.csv')

# --- Model Hyperparameters (Consistent with BWAF model) ---
SEQ_LEN = 2000
PAD_IDX = 4
VOCAB_SIZE = 5
EMBEDDING_DIM = 64
NUM_ATTN_HEADS = 4
NUM_TRANSFORMER_LAYERS = 2
TRANSFORMER_FF_DIM = EMBEDDING_DIM * 4
GAT_OUTPUT_DIM = 64
FUSION_HIDDEN_DIM = 128
DROPOUT_RATE = 0.3

# --- Training Hyperparameters ---
LEARNING_RATE = 0.0005
BATCH_SIZE = 16 
NUM_EPOCHS = 10 # Same number of epochs as main run for comparison
VALIDATION_SPLIT = 0.15
TEST_SPLIT = 0.15
RANDOM_SEED = 42
OPTIMIZER_WEIGHT_DECAY = 1e-5

# --- Output Files ---
OUTPUT_DIR_ABLATION = 'results_ablation_studies/'
os.makedirs(OUTPUT_DIR_ABLATION, exist_ok=True)
LOG_FILE_ABLATION = os.path.join(OUTPUT_DIR_ABLATION, 'ablation_log.txt')

# --- Hardware ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Setup ---
random.seed(RANDOM_SEED); np.random.seed(RANDOM_SEED); torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(RANDOM_SEED)

print(f"Ablation/Baseline Runs: Using device: {DEVICE}")
print(f"Ablation/Baseline Output directory: {OUTPUT_DIR_ABLATION}")

Ablation/Baseline Runs: Using device: cpu
Ablation/Baseline Output directory: results_ablation_studies/


## **2. Utility Functions & Data Loading**
Reusing utility functions for logging, encoding, and data loading from the main script.


In [3]:
# %% 2. Utility Functions & Data Loading
# ============================================================================
def log_message(message, log_file=LOG_FILE_ABLATION):
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    full_message = f"[{timestamp}] {message}"
    print(full_message)
    log_dir = os.path.dirname(log_file);
    if log_dir and not os.path.exists(log_dir): os.makedirs(log_dir)
    with open(log_file, 'a', encoding='utf-8') as f: f.write(full_message + '\n')

with open(LOG_FILE_ABLATION, 'w', encoding='utf-8') as f: f.write(f"--- Ablation Log Initialized ---\n")
log_message(f"Log file: {LOG_FILE_ABLATION}")

def integer_encode_sequence(sequence, max_len=SEQ_LEN):
    encoding_map={'A':0,'T':1,'C':2,'G':3}; encoded=np.full(max_len,PAD_IDX,dtype=np.int64)
    for i,nuc in enumerate(sequence[:max_len]): encoded[i]=encoding_map.get(nuc.upper(), PAD_IDX)
    return encoded
def log_transform_priors(counts): return np.log1p(np.array(counts,dtype=np.float32))
def extract_clean_gene_id(series): return series.astype(str).str.extract(r'(ENSG\d+)', expand=False).fillna('UNKNOWN')

def load_sequences(file_path, is_promoter=True):
    log_message(f"Loading sequences from {file_path}...")
    df = pd.read_csv(file_path); seq_col='promoter_sequence' if 'promoter_sequence' in df.columns else 'sequence'
    df = df[df[seq_col].str.len() == SEQ_LEN]; df = df[~df[seq_col].str.contains('N', na=False, case=False)]
    return df[seq_col].tolist(), extract_clean_gene_id(df['gene_id']).tolist(), [1 if is_promoter else 0]*len(df)

def load_priors(file_path, gene_id_order):
    log_message(f"Loading priors from {file_path}...")
    df=pd.read_csv(file_path); df['clean_gene_id']=extract_clean_gene_id(df['gene_id'])
    count_cols=[c for c in df.columns if '(Count)' in c]; prior_dim=len(count_cols)
    df_priors = df[['clean_gene_id']+count_cols].copy(); df_priors.set_index('clean_gene_id', inplace=True)
    df_priors=df_priors[~df_priors.index.duplicated(keep='first')]; aligned=df_priors.reindex(gene_id_order, fill_value=0)
    return log_transform_priors(aligned.values), prior_dim

class MainDataset(Dataset):
    def __init__(self, sequences, gene_ids_for_samples, labels, biological_priors, precomputed_gat_output, gat_gene_order):
        self.sequences = torch.tensor(sequences, dtype=torch.int64)
        self.labels = torch.tensor(labels, dtype=torch.float32).unsqueeze(1)
        self.priors = torch.tensor(biological_priors, dtype=torch.float32)
        self.precomputed_gat_output = torch.tensor(precomputed_gat_output, dtype=torch.float32)
        gat_gene_to_idx = {gene: idx for idx, gene in enumerate(gat_gene_order)}
        self.sample_id_to_gat_idx = np.array([gat_gene_to_idx.get(gid, -1) for gid in gene_ids_for_samples], dtype=np.int64)
    def __len__(self): return len(self.labels)
    def __getitem__(self, idx):
        gat_idx = self.sample_id_to_gat_idx[idx]
        graph_features = self.precomputed_gat_output[gat_idx] if gat_idx != -1 else torch.zeros(self.precomputed_gat_output.shape[1])
        return {'sequence': self.sequences[idx], 'priors': self.priors[idx], 'graph_features': graph_features, 'label': self.labels[idx]}

[2025-07-09 18:58:11] Log file: results_ablation_studies/ablation_log.txt


## **3. Model Definitions for Ablation Studies**
Here we define the architectures for our comparative experiments. We reuse the `TransformerBranch` from the main project.

In [4]:
# %% 3. Model Definitions for Ablation Studies
# ============================================================================
class TransformerBranch(nn.Module):
    def __init__(self, vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN, embed_dim=EMBEDDING_DIM,
                 num_heads=NUM_ATTN_HEADS, ff_dim=TRANSFORMER_FF_DIM,
                 num_layers=NUM_TRANSFORMER_LAYERS, dropout=DROPOUT_RATE):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=PAD_IDX)
        self.positional_encoding = nn.Parameter(torch.randn(1, seq_len, embed_dim))
        self.embed_dropout = nn.Dropout(p=dropout)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim,
            dropout=dropout, activation='relu', batch_first=True, norm_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fusion_prep_layer = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, embed_dim), nn.ReLU(), nn.Dropout(p=dropout))
    def forward(self, seq_data):
        x = self.embedding(seq_data) + self.positional_encoding[:, :seq_data.size(1), :]
        x = self.embed_dropout(x); padding_mask = (seq_data == PAD_IDX)
        transformer_output = self.transformer_encoder(x, src_key_padding_mask=padding_mask)
        mask = (~padding_mask).unsqueeze(-1).float()
        aggregated_output = (transformer_output * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1.0)
        return self.fusion_prep_layer(aggregated_output)

# --- Ablation 1 (A1): Transformer + Priors Only ---
class TransformerPriorsOnly(nn.Module):
    def __init__(self, prior_dim, embed_dim, fusion_hidden_dim, dropout):
        super().__init__(); self.transformer = TransformerBranch(embed_dim=embed_dim, dropout=dropout)
        combined_dim = embed_dim + prior_dim
        self.classifier=nn.Sequential(nn.LayerNorm(combined_dim), nn.Linear(combined_dim, fusion_hidden_dim),
                                      nn.ReLU(), nn.Dropout(dropout), nn.Linear(fusion_hidden_dim, 1))
    def forward(self, batch):
        seq_features = self.transformer(batch['sequence'])
        combined = torch.cat([seq_features, batch['priors']], dim=1)
        return self.classifier(combined)

# --- Ablation 2 (A2): GAT + Priors Only ---
class GATPriorsOnly(nn.Module):
    def __init__(self, prior_dim, gat_output_dim, fusion_hidden_dim, dropout):
        super().__init__()
        combined_dim = gat_output_dim + prior_dim
        self.classifier=nn.Sequential(nn.LayerNorm(combined_dim), nn.Linear(combined_dim, fusion_hidden_dim),
                                      nn.ReLU(), nn.Dropout(dropout), nn.Linear(fusion_hidden_dim, 1))
    def forward(self, batch):
        combined = torch.cat([batch['graph_features'], batch['priors']], dim=1)
        return self.classifier(combined)

# --- Ablation 3 (A3): Simple Concatenation Fusion ---
class SimpleConcatModel(nn.Module):
    def __init__(self, prior_dim, embed_dim, gat_output_dim, fusion_hidden_dim, dropout):
        super().__init__(); self.transformer = TransformerBranch(embed_dim=embed_dim, dropout=dropout)
        combined_dim = embed_dim + gat_output_dim + prior_dim
        self.classifier=nn.Sequential(nn.LayerNorm(combined_dim), nn.Linear(combined_dim, fusion_hidden_dim),
                                      nn.ReLU(), nn.Dropout(dropout), nn.Linear(fusion_hidden_dim, 1))
    def forward(self, batch):
        seq_features = self.transformer(batch['sequence'])
        combined = torch.cat([seq_features, batch['graph_features'], batch['priors']], dim=1)
        return self.classifier(combined)

# --- Ablation 4 (A4): Transformer Only ---
class TransformerOnly(nn.Module):
    def __init__(self, embed_dim, fusion_hidden_dim, dropout):
        super().__init__(); self.transformer = TransformerBranch(embed_dim=embed_dim, dropout=dropout)
        self.classifier = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, fusion_hidden_dim),
                                      nn.ReLU(), nn.Dropout(dropout), nn.Linear(fusion_hidden_dim, 1))
    def forward(self, batch):
        seq_features = self.transformer(batch['sequence'])
        return self.classifier(seq_features)

# --- Ablation 5 (A5): Priors Only ---
class PriorsOnly(nn.Module):
    def __init__(self, prior_dim, fusion_hidden_dim, dropout):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.LayerNorm(prior_dim),
            nn.Linear(prior_dim, fusion_hidden_dim), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(fusion_hidden_dim, fusion_hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(fusion_hidden_dim // 2, 1))
    def forward(self, batch):
        return self.classifier(batch['priors'])

## **4. Resumable Experiment Execution Framework**
This generic `run_experiment` function handles the training and evaluation for any given model, now with robust checkpointing to resume training if interrupted.

In [9]:
# %% 4. Experiment Execution Framework (Resumable, with Duration Logging)
# ============================================================================

def run_experiment(model_name, model_instance, train_loader, val_loader, test_loader,
                   num_epochs, learning_rate, device, output_dir,
                   precomputed_gat_output=None):
    """
    A generic, resumable function to train and evaluate a given model instance for ablation studies.
    Includes per-epoch duration logging and robust checkpointing.
    """
    log_message(f"\n{'='*25} STARTING EXPERIMENT: {model_name} {'='*25}")

    # --- Setup paths for this specific experiment ---
    exp_output_dir = os.path.join(output_dir, model_name)
    os.makedirs(exp_output_dir, exist_ok=True)
    checkpoint_dir = os.path.join(exp_output_dir, "checkpoints")
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    model_best_save_path = os.path.join(exp_output_dir, f"best_model_{model_name}.pth")
    loss_plot_path = os.path.join(exp_output_dir, f"loss_plot_{model_name}.png")
    
    # --- Initialize Model, Optimizer, Scheduler ---
    model_instance.to(device)
    
    if hasattr(model_instance, 'precomputed_gat_output'):
        model_instance.precomputed_gat_output = precomputed_gat_output
        if precomputed_gat_output is not None and precomputed_gat_output.numel() > 0:
             log_message(f"Provided precomputed GAT output tensor to {model_name}.")

    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(model_instance.parameters(), lr=learning_rate, weight_decay=OPTIMIZER_WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=3)

    # --- RESUME FROM CHECKPOINT LOGIC ---
    start_epoch = 0; train_losses, val_losses = [], []; best_val_loss = float('inf')
    checkpoints = sorted(glob.glob(os.path.join(checkpoint_dir, "checkpoint_epoch_*.pth")),
                         key=lambda x: int(re.search(r"epoch_(\d+)",x).group(1)) if re.search(r"epoch_(\d+)",x) else -1,
                         reverse=True)
    if checkpoints:
        latest_checkpoint_path = checkpoints[0]
        log_message(f"Resuming {model_name} from checkpoint: {os.path.basename(latest_checkpoint_path)}")
        try:
            checkpoint = torch.load(latest_checkpoint_path, map_location=device)
            model_instance.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            if scheduler and 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None:
                scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            start_epoch = checkpoint['epoch'] + 1
            train_losses = checkpoint.get('train_losses',[]); val_losses = checkpoint.get('val_losses',[])
            best_val_loss = checkpoint.get('best_val_loss', float('inf'))
            log_message(f"Resumed from epoch {start_epoch-1}. Best Val Loss so far: {best_val_loss:.4f}")
        except Exception as e:
            log_message(f"Error loading checkpoint for {model_name}: {e}. Training from scratch.")
            start_epoch = 0

    if start_epoch >= num_epochs:
        log_message(f"{model_name} training already completed. Loading best model for evaluation.")
        if os.path.exists(model_best_save_path):
            model_instance.load_state_dict(torch.load(model_best_save_path, map_location=device))
    else:
        # --- TRAINING LOOP ---
        log_message(f"--- Training {model_name} from epoch {start_epoch} to {num_epochs-1} ---")
        for epoch in range(start_epoch, num_epochs):
            epoch_start_time = time.time()
            model_instance.train(); running_loss=0.0
            train_loop = tqdm(train_loader, desc=f"{model_name} E{epoch+1} [Train]", leave=False)
            for batch in train_loop:
                batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
                optimizer.zero_grad(); logits = model_instance(batch); loss = criterion(logits, batch['label'])
                loss.backward(); optimizer.step(); running_loss += loss.item()
                train_loop.set_postfix(loss=f"{loss.item():.4f}")
            
            avg_train_loss = running_loss/len(train_loader) if len(train_loader)>0 else 0.0
            train_losses.append(avg_train_loss)

            model_instance.eval(); val_loss = 0.0
            with torch.no_grad():
                for batch in val_loader:
                    batch = {k:v.to(device) for k,v in batch.items() if isinstance(v, torch.Tensor)}
                    logits = model_instance(batch); val_loss += criterion(logits, batch['label']).item()
            avg_val_loss = val_loss/len(val_loader) if len(val_loader)>0 else 0.0
            val_losses.append(avg_val_loss)
            
            epoch_duration = time.time() - epoch_start_time; current_lr = optimizer.param_groups[0]['lr']
            log_message(f"{model_name} E{epoch+1}/{num_epochs}: TrL={avg_train_loss:.4f}, VaL={avg_val_loss:.4f}, Dur={epoch_duration:.2f}s, LR={current_lr:.2e}")
            if scheduler: scheduler.step(avg_val_loss)

            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss; torch.save(model_instance.state_dict(), model_best_save_path)
                log_message(f"Saved new best {model_name} model.")
            
            checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch:03d}.pth")
            torch.save({'epoch': epoch, 'model_state_dict': model_instance.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
                        'train_losses': train_losses, 'val_losses': val_losses, 'best_val_loss': best_val_loss}, checkpoint_path)

        if train_losses:
            epochs_plotted = len(train_losses); epochs_range = range(1, epochs_plotted + 1)
            plt.figure(figsize=(10,6)); plt.plot(epochs_range,train_losses,label='Train'); plt.plot(epochs_range,val_losses,label='Val');
            plt.title(f'{model_name} Training Loss'); plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.grid(True); plt.savefig(loss_plot_path, dpi=300); plt.close()

    # --- EVALUATION ---
    log_message(f"--- Evaluating best {model_name} model on Test Set ---")
    if os.path.exists(model_best_save_path): model_instance.load_state_dict(torch.load(model_best_save_path, map_location=device))
    else: log_message(f"Warning: Best model for {model_name} not found. Evaluating last epoch model.")
    
    model_instance.eval(); all_labels, all_probs = [], []
    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f"Eval {model_name}", leave=False):
            batch = {k:v.to(device) for k,v in batch.items() if isinstance(v, torch.Tensor)}
            logits = model_instance(batch)
            all_labels.extend(batch['label'].cpu().numpy()); all_probs.extend(torch.sigmoid(logits).cpu().numpy())
    
    all_labels=np.array(all_labels).flatten(); all_probs=np.array(all_probs).flatten()
    all_preds=(all_probs > 0.5).astype(int);
    
    acc = accuracy_score(all_labels, all_preds); f1 = f1_score(all_labels, all_preds, zero_division=0)
    prec = precision_score(all_labels, all_preds, zero_division=0); rec = recall_score(all_labels, all_preds, zero_division=0)
    cm = confusion_matrix(all_labels, all_preds, labels=[0,1] if len(np.unique(all_labels))>=2 else np.unique(all_labels))
    tn, fp, fn, tp = cm.ravel() if cm.shape == (2,2) else (cm[0,0],0,0,0) if cm.shape==(1,1) and all_labels[0]==0 else (0,0,0,cm[0,0]) if cm.shape==(1,1) and all_labels[0]==1 else (0,0,0,0)
    spec = tn/(tn+fp) if (tn+fp)>0 else 0.0
    auc = roc_auc_score(all_labels, all_probs) if len(np.unique(all_labels))>1 else np.nan

   
    results = {'model': model_name, 'accuracy': acc, 'f1_score': f1, 'auc': auc, 'precision': prec, 'recall': rec, 'specificity': spec}
    log_message(f"--- {model_name} Test Results: ACC={acc:.4f}, F1={f1:.4f}, AUC={auc if not np.isnan(auc) else -1:.4f} ---")
    
    results_path = os.path.join(exp_output_dir, f"results_{model_name}.csv")
    pd.DataFrame([results]).to_csv(results_path, index=False)
    log_message(f"Saved results for {model_name} to {results_path}")
    
    if 'plot_confusion_matrix' in globals():
        cm_path = os.path.join(exp_output_dir, f"cm_{model_name}.png")
        plot_confusion_matrix(cm, ['Non-P','P'], save_path=cm_path, title=f'{model_name} Confusion Matrix')
    
    return results

## **5. Main Execution Block for Ablation & Baselines**
This block sets up and runs each experiment sequentially, checking for completion before starting to allow for easy resumption of the entire experimental suite.

In [10]:
# %% 5. Main Execution Block for Ablation & Baselines
# ============================================================================
if __name__ == "__main__":
    log_message("--- Starting Ablation & Baseline Experiment Workflow ---")
    
    class ArgsAblation: pass
    args = ArgsAblation()
    args.force_rerun = False # Set to True to retrain all models from scratch

    try:
        # --- Load & Prepare Data ONCE ---
        log_message("\n--- Step 1: Loading and Preparing Data for All Experiments ---")
        prom_seqs, prom_ids, prom_labels = load_sequences(PROMOTER_SEQ_FILE, True)
        nonprom_seqs, nonprom_ids, nonprom_labels = load_sequences(NON_PROMOTER_SEQ_FILE, False)
        all_seqs_raw=prom_seqs+nonprom_seqs; all_gene_ids=prom_ids+nonprom_ids; all_labels=prom_labels+nonprom_labels
        master_gene_list=pd.Series(all_gene_ids).drop_duplicates().tolist()
        priors_aligned, prior_dim = load_priors(PRIOR_FILE, master_gene_list)
        gene_to_prior_map = {gid: priors_aligned[i] for i,gid in enumerate(master_gene_list)}
        final_priors_for_ds = np.array([gene_to_prior_map.get(gid, np.zeros(prior_dim,dtype=np.float32)) for gid in all_gene_ids])
        sequences_encoded = np.array([integer_encode_sequence(s) for s in tqdm(all_seqs_raw, desc="Encoding All Seqs")])

        # --- Load Precomputed GAT Output ---
        main_bwaf_output_dir = 'results_bwaf_final/'
        gat_output_path = os.path.join(main_bwaf_output_dir, "precomputed_gat_output.pt")
        
        if os.path.exists(gat_output_path):
            log_message(f"Loading pre-saved static GAT data from {gat_output_path}...")
            PRECOMPUTED_GAT_OUTPUT_STATIC = torch.load(gat_output_path, map_location='cpu') # Load to CPU
        else:
            log_message(f"WARNING: Precomputed GAT output not found at {gat_output_path}. "
                        "GAT-related ablations (A2, A3) will use a zero tensor.")
            num_genes_in_gat_data = len(master_gene_list)
            PRECOMPUTED_GAT_OUTPUT_STATIC = torch.zeros(num_genes_in_gat_data, GAT_OUTPUT_DIM)

        final_gat_gene_order = master_gene_list

        # --- Create Full Dataset & Splits ONCE ---
        full_dataset = MainDataset(sequences_encoded, all_gene_ids, all_labels,
                                   final_priors_for_ds, PRECOMPUTED_GAT_OUTPUT_STATIC.numpy(), final_gat_gene_order)
        indices = list(range(len(full_dataset))); np.random.seed(RANDOM_SEED); np.random.shuffle(indices)
        test_idx = int(np.floor(TEST_SPLIT*len(full_dataset))); val_idx = test_idx + int(np.floor(VALIDATION_SPLIT*len(full_dataset)))
        train_indices=indices[val_idx:]; val_indices=indices[test_idx:val_idx]; test_indices=indices[:test_idx]
        train_ds=Subset(full_dataset,train_indices); val_ds=Subset(full_dataset,val_indices); test_ds=Subset(full_dataset,test_indices)
        log_message(f"Ablation Data Split: Train={len(train_ds)}, Val={len(val_ds)}, Test={len(test_ds)}")
        num_workers = 0; drop_last_train = (len(train_ds) % BATCH_SIZE == 1) and len(train_ds) > 1
        train_loader=DataLoader(train_ds, BATCH_SIZE, shuffle=True, num_workers=num_workers, drop_last=drop_last_train)
        val_loader=DataLoader(val_ds, BATCH_SIZE, shuffle=False, num_workers=num_workers)
        test_loader=DataLoader(test_ds, BATCH_SIZE, shuffle=False, num_workers=num_workers)

        # --- Define and Run Experiments ---
        all_results = []
        experiments_to_run = {
            "A4_Transformer_Only": TransformerOnly(embed_dim=EMBEDDING_DIM, fusion_hidden_dim=FUSION_HIDDEN_DIM, dropout=DROPOUT_RATE),
            "A5_Priors_Only": PriorsOnly(prior_dim=prior_dim, fusion_hidden_dim=FUSION_HIDDEN_DIM, dropout=DROPOUT_RATE),
            "A1_Transformer_Priors": TransformerPriorsOnly(prior_dim=prior_dim, embed_dim=EMBEDDING_DIM, fusion_hidden_dim=FUSION_HIDDEN_DIM, dropout=DROPOUT_RATE),
            "A2_GAT_Priors": GATPriorsOnly(prior_dim=prior_dim, gat_output_dim=GAT_OUTPUT_DIM, fusion_hidden_dim=FUSION_HIDDEN_DIM, dropout=DROPOUT_RATE),
            "A3_Simple_Concat": SimpleConcatModel(prior_dim=prior_dim, embed_dim=EMBEDDING_DIM, gat_output_dim=GAT_OUTPUT_DIM, fusion_hidden_dim=FUSION_HIDDEN_DIM, dropout=DROPOUT_RATE)
        }

        for name, model in experiments_to_run.items():
            model_results_path = os.path.join(OUTPUT_DIR_ABLATION, f"results_{name}.csv")
            if os.path.exists(model_results_path) and not args.force_rerun:
                log_message(f"\nResults for {name} already exist. Skipping experiment.")
                try:
                    all_results.append(pd.read_csv(model_results_path).to_dict('records')[0])
                except pd.errors.EmptyDataError:
                    log_message(f"Warning: Found empty results file for {name}. Rerunning.")
                    # Fall through to run the experiment
                except Exception as e:
                    log_message(f"Warning: Could not read existing results for {name}: {e}. Rerunning.")
                else:
                    continue # Skip to next model

            results = run_experiment(name, model, train_loader, val_loader, test_loader,
                                     NUM_EPOCHS, LEARNING_RATE, DEVICE, OUTPUT_DIR_ABLATION,
                                     precomputed_gat_output=PRECOMPUTED_GAT_OUTPUT_STATIC)
            if results: all_results.append(results)

        # --- Summarize All Results ---
        if all_results:
            summary_df = pd.DataFrame(all_results)
            # Check if 'model' column exists before setting index
            if 'model' in summary_df.columns:
                summary_df = summary_df.set_index('model')
            else:
                log_message("Warning: 'model' column not found in results list, cannot set as index.")
            
            summary_path = os.path.join(OUTPUT_DIR_ABLATION, "ablation_baselines_summary.csv")
            summary_df.to_csv(summary_path)
            log_message(f"\n--- All Ablation/Baseline Experiments Complete. Summary saved to {summary_path} ---")
            
            print("\n--- Final Summary ---")
           
            if 'auc' in summary_df.columns:
                print(summary_df.sort_values(by='auc', ascending=False))
            else:
                log_message("Warning: 'auc' column not found in final summary DataFrame. Cannot sort.")
                print(summary_df)
        else:
            log_message("--- No experiments completed successfully. ---")
            
    except Exception as e:
        log_message(f"--- ABLATION/BASELINE WORKFLOW FAILED ---");
        log_message(f"ERROR: {type(e).__name__}: {e}");
        log_message("Traceback:\n" + traceback.format_exc())

[2025-07-14 07:32:08] --- Starting Ablation & Baseline Experiment Workflow ---
[2025-07-14 07:32:08] 
--- Step 1: Loading and Preparing Data for All Experiments ---
[2025-07-14 07:32:08] Loading sequences from ./data/raw/human_genome_annotation/updated_promoter_features_clean.csv...
[2025-07-14 07:32:09] Loading sequences from ./data/raw/human_genome_annotation/updated_non_promoter_sequences.csv...
[2025-07-14 07:32:10] Loading priors from ./data/raw/human_genome_annotation/biological_prior_for_transformer_branch.csv...


Encoding All Seqs:   0%|          | 0/40056 [00:00<?, ?it/s]

[2025-07-14 07:32:26] Ablation Data Split: Train=28040, Val=6008, Test=6008
[2025-07-14 07:32:26] 
[2025-07-14 07:32:26] Resuming A4_Transformer_Only from checkpoint: checkpoint_epoch_009.pth
[2025-07-14 07:32:26] Resumed from epoch 9. Best Val Loss so far: 0.5827
[2025-07-14 07:32:26] A4_Transformer_Only training already completed. Loading best model for evaluation.
[2025-07-14 07:32:26] --- Evaluating best A4_Transformer_Only model on Test Set ---


Eval A4_Transformer_Only:   0%|          | 0/376 [00:00<?, ?it/s]

[2025-07-14 07:56:19] --- A4_Transformer_Only Test Results: ACC=0.7124, F1=0.6850, AUC=0.7815 ---
[2025-07-14 07:56:19] Saved results for A4_Transformer_Only to results_ablation_studies/A4_Transformer_Only/results_A4_Transformer_Only.csv
[2025-07-14 07:56:19] 
[2025-07-14 07:56:19] Resuming A5_Priors_Only from checkpoint: checkpoint_epoch_009.pth
[2025-07-14 07:56:19] Resumed from epoch 9. Best Val Loss so far: 0.0584
[2025-07-14 07:56:19] A5_Priors_Only training already completed. Loading best model for evaluation.
[2025-07-14 07:56:19] --- Evaluating best A5_Priors_Only model on Test Set ---


Eval A5_Priors_Only:   0%|          | 0/376 [00:00<?, ?it/s]

[2025-07-14 07:56:19] --- A5_Priors_Only Test Results: ACC=0.9767, F1=0.9763, AUC=0.9970 ---
[2025-07-14 07:56:19] Saved results for A5_Priors_Only to results_ablation_studies/A5_Priors_Only/results_A5_Priors_Only.csv
[2025-07-14 07:56:19] 
[2025-07-14 07:56:19] Resuming A1_Transformer_Priors from checkpoint: checkpoint_epoch_009.pth
[2025-07-14 07:56:19] Resumed from epoch 9. Best Val Loss so far: 0.1530
[2025-07-14 07:56:19] A1_Transformer_Priors training already completed. Loading best model for evaluation.
[2025-07-14 07:56:19] --- Evaluating best A1_Transformer_Priors model on Test Set ---


Eval A1_Transformer_Priors:   0%|          | 0/376 [00:00<?, ?it/s]

[2025-07-14 08:19:39] --- A1_Transformer_Priors Test Results: ACC=0.9419, F1=0.9421, AUC=0.9877 ---
[2025-07-14 08:19:39] Saved results for A1_Transformer_Priors to results_ablation_studies/A1_Transformer_Priors/results_A1_Transformer_Priors.csv
[2025-07-14 08:19:39] 
[2025-07-14 08:19:39] Resuming A2_GAT_Priors from checkpoint: checkpoint_epoch_009.pth
[2025-07-14 08:19:39] Resumed from epoch 9. Best Val Loss so far: 0.1897
[2025-07-14 08:19:39] A2_GAT_Priors training already completed. Loading best model for evaluation.
[2025-07-14 08:19:39] --- Evaluating best A2_GAT_Priors model on Test Set ---


Eval A2_GAT_Priors:   0%|          | 0/376 [00:00<?, ?it/s]

[2025-07-14 08:19:39] --- A2_GAT_Priors Test Results: ACC=0.9404, F1=0.9369, AUC=0.9850 ---
[2025-07-14 08:19:39] Saved results for A2_GAT_Priors to results_ablation_studies/A2_GAT_Priors/results_A2_GAT_Priors.csv
[2025-07-14 08:19:39] 
[2025-07-14 08:19:39] Resuming A3_Simple_Concat from checkpoint: checkpoint_epoch_009.pth
[2025-07-14 08:19:39] Resumed from epoch 9. Best Val Loss so far: 0.1600
[2025-07-14 08:19:39] A3_Simple_Concat training already completed. Loading best model for evaluation.
[2025-07-14 08:19:39] --- Evaluating best A3_Simple_Concat model on Test Set ---


Eval A3_Simple_Concat:   0%|          | 0/376 [00:00<?, ?it/s]

[2025-07-14 08:42:50] --- A3_Simple_Concat Test Results: ACC=0.9371, F1=0.9391, AUC=0.9888 ---
[2025-07-14 08:42:50] Saved results for A3_Simple_Concat to results_ablation_studies/A3_Simple_Concat/results_A3_Simple_Concat.csv
[2025-07-14 08:42:50] 
--- All Ablation/Baseline Experiments Complete. Summary saved to results_ablation_studies/ablation_baselines_summary.csv ---

--- Final Summary ---
                       accuracy  f1_score       auc  precision    recall  \
model                                                                      
A5_Priors_Only         0.976698  0.976255  0.996985   1.000000  0.953612   
A3_Simple_Concat       0.937084  0.939130  0.988782   0.913534  0.966203   
A1_Transformer_Priors  0.941911  0.942075  0.987659   0.943798  0.940358   
A2_GAT_Priors          0.940413  0.936950  0.985045   1.000000  0.881378   
A4_Transformer_Only    0.712383  0.685016  0.781548   0.761345  0.622598   

                       specificity  
model                            