# **Supplementary Code: BWAF Multi-Modal Promoter Prediction Framework**

**Description:**
This notebook provides the complete Python implementation for the research thesis titled: **"Biologically Weighted Dynamic Fusion of Transformer and Graph Attention Networks for Promoter Region Identification in the Human Genome Using Multi-Omics Data"**.

The code implements a multi-modal deep learning pipeline integrating:
1.  **Genomic Sequences:** Processed using a Transformer network.
2.  **Gene-TF Interaction Networks:** From the GRAND database (36 tissues), processed using Graph Attention Networks (GATs). The GATs operate on **Gene-Gene graphs** derived from TF co-regulation patterns within each tissue.
3.  **Biological Priors:** Derived from promoter motif counts, log-transformed.

A key novelty is the **Biologically Weighted Attention Fusion (BWAF)** layer, which uses biological priors to dynamically modulate the fusion of sequence and network features. The pipeline includes data loading, extensive preprocessing (including a refined GAT edge index creation strategy with optional edge capping), model definition, training, and evaluation routines.

---

## **1. Imports**

This section imports all necessary libraries for data manipulation (Pandas, NumPy), deep learning (PyTorch, PyTorch Geometric), machine learning evaluation (scikit-learn), plotting (Matplotlib, Seaborn), and system utilities (os, glob, etc.).

In [1]:
# %% 1. Imports
# ============================================================================
# Standard library imports
import os
import re
import glob
import gzip
import time
import argparse
import datetime
import sys
import warnings
import random # Ensure random is imported

# Third-party imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import train_test_split # Used for splitting indices
from sklearn.metrics import (accuracy_score, precision_score, recall_score,
                           f1_score, roc_auc_score, confusion_matrix)
from scipy.sparse import coo_matrix # For sparse matrix operations if needed elsewhere
from tqdm.notebook import tqdm # Use notebook version for Jupyter
# from tqdm import tqdm # Use standard tqdm otherwise
import matplotlib.pyplot as plt
import seaborn as sns

# PyTorch Geometric imports
try:
    from torch_geometric.nn import GATConv
    from torch_geometric.utils import from_scipy_sparse_matrix # Useful for converting scipy sparse to edge_index
except ImportError:
    print("PyTorch Geometric not found. Please install it: https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html")
    raise ImportError("PyTorch Geometric is required but not found.")


# Suppress specific warnings
warnings.filterwarnings("ignore", category=UserWarning, module='torch_geometric.nn.conv.gat_conv')
warnings.filterwarnings("ignore", category=UserWarning, module='torch_geometric.nn.conv.gatv2_conv')
warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)

## **2. Configuration / Constants**

This section defines crucial parameters and file paths used throughout the script. **Users must verify and potentially modify the `Data Paths` section** to match their local environment. Hyperparameters for the model and training process are also set here, allowing for easy modification and experimentation.

In [2]:
# %% 2. Configuration / Constants
# ============================================================================
# --- Data Paths ---
# !!! USER MUST MODIFY THESE PATHS !!!
BASE_DATA_DIR = 'data/' # Modify to your base data directory
SEQ_DATA_DIR = os.path.join(BASE_DATA_DIR, 'raw/human_genome_annotation')
PRIOR_DATA_DIR = os.path.join(BASE_DATA_DIR, 'raw/human_genome_annotation')
GAT_RAW_DATA_DIR = os.path.join(BASE_DATA_DIR, 'raw/gene_interaction_network/GRAND_networks') # Where GRAND *.csv are (UNNORMALIZED)
GAT_PREPROCESSED_DIR = os.path.join(BASE_DATA_DIR, 'preprocessed/gat_normalized') # Filtered/normalized output
GAT_EDGE_INDEX_DIR = os.path.join(BASE_DATA_DIR, 'preprocessed/gat_edge_indices') # Edge index output

PROMOTER_SEQ_FILE = os.path.join(SEQ_DATA_DIR, 'updated_promoter_features_clean.csv')
NON_PROMOTER_SEQ_FILE = os.path.join(SEQ_DATA_DIR, 'updated_non_promoter_sequences.csv') # Ensure this file exists
PRIOR_FILE = os.path.join(PRIOR_DATA_DIR, 'biological_prior_for_transformer_branch.csv')

# --- Model Hyperparameters ---
SEQ_LEN = 2000
NUCLEOTIDES = ['A', 'T', 'C', 'G']
PAD_IDX = 4
VOCAB_SIZE = len(NUCLEOTIDES) + 1

EMBEDDING_DIM = 64 # Dimension for Transformer output and GAT output
NUM_ATTN_HEADS = 4 # Transformer MHA heads
NUM_TRANSFORMER_LAYERS = 2
TRANSFORMER_FF_DIM = EMBEDDING_DIM * 4

NUM_GAT_LAYERS = 1 # Number of GAT layers per tissue
NUM_TISSUES = 36 # Number of tissue-specific networks to process
GAT_HEADS = 4 # GAT MHA heads (intermediate layers if NUM_GAT_LAYERS > 1)
GAT_FINAL_HEADS = 1 # GAT final layer heads (output averaged if > 1 and concat=False)

GAT_INTERACTION_THRESHOLD_STD_FACTOR = 2.5
MAX_EDGES_PER_TISSUE_APPROX = 1000000 # Approx. 1M undirected edges (2M directed)

# BWAF Fusion Layer specific
FUSION_HIDDEN_DIM = 128 # Hidden dimension within the fusion classifier part

# General
DROPOUT_RATE = 0.3

# --- Training Hyperparameters ---
LEARNING_RATE = 0.0005
BATCH_SIZE = 16
NUM_EPOCHS = 10 # Adjust as needed
VALIDATION_SPLIT = 0.15
TEST_SPLIT = 0.15
RANDOM_SEED = 42
OPTIMIZER_WEIGHT_DECAY = 1e-5

# --- Output Files ---
OUTPUT_DIR = 'results_bwaf_v3/' # New directory for this version's results
MODEL_SAVE_PATH = os.path.join(OUTPUT_DIR, 'best_promoter_model_bwaf.pth')
LOSS_PLOT_PATH = os.path.join(OUTPUT_DIR, 'training_validation_loss_bwaf.png')
CONFUSION_MATRIX_PATH = os.path.join(OUTPUT_DIR, 'confusion_matrix_bwaf.png')
RESULTS_CSV_PATH = os.path.join(OUTPUT_DIR, 'test_set_evaluation_results_bwaf.csv')
LOG_FILE_PATH = os.path.join(OUTPUT_DIR, 'training_log_bwaf.txt')

# --- Hardware ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Setup ---
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
os.makedirs(GAT_PREPROCESSED_DIR, exist_ok=True)
os.makedirs(GAT_EDGE_INDEX_DIR, exist_ok=True)

## **3. Utility Functions**

This section contains helper functions for logging, DNA sequence encoding, log-transformation of priors, and standardized gene ID extraction.

In [3]:
# %% 3. Utility Functions
# ============================================================================
# --- Logging ---
def log_message(message, log_file=LOG_FILE_PATH):
    """Appends a timestamped message to the log file and prints it."""
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    full_message = f"[{timestamp}] {message}"
    print(full_message)
    log_dir = os.path.dirname(log_file)
    if log_dir and not os.path.exists(log_dir): os.makedirs(log_dir)
    try:
        with open(log_file, 'a', encoding='utf-8') as f: f.write(full_message + '\n')
    except IOError as e: print(f"Error writing log: {e}")

try: # Initialize Log File
    with open(LOG_FILE_PATH, 'w', encoding='utf-8') as f: f.write(f"--- BWAF Training Log Initialized: {datetime.datetime.now()} ---\n")
    log_message(f"Log file started at: {LOG_FILE_PATH}")
except IOError as e: print(f"CRITICAL ERROR: Could not write initial log file {LOG_FILE_PATH}: {e}"); sys.exit(1)

log_message(f"Using device: {DEVICE}")
log_message(f"Random Seed: {RANDOM_SEED}")

# --- Data Processing ---
def integer_encode_sequence(sequence, max_len=SEQ_LEN):
    """Encodes a DNA sequence (string) into integer indices for PyTorch nn.Embedding."""
    encoding_map = {'A': 0, 'T': 1, 'C': 2, 'G': 3}
    encoded = np.full(max_len, PAD_IDX, dtype=np.int64)
    seq_len_actual = min(len(sequence), max_len)
    n_unmapped = 0
    for i, nucleotide in enumerate(sequence[:seq_len_actual]):
        idx = encoding_map.get(nucleotide.upper())
        if idx is not None: encoded[i] = idx
        else:
            encoded[i] = PAD_IDX
            if nucleotide.upper() != 'N' and n_unmapped < 5: # Log first few unexpected chars
                 log_message(f"Warning: Non-ATCGN char '{nucleotide}' at pos {i}. Mapped to PAD_IDX.")
                 n_unmapped += 1
    return encoded

def log_transform_priors(prior_counts):
    """Applies log(1+x) transformation element-wise to biological prior counts."""
    if not isinstance(prior_counts, np.ndarray): prior_counts = np.array(prior_counts)
    prior_counts[prior_counts < 0] = 0
    return np.log1p(prior_counts.astype(np.float32))

def extract_clean_gene_id(raw_id_series):
    """Uses regex (r'(ENSG\d+)') to extract the Ensembl Gene ID from a pandas Series."""
    if not isinstance(raw_id_series, pd.Series): raw_id_series = pd.Series(raw_id_series)
    return raw_id_series.astype(str).str.extract(r'(ENSG\d+)', expand=False).fillna('UNKNOWN')

[2025-06-04 18:28:51] Log file started at: results_bwaf_v3/training_log_bwaf.txt
[2025-06-04 18:28:51] Using device: cpu
[2025-06-04 18:28:51] Random Seed: 42


## **4. Data Loading Functions**

These functions handle loading and preprocessing for sequences, priors, and GAT network data.
*   `load_sequences`: Loads, QCs, and labels sequences.
*   `load_priors`: Loads, aligns, and log-transforms motif counts.
*   `_preprocess_and_save_gat_matrix`: Helper to filter and MinMax normalize raw GAT matrices, ensuring consistent TF and Gene orders.
*   `preprocess_all_gat_data`: Orchestrates GAT matrix preprocessing and determines master TF list.
*   `load_processed_gat_data`: Loads normalized GAT data and creates initial node features (average TF interaction profiles).
*   `load_or_create_edge_indices`: **Refined function to create Gene-Gene graphs.** Edges connect genes co-regulated by the same TF above a TF-specific dynamic threshold, with an optional cap on total edges per tissue.

In [4]:
# %% 4. Data Loading Functions
# ============================================================================

