In [1]:
##### # =============================================================================
# IMPORTS
# =============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from pathlib import Path
import os
import numpy as np
from tqdm.notebook import tqdm
import random
import time
import pandas as pd
import math
from datetime import datetime

# =============================================================================
# DEVICE SETUP
# =============================================================================

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"PyTorch version: {torch.__version__}")
print(f"Using device: {DEVICE}")
if DEVICE == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# =============================================================================
# ðŸ”§ CONFIGURATION
# =============================================================================

DATASET_NAME = "Erlich"  # Change to: "Grass", "Organick", "Srinivasavaradhan"

DATASET_CONFIGS = {
    "Erlich": {
        "label_len": 152,
        "max_deviation": 10,  # From paper
        "target_failure": 0.02,
    },
    "Grass": {
        "label_len": 117,
        "max_deviation": 11,  # From paper
        "target_failure": 0.66,
    },
    "Organick": {
        "label_len": 110,
        "max_deviation": 5,  # From paper
        "target_failure": 0.17,
    },
    "Srinivasavaradhan": {
        "label_len": 110,
        "max_deviation": 10,  # From paper  
        "target_failure": 14.58,
    },
}

CONFIG = DATASET_CONFIGS[DATASET_NAME]
LABEL_SEQ_LEN = CONFIG["label_len"]
MAX_DEVIATION = CONFIG["max_deviation"]
MAX_READ_LEN = LABEL_SEQ_LEN + MAX_DEVIATION + 8  # Extra buffer for safety
MAX_CLUSTER_SIZE = 16  # Following DNAFormer paper

# =============================================================================
# TRAINING HYPERPARAMETERS
# =============================================================================

BATCH_SIZE = 200  # DNAFormer uses 64
LEARNING_RATE = 5e-4  
MIN_LR = 5e-8  # DNAFormer's ending LR
WEIGHT_DECAY = 1e-5
EPOCHS = 120
WARMUP_EPOCHS = 10
PATIENCE = 10

# Model architecture
EMBED_DIM = 300
ALIGNMENT_FILTERS = 128  # Lighter than DNAFormer's 128
EMBEDDING_FILTERS = 500  # Lighter than DNAFormer's 1024
GRU_HIDDEN = 300
GRU_LAYERS = 2
DROPOUT = 0.1

# =============================================================================
# FILE PATHS
# =============================================================================

SYNTHETIC_DATA_DIR = Path("./generated_data_corrected")
REAL_DATA_DIR = Path("./Data")


TRAIN_FILE = SYNTHETIC_DATA_DIR / f"binned_synthetic_{DATASET_NAME.lower()}.txt"
EVAL_FILE = REAL_DATA_DIR / f"{DATASET_NAME}.txt"


# Generate a timestamp (e.g., 2023-10-27_14-30)
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M")

EXPERIMENT_DIR = Path(f"./Experiments/{DATASET_NAME}_ImprovedBiGRU_{timestamp}")
WEIGHTS_DIR = EXPERIMENT_DIR / "Models"
RESULTS_DIR = EXPERIMENT_DIR / "Results"
WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

print(f"\n{'='*70}")
print(f"DATASET: {DATASET_NAME}")
print(f"{'='*70}")
print(f"  Label length:    {LABEL_SEQ_LEN}")
print(f"  Max deviation:   {MAX_DEVIATION}")
print(f"  Max read length: {MAX_READ_LEN}")
print(f"  Target failure:  {CONFIG['target_failure']}%")
print(f"\n  Files:")
print(f"    Train: {TRAIN_FILE}")
print(f"    Eval:  {EVAL_FILE}")
print(f"  Output: {EXPERIMENT_DIR}")
print(f"{'='*70}\n")

# =============================================================================
# VOCABULARY
# =============================================================================

VOCAB = {'N': 0, 'A': 1, 'C': 2, 'G': 3, 'T': 4}
PADDING_IDX = VOCAB['N']
VOCAB_SIZE = len(VOCAB)
INT_TO_CHAR = {i: c for c, i in VOCAB.items()}

