# **Baseline Model: Replication of msBERT-Promoter**

**Objective:**
This notebook implements a baseline model inspired by the core concepts of "msBERT-Promoter" (Li et al., 2024), focusing on the promoter identification task. It features:
1.  Multi-scale k-mer tokenization (k=3, 4, 5, 6).
2.  Training individual Transformer-based models for each k-mer, with **epoch-level checkpointing and resuming capability**.
3.  The main loop **skips k-mer models that have already completed training** (unless forced by `force_kmer_model_rebuild`).
4.  Ensembling predictions using soft voting.

This baseline is trained on the same data splits as the main proposed model for fair comparison.

---

## **0. Setup and Imports**
This cell imports all required libraries.

In [1]:
# %% 0. Setup and Imports
# ============================================================================
# Standard library imports
import os
import re
import glob
import gzip
import time
import argparse
import datetime
import sys
import warnings
import random
import traceback # <<< IMPORT TRACEBACK HERE

# Third-party imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import train_test_split # For splitting indices
from sklearn.metrics import (accuracy_score, precision_score, recall_score,
                           f1_score, roc_auc_score, confusion_matrix)
from tqdm.notebook import tqdm # Use notebook version for Jupyter
import matplotlib.pyplot as plt
import seaborn as sns

# Suppress specific warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)

print("Imports successful.")
print(f"PyTorch Version: {torch.__version__}")

Imports successful.
PyTorch Version: 2.7.0+cu126


## **1. Configuration / Constants (for msBERT-Promoter Replication)**
Defines paths, model hyperparameters, and training settings.

In [2]:
# %% 1. Configuration / Constants for msBERT-Promoter Replication
# ============================================================================
# --- Data Paths ---
BASE_DATA_DIR_REPLICATION = './data/'
SEQ_DATA_DIR_REPLICATION = os.path.join(BASE_DATA_DIR_REPLICATION, 'raw/human_genome_annotation')
OUTPUT_DIR_MSBERT_REPLICATION = 'results_msBERT_replication'

PROMOTER_SEQ_FILE_REPLICATION = os.path.join(SEQ_DATA_DIR_REPLICATION, 'updated_promoter_features_clean.csv')
NON_PROMOTER_SEQ_FILE_REPLICATION = os.path.join(SEQ_DATA_DIR_REPLICATION, 'updated_non_promoter_sequences.csv')

# --- K-mer Model Specific Hyperparameters ---
KMER_VALUES = [3, 4, 5, 6]
KMER_STRIDE = 1
SEQ_LEN_FOR_KMER_LOADING = 2000

KMER_EMBED_DIM = 64
KMER_NUM_HEADS = 4
KMER_NUM_TRANSFORMER_LAYERS = 2
KMER_TRANSFORMER_FF_DIM = KMER_EMBED_DIM * 4
KMER_DROPOUT_RATE = 0.1

# --- Training Hyperparameters ---
LEARNING_RATE_KMER = 0.0001
BATCH_SIZE_KMER = 32
NUM_EPOCHS_KMER = 10
VALIDATION_SPLIT = 0.15
TEST_SPLIT = 0.15
RANDOM_SEED = 42
OPTIMIZER_WEIGHT_DECAY_KMER = 1e-5

# --- Output Files ---
os.makedirs(OUTPUT_DIR_MSBERT_REPLICATION, exist_ok=True)
KMER_MODEL_BEST_SAVE_PATH_TPL = os.path.join(OUTPUT_DIR_MSBERT_REPLICATION, 'best_kmer_model_k{k}.pth')
KMER_CHECKPOINT_DIR_TPL = os.path.join(OUTPUT_DIR_MSBERT_REPLICATION, "checkpoints_k{k}")
KMER_LOSS_PLOT_PATH_TPL = os.path.join(OUTPUT_DIR_MSBERT_REPLICATION, 'loss_kmer_k{k}.png')
ENSEMBLE_RESULTS_CSV_PATH = os.path.join(OUTPUT_DIR_MSBERT_REPLICATION, 'test_results_kmer_ensemble.csv')
ENSEMBLE_CM_PLOT_PATH_BASE = os.path.join(OUTPUT_DIR_MSBERT_REPLICATION, 'confusion_matrix_kmer_ensemble')
LOG_FILE_PATH_KMER = os.path.join(OUTPUT_DIR_MSBERT_REPLICATION, 'training_log_kmer_baseline.txt')

# --- Hardware ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Setup ---
random.seed(RANDOM_SEED); np.random.seed(RANDOM_SEED); torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(RANDOM_SEED)

print(f"K-mer Replication: Using device: {DEVICE}")
print(f"K-mer Replication: Output directory: {OUTPUT_DIR_MSBERT_REPLICATION}")

K-mer Replication: Using device: cpu
K-mer Replication: Output directory: results_msBERT_replication


## **2. Utility Functions (for K-mer Processing)**
Helper functions for logging, k-mer tokenization, vocabulary building, numericalizing k-mer sequences, and loading raw DNA sequences for this baseline.

In [3]:
# %% 2. Utility Functions (for K-mer Processing)
# ============================================================================
def log_message_kmer(message, log_file=LOG_FILE_PATH_KMER):
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    full_message = f"[{timestamp}] {message}"
    print(full_message)
    log_dir = os.path.dirname(log_file)
    if log_dir and not os.path.exists(log_dir): os.makedirs(log_dir)
    try:
        with open(log_file, 'a', encoding='utf-8') as f: f.write(full_message + '\n')
    except IOError as e: print(f"Error writing k-mer log to {log_file}: {e}")

try:
    with open(LOG_FILE_PATH_KMER, 'w', encoding='utf-8') as f: f.write(f"--- K-mer Baseline Log Initialized: {datetime.datetime.now()} ---\n")
    log_message_kmer(f"K-mer Baseline Log started at: {LOG_FILE_PATH_KMER}")
except IOError as e: print(f"CRITICAL ERROR: Could not write k-mer log {LOG_FILE_PATH_KMER}: {e}")

log_message_kmer(f"Using device: {DEVICE} for k-mer operations.")

KMER_PAD_TOKEN = "<PAD>"
KMER_UNK_TOKEN = "<UNK>"

def sequence_to_kmers(sequence, k, stride=KMER_STRIDE):
    kmers = []
    if not isinstance(sequence, str) or len(sequence) < k: return [KMER_PAD_TOKEN]
    for i in range(0, len(sequence) - k + 1, stride): kmers.append(sequence[i:i+k].upper())
    if not kmers: kmers = [KMER_PAD_TOKEN]
    return kmers

def build_kmer_vocab(all_raw_sequences_list, k_values_list_for_vocab, kmer_stride_for_vocab, min_freq=1):
    kmer_vocabs_map_build = {}
    for k_val_build in k_values_list_for_vocab:
        log_message_kmer(f"Building vocabulary for k={k_val_build} using stride={kmer_stride_for_vocab}...")
        all_kmers_for_this_k_build = []
        for seq_build in tqdm(all_raw_sequences_list, desc=f"Tokenizing for k={k_val_build} vocab"):
            if not isinstance(seq_build, str): continue
            all_kmers_for_this_k_build.extend(sequence_to_kmers(seq_build, k_val_build, stride=kmer_stride_for_vocab))
        kmer_counts_build = pd.Series(all_kmers_for_this_k_build).value_counts()
        frequent_kmers_build = kmer_counts_build[kmer_counts_build >= min_freq].index.tolist()
        kmer_to_idx_build = {KMER_PAD_TOKEN: 0, KMER_UNK_TOKEN: 1}
        for i_vocab, kmer_vocab_item in enumerate(frequent_kmers_build):
            kmer_to_idx_build[kmer_vocab_item] = i_vocab + 2
        idx_to_kmer_build = {idx: kmer for kmer, idx in kmer_to_idx_build.items()}
        kmer_vocabs_map_build[k_val_build] = {
            'kmer_to_idx': kmer_to_idx_build, 'idx_to_kmer': idx_to_kmer_build,
            'vocab_size': len(kmer_to_idx_build),
            'pad_idx': kmer_to_idx_build[KMER_PAD_TOKEN],
            'unk_idx': kmer_to_idx_build[KMER_UNK_TOKEN]}
        log_message_kmer(f"Vocab for k={k_val_build} size: {len(kmer_to_idx_build)}.")
    return kmer_vocabs_map_build