def load_sequences(file_path, is_promoter=True):
    """Loads sequences, performs QC (N removal, length check), assigns labels, handles encodings."""
    log_message(f"Loading sequences from {file_path}...")
    start_time = time.time()
    try:
        if not os.path.exists(file_path): raise FileNotFoundError(f"Sequence file not found: {file_path}")
        df = None; encodings_to_try = ['utf-8', 'ISO-8859-1', 'latin1']
        for encoding in encodings_to_try:
            try: df = pd.read_csv(file_path, encoding=encoding, low_memory=False); break
            except (UnicodeDecodeError, pd.errors.ParserError): pass
        if df is None: raise ValueError(f"Could not read {file_path} with attempted encodings.")

        seq_col, id_col = None, None
        possible_seq_cols = ['promoter_sequence', 'sequence']
        possible_id_cols = ['gene_id', df.columns[0] if not df.empty else None]

        for col in possible_seq_cols:
            if col in df.columns: seq_col = col; break
        for col in possible_id_cols:
            if col in df.columns: id_col = col; break
        if seq_col is None or id_col is None: raise ValueError(f"Required columns not in {file_path}. Cols: {df.columns.tolist()}")

        df[seq_col] = df[seq_col].astype(str); df['clean_gene_id'] = extract_clean_gene_id(df[id_col])
        df.dropna(subset=['clean_gene_id', seq_col], inplace=True); df = df[df['clean_gene_id'] != 'UNKNOWN']; df = df[df[seq_col].str.strip() != '']

        initial_qc_count = len(df)
        df = df[~df[seq_col].str.contains('N', na=False, case=False)]; removed_n = initial_qc_count - len(df)
        initial_qc_count = len(df)
        # Only check length if SEQ_LEN is defined and positive
        if SEQ_LEN and SEQ_LEN > 0:
            df['seq_len_actual'] = df[seq_col].str.len()
            df = df[df['seq_len_actual'] == SEQ_LEN]; removed_len = initial_qc_count - len(df)
            df = df.drop(columns=['seq_len_actual'])
            log_message(f"Sequence QC for {os.path.basename(file_path)}: Total initial: {len(df) + removed_n + removed_len}. Removed {removed_n} (with 'N'). Removed {removed_len} (not {SEQ_LEN}bp). Final: {len(df)}.")
        else:
            log_message(f"Sequence QC for {os.path.basename(file_path)}: Total initial: {len(df) + removed_n}. Removed {removed_n} (with 'N'). Length check skipped. Final: {len(df)}.")


        if df.empty: log_message(f"Warning: No valid sequences in {file_path}."); return [], [], []
        sequences = df[seq_col].tolist(); gene_ids = df['clean_gene_id'].tolist(); labels = [1 if is_promoter else 0] * len(sequences)
        log_message(f"Loaded and filtered {len(sequences)} sequences from {os.path.basename(file_path)} in {time.time() - start_time:.2f}s.")
        return sequences, gene_ids, labels
    except Exception as e: log_message(f"CRITICAL ERROR loading sequence file {file_path}: {e}"); raise

def load_priors(file_path, gene_id_order):
    """Loads priors, aligns, log-transforms."""
    log_message(f"Loading priors from {file_path}...")
    start_time = time.time()
    try:
        if not os.path.exists(file_path): raise FileNotFoundError(f"Prior file not found: {file_path}")
        df = pd.read_csv(file_path)
        if 'gene_id' not in df.columns: raise ValueError("'gene_id' column missing in prior file.")
        df['clean_gene_id'] = extract_clean_gene_id(df['gene_id'])
        df.dropna(subset=['clean_gene_id'], inplace=True); df = df[df['clean_gene_id'] != 'UNKNOWN']
        count_columns = [col for col in df.columns if '(Count)' in col]
        if not count_columns: raise ValueError("No prior count columns found in prior file.")
        prior_dim = len(count_columns); log_message(f"Found {prior_dim} prior count columns.")
        df_priors = df[['clean_gene_id'] + count_columns].copy()
        for col in count_columns: df_priors[col] = pd.to_numeric(df_priors[col], errors='coerce')
        df_priors.fillna(0, inplace=True)
        df_priors.set_index('clean_gene_id', inplace=True)
        df_priors = df_priors[~df_priors.index.duplicated(keep='first')]
        aligned_df = df_priors.reindex(gene_id_order, fill_value=0)
        log_transformed_priors = log_transform_priors(aligned_df.values)
        log_message(f"Processed priors for {log_transformed_priors.shape[0]} genes in {time.time() - start_time:.2f}s.")
        return log_transformed_priors, prior_dim
    except Exception as e: log_message(f"CRITICAL ERROR processing prior file {file_path}: {e}"); raise

def _preprocess_and_save_gat_matrix(raw_file_path, output_file_path, gene_id_order_master_unique, master_tf_ids_list=None):
    """Internal: Filters raw GAT matrix by unique gene_id_order and master_tf_ids_list, normalizes, saves."""
    try:
        df_raw = pd.read_csv(raw_file_path, index_col=0)
        if df_raw.columns.has_duplicates: df_raw = df_raw.loc[:, ~df_raw.columns.duplicated(keep='first')]
        current_tf_ids_raw = df_raw.index.tolist()
        if master_tf_ids_list is None: master_tf_ids_list = current_tf_ids_raw

        df_tf_aligned = df_raw.reindex(index=master_tf_ids_list, fill_value=0.0)
        master_gene_set_unique = set(gene_id_order_master_unique)
        common_genes_in_raw_order = [gene for gene in df_tf_aligned.columns if gene in master_gene_set_unique]

        if not common_genes_in_raw_order:
            df_gene_aligned = pd.DataFrame(0.0, index=master_tf_ids_list, columns=gene_id_order_master_unique)
        else:
            df_gene_aligned = df_tf_aligned[common_genes_in_raw_order]
            df_gene_aligned = df_gene_aligned.reindex(columns=gene_id_order_master_unique, fill_value=0.0)

        numeric_data = df_gene_aligned.values.astype(np.float32)
        min_val, max_val = np.min(numeric_data), np.max(numeric_data)
        range_val = max_val - min_val
        normalized_values = np.zeros_like(numeric_data) if range_val < 1e-9 else (numeric_data - min_val) / (range_val + 1e-9)
        df_normalized_aligned = pd.DataFrame(normalized_values, index=master_tf_ids_list, columns=gene_id_order_master_unique)
        df_normalized_aligned.to_csv(output_file_path, compression='gzip')
        return df_normalized_aligned, master_tf_ids_list, df_normalized_aligned.columns.tolist()
    except Exception as e:
        log_message(f"Error in _preprocess_and_save_gat_matrix for {raw_file_path}: {e}")
        import traceback; log_message(traceback.format_exc())
        return None, master_tf_ids_list, None

def preprocess_all_gat_data(raw_data_dir, processed_data_dir, gene_id_order_master_unique, force_preprocess=False):
    """Filters & normalizes GAT matrices, establishes master TF list."""
    log_message(f"Preprocessing GAT data: {raw_data_dir} -> {processed_data_dir}")
    os.makedirs(processed_data_dir, exist_ok=True)
    raw_files = sorted(glob.glob(os.path.join(raw_data_dir, '*.csv')))
    if not raw_files: raise FileNotFoundError(f"No raw GAT *.csv files in {raw_data_dir}")

    expected_processed = [os.path.join(processed_data_dir, f"normalized_{os.path.basename(f)}.gz") for f in raw_files]
    all_exist = all(os.path.exists(f) for f in expected_processed)
    master_tfs_determined = None; num_genes_aligned = len(gene_id_order_master_unique)

    if all_exist and not force_preprocess:
        log_message("Processed GAT files found. Verifying TF/Gene alignment...")
        try:
            df_sample = pd.read_csv(expected_processed[0], index_col=0, compression='gzip')
            master_tfs_determined = df_sample.index.tolist()
            if df_sample.columns.tolist() != gene_id_order_master_unique:
                 log_message("CRITICAL: Existing processed GAT columns DO NOT match master gene order. Forcing re-preprocess."); force_preprocess = True
        except Exception as e: log_message(f"Error reading sample: {e}. Forcing re-preprocess."); force_preprocess = True
    
    if not all_exist or force_preprocess: # Condition to enter processing loop
        if all_exist and force_preprocess: log_message("Forcing re-preprocessing of GAT files...")
        else: log_message("Processing raw GAT files (filter/normalize)...")

        processed_count = 0
        for i, raw_fp in enumerate(tqdm(raw_files, desc="Preprocessing GAT matrices")):
            out_fp = os.path.join(processed_data_dir, f"normalized_{os.path.basename(raw_fp)}.gz")
            _, current_tfs_list, _ = _preprocess_and_save_gat_matrix(raw_fp, out_fp, gene_id_order_master_unique, master_tfs_determined)
            if current_tfs_list is not None:
                processed_count += 1
                if master_tfs_determined is None: # First successful processing sets the master
                    master_tfs_determined = current_tfs_list
                    log_message(f"Master TF list ({len(master_tfs_determined)} TFs) set from {os.path.basename(raw_fp)}")
                elif master_tfs_determined != current_tfs_list: # Should not happen if alignment is correct in _preprocess...
                    log_message(f"Warning: TF list from {os.path.basename(raw_fp)} differs from master after alignment.")
        if processed_count == 0 or master_tfs_determined is None: raise ValueError("Failed to process GAT files or determine master TFs.")
        if processed_count != len(raw_files): log_message(f"Warning: Processed {processed_count}/{len(raw_files)} raw GAT files.")

    num_tfs_final = len(master_tfs_determined) if master_tfs_determined else 0
    log_message(f"GAT preprocessing complete. Master TFs: {num_tfs_final}, Aligned Genes: {num_genes_aligned}")
    return num_tfs_final, num_genes_aligned, gene_id_order_master_unique, master_tfs_determined

def load_processed_gat_data(processed_data_dir, final_gene_order_gat, master_tf_ids_list):
    """Loads preprocessed (normalized & aligned) GAT matrices for initial node features."""
    log_message(f"Loading processed GAT data from {processed_data_dir} (aligning to {len(final_gene_order_gat)} genes, {len(master_tf_ids_list)} TFs)...")
    start_time = time.time()
    processed_files = sorted(glob.glob(os.path.join(processed_data_dir, 'normalized_*.csv.gz')))
    if not processed_files: raise FileNotFoundError(f"No processed GAT files in {processed_data_dir}")

    actual_tissues_found = len(processed_files)
    global NUM_TISSUES_EFFECTIVE # To pass to GATNetwork
    NUM_TISSUES_EFFECTIVE = actual_tissues_found
    if actual_tissues_found != NUM_TISSUES:
        log_message(f"Note: Found {actual_tissues_found} processed GAT files. NUM_TISSUES config is {NUM_TISSUES}. Effective tissues: {NUM_TISSUES_EFFECTIVE}.")

    all_tissue_matrices = []; num_tfs_expected = len(master_tf_ids_list); num_genes_expected = len(final_gene_order_gat)
    for file_path in tqdm(processed_files, desc="Loading processed GAT matrices"):
        try:
            df = pd.read_csv(file_path, index_col=0, compression='gzip')
            # Re-confirm TF and Gene alignment (should be correct if preprocessing was done right)
            df = df.reindex(index=master_tf_ids_list, columns=final_gene_order_gat, fill_value=0.0)
            if df.shape != (num_tfs_expected, num_genes_expected):
                log_message(f"CRITICAL Shape Mismatch: {file_path} has {df.shape}, expected ({num_tfs_expected},{num_genes_expected}).")
            all_tissue_matrices.append(torch.tensor(df.values, dtype=torch.float32))
        except Exception as e: log_message(f"Error loading GAT file {file_path}: {e}"); raise

    if not all_tissue_matrices: raise ValueError("Failed to load GAT matrices.")
    all_tissue_tensors = torch.stack(all_tissue_matrices, dim=0)
    avg_interactions_tf_gene = all_tissue_tensors.mean(dim=0)
    initial_node_features = avg_interactions_tf_gene.T
    elapsed = time.time() - start_time
    log_message(f"Loaded {len(all_tissue_matrices)} GAT matrices, created initial node features in {elapsed:.2f}s.")
    return initial_node_features, num_tfs_expected