def encode_seq(seq: str, char_to_int: dict, max_len: int, padding_idx: int) -> list:
    """Encode DNA sequence to integers with padding"""
    encoded = [char_to_int.get(c, padding_idx) for c in seq]
    encoded = encoded[:max_len]
    padded = encoded + [padding_idx] * (max_len - len(encoded))
    return padded

def decode_seq(tensor: torch.Tensor, int_to_char: dict) -> str:
    """Decode integer tensor to DNA sequence"""
    if tensor.is_cuda:
        tensor = tensor.cpu()
    ints = tensor.numpy().tolist()
    try:
        first_pad = ints.index(PADDING_IDX)
        ints = ints[:first_pad]
    except ValueError:
        pass
    return "".join([int_to_char.get(i, '?') for i in ints])

# =============================================================================
# DATASET
# =============================================================================

class DnaClusterDataset(Dataset):
    """DNA Cluster Dataset - No filtering (filtering breaks results!)"""
    
    def __init__(self, filepath, max_cluster_size, max_read_len, label_seq_len, 
                 char_to_int, padding_idx):
        self.max_cluster_size = max_cluster_size
        self.max_read_len = max_read_len
        self.label_seq_len = label_seq_len
        self.char_to_int = char_to_int
        self.padding_idx = padding_idx
        self.labels = []
        self.clusters = []
        self._load_data(filepath)

    def _load_data(self, filepath):
        print(f"Loading data from {filepath}...")
        try:
            with open(filepath, 'r') as f:
                content = f.read()
        except FileNotFoundError:
            print(f"ERROR: File not found at {filepath}")
            raise
        
        blocks = content.split('\n\n')
        for block in tqdm(blocks, desc=f"Parsing {filepath.name}"):
            if not block.strip():
                continue
            lines = block.strip().split('\n')
            if len(lines) < 3:
                continue
            label_seq = lines[0]
            reads = lines[2:]
            if not reads or not label_seq:
                continue
            self.labels.append(label_seq)
            self.clusters.append(reads)
        print(f"Successfully loaded {len(self.labels)} clusters.")

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        label_str = self.labels[idx]
        cluster_reads_str = self.clusters[idx]
        
        label_tensor = torch.tensor(
            encode_seq(label_str, self.char_to_int, self.label_seq_len, self.padding_idx),
            dtype=torch.long
        )
        
        cluster_tensor = torch.full(
            (self.max_cluster_size, self.max_read_len),
            self.padding_idx,
            dtype=torch.long
        )
        
        random.shuffle(cluster_reads_str)
        reads_to_process = cluster_reads_str[:self.max_cluster_size]
        
        for i, read_str in enumerate(reads_to_process):
            cluster_tensor[i] = torch.tensor(
                encode_seq(read_str, self.char_to_int, self.max_read_len, self.padding_idx),
                dtype=torch.long
            )
        
        return cluster_tensor, label_tensor

# =============================================================================
# MODEL COMPONENTS (DNAFormer-inspired but lighter)
# =============================================================================

class DepthwiseSeparableConv1d(nn.Module):
    """Depthwise separable convolution (more efficient than standard conv)"""
    
    def __init__(self, in_channels, out_channels, kernel_size, padding=0):
        super().__init__()
        self.depthwise = nn.Conv1d(
            in_channels, in_channels, 
            kernel_size=kernel_size,
            padding=padding,
            groups=in_channels
        )
        self.pointwise = nn.Conv1d(in_channels, out_channels, kernel_size=1)
    
    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x