def kmer_numericalize_sequence(sequence_str, k_val_num, kmer_to_idx_map, max_kmer_seq_len_for_k, kmer_stride_for_num):
    if not isinstance(sequence_str, str): sequence_str = ""
    kmers_list = sequence_to_kmers(sequence_str, k_val_num, stride=kmer_stride_for_num)
    pad_idx_local = kmer_to_idx_map[KMER_PAD_TOKEN]
    unk_idx_local = kmer_to_idx_map[KMER_UNK_TOKEN]
    numericalized = [kmer_to_idx_map.get(kmer, unk_idx_local) for kmer in kmers_list]
    current_len = len(numericalized)
    if current_len < max_kmer_seq_len_for_k:
        numericalized.extend([pad_idx_local] * (max_kmer_seq_len_for_k - current_len))
    elif current_len > max_kmer_seq_len_for_k:
        numericalized = numericalized[:max_kmer_seq_len_for_k]
    return np.array(numericalized, dtype=np.int64)

def load_raw_sequences_for_kmer(file_path, is_promoter=True, seq_len_expected=SEQ_LEN_FOR_KMER_LOADING):
    log_message_kmer(f"Loading RAW sequences from {file_path} (expected len: {seq_len_expected})...")
    start_time = time.time()
    try:
        if not os.path.exists(file_path): raise FileNotFoundError(f"Seq file missing: {file_path}")
        df = None; encodings_to_try = ['utf-8', 'ISO-8859-1', 'latin1']
        for encoding in encodings_to_try:
            try: df = pd.read_csv(file_path, encoding=encoding, low_memory=False); break
            except: pass
        if df is None: raise ValueError(f"Could not read {file_path}")
        seq_col, id_col = None, None
        possible_seq_cols=['promoter_sequence','sequence']; possible_id_cols=['gene_id', df.columns[0] if not df.empty else None]
        for col_s in possible_seq_cols:
            if col_s in df.columns: seq_col=col_s; break
        for col_i in possible_id_cols:
            if col_i in df.columns: id_col=col_i; break
        if seq_col is None or id_col is None: raise ValueError(f"Required cols not in {file_path}. Has: {df.columns.tolist()}")
        df[seq_col]=df[seq_col].astype(str); df=df[df[seq_col].str.strip()!='']
        initial_raw_count = len(df) # Count before any QC for logging
        df=df[~df[seq_col].str.contains('N',na=False,case=False)]; removed_n = initial_raw_count - len(df)
        initial_count_after_N = len(df)
        if seq_len_expected and seq_len_expected > 0:
            df['seq_len_actual_temp'] = df[seq_col].str.len()
            df_correct_len = df[df['seq_len_actual_temp'] == seq_len_expected]
            removed_len = initial_count_after_N - len(df_correct_len)
            df = df_correct_len.drop(columns=['seq_len_actual_temp'])
            log_message_kmer(f"QC {os.path.basename(file_path)}: Initial Raw={initial_raw_count}. After N-rem={initial_count_after_N}. Len-rem:{removed_len}. Final:{len(df)}.")
        else:
            log_message_kmer(f"QC {os.path.basename(file_path)}: Initial Raw={initial_raw_count}. After N-rem:{initial_count_after_N}. Len check skip. Final:{len(df)}.")
        if df.empty: log_message_kmer(f"Warning: No valid raw sequences in {file_path}."); return [],[]
        raw_sequences_list = df[seq_col].tolist(); labels_list = [1 if is_promoter else 0]*len(raw_sequences_list)
        log_message_kmer(f"Loaded {len(raw_sequences_list)} RAW sequences from {os.path.basename(file_path)} in {time.time()-start_time:.2f}s.")
        return raw_sequences_list, labels_list
    except Exception as e: log_message_kmer(f"CRITICAL ERROR loading RAW seq file {file_path}: {e}"); raise

[2025-06-19 13:18:59] K-mer Baseline Log started at: results_msBERT_replication/training_log_kmer_baseline.txt
[2025-06-19 13:18:59] Using device: cpu for k-mer operations.


## **3. K-mer Transformer Model Definition**
Defines the `KmerTransformer` class.