def load_or_create_edge_indices(
    num_genes_in_graph, num_tfs_in_raw, num_tissues_to_process, # num_tissues_to_process should be NUM_TISSUES_EFFECTIVE
    raw_gat_data_dir, gene_order_for_raw_alignment, edge_index_dir,
    threshold_std_factor, force_rebuild=False, max_edges_per_tissue_approx=None
):
    """
    Refined: Loads or creates Gene-Gene graph edge indices for each tissue.
    Edges connect genes if they are co-regulated by the same TF above a TF-specific dynamic threshold.
    """
    log_message(f"GAT Edge Indices (TF Co-reg Gene-Gene, ThrFactor: {threshold_std_factor}, MaxEdges: {max_edges_per_tissue_approx})...")
    os.makedirs(edge_index_dir, exist_ok=True); edge_indices_list = []
    raw_files_found = sorted(glob.glob(os.path.join(raw_gat_data_dir, '*.csv')))
    if not raw_files_found: raise FileNotFoundError(f"No raw GAT files in {raw_gat_data_dir}")

    actual_tissues_avail = len(raw_files_found)
    # Use num_tissues_to_process passed from main, which is NUM_TISSUES_EFFECTIVE
    files_to_loop = raw_files_found[:min(num_tissues_to_process, actual_tissues_avail)]

    if len(files_to_loop) != num_tissues_to_process:
        log_message(f"Warning: Will create edge indices for {len(files_to_loop)} tissues (based on raw files), not {num_tissues_to_process}.")

    for raw_file_path in tqdm(files_to_loop, desc="Creating Gene-Gene edge indices"):
        tissue_name = os.path.basename(raw_file_path).replace('.csv', '')
        fname_suffix = f"_gg_th{threshold_std_factor:.1f}" + (f"_maxE{max_edges_per_tissue_approx//1000}k" if max_edges_per_tissue_approx else "") + ".pt"
        edge_index_file = os.path.join(edge_index_dir, tissue_name + fname_suffix)
        current_edge_index = None

        if os.path.exists(edge_index_file) and not force_rebuild:
            try: current_edge_index = torch.load(edge_index_file)
            except: log_message(f"Error loading {edge_index_file}. Recreating."); current_edge_index = None
        
        if current_edge_index is None:
            # log_message(f"Creating Gene-Gene edge index for {tissue_name}...") # Can be too verbose
            try:
                if not os.path.exists(raw_file_path): raise FileNotFoundError(f"Raw file missing: {raw_file_path}")
                df_raw = pd.read_csv(raw_file_path, index_col=0)
                # Ensure row (TF) alignment to the master list of TFs determined from preprocessing
                # This uses num_tfs_in_raw for the expected count from the *master* TF list.
                # df_raw_tf_aligned = df_raw.reindex(index=master_tf_ids_list_from_preprocessing, fill_value=0.0) 
                # The above is complex if master_tf_ids_list_from_preprocessing isn't available here easily.
                # Assuming df_raw has the TFs we care about (num_tfs_in_raw of them)
                if df_raw.shape[0] != num_tfs_in_raw:
                    log_message(f"Warning: Raw TF count in {tissue_name} ({df_raw.shape[0]}) "
                                f"differs from expected master TF count ({num_tfs_in_raw}) for edge creation.")
                    # Decide on handling: reindex to master_tfs, or use available TFs in file?
                    # For now, proceed with TFs available in the file, but this means num_tfs_in_raw might vary by file in this loop.
                    # This part needs careful sync with how num_tfs_in_raw is passed.

                df_raw_aligned_cols = df_raw.reindex(columns=gene_order_for_raw_alignment, fill_value=0.0)
                if df_raw_aligned_cols.shape[1] != num_genes_in_graph: raise ValueError(f"Gene alignment error for {tissue_name}")

                abs_scores_matrix = np.abs(df_raw_aligned_cols.values) # (N_tfs_in_file, N_genes_in_graph)
                all_gene_pairs = set()

                for tf_idx in range(df_raw_aligned_cols.shape[0]): # Iterate over TFs IN THIS FILE
                    tf_scores = abs_scores_matrix[tf_idx, :]
                    non_zero_tf = tf_scores[tf_scores > 1e-6] # Consider scores effectively non-zero
                    if non_zero_tf.size < 2: continue # Need at least two genes for a pair

                    mean_tf = np.mean(non_zero_tf); std_tf = np.std(non_zero_tf)
                    robust_std_tf = max(std_tf, 1e-3 * (mean_tf if mean_tf > 1e-6 else 0.1)) # Avoid zero std
                    tf_specific_threshold = mean_tf + threshold_std_factor * robust_std_tf
                    
                    strongly_regulated_indices = np.where(tf_scores > tf_specific_threshold)[0]
                    if len(strongly_regulated_indices) >= 2:
                        genes_list = list(strongly_regulated_indices)
                        for i in range(len(genes_list)):
                            for j in range(i + 1, len(genes_list)):
                                all_gene_pairs.add(tuple(sorted((genes_list[i], genes_list[j]))))
                                if max_edges_per_tissue_approx and len(all_gene_pairs) > max_edges_per_tissue_approx * 1.1: break
                            if max_edges_per_tissue_approx and len(all_gene_pairs) > max_edges_per_tissue_approx * 1.1: break
                    if max_edges_per_tissue_approx and len(all_gene_pairs) > max_edges_per_tissue_approx * 1.1: break
                
                if max_edges_per_tissue_approx and len(all_gene_pairs) > max_edges_per_tissue_approx:
                    all_gene_pairs = set(random.sample(list(all_gene_pairs), max_edges_per_tissue_approx))
                
                src, trg = [], []
                for g1, g2 in all_gene_pairs: src.extend([g1,g2]); trg.extend([g2,g1])
                
                if src: current_edge_index = torch.tensor(np.vstack((src, trg)), dtype=torch.long)
                else: current_edge_index = torch.empty((2,0), dtype=torch.long)
                
                torch.save(current_edge_index, edge_index_file)
                # log_message(f"Saved Gene-Gene edge index for {tissue_name} ({len(all_gene_pairs)} undir edges) to {edge_index_file}")

            except Exception as e:
                log_message(f"Error creating edge index for {tissue_name}: {e}")
                current_edge_index = torch.empty((2,0), dtype=torch.long) # Empty on error
        
        edge_indices_list.append(current_edge_index)

    log_message(f"Finished processing {len(edge_indices_list)} Gene-Gene edge indices.")
    return edge_indices_list

## **5. PyTorch Dataset Class**

The `PromoterGATDataset` class prepares individual samples for the model. It stores integer-encoded sequences, log-transformed prior features, and binary labels. A key function is to provide `aligned_gat_idx` for each sample, which maps the gene ID to the correct row in the precomputed GAT output tensor, enabling feature selection in the `FullModel`'s forward pass.

In [5]:
# %% 5. PyTorch Dataset Class
# ============================================================================
class PromoterGATDataset(Dataset):
    """PyTorch Dataset for multi-modal promoter prediction."""
    def __init__(self, sequences, gene_ids_master_for_dataset, labels, biological_priors_for_dataset, final_gene_order_in_gat_data):
        n_seq = len(sequences)
        if not (n_seq == len(gene_ids_master_for_dataset) == len(labels) == len(biological_priors_for_dataset)):
            raise ValueError(f"Dataset input length mismatch: seq={n_seq}, ids={len(gene_ids_master_for_dataset)}, labels={len(labels)}, priors={len(biological_priors_for_dataset)}")

        self.sequences = torch.tensor(sequences, dtype=torch.int64)
        self.labels = torch.tensor(labels, dtype=torch.float32).unsqueeze(1)
        self.priors = torch.tensor(biological_priors_for_dataset, dtype=torch.float32)
        self.gat_gene_to_idx_map = {gene: idx for idx, gene in enumerate(final_gene_order_in_gat_data)}
        self.sample_id_to_aligned_gat_idx = np.array([self.gat_gene_to_idx_map.get(gid, -1) for gid in gene_ids_master_for_dataset], dtype=np.int64)
        num_missing = (self.sample_id_to_aligned_gat_idx == -1).sum()
        if num_missing > 0: log_message(f"Dataset Warning: {num_missing}/{len(gene_ids_master_for_dataset)} sample gene IDs not in final GAT gene order. Will have aligned_gat_idx = -1.")

    def __len__(self): return len(self.labels)
    def __getitem__(self, idx):
        return {'sequence': self.sequences[idx], 'priors': self.priors[idx], 'label': self.labels[idx], 'aligned_gat_idx': self.sample_id_to_aligned_gat_idx[idx]}


## **6. Model Architecture Classes**

This section defines the PyTorch modules for the multi-modal architecture.
*   **`GATLayer` & `GATNetwork`**: Implement the graph attention mechanism for processing tissue-specific gene-gene networks (derived from TF co-regulation) and aggregating information across tissues.
*   **`TransformerBranch`**: Processes DNA sequences using embeddings, positional encoding, and multi-head self-attention.
*   **`BWAFusionLayer` (Novel Component)**: The core of the proposed fusion strategy. It uses the biological prior features to learn attention weights ($\alpha_{\text{seq}}, \alpha_{\text{graph}}$). These weights then modulate the outputs of the Transformer and GAT branches. The modulated features are summed and then concatenated with the original priors before passing to a final classifier. This allows prior knowledge to dynamically influence the contribution of sequence versus network information.
*   **`FullModel`**: Integrates all branches and the BWAF layer. It includes a method to precompute the GAT branch output for training efficiency.

### Mathematical Formulation of BWAF
Given sequence features $h_{\text{seq}} \in \mathbb{R}^{d_{\text{embed}}}$, graph features $h_{\text{graph}} \in \mathbb{R}^{d_{\text{embed}}}$, and prior features $h_{\text{prior}} \in \mathbb{R}^{d_{\text{priors}}}$ for a sample:
1.  Attention weights from priors:
    \[ \alpha_{\text{seq}} = \sigma(W_{s} h_{\text{prior}} + b_{s}) \quad ; \quad \alpha_{\text{graph}} = \sigma(W_{g} h_{\text{prior}} + b_{g}) \]
2.  Modulate features: \(h_{\text{seq\_w}} = \alpha_{\text{seq}} \odot h_{\text{seq}}\) ; \(h_{\text{graph\_w}} = \alpha_{\text{graph}} \odot h_{\text{graph}}\)
3.  Combine modulated: \(h_{\text{fused\_wgtd}} = h_{\text{seq\_w}} + h_{\text{graph\_w}}\)
4.  Final integration: \(h_{\text{final}} = \text{concat}(h_{\text{fused\_wgtd}}, h_{\text{prior}})\)
5.  Prediction: \(\text{logits} = \text{Classifier}(\text{LayerNorm}(h_{\text{final}}))\)