class MultiKernelConvBlock(nn.Module):
    """
    Multi-kernel convolution block with proper channel handling.
    Fixed: Handles remainder when dividing channels into thirds.
    """
    
    def __init__(self, in_channels, out_channels, seq_len, dropout=0.1):
        super().__init__()
        
        # Split channels properly, handling remainder
        c1 = out_channels // 3
        c2 = out_channels // 3
        c3 = out_channels - c1 - c2  # Gets the remainder
        
        self.conv1 = DepthwiseSeparableConv1d(in_channels, c1, kernel_size=1)
        self.conv3 = DepthwiseSeparableConv1d(in_channels, c2, kernel_size=3, padding=1)
        self.conv5 = DepthwiseSeparableConv1d(in_channels, c3, kernel_size=5, padding=2)
        
        self.norm1 = nn.LayerNorm([c1, seq_len])
        self.norm2 = nn.LayerNorm([c2, seq_len])
        self.norm3 = nn.LayerNorm([c3, seq_len])
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        # Apply different kernel sizes in parallel
        x1 = F.gelu(self.norm1(self.conv1(x)))
        x2 = F.gelu(self.norm2(self.conv3(x)))
        x3 = F.gelu(self.norm3(self.conv5(x)))
        
        # Concatenate multi-kernel outputs (now sums to exactly out_channels)
        out = torch.cat([x1, x2, x3], dim=1)
        out = self.dropout(out)
        return out


class AlignmentModule(nn.Module):
    """
    Alignment module inspired by DNAFormer.
    Processes each read individually to learn alignment before NCI.
    Lighter than DNAFormer (uses 2 conv blocks instead of more complex architecture).
    """
    
    def __init__(self, embed_dim, out_channels, seq_len, dropout=0.1):
        super().__init__()
        self.conv_block1 = MultiKernelConvBlock(embed_dim, out_channels, seq_len, dropout)
        self.conv_block2 = MultiKernelConvBlock(out_channels, out_channels, seq_len, dropout)
    
    def forward(self, x):
        # x shape: (batch, cluster_size, embed_dim, seq_len)
        batch, cluster, emb, seq = x.shape
        
        # Process each read independently
        x = x.view(batch * cluster, emb, seq)
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        
        # Reshape back
        x = x.view(batch, cluster, -1, seq)
        return x


class EmbeddingModule(nn.Module):
    """
    Embedding module - processes cluster after NCI to extract cluster-level features.
    Fixed: Linear layer now correctly transforms sequence length.
    """
    
    def __init__(self, in_channels, out_channels, in_len, out_len, dropout=0.1):
        super().__init__()
        self.conv_block = MultiKernelConvBlock(in_channels, out_channels, in_len, dropout)
        
        # Linear projection to target length
        self.linear = nn.Linear(in_len, out_len)
    
    def forward(self, x):
        # x shape: (batch, channels, seq_len)
        x = self.conv_block(x)  # (B, out_channels, in_len)
        
        # Apply linear transformation to sequence dimension
        # Reshape to apply linear independently to each channel
        batch, channels, seq_len = x.shape
        
        # Reshape: (B, C, L) -> (B*C, L)
        x = x.reshape(batch * channels, seq_len)
        
        # Apply linear: (B*C, in_len) -> (B*C, out_len)
        x = self.linear(x)
        
        # Reshape back: (B*C, out_len) -> (B, C, out_len)
        x = x.reshape(batch, channels, -1)
        
        return x


# =============================================================================
# IMPROVED MODEL (DNAFormer-inspired BiGRU)
# =============================================================================