In [4]:
# %% 3. K-mer Transformer Model Definition
# ============================================================================
class KmerTransformer(nn.Module):
    def __init__(self, kmer_vocab_size, max_kmer_seq_len,
                 embed_dim=KMER_EMBED_DIM, num_heads=KMER_NUM_HEADS,
                 ff_dim=KMER_TRANSFORMER_FF_DIM, num_layers=KMER_NUM_TRANSFORMER_LAYERS,
                 dropout=KMER_DROPOUT_RATE, kmer_pad_idx=0): # Default kmer_pad_idx to 0 as per vocab
        super().__init__()
        self.embedding = nn.Embedding(kmer_vocab_size, embed_dim, padding_idx=kmer_pad_idx)
        self.positional_encoding = nn.Parameter(torch.randn(1, max_kmer_seq_len, embed_dim))
        self.embed_dropout = nn.Dropout(p=dropout)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim,
            dropout=dropout, activation='relu', batch_first=True, norm_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.classifier = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, embed_dim // 2), nn.ReLU(), nn.Dropout(p=dropout),
            nn.Linear(embed_dim // 2, 1))

    def forward(self, kmer_idx_sequence):
        N, L_kmer_actual = kmer_idx_sequence.shape
        x = self.embedding(kmer_idx_sequence) + self.positional_encoding[:, :L_kmer_actual, :]
        x = self.embed_dropout(x)
        padding_mask = (kmer_idx_sequence == self.embedding.padding_idx)
        transformer_output = self.transformer_encoder(x, src_key_padding_mask=padding_mask)
        mask = (~padding_mask).unsqueeze(-1).float()
        aggregated_output = (transformer_output * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1.0)
        return self.classifier(aggregated_output)


## **4. K-mer Dataset Definition**
The `KmerDataset` class.

In [5]:
# %% 4. K-mer Dataset Definition
# ============================================================================
class KmerDataset(Dataset):
    def __init__(self, list_of_raw_dna_strings, list_of_labels, k_value,
                 kmer_vocabulary_dict, max_kmer_len_for_this_k, kmer_stride_val): # Pass stride
        self.raw_sequences = list_of_raw_dna_strings
        self.labels = torch.tensor(list_of_labels, dtype=torch.float32).unsqueeze(1)
        self.k = k_value
        self.kmer_to_idx = kmer_vocabulary_dict['kmer_to_idx']
        self.max_kmer_seq_len = max_kmer_len_for_this_k
        self.kmer_stride = kmer_stride_val # Store stride

    def __len__(self): return len(self.raw_sequences)
    def __getitem__(self, idx):
        raw_seq_item = self.raw_sequences[idx]
        if not isinstance(raw_seq_item, str): raw_seq_item = ""
        kmer_indices_item = kmer_numericalize_sequence(
            raw_seq_item, self.k, self.kmer_to_idx,
            self.max_kmer_seq_len, kmer_stride_for_num=self.kmer_stride # Use stored stride
        )
        return {'kmer_indices': kmer_indices_item, 'label': self.labels[idx]}

## **5. Training Function for Individual K-mer Models (with Checkpointing)**
The `train_kmer_model` function with epoch-wise checkpointing and resuming.

In [6]:
# %% 5. Training Function for Individual K-mer Models (with Checkpointing)
# ============================================================================
def train_kmer_model(k_value_train, kmer_model_instance, train_loader_k, val_loader_k,
                     criterion_k, optimizer_k, num_epochs_k, device_k,
                     model_best_save_path_tpl_k, loss_plot_path_tpl_k, scheduler_k=None,
                     checkpoint_dir_tpl_k=KMER_CHECKPOINT_DIR_TPL):

    k_checkpoint_dir = checkpoint_dir_tpl_k.format(k=k_value_train)
    os.makedirs(k_checkpoint_dir, exist_ok=True)
    k_model_best_save_path = model_best_save_path_tpl_k.format(k=k_value_train)
    k_loss_plot_actual = loss_plot_path_tpl_k.format(k=k_value_train)

    start_epoch = 0; train_losses_k, val_losses_k = [], []; best_val_loss_k = float('inf')

    checkpoint_files = sorted(
        glob.glob(os.path.join(k_checkpoint_dir, "checkpoint_epoch_*.pth")),
        key=lambda x: int(re.search(r"epoch_(\d+)\.pth", os.path.basename(x)).group(1)) if re.search(r"epoch_(\d+)\.pth", os.path.basename(x)) else -1, # Robust key
        reverse=True
    )
    if checkpoint_files:
        latest_checkpoint_path = checkpoint_files[0]
        log_message_kmer(f"Resuming k={k_value_train} from checkpoint: {latest_checkpoint_path}")
        try:
            checkpoint = torch.load(latest_checkpoint_path, map_location=device_k)
            kmer_model_instance.load_state_dict(checkpoint['model_state_dict'])
            optimizer_k.load_state_dict(checkpoint['optimizer_state_dict'])
            if scheduler_k and 'scheduler_state_dict' in checkpoint: scheduler_k.load_state_dict(checkpoint['scheduler_state_dict'])
            start_epoch = checkpoint['epoch'] + 1
            train_losses_k = checkpoint.get('train_losses', []); val_losses_k = checkpoint.get('val_losses', [])
            best_val_loss_k = checkpoint.get('best_val_loss', float('inf'))
            log_message_kmer(f"Resumed from epoch {start_epoch-1}. Best Val Loss: {best_val_loss_k:.4f}")
        except Exception as e:
            log_message_kmer(f"Error loading checkpoint {latest_checkpoint_path}: {e}. Training k={k_value_train} from scratch.")
            start_epoch = 0; train_losses_k, val_losses_k = [], []; best_val_loss_k = float('inf')
    
    if start_epoch >= num_epochs_k:
        log_message_kmer(f"k={k_value_train} already trained for {num_epochs_k} epochs. Skipping.")
        if os.path.exists(k_model_best_save_path):
            try: kmer_model_instance.load_state_dict(torch.load(k_model_best_save_path, map_location=device_k))
            except: log_message_kmer(f"Could not load best model for k={k_value_train}.")
        return kmer_model_instance

    log_message_kmer(f"--- Training KmerTransformer (k={k_value_train}) from epoch {start_epoch} to {num_epochs_k-1} ---")
    for epoch_k in range(start_epoch, num_epochs_k):
        epoch_start_time = time.time(); kmer_model_instance.train(); running_train_loss_k = 0.0
        train_loop_k = tqdm(train_loader_k, desc=f"k={k_value_train} E{epoch_k+1}/{num_epochs_k} [Train]", leave=False)
        for batch_k in train_loop_k:
            kmer_indices_k=batch_k['kmer_indices'].to(device_k); labels_k=batch_k['label'].to(device_k)
            optimizer_k.zero_grad(); logits_k=kmer_model_instance(kmer_indices_k); loss_k=criterion_k(logits_k,labels_k)
            loss_k.backward(); optimizer_k.step(); running_train_loss_k += loss_k.item()
            train_loop_k.set_postfix(loss=f"{loss_k.item():.4f}")
        epoch_train_loss_k = running_train_loss_k/len(train_loader_k) if len(train_loader_k)>0 else 0.0
        train_losses_k.append(epoch_train_loss_k)

        kmer_model_instance.eval(); running_val_loss_k = 0.0
        val_loop_k = tqdm(val_loader_k, desc=f"k={k_value_train} E{epoch_k+1}/{num_epochs_k} [Val]", leave=False)
        with torch.no_grad():
            for batch_k_val in val_loop_k:
                kmer_indices_k_val=batch_k_val['kmer_indices'].to(device_k); labels_k_val=batch_k_val['label'].to(device_k)
                logits_k_val=kmer_model_instance(kmer_indices_k_val); loss_k_val=criterion_k(logits_k_val,labels_k_val)
                running_val_loss_k += loss_k_val.item(); val_loop_k.set_postfix(loss=f"{loss_k_val.item():.4f}")
        epoch_val_loss_k = running_val_loss_k/len(val_loader_k) if len(val_loader_k)>0 else 0.0
        val_losses_k.append(epoch_val_loss_k)
        current_lr_k = optimizer_k.param_groups[0]['lr']
        if scheduler_k: scheduler_k.step(epoch_val_loss_k)
        log_message_kmer(f"k={k_value_train} E {epoch_k+1}/{num_epochs_k} - TrL: {epoch_train_loss_k:.4f}, VaL: {epoch_val_loss_k:.4f}, Dur: {time.time()-epoch_start_time:.2f}s, LR: {current_lr_k:.2e}")
        
        if epoch_val_loss_k < best_val_loss_k:
            best_val_loss_k = epoch_val_loss_k; torch.save(kmer_model_instance.state_dict(), k_model_best_save_path)
            log_message_kmer(f"k={k_value_train} Saved NEW BEST model (Val Loss: {best_val_loss_k:.4f}) to {k_model_best_save_path}")
        
        checkpoint_save_path = os.path.join(k_checkpoint_dir, f"checkpoint_epoch_{epoch_k:03d}.pth")
        checkpoint_data_save = {'epoch': epoch_k, 'model_state_dict': kmer_model_instance.state_dict(),
                                'optimizer_state_dict': optimizer_k.state_dict(), 'train_losses': train_losses_k,
                                'val_losses': val_losses_k, 'best_val_loss': best_val_loss_k}
        if scheduler_k: checkpoint_data_save['scheduler_state_dict'] = scheduler_k.state_dict()
        torch.save(checkpoint_data_save, checkpoint_save_path)

    if train_losses_k and val_losses_k and len(train_losses_k) > 0:
        try:
            epochs_plotted = len(train_losses_k); epochs_range_k = range(1, epochs_plotted + 1)
            plt.figure(figsize=(10,6)); plt.plot(epochs_range_k,train_losses_k,label=f'k={k_value_train} Train',marker='.'); plt.plot(epochs_range_k,val_losses_k,label=f'k={k_value_train} Val',marker='.')
            plt.xlabel('Epochs'); plt.ylabel('Loss'); plt.title(f'k={k_value_train} Train/Val Loss (Up to Epoch {epochs_plotted})'); plt.legend(); plt.grid(True,linestyle=':'); plt.tight_layout(); plt.savefig(k_loss_plot_actual, dpi=300); plt.close()
            log_message_kmer(f"k={k_value_train} Saved final loss plot to {k_loss_plot_actual}")
        except Exception as e: log_message_kmer(f"Error plotting k={k_value_train} loss: {e}")
    
    log_message_kmer(f"--- k={k_value_train} Training Finished/Resumed --- Best Val Loss: {best_val_loss_k:.4f}")
    if os.path.exists(k_model_best_save_path):
        try: kmer_model_instance.load_state_dict(torch.load(k_model_best_save_path, map_location=device_k))
        except: log_message_kmer(f"k={k_value_train} Error loading BEST model for return.")
    return kmer_model_instance

## **6. Ensemble Evaluation Function**
The `evaluate_kmer_ensemble` function combines predictions. Helper `RawSequenceDatasetForEnsemble`.

In [7]:
# %% 6. Ensemble Evaluation Function
# ============================================================================
class RawSequenceDatasetForEnsemble(Dataset):
    def __init__(self, list_of_raw_dna_strings, list_of_labels):
        self.raw_seqs = list_of_raw_dna_strings
        self.labels = torch.tensor(list_of_labels, dtype=torch.float32)
        if self.labels.ndim == 1: self.labels = self.labels.unsqueeze(1)
    def __len__(self): return len(self.raw_seqs)
    def __getitem__(self, idx):
        return {'raw_sequence_list': [self.raw_seqs[idx]], 'label_list': self.labels[idx]}

def plot_confusion_matrix_kmer(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues, save_path="cm.png"):
    try:
        if normalize:
            cm_sum = cm.sum(axis=1)[:, np.newaxis]
            cm_plot = np.divide(cm.astype('float'), cm_sum, out=np.zeros_like(cm,dtype=float), where=cm_sum!=0)
            fmt = '.2f'; plot_title = title + ' (Normalized)'
        else: fmt = 'd'; cm_plot = cm; plot_title = title
        plt.figure(figsize=(8,6)); sns.heatmap(cm_plot, annot=True, fmt=fmt, cmap=cmap, xticklabels=classes, yticklabels=classes, square=True, cbar=False, linewidths=.5, linecolor='grey', annot_kws={"size":10}); plt.title(plot_title, fontsize=14); plt.ylabel('True label', fontsize=12); plt.xlabel('Predicted label', fontsize=12); plt.xticks(fontsize=10); plt.yticks(fontsize=10,rotation=0); plt.tight_layout(); plt.savefig(save_path, dpi=300); plt.close()
        log_message_kmer(f"Saved CM to {save_path}")
    except Exception as e: log_message_kmer(f"Error plotting CM: {e}")

def evaluate_kmer_ensemble(dict_of_trained_k_models, test_raw_seq_loader, k_val_list,
                           kmer_vocab_map_eval, max_kmer_len_map_eval, device_eval,
                           kmer_stride_for_eval,
                           results_csv_path_ens=ENSEMBLE_RESULTS_CSV_PATH,
                           cm_plot_path_base_ens=ENSEMBLE_CM_PLOT_PATH_BASE):
    log_message_kmer(f"--- Evaluating K-mer Ensemble on Test Set ---")
    all_ensemble_labels, all_ensembled_probabilities = [], []
    for model_k in dict_of_trained_k_models.values():
        if model_k: model_k.eval()
    with torch.no_grad():
        test_raw_loop = tqdm(test_raw_seq_loader, desc="Ensemble Test Evaluation", leave=False)
        for batch_raw in test_raw_loop:
            raw_sequences_in_batch_outer = batch_raw['raw_sequence_list']
            raw_sequences_in_batch_flat = [item for sublist in raw_sequences_in_batch_outer for item in sublist]
            labels_in_batch = batch_raw['label_list']
            batch_all_k_probs_list = []
            for k_val_ens in k_val_list:
                model_instance_ens = dict_of_trained_k_models.get(k_val_ens)
                if model_instance_ens:
                    kmer_indices_for_batch_k = np.array([kmer_numericalize_sequence(
                        s, k_val_ens, kmer_vocab_map_eval[k_val_ens]['kmer_to_idx'],
                        max_kmer_len_map_eval[k_val_ens], kmer_stride_for_num=kmer_stride_for_eval
                    ) for s in raw_sequences_in_batch_flat])
                    kmer_indices_tensor_k = torch.tensor(kmer_indices_for_batch_k).to(device_eval)
                    logits_k_ens = model_instance_ens(kmer_indices_tensor_k); probs_k_ens = torch.sigmoid(logits_k_ens)
                    batch_all_k_probs_list.append(probs_k_ens)
                else: batch_all_k_probs_list.append(torch.full((len(raw_sequences_in_batch_flat),1),0.5,device=device_eval))
            if batch_all_k_probs_list:
                avg_probs_for_batch = torch.mean(torch.stack(batch_all_k_probs_list, dim=0), dim=0)
                all_ensembled_probabilities.extend(avg_probs_for_batch.cpu().numpy())
                all_ensemble_labels.extend(labels_in_batch.cpu().numpy() if isinstance(labels_in_batch,torch.Tensor) else labels_in_batch)
            else:
                all_ensemble_labels.extend(labels_in_batch.cpu().numpy() if isinstance(labels_in_batch,torch.Tensor) else labels_in_batch)
                all_ensembled_probabilities.extend(np.full((len(labels_in_batch),1),0.5))
    if not all_ensemble_labels: log_message_kmer("No predictions for ensemble."); return {}
    all_labels_np=np.array(all_ensemble_labels).flatten(); all_probs_np=np.array(all_ensembled_probabilities).flatten(); all_preds_np=(all_probs_np > 0.5).astype(float)
    acc=accuracy_score(all_labels_np,all_preds_np); prec=precision_score(all_labels_np,all_preds_np,zero_division=0)
    rec=recall_score(all_labels_np,all_preds_np,zero_division=0); f1=f1_score(all_labels_np,all_preds_np,zero_division=0)
    unique_lbls=np.unique(all_labels_np); cm_lbls=[0,1] if len(unique_lbls)==2 else unique_lbls.tolist() if len(unique_lbls)==1 else [0,1]
    if not cm_lbls: cm_lbls=[0,1]
    cm=confusion_matrix(all_labels_np,all_preds_np,labels=cm_lbls)
    if cm.shape==(2,2): tn,fp,fn,tp=cm.ravel()
    elif cm.shape==(1,1) and 0 in cm_lbls: tn,fp,fn,tp=cm[0,0],0,0,0
    elif cm.shape==(1,1) and 1 in cm_lbls: tn,fp,fn,tp=0,0,0,cm[0,0]
    else: log_message_kmer(f"Ens CM unhandled: {cm.shape},{cm_lbls}."); tn,fp,fn,tp=0,0,0,0
    spec=tn/(tn+fp) if (tn+fp)>0 else 0.0
    auc=roc_auc_score(all_labels_np,all_probs_np) if len(unique_lbls)>1 else np.nan
    if np.isnan(auc):log_message_kmer("Ens AUC is NaN.")
    results={'model':'msBERT_rep_ensemble','acc':acc,'prec':prec,'rec':rec,'spec':spec,'f1':f1,'auc':auc,'TP':int(tp),'FP':int(fp),'TN':int(tn),'FN':int(fn)}
    log_message_kmer("\n--- K-mer Ensemble Test Results ---")
    for k,v in results.items(): log_message_kmer(f"{k.upper()}: {v:.4f}" if isinstance(v,float) else f"{k.upper()}: {v}")
    plot_confusion_matrix_kmer(cm,['Non-P','P'],title='K-mer Ens Test CM (Counts)',save_path=cm_plot_path_base_ens+"_counts.png")
    plot_confusion_matrix_kmer(cm,['Non-P','P'],normalize=True,title='K-mer Ens Test CM (Norm)',save_path=cm_plot_path_base_ens+"_norm.png")
    pd.DataFrame([results]).to_csv(results_csv_path_ens,index=False); log_message_kmer(f"Saved K-mer Ens test results to {results_csv_path_ens}")
    return results

## **7. Main Execution Block for msBERT-Promoter Replication**
This block orchestrates the replication.

In [8]:
# %% 7. Main Execution Block for msBERT-Promoter Replication
# ============================================================================
if __name__ == "__main__":
    log_message_kmer(f"--- Starting msBERT-Promoter Replication Workflow (PID: {os.getpid()}) ---")
    msbert_start_time = time.time()

    class ArgsRep: pass
    args_rep = ArgsRep()
    args_rep.force_kmer_model_rebuild = False # Set True to retrain all k-mer models, even if checkpoints exist
    args_rep.kmer_epochs = NUM_EPOCHS_KMER
    args_rep.kmer_batch_size = BATCH_SIZE_KMER
    args_rep.kmer_lr = LEARNING_RATE_KMER
    # Example Override:
    # args_rep.force_kmer_model_rebuild = True

    log_message_kmer(f"K-mer Baseline Config: Epochs={args_rep.kmer_epochs}, Batch={args_rep.kmer_batch_size}, LR={args_rep.kmer_lr}, ForceRebuild={args_rep.force_kmer_model_rebuild}")

    try:
        log_message_kmer("\n--- Step 1: Loading Raw Sequence Data ---")
        promoter_seqs_raw, promoter_labels_raw = load_raw_sequences_for_kmer(
            PROMOTER_SEQ_FILE_REPLICATION, is_promoter=True, seq_len_expected=SEQ_LEN_FOR_KMER_LOADING)
        nonpromoter_seqs_raw, nonpromoter_labels_raw = load_raw_sequences_for_kmer(
            NON_PROMOTER_SEQ_FILE_REPLICATION, is_promoter=False, seq_len_expected=SEQ_LEN_FOR_KMER_LOADING)
        if not promoter_seqs_raw or not nonpromoter_seqs_raw: raise ValueError("Failed to load raw sequences.")
        all_raw_sequences_for_kmer_baseline = promoter_seqs_raw + nonpromoter_seqs_raw
        all_labels_for_kmer_baseline = promoter_labels_raw + nonpromoter_labels_raw
        log_message_kmer(f"Total raw sequences for k-mer models: {len(all_raw_sequences_for_kmer_baseline)}")

        log_message_kmer("\n--- Step 2: Building K-mer Vocabularies ---")
        kmer_vocabs_map = build_kmer_vocab(all_raw_sequences_for_kmer_baseline, KMER_VALUES, kmer_stride_for_vocab=KMER_STRIDE)

        max_kmer_seq_lens_map = {}
        log_message_kmer("\n--- Step 3: Calculating Max K-mer Sequence Lengths ---")
        for k_val_len in KMER_VALUES:
            current_k_max_len = 0
            for seq_len_calc in tqdm(all_raw_sequences_for_kmer_baseline, desc=f"Max k-mer len for k={k_val_len}"):
                if isinstance(seq_len_calc, str):
                    current_k_max_len = max(current_k_max_len, len(sequence_to_kmers(seq_len_calc, k_val_len, stride=KMER_STRIDE)))
            max_kmer_seq_lens_map[k_val_len] = current_k_max_len if current_k_max_len > 0 else 1
            log_message_kmer(f"Max k-mer seq length for k={k_val_len}: {max_kmer_seq_lens_map[k_val_len]}")

        log_message_kmer("\n--- Step 4: Training/Loading Individual K-mer Models ---")
        trained_kmer_models_dict = {}
        kmer_full_dataset_size = len(all_raw_sequences_for_kmer_baseline)
        kmer_all_indices = list(range(kmer_full_dataset_size))
        np.random.seed(RANDOM_SEED); np.random.shuffle(kmer_all_indices)
        k_test_split_idx = int(np.floor(TEST_SPLIT * kmer_full_dataset_size))
        k_val_split_idx = k_test_split_idx + int(np.floor(VALIDATION_SPLIT * kmer_full_dataset_size))
        k_test_indices = kmer_all_indices[:k_test_split_idx]
        k_val_indices  = kmer_all_indices[k_test_split_idx:k_val_split_idx]
        k_train_indices= kmer_all_indices[k_val_split_idx:]
        if not k_train_indices or not k_val_indices or not k_test_indices: raise ValueError("K-mer dataset splitting error.")
        log_message_kmer(f"K-mer Data Split: Train={len(k_train_indices)}, Val={len(k_val_indices)}, Test={len(k_test_indices)}")

        for k_val_run in KMER_VALUES:
            log_message_kmer(f"\n--- Processing for k={k_val_run} ---")
            k_model_best_path_run = KMER_MODEL_BEST_SAVE_PATH_TPL.format(k=k_val_run)
            k_loss_plot_path_run = KMER_LOSS_PLOT_PATH_TPL.format(k=k_val_run)
            k_checkpoint_dir_run = KMER_CHECKPOINT_DIR_TPL.format(k=k_val_run) # Specific dir for this k

            k_model_instance = KmerTransformer(
                kmer_vocab_size=kmer_vocabs_map[k_val_run]['vocab_size'],
                max_kmer_seq_len=max_kmer_seq_lens_map[k_val_run],
                embed_dim=KMER_EMBED_DIM, num_heads=KMER_NUM_HEADS,
                ff_dim=KMER_TRANSFORMER_FF_DIM, num_layers=KMER_NUM_TRANSFORMER_LAYERS,
                dropout=KMER_DROPOUT_RATE, kmer_pad_idx=kmer_vocabs_map[k_val_run]['pad_idx']
            ).to(DEVICE)
            # log_message_kmer(f"k={k_val_run} Initialized Model Params: {sum(p.numel() for p in k_model_instance.parameters() if p.requires_grad):,}") # Verbose

            # Skip training if best model AND plot exist AND not forcing rebuild
            if os.path.exists(k_model_best_path_run) and os.path.exists(k_loss_plot_path_run) and not args_rep.force_kmer_model_rebuild:
                log_message_kmer(f"Found completed training for k={k_val_run} (best model & plot exist). Loading: {k_model_best_path_run}")
                try:
                    k_model_instance.load_state_dict(torch.load(k_model_best_path_run, map_location=DEVICE))
                    trained_kmer_models_dict[k_val_run] = k_model_instance
                    log_message_kmer(f"Successfully loaded best model for k={k_val_run}.")
                    continue # Skip to the next k-value
                except Exception as e:
                    log_message_kmer(f"Error loading existing best model for k={k_val_run}: {e}. Will attempt to train/resume.")
            
            temp_full_k_dataset = KmerDataset(all_raw_sequences_for_kmer_baseline, all_labels_for_kmer_baseline,
                k_val_run, kmer_vocabs_map[k_val_run], max_kmer_seq_lens_map[k_val_run], kmer_stride_val=KMER_STRIDE)
            k_train_subset = Subset(temp_full_k_dataset, k_train_indices)
            k_val_subset = Subset(temp_full_k_dataset, k_val_indices)
            k_num_workers = 0
            k_train_loader_run = DataLoader(k_train_subset, batch_size=args_rep.kmer_batch_size, shuffle=True, num_workers=k_num_workers, drop_last=(len(k_train_subset) % args_rep.kmer_batch_size == 1 and len(k_train_subset)>1))
            k_val_loader_run = DataLoader(k_val_subset, batch_size=args_rep.kmer_batch_size, shuffle=False, num_workers=k_num_workers)

            k_criterion_run = nn.BCEWithLogitsLoss()
            k_optimizer_run = optim.AdamW(k_model_instance.parameters(), lr=args_rep.kmer_lr, weight_decay=OPTIMIZER_WEIGHT_DECAY_KMER)
            k_scheduler_run = optim.lr_scheduler.ReduceLROnPlateau(k_optimizer_run, 'min', factor=0.2, patience=3) 

            trained_model_k = train_kmer_model(
                k_val_run, k_model_instance, k_train_loader_run, k_val_loader_run, k_criterion_run, k_optimizer_run,
                args_rep.kmer_epochs, DEVICE, 
                KMER_MODEL_BEST_SAVE_PATH_TPL, # Pass template for best model
                KMER_LOSS_PLOT_PATH_TPL,
                k_scheduler_run,
                checkpoint_dir_tpl_k=KMER_CHECKPOINT_DIR_TPL # Pass template for epoch checkpoints
            )
            trained_kmer_models_dict[k_val_run] = trained_model_k
            if trained_model_k is None: log_message_kmer(f"ERROR: Training/Resuming failed for k={k_val_run}.")

        log_message_kmer("\n--- Step 5: Evaluating K-mer Ensemble ---")
        if any(model is None for model in trained_kmer_models_dict.values()):
            log_message_kmer("Warning: One or more k-mer models missing. Ensemble results may be based on fewer models.")
        
        valid_trained_k_models = {k: m for k, m in trained_kmer_models_dict.items() if m is not None}
        valid_k_values_for_ensemble = list(valid_trained_k_models.keys())

        if not valid_trained_k_models:
            log_message_kmer("No k-mer models available for ensemble evaluation. Skipping.")
        else:
            raw_test_sequences = [all_raw_sequences_for_kmer_baseline[i] for i in k_test_indices]
            raw_test_labels = [all_labels_for_kmer_baseline[i] for i in k_test_indices]
            ensemble_test_dataset = RawSequenceDatasetForEnsemble(raw_test_sequences, raw_test_labels)
            ensemble_test_loader = DataLoader(ensemble_test_dataset, batch_size=args_rep.kmer_batch_size, shuffle=False, num_workers=0)

            evaluate_kmer_ensemble(valid_trained_k_models, ensemble_test_loader, valid_k_values_for_ensemble,
                                   kmer_vocabs_map, max_kmer_seq_lens_map, DEVICE,
                                   kmer_stride_for_eval=KMER_STRIDE,
                                   results_csv_path_ens=ENSEMBLE_RESULTS_CSV_PATH,
                                   cm_plot_path_base_ens=ENSEMBLE_CM_PLOT_PATH_BASE)

        total_msbert_duration = time.time() - msbert_start_time
        log_message_kmer(f"--- msBERT Replication Workflow Completed in {total_msbert_duration:.2f}s ({total_msbert_duration/3600:.2f} hrs) ---")

    except FileNotFoundError as fnf_error: log_message_kmer(f"\nKMER WORKFLOW FileNotFoundError: {fnf_error}")
    except ValueError as val_error: log_message_kmer(f"\nKMER WORKFLOW ValueError: {val_error}\n{traceback.format_exc()}")
    except RuntimeError as rt_error: log_message_kmer(f"\nKMER WORKFLOW RuntimeError: {rt_error}\n{traceback.format_exc()}")
    except Exception as main_error: log_message_kmer(f"\nKMER WORKFLOW Exception: {type(main_error).__name__}: {main_error}\n{traceback.format_exc()}")


[2025-06-19 13:18:59] --- Starting msBERT-Promoter Replication Workflow (PID: 3459) ---
[2025-06-19 13:18:59] K-mer Baseline Config: Epochs=10, Batch=32, LR=0.0001, ForceRebuild=False
[2025-06-19 13:18:59] 
--- Step 1: Loading Raw Sequence Data ---
[2025-06-19 13:18:59] Loading RAW sequences from ./data/raw/human_genome_annotation/updated_promoter_features_clean.csv (expected len: 2000)...
[2025-06-19 13:19:00] QC updated_promoter_features_clean.csv: Initial Raw=20028. After N-rem=20028. Len-rem:0. Final:20028.
[2025-06-19 13:19:00] Loaded 20028 RAW sequences from updated_promoter_features_clean.csv in 0.71s.
[2025-06-19 13:19:00] Loading RAW sequences from ./data/raw/human_genome_annotation/updated_non_promoter_sequences.csv (expected len: 2000)...
[2025-06-19 13:19:01] QC updated_non_promoter_sequences.csv: Initial Raw=20028. After N-rem=20028. Len-rem:0. Final:20028.
[2025-06-19 13:19:01] Loaded 20028 RAW sequences from updated_non_promoter_sequences.csv in 0.64s.
[2025-06-19 13:19:

Tokenizing for k=3 vocab:   0%|          | 0/40056 [00:00<?, ?it/s]

[2025-06-19 13:19:25] Vocab for k=3 size: 66.
[2025-06-19 13:19:25] Building vocabulary for k=4 using stride=1...


Tokenizing for k=4 vocab:   0%|          | 0/40056 [00:00<?, ?it/s]

[2025-06-19 13:19:49] Vocab for k=4 size: 258.
[2025-06-19 13:19:49] Building vocabulary for k=5 using stride=1...


Tokenizing for k=5 vocab:   0%|          | 0/40056 [00:00<?, ?it/s]

[2025-06-19 13:20:15] Vocab for k=5 size: 1026.
[2025-06-19 13:20:15] Building vocabulary for k=6 using stride=1...


Tokenizing for k=6 vocab:   0%|          | 0/40056 [00:00<?, ?it/s]

[2025-06-19 13:20:42] Vocab for k=6 size: 4098.
[2025-06-19 13:20:43] 
--- Step 3: Calculating Max K-mer Sequence Lengths ---


Max k-mer len for k=3:   0%|          | 0/40056 [00:00<?, ?it/s]

[2025-06-19 13:20:57] Max k-mer seq length for k=3: 1998


Max k-mer len for k=4:   0%|          | 0/40056 [00:00<?, ?it/s]

[2025-06-19 13:21:11] Max k-mer seq length for k=4: 1997


Max k-mer len for k=5:   0%|          | 0/40056 [00:00<?, ?it/s]

[2025-06-19 13:21:25] Max k-mer seq length for k=5: 1996


Max k-mer len for k=6:   0%|          | 0/40056 [00:00<?, ?it/s]

[2025-06-19 13:21:40] Max k-mer seq length for k=6: 1995
[2025-06-19 13:21:40] 
--- Step 4: Training/Loading Individual K-mer Models ---
[2025-06-19 13:21:40] K-mer Data Split: Train=28040, Val=6008, Test=6008
[2025-06-19 13:21:40] 
--- Processing for k=3 ---
[2025-06-19 13:21:41] Resuming k=3 from checkpoint: results_msBERT_replication/checkpoints_k3/checkpoint_epoch_003.pth
[2025-06-19 13:21:41] Resumed from epoch 3. Best Val Loss: 0.4831
[2025-06-19 13:21:41] --- Training KmerTransformer (k=3) from epoch 4 to 9 ---


k=3 E5/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=3 E5/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-19 20:22:21] k=3 E 5/10 - TrL: 0.4937, VaL: 0.4784, Dur: 25240.26s, LR: 1.00e-04
[2025-06-19 20:22:21] k=3 Saved NEW BEST model (Val Loss: 0.4784) to results_msBERT_replication/best_kmer_model_k3.pth


k=3 E6/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=3 E6/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-20 03:07:35] k=3 E 6/10 - TrL: 0.4910, VaL: 0.4741, Dur: 24314.11s, LR: 1.00e-04
[2025-06-20 03:07:35] k=3 Saved NEW BEST model (Val Loss: 0.4741) to results_msBERT_replication/best_kmer_model_k3.pth