In [6]:
# %% 6. Model Architecture Classes
# ============================================================================

class GATLayer(nn.Module):
    """
    Wrapper for PyTorch Geometric's GATConv layer.
    This layer performs graph attention on node features based on the graph structure.

    Args:
        in_channels (int): Number of input features per node.
        out_channels (int): Number of output features per node (per head).
        heads (int): Number of attention heads.
        dropout (float): Dropout rate for attention coefficients.
        concat (bool): If True, head outputs are concatenated; otherwise, averaged.
        activation_fn (callable, optional): Activation function to apply after GAT convolution.
                                           PyG's GATConv uses LeakyReLU internally if its own `act` is not specified.
        add_self_loops (bool): If True, adds self-loops to the adjacency matrix.
    """
    def __init__(self, in_channels, out_channels, heads=1, dropout=0.6,
                 concat=True, activation_fn=F.elu, add_self_loops=True):
        super().__init__()
        self.activation_fn = activation_fn
        # GATConv has built-in LeakyReLU (negative_slope=0.2) if activation is None
        # If we want a different activation, we pass it here. If None, GATConv uses LeakyReLU.
        # For clarity, if we always want ELU (or another specific one post-conv), we can apply it externally.
        self.gat_conv = GATConv(in_channels, out_channels, heads=heads, dropout=dropout,
                                concat=concat, add_self_loops=add_self_loops,
                                negative_slope=0.2) # Default for LeakyReLU used in GAT paper
        self.concat = concat
        self.out_channels_per_head = out_channels

    def forward(self, x, edge_index):
        """
        x: Node features (Num_Nodes, In_Channels)
        edge_index: Graph connectivity (2, Num_Edges)
        """
        h = self.gat_conv(x, edge_index)
        # Apply activation *after* the convolution if specified and different from GATConv's internal
        if self.activation_fn:
             h = self.activation_fn(h)
        return h

    def get_output_dim(self):
        """Returns the output dimension of this layer."""
        if self.concat:
            return self.out_channels_per_head * self.gat_conv.heads
        else: # Heads are averaged by GATConv if concat=False
            return self.out_channels_per_head

class GATNetwork(nn.Module):
    """
    Processes multiple tissue-specific gene graphs using GAT layers.
    Applies GAT layers independently for each tissue and then aggregates
    the resulting gene embeddings (e.g., by averaging).

    Args:
        num_genes_nodes (int): Total number of unique genes (nodes) in the graphs.
        initial_node_feature_dim (int): Dimensionality of the input features for each gene (e.g., N_TFs).
        num_tissues_to_process (int): The number of tissue-specific graphs/edge_indices expected.
        hidden_dim (int): Target output dimension of the GAT branch features (per gene) after aggregation.
        num_layers (int): Number of GAT layers to stack per tissue.
        heads (int): Number of attention heads for intermediate GAT layers.
        final_heads (int): Number of attention heads for the final GAT layer.
        dropout (float): Dropout rate.
    """
    def __init__(self, num_genes_nodes, initial_node_feature_dim, num_tissues_to_process,
                 hidden_dim, num_layers=NUM_GAT_LAYERS, heads=GAT_HEADS,
                 final_heads=GAT_FINAL_HEADS, dropout=DROPOUT_RATE):
        super().__init__()
        if num_layers < 1: raise ValueError("Number of GAT layers must be at least 1.")

        self.num_tissues_to_process = num_tissues_to_process
        self.num_layers = num_layers
        self.layers = nn.ModuleList()
        self.dropout_layer = nn.Dropout(p=dropout)

        in_channels = initial_node_feature_dim
        for i in range(self.num_layers):
            is_final_layer = (i == self.num_layers - 1)
            current_heads = heads if not is_final_layer else final_heads
            # Intermediate layers concat, final layer averages to get hidden_dim
            concat_heads = not is_final_layer
            # Output channels per head. If final layer averages, it directly outputs hidden_dim.
            # If intermediate layer concats, each head outputs hidden_dim / current_heads.
            out_ch_per_head = hidden_dim
            if concat_heads:
                if hidden_dim % current_heads != 0:
                    raise ValueError(f"GATNetwork: hidden_dim ({hidden_dim}) must be divisible by "
                                     f"intermediate heads ({current_heads}) if concat=True.")
                out_ch_per_head = hidden_dim // current_heads

            # Apply activation for all layers except the very last one (its output is aggregated)
            # GATConv already has LeakyReLU, so external activation is only if we want something else.
            # Usually, ELU or ReLU is applied *after* GATConv output.
            layer_activation = F.elu if not is_final_layer else None # No explicit activation on last GAT layer output before aggregation

            gat_layer = GATLayer(in_channels, out_ch_per_head, heads=current_heads,
                                 dropout=dropout, concat=concat_heads,
                                 activation_fn=layer_activation, add_self_loops=True)
            self.layers.append(gat_layer)
            in_channels = gat_layer.get_output_dim() # This will be hidden_dim if final, or hidden_dim if intermediate concat

        self.final_output_dim = in_channels # This should be hidden_dim after the GAT stack

    def forward(self, x_node, edge_indices_list):
        """
        Applies GAT layers per tissue and aggregates the final gene embeddings.
        x_node: Initial node features (Num_Genes_Aligned, Initial_Node_Feature_Dim).
        edge_indices_list: List of edge_index tensors for each tissue.
        """
        num_available_edge_indices = len(edge_indices_list)
        if num_available_edge_indices == 0 and self.num_tissues_to_process > 0:
            log_message("GATNetwork FWD Error: No edge indices provided but tissues expected. Returning zeros.")
            return torch.zeros(x_node.shape[0], self.final_output_dim, device=x_node.device)

        # Determine how many tissues will actually be processed
        actual_tissues_to_loop_over = min(num_available_edge_indices, self.num_tissues_to_process)
        if actual_tissues_to_loop_over != self.num_tissues_to_process:
             log_message(f"GATNetwork FWD Warning: Processing {actual_tissues_to_loop_over} tissues "
                         f"(based on {num_available_edge_indices} available edge_indices), "
                         f"though GATNetwork configured for {self.num_tissues_to_process} tissues.")

        tissue_final_layer_outputs = []
        for i in range(actual_tissues_to_loop_over):
            h = x_node # Start with the same initial features for each tissue-specific pass
            edge_index = edge_indices_list[i]

            if edge_index is None or edge_index.numel() == 0 :
                 log_message(f"GATNetwork FWD: Skipping tissue index {i} due to missing or empty edge_index.")
                 tissue_final_layer_outputs.append(torch.zeros(x_node.shape[0], self.final_output_dim, device=x_node.device))
                 continue
            if x_node.shape[0] > 0 and edge_index.max().item() >= x_node.shape[0]: # Check edge indices bounds
                log_message(f"GATNetwork FWD ERROR: Max edge_idx {edge_index.max().item()} >= num_nodes {x_node.shape[0]} "
                            f"for tissue index {i}. Using zeros for this tissue output.")
                tissue_final_layer_outputs.append(torch.zeros(x_node.shape[0], self.final_output_dim, device=x_node.device))
                continue

            for k_layer_idx, layer_module in enumerate(self.layers):
                 h = layer_module(h, edge_index)
                 # Apply dropout *after* activation and *before* next GAT layer (for intermediate layers)
                 if k_layer_idx < self.num_layers - 1 : # Check against self.num_layers
                      if layer_module.activation_fn is not None: # Only apply dropout if there was an activation
                           h = self.dropout_layer(h)
            tissue_final_layer_outputs.append(h)

        if not tissue_final_layer_outputs: # If all tissues were skipped
            log_message("GATNetwork FWD Error: No valid GAT outputs generated across any tissue.")
            return torch.zeros(x_node.shape[0], self.final_output_dim, device=x_node.device)

        stacked_embeddings = torch.stack(tissue_final_layer_outputs, dim=0)
        aggregated_embedding = torch.mean(stacked_embeddings, dim=0)
        return aggregated_embedding


class TransformerBranch(nn.Module):
    """
    Transformer branch for sequence processing using nn.TransformerEncoder.
    Includes embedding, learnable positional encoding, multi-head self-attention,
    masked aggregation, and a final preparation layer for fusion.
    """
    def __init__(self, vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN, embed_dim=EMBEDDING_DIM,
                 num_heads=NUM_ATTN_HEADS, ff_dim=TRANSFORMER_FF_DIM,
                 num_layers=NUM_TRANSFORMER_LAYERS, dropout=DROPOUT_RATE):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=PAD_IDX)
        self.positional_encoding = nn.Parameter(torch.randn(1, seq_len, embed_dim), requires_grad=True)
        self.embed_dropout = nn.Dropout(p=dropout)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim,
            dropout=dropout, activation='relu', batch_first=True, norm_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.fusion_prep_layer = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, embed_dim), # Output dimension is embed_dim
            nn.ReLU(),
            nn.Dropout(p=dropout)
        )

    def forward(self, seq_data):
        """Forward pass for the Transformer branch."""
        N, L_actual = seq_data.shape
        x = self.embedding(seq_data) # (N, L_actual, E)
        pos_enc = self.positional_encoding[:, :L_actual, :] # (1, L_actual, E)
        x = x + pos_enc
        x = self.embed_dropout(x)

        padding_mask = (seq_data == PAD_IDX) # (N, L_actual)
        transformer_output = self.transformer_encoder(x, src_key_padding_mask=padding_mask) # (N, L_actual, E)

        mask = (~padding_mask).unsqueeze(-1).float() # (N, L_actual, 1)
        summed_output = (transformer_output * mask).sum(dim=1) # (N, E)
        num_non_padding = mask.sum(dim=1).clamp(min=1.0) # (N, 1), prevent div by zero
        aggregated_output = summed_output / num_non_padding # (N, E)

        seq_features = self.fusion_prep_layer(aggregated_output) # (N, E)
        return seq_features