class ImprovedDNAReconstructionModel(nn.Module):
    """
    Improved DNA reconstruction model inspired by DNAFormer architecture.
    
    Architecture:
    1. Embedding layer
    2. Alignment module (per-read processing with multi-kernel convs)
    3. NCI (Non-Coherent Integration) - sum over cluster dimension
    4. Embedding module (cluster-level feature extraction)
    5. BiGRU (instead of Transformer for efficiency)
    6. Output projection
    
    Key differences from your original BiGRU:
    - Adds per-read alignment processing BEFORE averaging
    - Uses multi-kernel convolutions to capture indel patterns
    - Processes cluster as a whole after NCI
    - Still uses BiGRU (not Transformer) to keep params low
    
    Parameters: ~5-8M (vs DNAFormer's 100M, your BiGRU's ~2M)
    """
    
    def __init__(self, vocab_size, label_seq_len, max_read_len, padding_idx,
                 embed_dim=128, alignment_filters=128, embedding_filters=256,
                 gru_hidden=256, gru_layers=2, dropout=0.2):
        super().__init__()
        
        self.label_seq_len = label_seq_len
        self.max_read_len = max_read_len
        
        # 1. Embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
        
        # 2. Alignment module (per-read processing)
        self.alignment = AlignmentModule(
            embed_dim=embed_dim,
            out_channels=alignment_filters,
            seq_len=max_read_len,
            dropout=dropout
        )
        
        # 3. NCI (Non-Coherent Integration) is just a sum - no learnable params
        
        # 4. Embedding module (cluster-level processing)
        self.embedding_module = EmbeddingModule(
            in_channels=alignment_filters,
            out_channels=embedding_filters,
            in_len=max_read_len,
            out_len=label_seq_len,
            dropout=dropout
        )
        
        # 5. BiGRU for sequence modeling
        self.gru = nn.GRU(
            embedding_filters,
            gru_hidden,
            num_layers=gru_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if gru_layers > 1 else 0
        )
        
        self.dropout = nn.Dropout(dropout)
        
        # 6. Output projection
        self.fc_out = nn.Linear(gru_hidden * 2, vocab_size)
    
    def forward(self, cluster_batch):
        """
        Args:
            cluster_batch: (batch_size, max_cluster_size, max_read_len)
        
        Returns:
            logits: (batch_size, label_seq_len, vocab_size)
        """
        # 1. Embed all reads
        embedded = self.embedding(cluster_batch)  # (B, cluster, read_len, embed_dim)
        embedded = embedded.permute(0, 1, 3, 2)  # (B, cluster, embed_dim, read_len)
        
        # 2. Alignment module - process each read independently
        aligned = self.alignment(embedded)  # (B, cluster, alignment_filters, read_len)
        
        # 3. NCI (Non-Coherent Integration) - sum over cluster dimension
        # This is DNAFormer's key insight: improves SNR and robustness to cluster size
        nci_output = torch.sum(aligned, dim=1)  # (B, alignment_filters, read_len)
        
        # 4. Embedding module - process cluster as a whole
        cluster_features = self.embedding_module(nci_output)  # (B, embedding_filters, label_seq_len)
        
        # 5. Prepare for GRU: (B, seq_len, features)
        x = cluster_features.permute(0, 2, 1)  # (B, label_seq_len, embedding_filters)
        
        # 6. BiGRU
        gru_out, _ = self.gru(x)  # (B, label_seq_len, gru_hidden*2)
        gru_out = self.dropout(gru_out)
        
        # 7. Output projection
        logits = self.fc_out(gru_out)  # (B, label_seq_len, vocab_size)
        
        return logits


# =============================================================================
# COSINE ANNEALING WARMUP SCHEDULER
# =============================================================================