k=3 E7/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=3 E7/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-20 10:02:59] k=3 E 7/10 - TrL: 0.4883, VaL: 0.5122, Dur: 24923.51s, LR: 1.00e-04


k=3 E8/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=3 E8/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-20 16:52:46] k=3 E 8/10 - TrL: 0.4839, VaL: 0.4696, Dur: 24587.20s, LR: 1.00e-04
[2025-06-20 16:52:46] k=3 Saved NEW BEST model (Val Loss: 0.4696) to results_msBERT_replication/best_kmer_model_k3.pth


k=3 E9/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=3 E9/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-20 23:49:10] k=3 E 9/10 - TrL: 0.4822, VaL: 0.4851, Dur: 24984.03s, LR: 1.00e-04


k=3 E10/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=3 E10/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-21 06:36:15] k=3 E 10/10 - TrL: 0.4783, VaL: 0.4966, Dur: 24425.07s, LR: 1.00e-04
[2025-06-21 06:36:17] k=3 Saved final loss plot to results_msBERT_replication/loss_kmer_k3.png
[2025-06-21 06:36:17] --- k=3 Training Finished/Resumed --- Best Val Loss: 0.4696
[2025-06-21 06:36:17] 
--- Processing for k=4 ---
[2025-06-21 06:36:17] --- Training KmerTransformer (k=4) from epoch 0 to 9 ---