class BWAFusionLayer(nn.Module):
    """
    Implements Biologically Weighted Attention Fusion (BWAF).
    Uses biological priors (h_prior) to generate attention weights (alpha_seq, alpha_graph)
    that dynamically modulate sequence and graph features before fusion.
    """
    def __init__(self, seq_dim, graph_dim, prior_dim, common_feature_dim,
                 fusion_hidden_dim=FUSION_HIDDEN_DIM, dropout=DROPOUT_RATE):
        super().__init__()
        self.prior_dim = prior_dim # Store prior_dim
        self.common_feature_dim = common_feature_dim

        # Projection layers to make seq_dim and graph_dim equal to common_feature_dim
        self.project_seq = nn.Linear(seq_dim, self.common_feature_dim) if seq_dim != self.common_feature_dim else nn.Identity()
        self.project_graph = nn.Linear(graph_dim, self.common_feature_dim) if graph_dim != self.common_feature_dim else nn.Identity()

        # Attention Weight Generation
        attn_gen_input_dim = max(1, prior_dim) # Handle prior_dim = 0 for Linear layer
        attn_gen_hidden = max(16, attn_gen_input_dim // 2) if attn_gen_input_dim > 0 else 16

        self.weight_generator_seq = nn.Sequential(
            nn.Linear(attn_gen_input_dim, attn_gen_hidden), nn.ReLU(), nn.Linear(attn_gen_hidden, 1)
        )
        self.weight_generator_graph = nn.Sequential(
            nn.Linear(attn_gen_input_dim, attn_gen_hidden), nn.ReLU(), nn.Linear(attn_gen_hidden, 1)
        )

        # Classifier Part
        # Input dimension = common_feature_dim (from weighted sum) + prior_dim (original priors concatenated)
        final_combined_dim = self.common_feature_dim + prior_dim # Use original prior_dim for cat size
        self.layer_norm = nn.LayerNorm(final_combined_dim)

        self.fc_block = nn.Sequential(
            nn.Linear(final_combined_dim, fusion_hidden_dim), nn.ReLU(), nn.Dropout(p=dropout),
            nn.Linear(fusion_hidden_dim, fusion_hidden_dim // 2), nn.ReLU(), nn.Dropout(p=dropout),
            nn.Linear(fusion_hidden_dim // 2, 1) # Output raw logits
        )

    def forward(self, seq_features, graph_features, prior_features):
        # Handle case where prior_features might be empty if prior_dim is 0
        if self.prior_dim > 0:
            alpha_seq = torch.sigmoid(self.weight_generator_seq(prior_features))
            alpha_graph = torch.sigmoid(self.weight_generator_graph(prior_features))
        else: # Fallback if no prior features (prior_dim = 0)
            # Create tensors of ones with the correct batch size and device
            batch_s = seq_features.size(0)
            alpha_seq = torch.ones(batch_s, 1, device=seq_features.device) * 0.5
            alpha_graph = torch.ones(batch_s, 1, device=graph_features.device) * 0.5

        seq_proj = self.project_seq(seq_features)
        graph_proj = self.project_graph(graph_features)

        seq_weighted = alpha_seq * seq_proj
        graph_weighted = alpha_graph * graph_proj
        weighted_fused = seq_weighted + graph_weighted # Element-wise sum

        # Ensure prior_features is 2D for concatenation even if prior_dim is 0
        priors_to_concat = prior_features if self.prior_dim > 0 else torch.empty(weighted_fused.size(0), 0, device=weighted_fused.device)

        final_input = torch.cat([weighted_fused, priors_to_concat], dim=1)

        # Apply LayerNorm before the fully connected block
        return self.fc_block(self.layer_norm(final_input))


class FullModel(nn.Module):
    """Integrates Transformer, GAT, and BWAFusionLayer."""
    def __init__(self, num_genes_nodes, initial_node_feature_dim, num_tissues_to_process, prior_dim,
                 embed_dim=EMBEDDING_DIM, num_attn_heads=NUM_ATTN_HEADS,
                 transformer_ff_dim=TRANSFORMER_FF_DIM, num_transformer_layers=NUM_TRANSFORMER_LAYERS,
                 gat_hidden_dim=EMBEDDING_DIM, num_gat_layers=NUM_GAT_LAYERS,
                 gat_heads=GAT_HEADS, gat_final_heads=GAT_FINAL_HEADS,
                 fusion_hidden_dim=FUSION_HIDDEN_DIM, dropout=DROPOUT_RATE):
        super().__init__()
        self.transformer = TransformerBranch(
            vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN, embed_dim=embed_dim, num_heads=num_attn_heads,
            ff_dim=transformer_ff_dim, num_layers=num_transformer_layers, dropout=dropout
        )
        self.gat = GATNetwork(
            num_genes_nodes=num_genes_nodes,
            initial_node_feature_dim=initial_node_feature_dim,
            num_tissues_to_process=num_tissues_to_process,
            hidden_dim=gat_hidden_dim, # GAT output dim before aggregation
            num_layers=num_gat_layers,
            heads=gat_heads,
            final_heads=gat_final_heads,
            dropout=dropout
        )
        gat_output_dim = self.gat.final_output_dim # Actual output dim from GAT branch

        self.fusion = BWAFusionLayer(
            seq_dim=embed_dim,          # Output dim of Transformer branch
            graph_dim=gat_output_dim,   # Output dim of GAT branch
            prior_dim=prior_dim,        # Dimension of prior features
            common_feature_dim=embed_dim, # Target dim for weighted sum in BWAF
            fusion_hidden_dim=fusion_hidden_dim,
            dropout=dropout
        )

        self.precomputed_gat_output = None
        self.gat_device = 'cpu' # Device where GAT output was precomputed

    def precompute_gat(self, x_node_static, edge_indices_list_static, device):
        """Computes and stores the GAT branch output on CPU for efficiency."""
        log_message("Precomputing GAT branch output...")
        start_time = time.time()
        self.gat.eval() # Set GAT to evaluation mode for precomputation
        try:
            with torch.no_grad():
                x_node_static_dev = x_node_static.to(device)
                # Ensure edge indices are moved, handle None gracefully
                edge_indices_list_static_dev = [
                    ei.to(device) if ei is not None and ei.numel() > 0 else None
                    for ei in edge_indices_list_static
                ]
                # Pass only non-None, non-empty edge indices to GAT forward
                valid_edge_indices = [ei for ei in edge_indices_list_static_dev if ei is not None]

                if not valid_edge_indices:
                    log_message("ERROR during GAT precomputation: No valid edge indices found after filtering! GAT output will be None.")
                    self.precomputed_gat_output = None
                    # Potentially create a zero tensor of expected shape to avoid downstream errors if critical
                    # num_nodes = x_node_static.shape[0]
                    # self.precomputed_gat_output = torch.zeros(num_nodes, self.gat.final_output_dim, device='cpu')
                    return # Exit if no valid edges to process

                # If GATNetwork's num_tissues_to_process is strict, this could be an issue
                if len(valid_edge_indices) != self.gat.num_tissues_to_process:
                     log_message(f"Warning during GAT precomputation: Processing with {len(valid_edge_indices)} "
                                 f"valid edge indices. GATNetwork configured for {self.gat.num_tissues_to_process} tissues.")

                self.precomputed_gat_output = self.gat(x_node_static_dev, valid_edge_indices).cpu() # Store on CPU
                self.gat_device = device # Remember the device it ran on
        except Exception as e:
            log_message(f"Error during GAT precomputation: {e}")
            self.precomputed_gat_output = None; # Reset on error
            import traceback; log_message(traceback.format_exc())
            # raise # Optionally re-raise to halt execution
        finally:
            self.gat.train() # Set GAT back to train mode

        elapsed = time.time() - start_time
        if self.precomputed_gat_output is not None:
            log_message(f"GAT output precomputed on {device}. Shape: {self.precomputed_gat_output.shape}. Time: {elapsed:.2f}s.")
        else:
            log_message(f"GAT precomputation failed or resulted in no output after {elapsed:.2f}s.")


    def forward(self, batch):
        """Forward pass implementing the BWAF architecture."""
        current_device = batch['sequence'].device # Assume batch data is already on the correct device

        # 1. Transformer Branch
        seq_features = self.transformer(batch['sequence']) # (N, embed_dim)

        # 2. GAT Branch Output (Retrieve Precomputed)
        if self.precomputed_gat_output is None:
             # Fallback: return zeros if GAT output is not available, log critical warning
             log_message("CRITICAL RUNTIME ERROR: GAT output is None in forward pass. Using zeros for graph features.")
             batch_graph_features = torch.zeros(seq_features.size(0), self.gat.final_output_dim, device=current_device)
        else:
            # Move precomputed output to the current batch's device
            h_graph_agg_all_genes = self.precomputed_gat_output.to(current_device) # (N_genes_aligned, gat_output_dim)

            # 3. Select GAT features for the current batch using aligned indices
            aligned_gat_indices = batch['aligned_gat_idx'] # (N,) tensor of indices or -1

            # Ensure aligned_gat_indices is a tensor and on the correct device for indexing
            if not isinstance(aligned_gat_indices, torch.Tensor):
                aligned_gat_indices_tensor = torch.tensor(aligned_gat_indices, device=h_graph_agg_all_genes.device, dtype=torch.long)
            else:
                aligned_gat_indices_tensor = aligned_gat_indices.to(device=h_graph_agg_all_genes.device, dtype=torch.long)

            valid_indices_mask = (aligned_gat_indices_tensor != -1) # Boolean mask
            valid_gat_indices_for_indexing = aligned_gat_indices_tensor[valid_indices_mask] # Actual indices to use

            batch_graph_features = torch.zeros(seq_features.size(0), self.gat.final_output_dim, device=current_device)

            if valid_gat_indices_for_indexing.numel() > 0: # If there are any valid genes for this batch
                if h_graph_agg_all_genes.shape[0] == 0:
                     log_message(f"RUNTIME WARNING: Precomputed GAT output (h_graph_agg_all_genes) is empty. Graph features will be zero.")
                # Check bounds before indexing
                elif valid_gat_indices_for_indexing.max().item() >= h_graph_agg_all_genes.shape[0]:
                     log_message(f"RUNTIME ERROR: Max index {valid_gat_indices_for_indexing.max().item()} for GAT features out of bounds ({h_graph_agg_all_genes.shape[0]})")
                     # Zeros will be used as batch_graph_features was initialized with zeros
                else:
                     selected_graph_features = h_graph_agg_all_genes.index_select(0, valid_gat_indices_for_indexing)
                     # Use the boolean mask (on the current_device) to place the selected features
                     batch_graph_features[valid_indices_mask.to(current_device)] = selected_graph_features
            # If valid_gat_indices_for_indexing.numel() == 0, batch_graph_features remains zeros

        # 4. BWAF Fusion Layer
        prior_features = batch['priors'] # (N, prior_dim)
        logits = self.fusion(seq_features, batch_graph_features, prior_features) # (N, 1)

        return logits

## **7. Training Function**

This function (`train_model`) handles the model training loop, including forward/backward passes, optimization, validation, learning rate scheduling, model saving, and loss plotting.

In [7]:
# %% 7. Training Function
# ============================================================================
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device,
                model_save_path=MODEL_SAVE_PATH, loss_plot_path=LOSS_PLOT_PATH, scheduler=None):
    train_losses, val_losses = [], []
    best_val_loss = float('inf')
    if model.precomputed_gat_output is None and hasattr(model, 'precompute_gat'):
         log_message("CRITICAL ERROR: GAT output must be precomputed for training."); return None

    log_message(f"--- Starting Training for {num_epochs} Epochs ---")
    for epoch in range(num_epochs):
        start_time = time.time(); model.train(); running_train_loss = 0.0
        train_loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]", leave=False)
        for batch in train_loop:
            batch_on_device = {k: v.to(device) if isinstance(v,torch.Tensor) else v for k,v in batch.items()}
            labels = batch_on_device['label']
            optimizer.zero_grad(); logits = model(batch_on_device); loss = criterion(logits, labels)
            loss.backward(); optimizer.step(); running_train_loss += loss.item()
            train_loop.set_postfix(loss=f"{loss.item():.4f}")

        epoch_train_loss = running_train_loss/len(train_loader) if len(train_loader)>0 else 0.0
        train_losses.append(epoch_train_loss)

        model.eval(); running_val_loss = 0.0
        val_loop = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]", leave=False)
        with torch.no_grad():
            for batch in val_loop:
                batch_on_device = {k: v.to(device) if isinstance(v,torch.Tensor) else v for k,v in batch.items()}
                labels = batch_on_device['label']
                logits = model(batch_on_device); loss = criterion(logits, labels)
                running_val_loss += loss.item(); val_loop.set_postfix(loss=f"{loss.item():.4f}")
        epoch_val_loss = running_val_loss/len(val_loader) if len(val_loader)>0 else 0.0
        val_losses.append(epoch_val_loss)
        current_lr = optimizer.param_groups[0]['lr']
        if scheduler: scheduler.step(epoch_val_loss)

        log_message(f"E {epoch+1}/{num_epochs} - TrL: {epoch_train_loss:.4f}, VaL: {epoch_val_loss:.4f}, Dur: {time.time()-start_time:.2f}s, LR: {current_lr:.2e}")
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss; torch.save(model.state_dict(), model_save_path)
            log_message(f"Saved best model (Val Loss: {best_val_loss:.4f})")

    try: # Plotting
        plt.figure(figsize=(10,6)); epochs_range = range(1,num_epochs+1) if num_epochs > 0 else [1]
        plt.plot(epochs_range,train_losses,label='Train Loss',marker='.'); plt.plot(epochs_range,val_losses,label='Val Loss',marker='.')
        plt.xlabel('Epochs'); plt.ylabel('Loss'); plt.title('Training & Validation Loss'); plt.legend(); plt.grid(True,linestyle=':'); plt.tight_layout(); plt.savefig(loss_plot_path, dpi=300); plt.close()
        log_message(f"Saved loss plot to {loss_plot_path}")
    except Exception as e: log_message(f"Error plotting loss: {e}")
    log_message(f"--- Training Finished --- Best Val Loss: {best_val_loss:.4f}")
    return model