class CosineAnnealingWarmupScheduler:
    """Cosine annealing learning rate scheduler with warmup (DNAFormer uses this)"""
    
    def __init__(self, optimizer, warmup_epochs, total_epochs, max_lr, min_lr):
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        self.max_lr = max_lr
        self.min_lr = min_lr
        self.current_epoch = 0
    
    def step(self):
        if self.current_epoch < self.warmup_epochs:
            # Linear warmup
            lr = self.max_lr * (self.current_epoch + 1) / self.warmup_epochs
        else:
            # Cosine annealing
            progress = (self.current_epoch - self.warmup_epochs) / (self.total_epochs - self.warmup_epochs)
            lr = self.min_lr + (self.max_lr - self.min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
        
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        
        self.current_epoch += 1
        return lr
    
    def get_last_lr(self):
        return [group['lr'] for group in self.optimizer.param_groups]


# =============================================================================
# TRAINING UTILITIES
# =============================================================================

def train_one_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    loop = tqdm(loader, desc="Training", leave=False)
    
    for clusters, labels in loop:
        clusters, labels = clusters.to(device), labels.to(device)
        
        # Forward pass
        logits = model(clusters)
        loss = criterion(logits.view(-1, VOCAB_SIZE), labels.view(-1))
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
        loop.set_postfix(loss=loss.item())
    
    return total_loss / len(loader)


def validate(model, loader, criterion, device):
    """Validate the model"""
    model.eval()
    total_loss = 0
    loop = tqdm(loader, desc="Validating", leave=False)
    
    with torch.no_grad():
        for clusters, labels in loop:
            clusters, labels = clusters.to(device), labels.to(device)
            logits = model(clusters)
            loss = criterion(logits.view(-1, VOCAB_SIZE), labels.view(-1))
            total_loss += loss.item()
            loop.set_postfix(loss=loss.item())
    
    return total_loss / len(loader)


# =============================================================================
# MAIN EXECUTION
# =============================================================================

if __name__ == "__main__":
    
    # =========================================================================
    # LOAD DATA
    # =========================================================================
    
    print(f"\n{'='*70}")
    print("LOADING DATA")
    print(f"{'='*70}\n")
    
    full_train_dataset = DnaClusterDataset(
        filepath=TRAIN_FILE,
        max_cluster_size=MAX_CLUSTER_SIZE,
        max_read_len=MAX_READ_LEN,
        label_seq_len=LABEL_SEQ_LEN,
        char_to_int=VOCAB,
        padding_idx=PADDING_IDX
    )
    
    # 98% train, 2% validation
    val_size = int(0.02 * len(full_train_dataset))
    if val_size < BATCH_SIZE:
        val_size = min(BATCH_SIZE * 2, len(full_train_dataset) // 10)
    train_size = len(full_train_dataset) - val_size
    
    train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])
    
    print(f"\nTotal clusters: {len(full_train_dataset):,}")
    print(f"Training:       {len(train_dataset):,}")
    print(f"Validation:     {len(val_dataset):,}")
    
    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=2, pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=2, pin_memory=True
    )
    
    # =========================================================================
    # INITIALIZE MODEL
    # =========================================================================
    
    model = ImprovedDNAReconstructionModel(
        vocab_size=VOCAB_SIZE,
        label_seq_len=LABEL_SEQ_LEN,
        max_read_len=MAX_READ_LEN,
        padding_idx=PADDING_IDX,
        embed_dim=EMBED_DIM,
        alignment_filters=ALIGNMENT_FILTERS,
        embedding_filters=EMBEDDING_FILTERS,
        gru_hidden=GRU_HIDDEN,
        gru_layers=GRU_LAYERS,
        dropout=DROPOUT
    ).to(DEVICE)
    
    num_params = sum(p.numel() for p in model.parameters())
    print(f"\n{'='*70}")
    print("MODEL ARCHITECTURE")
    print(f"{'='*70}")
    print(f"Total parameters: {num_params:,} ({num_params/1e6:.2f}M)")
    print(f"Embedding dim:    {EMBED_DIM}")
    print(f"Alignment:        {ALIGNMENT_FILTERS} filters")
    print(f"Embedding:        {EMBEDDING_FILTERS} filters")
    print(f"GRU hidden:       {GRU_HIDDEN}")
    print(f"GRU layers:       {GRU_LAYERS}")
    print(f"{'='*70}")
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=PADDING_IDX, label_smoothing=0.1)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    
    scheduler = CosineAnnealingWarmupScheduler(
        optimizer=optimizer,
        warmup_epochs=WARMUP_EPOCHS,
        total_epochs=EPOCHS,
        max_lr=LEARNING_RATE,
        min_lr=MIN_LR
    )
    
    print(f"\nOptimizer: Adam")
    print(f"LR Schedule: Warmup {WARMUP_EPOCHS} â†’ Cosine {LEARNING_RATE:.2e} â†’ {MIN_LR:.2e}")
    print(f"Batch size: {BATCH_SIZE}")
    
    # =========================================================================
    # TRAINING LOOP
    # =========================================================================
    
    print(f"\n{'='*70}")
    print(f"TRAINING: {DATASET_NAME}")
    print(f"{'='*70}\n")
    
    best_val_loss = float('inf')
    patience_counter = 0
    train_losses = []
    val_losses = []
    learning_rates = []
    
    for epoch in range(EPOCHS):
        start_time = time.time()
        current_lr = scheduler.get_last_lr()[0]
        
        # Train and validate
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, DEVICE)
        val_loss = validate(model, val_loader, criterion, DEVICE)
        
        # Step scheduler
        new_lr = scheduler.step()
        
        # Record history
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        learning_rates.append(current_lr)
        
        # Print progress
        print(f"Epoch {epoch+1:3d}/{EPOCHS} | "
              f"LR: {current_lr:.2e} | "
              f"Train: {train_loss:.4f} | "
              f"Val: {val_loss:.4f} | "
              f"Time: {time.time()-start_time:.1f}s", end="")
        
        # Save best model
        if val_loss < best_val_loss:
            improvement = best_val_loss - val_loss
            best_val_loss = val_loss
            torch.save(model.state_dict(), WEIGHTS_DIR / f"best_model_{DATASET_NAME}.pth")
            print(f" âœ“ BEST (â†“{improvement:.4f})")
            patience_counter = 0
        else:
            print()
            patience_counter += 1
        
        # Early stopping
        if patience_counter >= PATIENCE:
            print(f"\nEarly stopping at epoch {epoch+1}")
            break
    
    # Save final model
    torch.save(model.state_dict(), WEIGHTS_DIR / f"final_model_{DATASET_NAME}.pth")
    
    print(f"\n{'='*70}")
    print(f"Training Finished")
    print(f"{'='*70}")
    print(f"Best val loss: {best_val_loss:.4f}")
    print(f"Models saved to: {WEIGHTS_DIR}")
    
    # Save training history
    pd.DataFrame({
        'epoch': range(1, len(train_losses)+1),
        'train_loss': train_losses,
        'val_loss': val_losses,
        'learning_rate': learning_rates
    }).to_csv(RESULTS_DIR / f"history_{DATASET_NAME}.csv", index=False)
    
    # =========================================================================
    # EVALUATION
    # =========================================================================
    
    print(f"\n{'='*70}")
    print(f"EVALUATION: {DATASET_NAME}")
    print(f"{'='*70}\n")
    
    # Load best model
    inference_model = ImprovedDNAReconstructionModel(
        vocab_size=VOCAB_SIZE,
        label_seq_len=LABEL_SEQ_LEN,
        max_read_len=MAX_READ_LEN,
        padding_idx=PADDING_IDX,
        embed_dim=EMBED_DIM,
        alignment_filters=ALIGNMENT_FILTERS,
        embedding_filters=EMBEDDING_FILTERS,
        gru_hidden=GRU_HIDDEN,
        gru_layers=GRU_LAYERS,
        dropout=DROPOUT
    ).to(DEVICE)
    
    model_path = WEIGHTS_DIR / f"best_model_{DATASET_NAME}.pth"
    inference_model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    inference_model.eval()
    print(f"âœ“ Loaded model from {model_path}")
    
    # Load test data
    test_dataset = DnaClusterDataset(
        filepath=EVAL_FILE,
        max_cluster_size=MAX_CLUSTER_SIZE,
        max_read_len=MAX_READ_LEN,
        label_seq_len=LABEL_SEQ_LEN,
        char_to_int=VOCAB,
        padding_idx=PADDING_IDX
    )
    
    test_loader = DataLoader(
        test_dataset, batch_size=BATCH_SIZE * 2, shuffle=False,
        num_workers=2, pin_memory=True
    )
    
    # Evaluate
    total_clusters = 0
    failed_clusters = 0
    
    with torch.no_grad():
        for clusters, labels in tqdm(test_loader, desc="Testing"):
            clusters = clusters.to(DEVICE)
            logits = inference_model(clusters)
            predictions = torch.argmax(logits, dim=2).cpu()
            
            for i in range(labels.shape[0]):
                total_clusters += 1
                pred_seq = decode_seq(predictions[i], INT_TO_CHAR)
                label_seq = decode_seq(labels[i], INT_TO_CHAR)
                if pred_seq != label_seq:
                    failed_clusters += 1
    
    failure_rate = (failed_clusters / total_clusters) * 100 if total_clusters > 0 else 0
    
    print(f"\n{'='*70}")
    print(f"RESULTS: {DATASET_NAME}")
    print(f"{'='*70}")
    print(f"Total clusters:  {total_clusters:,}")
    print(f"Failed clusters: {failed_clusters}")
    print(f"Failure rate:    {failure_rate:.4f}%")
    print(f"Target:          {CONFIG['target_failure']}%")
    print(f"{'='*70}")
    
    if failure_rate <= CONFIG['target_failure']:
        print(f"\nðŸŽ‰ SUCCESS! Target achieved!")
    else:
        gap = failure_rate - CONFIG['target_failure']
        print(f"\nðŸ“ˆ Gap to target: {gap:.4f}%")
    
    # Save results
    pd.DataFrame([{
        'dataset': DATASET_NAME,
        'total_clusters': total_clusters,
        'failed_clusters': failed_clusters,
        'failure_rate_percent': failure_rate,
        'target_percent': CONFIG['target_failure'],
        'model_params': num_params
    }]).to_csv(RESULTS_DIR / f"results_{DATASET_NAME}.csv", index=False)
    
    print(f"\nResults saved to: {RESULTS_DIR}")
    print(f"\n{'='*70}")
    print("COMPLETE!")
    print(f"{'='*70}")