k=4 E1/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=4 E1/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-21 13:22:39] k=4 E 1/10 - TrL: 0.6056, VaL: 0.5360, Dur: 24382.13s, LR: 1.00e-04
[2025-06-21 13:22:39] k=4 Saved NEW BEST model (Val Loss: 0.5360) to results_msBERT_replication/best_kmer_model_k4.pth


k=4 E2/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=4 E2/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-21 20:03:34] k=4 E 2/10 - TrL: 0.5302, VaL: 0.5077, Dur: 24054.65s, LR: 1.00e-04
[2025-06-21 20:03:34] k=4 Saved NEW BEST model (Val Loss: 0.5077) to results_msBERT_replication/best_kmer_model_k4.pth


k=4 E3/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=4 E3/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-22 02:44:17] k=4 E 3/10 - TrL: 0.5145, VaL: 0.5036, Dur: 24042.68s, LR: 1.00e-04
[2025-06-22 02:44:17] k=4 Saved NEW BEST model (Val Loss: 0.5036) to results_msBERT_replication/best_kmer_model_k4.pth


k=4 E4/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=4 E4/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-22 09:24:16] k=4 E 4/10 - TrL: 0.5038, VaL: 0.4891, Dur: 23999.41s, LR: 1.00e-04
[2025-06-22 09:24:16] k=4 Saved NEW BEST model (Val Loss: 0.4891) to results_msBERT_replication/best_kmer_model_k4.pth