## **8. Evaluation Function**

This section defines functions for model evaluation on the test set:
*   `plot_confusion_matrix`: Generates and saves a heatmap visualization of the confusion matrix.
*   `evaluate_model`: Performs inference on the test set, calculates classification metrics (Accuracy, Precision, Recall, F1, Specificity, AUC-ROC, TP/FP/TN/FN), logs results, plots the confusion matrix, and saves metrics to a CSV file.

In [8]:
# %% 8. Evaluation Function
# ============================================================================
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues, save_path=CONFUSION_MATRIX_PATH):
    """Plots and saves the confusion matrix."""
    try:
        if normalize:
            cm_sum = cm.sum(axis=1)[:, np.newaxis]
            cm_plot = np.divide(cm.astype('float'), cm_sum, out=np.zeros_like(cm,dtype=float), where=cm_sum!=0)
            fmt = '.2f'; plot_title = title + ' (Normalized)'
        else:
            fmt = 'd'; cm_plot = cm; plot_title = title
        plt.figure(figsize=(8,6)); sns.heatmap(cm_plot, annot=True, fmt=fmt, cmap=cmap, xticklabels=classes, yticklabels=classes, square=True, cbar=False, linewidths=.5, linecolor='grey', annot_kws={"size":10}); plt.title(plot_title, fontsize=14); plt.ylabel('True label', fontsize=12); plt.xlabel('Predicted label', fontsize=12); plt.xticks(fontsize=10); plt.yticks(fontsize=10,rotation=0); plt.tight_layout(); plt.savefig(save_path, dpi=300); plt.close()
        log_message(f"Saved CM to {save_path}")
    except Exception as e: log_message(f"Error plotting CM: {e}")

def evaluate_model(model, test_loader, criterion, device, results_csv_path=RESULTS_CSV_PATH):
    """Evaluates model, logs metrics, plots CM, saves results."""
    model.eval(); all_labels, all_preds, all_probs = [], [], []; running_test_loss = 0.0
    if model.precomputed_gat_output is None and hasattr(model, 'precompute_gat'): log_message("CRITICAL: GAT output not precomputed for eval."); return None

    log_message(f"--- Evaluating on test set ({len(test_loader)} batches) ---")
    test_loop = tqdm(test_loader, desc="Evaluating [Test]", leave=False)
    with torch.no_grad():
        for batch in test_loop:
            batch_on_device = {k: v.to(device) if isinstance(v,torch.Tensor) else v for k,v in batch.items()}
            labels = batch_on_device['label']
            logits = model(batch_on_device); loss = criterion(logits, labels)
            running_test_loss += loss.item()
            probs = torch.sigmoid(logits); preds = (probs > 0.5).float()
            all_labels.extend(labels.cpu().numpy()); all_preds.extend(preds.cpu().numpy()); all_probs.extend(probs.cpu().numpy())

    test_loss = running_test_loss/len(test_loader) if len(test_loader)>0 else 0.0
    all_labels=np.array(all_labels).flatten(); all_preds=np.array(all_preds).flatten(); all_probs=np.array(all_probs).flatten()
    
    accuracy=accuracy_score(all_labels,all_preds); precision=precision_score(all_labels,all_preds,zero_division=0)
    recall=recall_score(all_labels,all_preds,zero_division=0); f1=f1_score(all_labels,all_preds,zero_division=0)
    
    # Robust confusion matrix calculation
    unique_test_labels = np.unique(all_labels)
    cm_sklearn_labels = [0,1] if len(unique_test_labels) == 2 else unique_test_labels.tolist() if len(unique_test_labels) == 1 else [0,1] # Default to [0,1] if empty or unexpected
    if not cm_sklearn_labels: cm_sklearn_labels = [0,1] # Final fallback

    cm = confusion_matrix(all_labels, all_preds, labels=cm_sklearn_labels)

    if cm.shape == (2,2): tn,fp,fn,tp = cm.ravel()
    elif cm.shape == (1,1) and 0 in cm_sklearn_labels : tn,fp,fn,tp = cm[0,0],0,0,0
    elif cm.shape == (1,1) and 1 in cm_sklearn_labels : tn,fp,fn,tp = 0,0,0,cm[0,0]
    else: log_message(f"CM shape {cm.shape}, labels {cm_sklearn_labels} unhandled."); tn,fp,fn,tp = 0,0,0,0
    
    specificity = tn/(tn+fp) if (tn+fp)>0 else 0.0
    auc = roc_auc_score(all_labels, all_probs) if len(unique_test_labels)>1 else np.nan
    if np.isnan(auc): log_message(f"AUC is NaN (likely single class in labels/preds for test set).")

    results = {'loss':test_loss, 'acc':accuracy, 'prec':precision, 'rec':recall, 'spec':specificity, 'f1':f1, 'auc':auc, 'TP':int(tp),'FP':int(fp),'TN':int(tn),'FN':int(fn)}
    log_message("\n--- Test Set Results ---")
    for k,v in results.items(): log_message(f"{k.upper()}: {v:.4f}" if isinstance(v,float) else f"{k.upper()}: {v}")
    plot_confusion_matrix(cm, ['Non-P','P'], normalize=False, title='Test CM (Counts)')
    plot_confusion_matrix(cm, ['Non-P','P'], normalize=True, save_path=CONFUSION_MATRIX_PATH.replace('.png','_norm.png'))
    try: pd.DataFrame([results]).to_csv(results_csv_path,index=False); log_message(f"Saved test results to {results_csv_path}")
    except IOError as e: log_message(f"Error saving test CSV: {e}")
    return results

## **9. Main Execution Block**

This block orchestrates the entire pipeline from argument parsing to final evaluation. It ensures data is loaded and preprocessed in the correct order, the model is initialized, GAT outputs are precomputed, training occurs, and the best model is evaluated.