PyTorch version: 2.5.1
Using device: cuda
GPU: NVIDIA GeForce RTX 3080

DATASET: Erlich
  Label length:    152
  Max deviation:   10
  Max read length: 170
  Target failure:  0.02%

  Files:
    Train: generated_data_corrected/binned_synthetic_erlich.txt
    Eval:  Data/Erlich.txt
  Output: Experiments/Erlich_ImprovedBiGRU_2026-01-14_12-56


LOADING DATA

Loading data from generated_data_corrected/binned_synthetic_erlich.txt...


Parsing binned_synthetic_erlich.txt:   0%|          | 0/1500001 [00:00<?, ?it/s]

Successfully loaded 1500000 clusters.

Total clusters: 1,500,000
Training:       1,470,000
Validation:     30,000

MODEL ARCHITECTURE
Total parameters: 3,480,949 (3.48M)
Embedding dim:    300
Alignment:        128 filters
Embedding:        500 filters
GRU hidden:       300
GRU layers:       2

Optimizer: Adam
LR Schedule: Warmup 10 â†’ Cosine 5.00e-04 â†’ 5.00e-08
Batch size: 200

TRAINING: Erlich



Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch   1/120 | LR: 5.00e-04 | Train: 0.4289 | Val: 0.3995 | Time: 1436.4s âœ“ BEST (â†“inf)


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch   2/120 | LR: 5.00e-05 | Train: 0.3990 | Val: 0.3980 | Time: 1432.5s âœ“ BEST (â†“0.0015)


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch   3/120 | LR: 1.00e-04 | Train: 0.3990 | Val: 0.3981 | Time: 1432.4s


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch   4/120 | LR: 1.50e-04 | Train: 0.3989 | Val: 0.3980 | Time: 1432.1s âœ“ BEST (â†“0.0001)


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch   5/120 | LR: 2.00e-04 | Train: 0.3989 | Val: 0.3979 | Time: 1432.0s âœ“ BEST (â†“0.0001)


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch   6/120 | LR: 2.50e-04 | Train: 0.3988 | Val: 0.3983 | Time: 1432.1s


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch   7/120 | LR: 3.00e-04 | Train: 0.3988 | Val: 0.3980 | Time: 1431.8s


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch   8/120 | LR: 3.50e-04 | Train: 0.3988 | Val: 0.3982 | Time: 1431.0s


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch   9/120 | LR: 4.00e-04 | Train: 0.3988 | Val: 0.3978 | Time: 1431.5s âœ“ BEST (â†“0.0001)


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch  10/120 | LR: 4.50e-04 | Train: 0.3988 | Val: 0.3981 | Time: 1431.0s


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch  11/120 | LR: 5.00e-04 | Train: 0.3988 | Val: 0.3980 | Time: 1431.8s


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch  12/120 | LR: 5.00e-04 | Train: 0.3987 | Val: 0.3980 | Time: 1431.7s


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch  13/120 | LR: 5.00e-04 | Train: 0.3986 | Val: 0.3975 | Time: 1431.5s âœ“ BEST (â†“0.0003)


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch  14/120 | LR: 5.00e-04 | Train: 0.3985 | Val: 0.3976 | Time: 1431.5s


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch  15/120 | LR: 4.99e-04 | Train: 0.3985 | Val: 0.3978 | Time: 1431.9s


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch  16/120 | LR: 4.98e-04 | Train: 0.3984 | Val: 0.3976 | Time: 1431.8s


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch  17/120 | LR: 4.97e-04 | Train: 0.3984 | Val: 0.3978 | Time: 1431.1s


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch  18/120 | LR: 4.96e-04 | Train: 0.3983 | Val: 0.3977 | Time: 1431.9s


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch  19/120 | LR: 4.95e-04 | Train: 0.3983 | Val: 0.3976 | Time: 1431.7s


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch  20/120 | LR: 4.94e-04 | Train: 0.3983 | Val: 0.3974 | Time: 1431.1s âœ“ BEST (â†“0.0001)


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch  21/120 | LR: 4.92e-04 | Train: 0.3983 | Val: 0.3977 | Time: 1431.9s


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch  22/120 | LR: 4.90e-04 | Train: 0.3983 | Val: 0.3978 | Time: 1431.8s


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch  23/120 | LR: 4.88e-04 | Train: 0.3983 | Val: 0.3978 | Time: 1431.9s


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch  24/120 | LR: 4.85e-04 | Train: 0.3983 | Val: 0.3974 | Time: 1432.5s


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch  25/120 | LR: 4.83e-04 | Train: 0.3983 | Val: 0.3975 | Time: 1431.8s


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch  26/120 | LR: 4.80e-04 | Train: 0.3982 | Val: 0.3975 | Time: 1431.4s


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch  27/120 | LR: 4.77e-04 | Train: 0.3982 | Val: 0.3975 | Time: 1431.4s


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch  28/120 | LR: 4.74e-04 | Train: 0.3982 | Val: 0.3976 | Time: 1431.9s


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch  29/120 | LR: 4.71e-04 | Train: 0.3982 | Val: 0.3978 | Time: 1432.2s


Training:   0%|          | 0/7350 [00:00<?, ?it/s]

Validating:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch  30/120 | LR: 4.68e-04 | Train: 0.3982 | Val: 0.3976 | Time: 1431.6s

Early stopping at epoch 30

Training Finished
Best val loss: 0.3974
Models saved to: Experiments/Erlich_ImprovedBiGRU_2026-01-14_12-56/Models

EVALUATION: Erlich

âœ“ Loaded model from Experiments/Erlich_ImprovedBiGRU_2026-01-14_12-56/Models/best_model_Erlich.pth
Loading data from Data/Erlich.txt...


  inference_model.load_state_dict(torch.load(model_path, map_location=DEVICE))


Parsing Erlich.txt:   0%|          | 0/72001 [00:00<?, ?it/s]

Successfully loaded 72000 clusters.


Testing:   0%|          | 0/180 [00:00<?, ?it/s]


RESULTS: Erlich
Total clusters:  72,000
Failed clusters: 14
Failure rate:    0.0194%
Target:          0.02%

ðŸŽ‰ SUCCESS! Target achieved!

Results saved to: Experiments/Erlich_ImprovedBiGRU_2026-01-14_12-56/Results

COMPLETE!