k=4 E5/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=4 E5/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-22 17:49:20] k=4 E 5/10 - TrL: 0.4999, VaL: 0.4927, Dur: 30304.26s, LR: 1.00e-04


k=4 E6/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=4 E6/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-23 02:19:40] k=4 E 6/10 - TrL: 0.4960, VaL: 0.4901, Dur: 30619.89s, LR: 1.00e-04


k=4 E7/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=4 E7/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-23 09:00:34] k=4 E 7/10 - TrL: 0.4902, VaL: 0.4816, Dur: 24054.04s, LR: 1.00e-04
[2025-06-23 09:00:34] k=4 Saved NEW BEST model (Val Loss: 0.4816) to results_msBERT_replication/best_kmer_model_k4.pth


k=4 E8/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=4 E8/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-23 17:19:59] k=4 E 8/10 - TrL: 0.4871, VaL: 0.4769, Dur: 29964.41s, LR: 1.00e-04
[2025-06-23 17:19:59] k=4 Saved NEW BEST model (Val Loss: 0.4769) to results_msBERT_replication/best_kmer_model_k4.pth


k=4 E9/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=4 E9/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-23 23:58:56] k=4 E 9/10 - TrL: 0.4862, VaL: 0.4794, Dur: 23937.15s, LR: 1.00e-04