In [9]:
# %% 9. Main Execution Block
# ============================================================================
if __name__ == "__main__":
    log_message("--- Starting BWAF Multi-Modal Promoter Prediction Workflow ---")
    script_start_time = time.time()

    # --- Argument Parsing ---
    class Args: """Mock class for arguments when not using command-line parsing."""
    if 'ipykernel' in sys.modules or 'google.colab' in sys.modules:
        log_message("Interactive environment detected, using default args.")
        args = Args(); args.epochs=NUM_EPOCHS; args.batch_size=BATCH_SIZE; args.lr=LEARNING_RATE; args.embed_dim=EMBEDDING_DIM
        args.dropout=DROPOUT_RATE; args.gat_layers=NUM_GAT_LAYERS; args.transformer_layers=NUM_TRANSFORMER_LAYERS
        args.gat_threshold_factor=GAT_INTERACTION_THRESHOLD_STD_FACTOR; args.fusion_hidden=FUSION_HIDDEN_DIM
        args.force_gat_preprocess=False; args.force_edge_rebuild=False; args.no_cuda=False
        # --- Example: Manual override for interactive testing ---
        # args.epochs = 2 # For quick test
        # args.batch_size = 16
        # args.force_edge_rebuild = True # Uncomment to test edge creation
    else:
        parser = argparse.ArgumentParser(description="Train BWAF Multi-Modal Promoter Prediction Model")
        parser.add_argument('--epochs',type=int,default=NUM_EPOCHS, help=f"Epochs (def: {NUM_EPOCHS})")
        parser.add_argument('--batch_size',type=int,default=BATCH_SIZE, help=f"Batch size (def: {BATCH_SIZE})")
        parser.add_argument('--lr',type=float,default=LEARNING_RATE, help=f"Learning rate (def: {LEARNING_RATE})")
        parser.add_argument('--embed_dim',type=int,default=EMBEDDING_DIM, help=f"Embed/Hidden dim (def: {EMBEDDING_DIM})")
        parser.add_argument('--dropout',type=float,default=DROPOUT_RATE, help=f"Dropout rate (def: {DROPOUT_RATE})")
        parser.add_argument('--gat_layers',type=int,default=NUM_GAT_LAYERS, help=f"GAT layers (def: {NUM_GAT_LAYERS})")
        parser.add_argument('--transformer_layers',type=int,default=NUM_TRANSFORMER_LAYERS, help=f"Transformer layers (def: {NUM_TRANSFORMER_LAYERS})")
        parser.add_argument('--gat_threshold_factor',type=float,default=GAT_INTERACTION_THRESHOLD_STD_FACTOR, help=f"GAT edge TF co-reg threshold factor (def: {GAT_INTERACTION_THRESHOLD_STD_FACTOR})")
        parser.add_argument('--fusion_hidden',type=int,default=FUSION_HIDDEN_DIM, help=f"Fusion hidden dim (def: {FUSION_HIDDEN_DIM})")
        parser.add_argument('--force_gat_preprocess',action='store_true', help='Force filtering/normalization of GAT data.')
        parser.add_argument('--force_edge_rebuild',action='store_true', help='Force rebuilding GAT edge indices.')
        parser.add_argument('--no_cuda',action='store_true', help='Disable CUDA.');
        args = parser.parse_args()

    # Update Config from Args
    NUM_EPOCHS=args.epochs; BATCH_SIZE=args.batch_size; LEARNING_RATE=args.lr; EMBEDDING_DIM=args.embed_dim
    DROPOUT_RATE=args.dropout; NUM_GAT_LAYERS=args.gat_layers; NUM_TRANSFORMER_LAYERS=args.transformer_layers
    GAT_INTERACTION_THRESHOLD_STD_FACTOR=args.gat_threshold_factor; FUSION_HIDDEN_DIM=args.fusion_hidden
    if args.no_cuda: DEVICE=torch.device("cpu"); log_message("CUDA disabled by argument.")
    log_message(f"Runtime Config: Epochs={NUM_EPOCHS}, Batch={BATCH_SIZE}, LR={LEARNING_RATE}, Embed={EMBEDDING_DIM}, Dropout={DROPOUT_RATE}, GAT ThrFactor={GAT_INTERACTION_THRESHOLD_STD_FACTOR}, FusionHidden={FUSION_HIDDEN_DIM}")

    try:
        # 1. Load Sequences & Get Master Gene List for Alignment
        log_message("\n--- Step 1: Loading Sequence Data ---")
        prom_seqs, prom_ids_raw, prom_labels = load_sequences(PROMOTER_SEQ_FILE, True)
        nonprom_seqs, nonprom_ids_raw, nonprom_labels = load_sequences(NON_PROMOTER_SEQ_FILE, False)
        if not prom_seqs or not nonprom_seqs: raise ValueError("Sequence loading failed.")
        all_sequences_raw = prom_seqs + nonprom_seqs
        all_gene_ids_for_samples = prom_ids_raw + nonprom_ids_raw # Order for dataset samples
        all_labels = prom_labels + nonprom_labels
        log_message(f"Total sequences loaded and QC'd: {len(all_sequences_raw)}")
        # Master list for aligning GAT/Priors is unique genes from all samples
        master_gene_id_list_for_alignment = pd.Series(all_gene_ids_for_samples).drop_duplicates(keep='first').tolist()
        log_message(f"Unique master gene IDs for GAT/Prior alignment: {len(master_gene_id_list_for_alignment)}")

        # 2. Load Priors (aligned to master_gene_id_list_for_alignment)
        log_message("\n--- Step 2: Loading and Processing Priors ---")
        priors_features_aligned_to_master_list, prior_dim = load_priors(PRIOR_FILE, master_gene_id_list_for_alignment)
        # Now map these aligned priors back to the order of all_gene_ids_for_samples for the Dataset
        gene_to_aligned_prior_map = {gene_id: priors_features_aligned_to_master_list[i] for i, gene_id in enumerate(master_gene_id_list_for_alignment)}
        final_priors_for_dataset = np.array([gene_to_aligned_prior_map.get(gid, np.zeros(prior_dim, dtype=np.float32)) for gid in all_gene_ids_for_samples])
        log_message(f"Prior dimension: {prior_dim}. Priors prepared for {len(final_priors_for_dataset)} dataset samples.")

        # 3. Encode Sequences
        log_message("\n--- Step 3: Encoding Sequences ---")
        all_sequences_encoded = np.array([integer_encode_sequence(seq) for seq in tqdm(all_sequences_raw, desc="Encoding sequences")])

        # 4. Process GAT Data
        log_message("\n--- Step 4: Processing GAT Data ---")
        num_tfs_master, num_genes_in_aligned_gat_data, final_gene_order_for_gat_cols, master_tf_ids_final = preprocess_all_gat_data(
            GAT_RAW_DATA_DIR, GAT_PREPROCESSED_DIR, master_gene_id_list_for_alignment, args.force_gat_preprocess
        )
        initial_node_features_for_gat, _ = load_processed_gat_data(GAT_PREPROCESSED_DIR, final_gene_order_for_gat_cols, master_tf_ids_final)
        log_message(f"Initial GAT node features (X_node). Shape: {initial_node_features_for_gat.shape}")

        # Determine effective number of tissues based on raw files for edge index creation
        _raw_gat_files = glob.glob(os.path.join(GAT_RAW_DATA_DIR, '*.csv'))
        effective_num_tissues_for_edges = len(_raw_gat_files) if _raw_gat_files else 0
        if effective_num_tissues_for_edges == 0: raise FileNotFoundError(f"No raw GAT files found in {GAT_RAW_DATA_DIR} for edge index generation.")
        if effective_num_tissues_for_edges != NUM_TISSUES:
             log_message(f"Adjusting NUM_TISSUES for GAT edge/model from {NUM_TISSUES} to {effective_num_tissues_for_edges} based on raw files.")

        edge_indices_list = load_or_create_edge_indices(
            num_genes_in_graph=num_genes_in_aligned_gat_data, # Genes in GAT columns
            num_tfs_in_raw=num_tfs_master, # TFs from master list
            num_tissues_to_process=effective_num_tissues_for_edges,
            raw_gat_data_dir=GAT_RAW_DATA_DIR,
            gene_order_for_raw_alignment=final_gene_order_for_gat_cols, # Gene order for raw matrix alignment
            edge_index_dir=GAT_EDGE_INDEX_DIR,
            threshold_std_factor=GAT_INTERACTION_THRESHOLD_STD_FACTOR,
            force_rebuild=args.force_edge_rebuild,
            max_edges_per_tissue_approx=MAX_EDGES_PER_TISSUE_APPROX
        )

        # 5. Datasets & DataLoaders
        log_message("\n--- Step 5: Creating Datasets and DataLoaders ---")
        # PromoterGATDataset needs gene order from GAT data for mapping
        full_dataset = PromoterGATDataset(all_sequences_encoded, all_gene_ids_for_samples, all_labels,
                                          final_priors_for_dataset, final_gene_order_for_gat_cols)
        # ... (Splitting logic and DataLoader creation remains the same)
        dataset_size=len(full_dataset); indices=list(range(dataset_size)); np.random.seed(RANDOM_SEED); np.random.shuffle(indices)
        test_split_idx=int(np.floor(TEST_SPLIT*dataset_size)); val_split_idx=test_split_idx+int(np.floor(VALIDATION_SPLIT*dataset_size))
        train_indices = indices[val_split_idx:]; val_indices = indices[test_split_idx:val_split_idx]; test_indices = indices[:test_split_idx]
        if not train_indices or not val_indices or not test_indices: raise ValueError("Dataset splitting resulted in empty sets.")
        train_dataset=Subset(full_dataset,train_indices); val_dataset=Subset(full_dataset,val_indices); test_dataset=Subset(full_dataset,test_indices)
        log_message(f"Dataset split: Train={len(train_dataset)}, Val={len(val_dataset)}, Test={len(test_dataset)}")
        num_workers = min(os.cpu_count() // 2 if os.cpu_count() is not None else 0, 4) if DEVICE.type=='cuda' else 0
        drop_last_train = (len(train_dataset) % BATCH_SIZE == 1) # Avoid single-sample batch if it causes issues
        train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=num_workers,pin_memory=(DEVICE.type=='cuda'),drop_last=drop_last_train)
        val_loader = DataLoader(val_dataset,batch_size=BATCH_SIZE,shuffle=False,num_workers=num_workers,pin_memory=(DEVICE.type=='cuda'))
        test_loader = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False,num_workers=num_workers,pin_memory=(DEVICE.type=='cuda'))


        # 6. Initialize Model
        log_message("\n--- Step 6: Initializing BWAF Model ---")
        model = FullModel(
            num_genes_nodes=initial_node_features_for_gat.shape[0], # num_genes_in_aligned_gat_data
            initial_node_feature_dim=initial_node_features_for_gat.shape[1], # num_tfs_master
            num_tissues_to_process=effective_num_tissues_for_edges, # Use actual number of tissues with edges
            prior_dim=prior_dim,
            embed_dim=EMBEDDING_DIM, num_attn_heads=NUM_ATTN_HEADS, transformer_ff_dim=TRANSFORMER_FF_DIM,
            num_transformer_layers=NUM_TRANSFORMER_LAYERS, gat_hidden_dim=EMBEDDING_DIM, num_gat_layers=NUM_GAT_LAYERS,
            gat_heads=GAT_HEADS, gat_final_heads=GAT_FINAL_HEADS, fusion_hidden_dim=FUSION_HIDDEN_DIM, dropout=DROPOUT_RATE
        ).to(DEVICE)
        total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        log_message(f"Model Initialized. Trainable Params: {total_params:,}")

        # 7. Precompute GAT
        log_message("\n--- Step 7: Precomputing GAT Output ---")
        model.precompute_gat(initial_node_features_for_gat.to(DEVICE), edge_indices_list, DEVICE)
        del initial_node_features_for_gat; # Free memory
        if torch.cuda.is_available(): torch.cuda.empty_cache()

        # 8. Training Components
        criterion = nn.BCEWithLogitsLoss()
        optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=OPTIMIZER_WEIGHT_DECAY)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=5)

        # 9. Train
        log_message("\n--- Step 8: Training Model ---")
        trained_model = train_model(model,train_loader,val_loader,criterion,optimizer,NUM_EPOCHS,DEVICE,scheduler=scheduler)
        if trained_model is None: raise RuntimeError("Model training failed.")

        # 10. Evaluate
        log_message("\n--- Step 9: Evaluating Best Model ---")
        best_model_path_to_load = MODEL_SAVE_PATH
        log_message(f"Loading best model state from '{best_model_path_to_load}' for final evaluation.")
        if os.path.exists(best_model_path_to_load):
            eval_model = FullModel( # Re-initialize structure
                 num_genes_nodes=num_genes_in_aligned_gat_data, initial_node_feature_dim=num_tfs_master,
                 num_tissues_to_process=effective_num_tissues_for_edges, prior_dim=prior_dim,
                 embed_dim=EMBEDDING_DIM, num_attn_heads=NUM_ATTN_HEADS, transformer_ff_dim=TRANSFORMER_FF_DIM,
                 num_transformer_layers=NUM_TRANSFORMER_LAYERS, gat_hidden_dim=EMBEDDING_DIM, num_gat_layers=NUM_GAT_LAYERS,
                 gat_heads=GAT_HEADS, gat_final_heads=GAT_FINAL_HEADS, fusion_hidden_dim=FUSION_HIDDEN_DIM, dropout=DROPOUT_RATE
            ).to(DEVICE)
            try:
                eval_model.load_state_dict(torch.load(best_model_path_to_load, map_location=DEVICE))
                # Copy precomputed GAT output to the eval model instance
                if trained_model.precomputed_gat_output is not None: # Use from model instance that was trained
                    eval_model.precomputed_gat_output = trained_model.precomputed_gat_output
                    eval_model.gat_device = trained_model.gat_device
                else: # Fallback: re-precompute if not available on trained_model (should be)
                     log_message("Re-precomputing GAT for loaded eval_model as it was not available from trained model instance.")
                     temp_initial_node_features, _ = load_processed_gat_data(GAT_PREPROCESSED_DIR, final_gene_order_from_gat_processing, master_tf_ids_final)
                     temp_edge_indices_list = load_or_create_edge_indices(
                        num_genes_in_aligned_gat_data, num_tfs_master, effective_num_tissues_for_edges, GAT_RAW_DATA_DIR, final_gene_order_from_gat_processing,
                        GAT_EDGE_INDEX_DIR, GAT_INTERACTION_THRESHOLD_STD_FACTOR, False, MAX_EDGES_PER_TISSUE_APPROX
                     )
                     eval_model.precompute_gat(temp_initial_node_features.to(DEVICE), temp_edge_indices_list, DEVICE)
                     del temp_initial_node_features, temp_edge_indices_list
                     if torch.cuda.is_available(): torch.cuda.empty_cache()
            except Exception as e:
                log_message(f"Error loading best model state dict: {e}. Evaluating model from last epoch instead.")
                eval_model = trained_model # Fallback to last epoch model
        else:
            log_message(f"Warning: Best model file '{best_model_path_to_load}' not found. Evaluating model from last epoch.")
            eval_model = trained_model

        test_results = evaluate_model(eval_model, test_loader, criterion, DEVICE, results_csv_path=RESULTS_CSV_PATH)

        script_end_time = time.time()
        total_duration_seconds = script_end_time - script_start_time
        log_message(f"\n--- Workflow Completed Successfully in {total_duration_seconds:.2f} seconds ({total_duration_seconds/3600:.2f} hours) ---")

    except FileNotFoundError as fnf_error:
        log_message(f"\n--- WORKFLOW ERROR: FILE NOT FOUND ---"); log_message(f"Error: {fnf_error}")
        log_message("Please ensure all data paths in Section 2 are correct and files exist."); sys.exit(1)
    except ValueError as val_error:
        log_message(f"\n--- WORKFLOW ERROR: VALUE ERROR ---"); log_message(f"Error: {val_error}")
        import traceback; log_message("Traceback:\n" + traceback.format_exc()); sys.exit(1)
    except RuntimeError as rt_error:
        log_message(f"\n--- PYTORCH RUNTIME ERROR (often CUDA memory or tensor shape issues) ---"); log_message(f"Error: {rt_error}")
        import traceback; log_message("Traceback:\n" + traceback.format_exc()); sys.exit(1)
    except Exception as main_error:
        log_message(f"\n--- CRITICAL UNEXPECTED WORKFLOW ERROR ---"); log_message(f"Type: {type(main_error).__name__}"); log_message(f"Msg: {main_error}")
        import traceback; log_message("Traceback:\n" + traceback.format_exc()); sys.exit(1)