k=4 E10/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=4 E10/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-24 06:38:23] k=4 E 10/10 - TrL: 0.4825, VaL: 0.4789, Dur: 23966.86s, LR: 1.00e-04
[2025-06-24 06:38:24] k=4 Saved final loss plot to results_msBERT_replication/loss_kmer_k4.png
[2025-06-24 06:38:24] --- k=4 Training Finished/Resumed --- Best Val Loss: 0.4769
[2025-06-24 06:38:24] 
--- Processing for k=5 ---
[2025-06-24 06:38:24] --- Training KmerTransformer (k=5) from epoch 0 to 9 ---


k=5 E1/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=5 E1/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-24 14:26:04] k=5 E 1/10 - TrL: 0.6119, VaL: 0.5528, Dur: 28059.74s, LR: 1.00e-04
[2025-06-24 14:26:04] k=5 Saved NEW BEST model (Val Loss: 0.5528) to results_msBERT_replication/best_kmer_model_k5.pth


k=5 E2/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=5 E2/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-24 21:14:21] k=5 E 2/10 - TrL: 0.5393, VaL: 0.5277, Dur: 24496.94s, LR: 1.00e-04
[2025-06-24 21:14:21] k=5 Saved NEW BEST model (Val Loss: 0.5277) to results_msBERT_replication/best_kmer_model_k5.pth


k=5 E3/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=5 E3/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-25 03:58:51] k=5 E 3/10 - TrL: 0.5250, VaL: 0.5122, Dur: 24269.64s, LR: 1.00e-04
[2025-06-25 03:58:51] k=5 Saved NEW BEST model (Val Loss: 0.5122) to results_msBERT_replication/best_kmer_model_k5.pth


k=5 E4/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=5 E4/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-25 10:43:38] k=5 E 4/10 - TrL: 0.5135, VaL: 0.5142, Dur: 24287.68s, LR: 1.00e-04


k=5 E5/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=5 E5/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-25 17:39:01] k=5 E 5/10 - TrL: 0.5062, VaL: 0.4964, Dur: 24923.06s, LR: 1.00e-04
[2025-06-25 17:39:01] k=5 Saved NEW BEST model (Val Loss: 0.4964) to results_msBERT_replication/best_kmer_model_k5.pth


k=5 E6/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=5 E6/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-26 00:24:19] k=5 E 6/10 - TrL: 0.4974, VaL: 0.4877, Dur: 24317.79s, LR: 1.00e-04
[2025-06-26 00:24:19] k=5 Saved NEW BEST model (Val Loss: 0.4877) to results_msBERT_replication/best_kmer_model_k5.pth


k=5 E7/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=5 E7/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-26 07:08:51] k=5 E 7/10 - TrL: 0.4911, VaL: 0.4837, Dur: 24271.84s, LR: 1.00e-04
[2025-06-26 07:08:51] k=5 Saved NEW BEST model (Val Loss: 0.4837) to results_msBERT_replication/best_kmer_model_k5.pth


k=5 E8/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=5 E8/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-26 14:07:06] k=5 E 8/10 - TrL: 0.4876, VaL: 0.4793, Dur: 25094.23s, LR: 1.00e-04
[2025-06-26 14:07:06] k=5 Saved NEW BEST model (Val Loss: 0.4793) to results_msBERT_replication/best_kmer_model_k5.pth


k=5 E9/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=5 E9/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-26 20:57:01] k=5 E 9/10 - TrL: 0.4818, VaL: 0.4720, Dur: 24595.19s, LR: 1.00e-04
[2025-06-26 20:57:01] k=5 Saved NEW BEST model (Val Loss: 0.4720) to results_msBERT_replication/best_kmer_model_k5.pth


k=5 E10/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=5 E10/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-27 03:42:53] k=5 E 10/10 - TrL: 0.4768, VaL: 0.4717, Dur: 24352.57s, LR: 1.00e-04
[2025-06-27 03:42:54] k=5 Saved NEW BEST model (Val Loss: 0.4717) to results_msBERT_replication/best_kmer_model_k5.pth
[2025-06-27 03:42:54] k=5 Saved final loss plot to results_msBERT_replication/loss_kmer_k5.png
[2025-06-27 03:42:54] --- k=5 Training Finished/Resumed --- Best Val Loss: 0.4717
[2025-06-27 03:42:54] 
--- Processing for k=6 ---
[2025-06-27 03:42:54] --- Training KmerTransformer (k=6) from epoch 0 to 9 ---


k=6 E1/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=6 E1/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-27 10:22:06] k=6 E 1/10 - TrL: 0.6420, VaL: 0.5701, Dur: 23951.45s, LR: 1.00e-04
[2025-06-27 10:22:06] k=6 Saved NEW BEST model (Val Loss: 0.5701) to results_msBERT_replication/best_kmer_model_k6.pth


k=6 E2/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=6 E2/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-27 17:45:24] k=6 E 2/10 - TrL: 0.5609, VaL: 0.5480, Dur: 26597.69s, LR: 1.00e-04
[2025-06-27 17:45:24] k=6 Saved NEW BEST model (Val Loss: 0.5480) to results_msBERT_replication/best_kmer_model_k6.pth


k=6 E3/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=6 E3/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-28 00:26:34] k=6 E 3/10 - TrL: 0.5364, VaL: 0.5340, Dur: 24070.20s, LR: 1.00e-04
[2025-06-28 00:26:34] k=6 Saved NEW BEST model (Val Loss: 0.5340) to results_msBERT_replication/best_kmer_model_k6.pth


k=6 E4/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=6 E4/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-28 07:07:03] k=6 E 4/10 - TrL: 0.5237, VaL: 0.5161, Dur: 24028.94s, LR: 1.00e-04
[2025-06-28 07:07:03] k=6 Saved NEW BEST model (Val Loss: 0.5161) to results_msBERT_replication/best_kmer_model_k6.pth


k=6 E5/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=6 E5/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-28 13:52:28] k=6 E 5/10 - TrL: 0.5145, VaL: 0.5101, Dur: 24325.46s, LR: 1.00e-04
[2025-06-28 13:52:29] k=6 Saved NEW BEST model (Val Loss: 0.5101) to results_msBERT_replication/best_kmer_model_k6.pth


k=6 E6/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=6 E6/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-28 21:11:48] k=6 E 6/10 - TrL: 0.5072, VaL: 0.5040, Dur: 26359.60s, LR: 1.00e-04
[2025-06-28 21:11:48] k=6 Saved NEW BEST model (Val Loss: 0.5040) to results_msBERT_replication/best_kmer_model_k6.pth


k=6 E7/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=6 E7/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-29 03:52:56] k=6 E 7/10 - TrL: 0.4982, VaL: 0.4979, Dur: 24068.22s, LR: 1.00e-04
[2025-06-29 03:52:57] k=6 Saved NEW BEST model (Val Loss: 0.4979) to results_msBERT_replication/best_kmer_model_k6.pth


k=6 E8/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=6 E8/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-29 10:34:03] k=6 E 8/10 - TrL: 0.4904, VaL: 0.4895, Dur: 24066.74s, LR: 1.00e-04
[2025-06-29 10:34:03] k=6 Saved NEW BEST model (Val Loss: 0.4895) to results_msBERT_replication/best_kmer_model_k6.pth


k=6 E9/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=6 E9/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-29 18:21:26] k=6 E 9/10 - TrL: 0.4822, VaL: 0.4897, Dur: 28042.61s, LR: 1.00e-04


k=6 E10/10 [Train]:   0%|          | 0/877 [00:00<?, ?it/s]

k=6 E10/10 [Val]:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-30 01:02:17] k=6 E 10/10 - TrL: 0.4773, VaL: 0.4818, Dur: 24051.29s, LR: 1.00e-04
[2025-06-30 01:02:17] k=6 Saved NEW BEST model (Val Loss: 0.4818) to results_msBERT_replication/best_kmer_model_k6.pth
[2025-06-30 01:02:18] k=6 Saved final loss plot to results_msBERT_replication/loss_kmer_k6.png
[2025-06-30 01:02:18] --- k=6 Training Finished/Resumed --- Best Val Loss: 0.4818
[2025-06-30 01:02:18] 
--- Step 5: Evaluating K-mer Ensemble ---
[2025-06-30 01:02:19] --- Evaluating K-mer Ensemble on Test Set ---


Ensemble Test Evaluation:   0%|          | 0/188 [00:00<?, ?it/s]

[2025-06-30 02:33:03] 
--- K-mer Ensemble Test Results ---
[2025-06-30 02:33:03] MODEL: msBERT_rep_ensemble
[2025-06-30 02:33:03] ACC: 0.7833
[2025-06-30 02:33:03] PREC: 0.8796
[2025-06-30 02:33:03] REC: 0.6587
[2025-06-30 02:33:03] SPEC: 0.9090
[2025-06-30 02:33:03] F1: 0.7533
[2025-06-30 02:33:03] AUC: 0.8539
[2025-06-30 02:33:03] TP: 1988
[2025-06-30 02:33:03] FP: 272
[2025-06-30 02:33:03] TN: 2718
[2025-06-30 02:33:03] FN: 1030
[2025-06-30 02:33:04] Saved CM to results_msBERT_replication/confusion_matrix_kmer_ensemble_counts.png
[2025-06-30 02:33:04] Saved CM to results_msBERT_replication/confusion_matrix_kmer_ensemble_norm.png
[2025-06-30 02:33:04] Saved K-mer Ens test results to results_msBERT_replication/test_results_kmer_ensemble.csv
[2025-06-30 02:33:04] --- msBERT Replication Workflow Completed in 911644.50s (253.23 hrs) ---