[2025-06-04 18:28:57] --- Starting BWAF Multi-Modal Promoter Prediction Workflow ---
[2025-06-04 18:28:57] Interactive environment detected, using default args.
[2025-06-04 18:28:57] Runtime Config: Epochs=10, Batch=16, LR=0.0005, Embed=64, Dropout=0.3, GAT ThrFactor=2.5, FusionHidden=128
[2025-06-04 18:28:57] 
--- Step 1: Loading Sequence Data ---
[2025-06-04 18:28:57] Loading sequences from data/raw/human_genome_annotation/updated_promoter_features_clean.csv...
[2025-06-04 18:28:57] Sequence QC for updated_promoter_features_clean.csv: Total initial: 20028. Removed 0 (with 'N'). Removed 0 (not 2000bp). Final: 20028.
[2025-06-04 18:28:57] Loaded and filtered 20028 sequences from updated_promoter_features_clean.csv in 0.79s.
[2025-06-04 18:28:57] Loading sequences from data/raw/human_genome_annotation/updated_non_promoter_sequences.csv...
[2025-06-04 18:28:58] Sequence QC for updated_non_promoter_sequences.csv: Total initial: 20028. Removed 0 (with 'N'). Removed 0 (not 2000bp). Final: 2

Encoding sequences:   0%|          | 0/40056 [00:00<?, ?it/s]

[2025-06-04 18:29:16] 
--- Step 4: Processing GAT Data ---
[2025-06-04 18:29:16] Preprocessing GAT data: data/raw/gene_interaction_network/GRAND_networks -> data/preprocessed/gat_normalized
[2025-06-04 18:29:16] Processed GAT files found. Verifying TF/Gene alignment...
[2025-06-04 18:29:20] GAT preprocessing complete. Master TFs: 644, Aligned Genes: 20028
[2025-06-04 18:29:20] Loading processed GAT data from data/preprocessed/gat_normalized (aligning to 20028 genes, 644 TFs)...


Loading processed GAT matrices:   0%|          | 0/36 [00:00<?, ?it/s]

[2025-06-04 18:31:55] Loaded 36 GAT matrices, created initial node features in 154.72s.
[2025-06-04 18:31:55] Initial GAT node features (X_node). Shape: torch.Size([20028, 644])
[2025-06-04 18:31:55] GAT Edge Indices (TF Co-reg Gene-Gene, ThrFactor: 2.5, MaxEdges: 1000000)...


Creating Gene-Gene edge indices:   0%|          | 0/36 [00:00<?, ?it/s]

[2025-06-04 18:31:56] Finished processing 36 Gene-Gene edge indices.
[2025-06-04 18:31:56] 
--- Step 5: Creating Datasets and DataLoaders ---
[2025-06-04 18:31:56] Dataset split: Train=28040, Val=6008, Test=6008
[2025-06-04 18:31:56] 
--- Step 6: Initializing BWAF Model ---




[2025-06-04 18:31:56] Model Initialized. Trainable Params: 292,601
[2025-06-04 18:31:56] 
--- Step 7: Precomputing GAT Output ---
[2025-06-04 18:31:56] Precomputing GAT branch output...
[2025-06-04 18:32:15] GAT output precomputed on cpu. Shape: torch.Size([20028, 64]). Time: 19.13s.
[2025-06-04 18:32:15] 
--- Step 8: Training Model ---
[2025-06-04 18:32:15] --- Starting Training for 10 Epochs ---


Epoch 1/10 [Train]:   0%|          | 0/1753 [00:00<?, ?it/s]

Epoch 1/10 [Val]:   0%|          | 0/376 [00:00<?, ?it/s]

[2025-06-05 02:01:46] E 1/10 - TrL: 0.4090, VaL: 0.1203, Dur: 26970.47s, LR: 5.00e-04
[2025-06-05 02:01:46] Saved best model (Val Loss: 0.1203)


Epoch 2/10 [Train]:   0%|          | 0/1753 [00:00<?, ?it/s]

Epoch 2/10 [Val]:   0%|          | 0/376 [00:00<?, ?it/s]

[2025-06-05 09:29:27] E 2/10 - TrL: 0.0807, VaL: 0.0301, Dur: 26860.80s, LR: 5.00e-04
[2025-06-05 09:29:27] Saved best model (Val Loss: 0.0301)


Epoch 3/10 [Train]:   0%|          | 0/1753 [00:00<?, ?it/s]

Epoch 3/10 [Val]:   0%|          | 0/376 [00:00<?, ?it/s]

[2025-06-05 17:03:21] E 3/10 - TrL: 0.0363, VaL: 0.0124, Dur: 27234.67s, LR: 5.00e-04
[2025-06-05 17:03:21] Saved best model (Val Loss: 0.0124)


Epoch 4/10 [Train]:   0%|          | 0/1753 [00:00<?, ?it/s]

Epoch 4/10 [Val]:   0%|          | 0/376 [00:00<?, ?it/s]

[2025-06-06 00:29:27] E 4/10 - TrL: 0.0250, VaL: 0.0123, Dur: 26765.42s, LR: 5.00e-04
[2025-06-06 00:29:27] Saved best model (Val Loss: 0.0123)


Epoch 5/10 [Train]:   0%|          | 0/1753 [00:00<?, ?it/s]

Epoch 5/10 [Val]:   0%|          | 0/376 [00:00<?, ?it/s]

[2025-06-06 07:49:49] E 5/10 - TrL: 0.0205, VaL: 0.0089, Dur: 26422.10s, LR: 5.00e-04
[2025-06-06 07:49:49] Saved best model (Val Loss: 0.0089)


Epoch 6/10 [Train]:   0%|          | 0/1753 [00:00<?, ?it/s]

Epoch 6/10 [Val]:   0%|          | 0/376 [00:00<?, ?it/s]

[2025-06-06 15:15:55] E 6/10 - TrL: 0.0167, VaL: 0.0086, Dur: 26765.95s, LR: 5.00e-04
[2025-06-06 15:15:55] Saved best model (Val Loss: 0.0086)


Epoch 7/10 [Train]:   0%|          | 0/1753 [00:00<?, ?it/s]

Epoch 7/10 [Val]:   0%|          | 0/376 [00:00<?, ?it/s]

[2025-06-06 23:00:55] E 7/10 - TrL: 0.0160, VaL: 0.0109, Dur: 27900.61s, LR: 5.00e-04


Epoch 8/10 [Train]:   0%|          | 0/1753 [00:00<?, ?it/s]

Epoch 8/10 [Val]:   0%|          | 0/376 [00:00<?, ?it/s]

[2025-06-07 07:22:37] E 8/10 - TrL: 0.0135, VaL: 0.0060, Dur: 30101.59s, LR: 5.00e-04
[2025-06-07 07:22:37] Saved best model (Val Loss: 0.0060)


Epoch 9/10 [Train]:   0%|          | 0/1753 [00:00<?, ?it/s]

Epoch 9/10 [Val]:   0%|          | 0/376 [00:00<?, ?it/s]

[2025-06-07 16:06:44] E 9/10 - TrL: 0.0115, VaL: 0.0073, Dur: 31446.88s, LR: 5.00e-04


Epoch 10/10 [Train]:   0%|          | 0/1753 [00:00<?, ?it/s]

Epoch 10/10 [Val]:   0%|          | 0/376 [00:00<?, ?it/s]

[2025-06-08 00:46:28] E 10/10 - TrL: 0.0112, VaL: 0.0113, Dur: 31183.75s, LR: 5.00e-04
[2025-06-08 00:46:28] Saved loss plot to results_bwaf_v3/training_validation_loss_bwaf.png
[2025-06-08 00:46:28] --- Training Finished --- Best Val Loss: 0.0060
[2025-06-08 00:46:28] 
--- Step 9: Evaluating Best Model ---
[2025-06-08 00:46:28] Loading best model state from 'results_bwaf_v3/best_promoter_model_bwaf.pth' for final evaluation.
[2025-06-08 00:46:28] --- Evaluating on test set (376 batches) ---


Evaluating [Test]:   0%|          | 0/376 [00:00<?, ?it/s]

[2025-06-08 01:14:48] 
--- Test Set Results ---
[2025-06-08 01:14:48] LOSS: 0.0060
[2025-06-08 01:14:48] ACC: 0.9987
[2025-06-08 01:14:48] PREC: 1.0000
[2025-06-08 01:14:48] REC: 0.9973
[2025-06-08 01:14:48] SPEC: 1.0000
[2025-06-08 01:14:48] F1: 0.9987
[2025-06-08 01:14:48] AUC: 0.9999
[2025-06-08 01:14:48] TP: 3010
[2025-06-08 01:14:48] FP: 0
[2025-06-08 01:14:48] TN: 2990
[2025-06-08 01:14:48] FN: 8
[2025-06-08 01:14:48] Saved CM to results_bwaf_v3/confusion_matrix_bwaf.png
[2025-06-08 01:14:48] Saved CM to results_bwaf_v3/confusion_matrix_bwaf_norm.png
[2025-06-08 01:14:48] Saved test results to results_bwaf_v3/test_set_evaluation_results_bwaf.csv
[2025-06-08 01:14:48] 
--- Workflow Completed Successfully in 283551.59 seconds (78.76 hours) ---
