## setup and dataloading

In [3]:
import sys
sys.path.insert(0, "..")

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import numpy as np

from mergedna.dataloader import dataloader, merge_sequences
from mergedna.backbone import MergeDNAModel
from mergedna.merging import MergeDNALayer, MergeDNAUnmerge

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Paths
DATA_DIR = "../data/human_nontata_promoters/"
CHECKPOINT_PATH = "../checkpoints/mergedna_pretrain.pt"

# Create checkpoint directory
import os
os.makedirs("../checkpoints", exist_ok=True)

In [None]:
#variables
batch = 4

In [None]:
pretrain_sequences = dataloader(directory_path)
sequences = merge_sequences(pretrain_sequences)

Loading sequences...
test
positive
negative
train
positive
negative
Total sequences loaded: 36131
Sample sequence length: 251
Sample sequence: CAGGAATCGAACACTAGGAATCCTACTCGATAGGTGGCACGGATTGCGGA...


## train

In [7]:
# DNA vocabulary mapping
DNA_VOCAB = {'A': 0, 'C': 1, 'G': 2, 'T': 3, 'N': 0}  # N maps to A as fallback

class DNADataset(Dataset):
    """Dataset for DNA sequences."""
    def __init__(self, sequences, max_len=251):
        self.sequences = sequences
        self.max_len = max_len
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        seq = self.sequences[idx].strip().upper()
        # Convert to indices
        indices = [DNA_VOCAB.get(base, 0) for base in seq[:self.max_len]]
        # Pad if necessary
        if len(indices) < self.max_len:
            indices += [0] * (self.max_len - len(indices))
        return torch.tensor(indices, dtype=torch.long)

# Test the dataset
dataset = DNADataset(all_sequences, max_len=MAX_SEQ_LEN)
print(f"Dataset size: {len(dataset)}")
print(f"Sample tensor shape: {dataset[0].shape}")
print(f"Sample tensor values: {dataset[0][:20]}")


Dataset size: 36131
Sample tensor shape: torch.Size([251])
Sample tensor values: tensor([1, 0, 2, 2, 0, 0, 3, 1, 2, 0, 0, 1, 0, 1, 3, 0, 2, 2, 0, 0])


In [10]:
class LocalEncoder(nn.Module):
    """
    Local Encoder: Embeds DNA bases and performs token merging via pooling.
    Maps N bases -> L tokens (L < N) with source matrix tracking.
    Uses simple average pooling for robust merging.
    """
    def __init__(self, dim, vocab_size=4, merge_ratio=0.5):
        super().__init__()
        self.dim = dim
        self.merge_ratio = merge_ratio
        self.embed = nn.Embedding(vocab_size, dim)
        self.pos_embed = nn.Parameter(torch.randn(1, 512, dim) * 0.02)
        
        # Local attention for context before merging
        self.local_attn = nn.TransformerEncoderLayer(
            d_model=dim, nhead=4, dim_feedforward=dim*4, 
            dropout=0.1, batch_first=True
        )
        
        # Learnable merge weights
        self.merge_proj = nn.Linear(dim, 1)
    
    def forward(self, x):
        """
        x: [B, N] indices or [B, N, 4] one-hot
        Returns: z_l [B, L, D], source_map [B, N]
        """
        if x.dim() == 3:  # one-hot
            x_idx = x.argmax(dim=-1)
        else:
            x_idx = x
        
        B, N = x_idx.shape
        
        # Embed
        z = self.embed(x_idx)  # [B, N, D]
        z = z + self.pos_embed[:, :N, :]
        
        # Local attention
        z = self.local_attn(z)
        
        # Simple pooling-based merging (pairs of adjacent tokens)
        # This is more robust than bipartite matching
        target_len = max(int(N * self.merge_ratio), 1)
        
        # Merge by averaging pairs of tokens
        # Pad to even length if needed
        if N % 2 == 1:
            z = F.pad(z, (0, 0, 0, 1))  # Pad sequence dim
            N_padded = N + 1
        else:
            N_padded = N
        
        # Reshape and average pairs: [B, N, D] -> [B, N/2, 2, D] -> [B, N/2, D]
        z_pairs = z.view(B, N_padded // 2, 2, self.dim)
        z_merged = z_pairs.mean(dim=2)  # [B, N/2, D]
        
        # Build source map: each original token points to its merged token
        # Token 0,1 -> 0, Token 2,3 -> 1, etc.
        source_map = torch.arange(N, device=z.device) // 2
        source_map = source_map.unsqueeze(0).expand(B, -1).clone()  # [B, N]
        
        return z_merged, source_map


class LocalDecoder(nn.Module):
    """
    Local Decoder: Unmerges tokens back to original length and refines.
    Maps L tokens -> N bases using source matrix.
    """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        
        # Refinement layers after unmerging
        self.refine = nn.TransformerEncoderLayer(
            d_model=dim, nhead=4, dim_feedforward=dim*4,
            dropout=0.1, batch_first=True
        )
        self.proj = nn.Linear(dim, dim)
    
    def forward(self, z_l, source_map):
        """
        z_l: [B, L, D] merged tokens
        source_map: [B, N] ownership map (indices into L)
        Returns: x_hat [B, N, D]
        """
        B, L, D = z_l.shape
        N = source_map.shape[1]
        
        # Unmerge: broadcast L tokens back to N positions using source_map
        indices = source_map.unsqueeze(-1).expand(-1, -1, D)  # [B, N, D]
        x_unmerged = torch.gather(z_l, 1, indices)  # [B, N, D]
        
        # Refine
        x_refined = self.refine(x_unmerged)
        
        return self.proj(x_refined)

# Test the components
test_input = dataset[0].unsqueeze(0).to(device)
local_enc = LocalEncoder(DIM).to(device)
local_dec = LocalDecoder(DIM).to(device)

z_l, s_map = local_enc(test_input)
x_hat = local_dec(z_l, s_map)

print(f"Input shape: {test_input.shape}")
print(f"Latent shape: {z_l.shape}")
print(f"Source map shape: {s_map.shape}")
print(f"Reconstructed shape: {x_hat.shape}")


Input shape: torch.Size([1, 251])
Latent shape: torch.Size([1, 126, 64])
Source map shape: torch.Size([1, 251])
Reconstructed shape: torch.Size([1, 251, 64])


In [11]:
# Initialize model components
local_encoder = LocalEncoder(DIM, merge_ratio=0.5).to(device)
local_decoder = LocalDecoder(DIM).to(device)

# Create the full MergeDNA model
model = MergeDNAModel(
    local_encoder=local_encoder,
    local_decoder=local_decoder,
    dim=DIM,
    latent_enc_depth=LATENT_ENC_DEPTH,
    latent_dec_depth=LATENT_DEC_DEPTH,
    vocab_size=4
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")


Total parameters: 442,581
Trainable parameters: 442,581


In [12]:
# Create DataLoader
train_loader = DataLoader(
    dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=0,
    pin_memory=True if device.type == 'cuda' else False
)

print(f"Number of batches: {len(train_loader)}")
print(f"Batch size: {BATCH_SIZE}")


Number of batches: 2259
Batch size: 16


In [13]:
def train_epoch(model, dataloader, optimizer, device, lambda_latent=0.25):
    """Train for one epoch."""
    model.train()
    
    epoch_losses = {'total': [], 'mtr': [], 'latent': [], 'amtm': []}
    
    pbar = tqdm(dataloader, desc="Training")
    for batch_idx, batch in enumerate(pbar):
        batch = batch.to(device)
        optimizer.zero_grad()
        
        try:
            loss, logs = model.forward_train(batch, lambda_latent=lambda_latent)
        except Exception as e:
            print(f"Error in batch {batch_idx}: {e}")
            continue
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        epoch_losses['total'].append(loss.item())
        epoch_losses['mtr'].append(logs['loss_mtr'])
        epoch_losses['latent'].append(logs['loss_latent'])
        epoch_losses['amtm'].append(logs['loss_amtm'])
        
        pbar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'mtr': f"{logs['loss_mtr']:.3f}",
            'amtm': f"{logs['loss_amtm']:.3f}"
        })
    
    return {k: np.mean(v) for k, v in epoch_losses.items()}


In [14]:
# Initialize optimizer and scheduler
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-6)

# Training history
history = {'total': [], 'mtr': [], 'latent': [], 'amtm': [], 'lr': []}

print("Starting pre-training...")
print("="*60)

best_loss = float('inf')

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    
    epoch_losses = train_epoch(model, train_loader, optimizer, device, lambda_latent=LAMBDA_LATENT)
    
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]
    
    for key in ['total', 'mtr', 'latent', 'amtm']:
        history[key].append(epoch_losses[key])
    history['lr'].append(current_lr)
    
    print(f"  Total: {epoch_losses['total']:.4f} | MTR: {epoch_losses['mtr']:.4f} | Latent: {epoch_losses['latent']:.4f} | AMTM: {epoch_losses['amtm']:.4f}")
    
    if epoch_losses['total'] < best_loss:
        best_loss = epoch_losses['total']
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': best_loss,
            'history': history
        }, CHECKPOINT_PATH)
        print(f"  ✓ Saved best model (loss: {best_loss:.4f})")

print("\n" + "="*60)
print("Pre-training complete!")


Starting pre-training...

Epoch 1/10


Training:   0%|          | 0/2259 [00:13<?, ?it/s]


KeyboardInterrupt: 

## Training Visualization

In [None]:
# Load training history from checkpoint (if training was interrupted)
if os.path.exists(CHECKPOINT_PATH):
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
    if 'history' in checkpoint:
        history = checkpoint['history']
        print(f"Loaded history from checkpoint (epoch {checkpoint['epoch'] + 1})")
    else:
        print("No history in checkpoint, using current history")
else:
    print("No checkpoint found, using current history")

# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('MergeDNA Pre-training Losses', fontsize=16, fontweight='bold', y=1.02)

colors = {
    'total': '#2563EB',   # Blue
    'mtr': '#DC2626',     # Red
    'latent': '#F59E0B',  # Amber
    'amtm': '#10B981'     # Emerald
}

epochs_range = range(1, len(history['total']) + 1)

# Total Loss
ax = axes[0, 0]
ax.plot(epochs_range, history['total'], 'o-', color=colors['total'], linewidth=2.5, markersize=8)
ax.fill_between(epochs_range, history['total'], alpha=0.15, color=colors['total'])
ax.set_title('Total Loss', fontsize=13, fontweight='bold', pad=10)
ax.set_xlabel('Epoch', fontsize=11)
ax.set_ylabel('Loss', fontsize=11)
ax.grid(True, alpha=0.3, linestyle='--')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# MTR Loss
ax = axes[0, 1]
ax.plot(epochs_range, history['mtr'], 's-', color=colors['mtr'], linewidth=2.5, markersize=8)
ax.fill_between(epochs_range, history['mtr'], alpha=0.15, color=colors['mtr'])
ax.set_title('MTR Loss (Merged Token Reconstruction)', fontsize=13, fontweight='bold', pad=10)
ax.set_xlabel('Epoch', fontsize=11)
ax.set_ylabel('Loss', fontsize=11)
ax.grid(True, alpha=0.3, linestyle='--')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# Latent MTR Loss
ax = axes[1, 0]
ax.plot(epochs_range, history['latent'], '^-', color=colors['latent'], linewidth=2.5, markersize=8)
ax.fill_between(epochs_range, history['latent'], alpha=0.15, color=colors['latent'])
ax.set_title('Latent MTR Loss (Adaptive Selection)', fontsize=13, fontweight='bold', pad=10)
ax.set_xlabel('Epoch', fontsize=11)
ax.set_ylabel('Loss', fontsize=11)
ax.grid(True, alpha=0.3, linestyle='--')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# AMTM Loss
ax = axes[1, 1]
ax.plot(epochs_range, history['amtm'], 'd-', color=colors['amtm'], linewidth=2.5, markersize=8)
ax.fill_between(epochs_range, history['amtm'], alpha=0.15, color=colors['amtm'])
ax.set_title('AMTM Loss (Adaptive Masked Modeling)', fontsize=13, fontweight='bold', pad=10)
ax.set_xlabel('Epoch', fontsize=11)
ax.set_ylabel('Loss', fontsize=11)
ax.grid(True, alpha=0.3, linestyle='--')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.tight_layout()
plt.savefig('../checkpoints/training_curves.png', dpi=150, bbox_inches='tight', facecolor='white')
plt.show()

print(f"\nTraining curves saved to ../checkpoints/training_curves.png")


In [None]:
# Combined loss plot with all metrics
fig, ax = plt.subplots(figsize=(12, 6))

ax.plot(epochs_range, history['total'], 'o-', label='Total', color=colors['total'], linewidth=2.5, markersize=7)
ax.plot(epochs_range, history['mtr'], 's-', label='MTR', color=colors['mtr'], linewidth=2, markersize=6)
ax.plot(epochs_range, history['latent'], '^-', label='Latent MTR', color=colors['latent'], linewidth=2, markersize=6)
ax.plot(epochs_range, history['amtm'], 'd-', label='AMTM', color=colors['amtm'], linewidth=2, markersize=6)

ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('MergeDNA Pre-training: All Losses', fontsize=14, fontweight='bold')
ax.legend(loc='upper right', fontsize=11, framealpha=0.9)
ax.grid(True, alpha=0.3, linestyle='--')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.tight_layout()
plt.savefig('../checkpoints/combined_losses.png', dpi=150, bbox_inches='tight', facecolor='white')
plt.show()

# Training summary
print("\n" + "="*60)
print("TRAINING SUMMARY")
print("="*60)
print(f"\nFinal Losses (Epoch {len(history['total'])}):")
print(f"  • Total:  {history['total'][-1]:.4f}")
print(f"  • MTR:    {history['mtr'][-1]:.4f}")
print(f"  • Latent: {history['latent'][-1]:.4f}")
print(f"  • AMTM:   {history['amtm'][-1]:.4f}")

if len(history['total']) > 1:
    improvement = (history['total'][0] - history['total'][-1]) / history['total'][0] * 100
    print(f"\nImprovement: {improvement:.1f}% reduction in total loss")


## Inference: Pure Reconstruction (Autoencoding)

**Mode 1: Anomaly Detection via Reconstruction Error**

- **Input**: Full DNA sequence (e.g., `ATGC...`)
- **Action**: Compress → Process → Decompress
- **Use Case**: Anomaly detection. If you feed in a mutated/unusual sequence and the model reconstructs it as a "normal" sequence (high reconstruction error), you know the input was anomalous.

The reconstruction error can serve as an **anomaly score** — higher error indicates the sequence deviates from patterns learned during pre-training.


In [None]:
# Load the best model from checkpoint
def load_model_from_checkpoint(checkpoint_path, device):
    """Load pre-trained model from checkpoint."""
    
    # Recreate model architecture
    local_encoder = LocalEncoder(DIM, merge_ratio=0.5).to(device)
    local_decoder = LocalDecoder(DIM).to(device)
    
    model = MergeDNAModel(
        local_encoder=local_encoder,
        local_decoder=local_decoder,
        dim=DIM,
        latent_enc_depth=LATENT_ENC_DEPTH,
        latent_dec_depth=LATENT_DEC_DEPTH,
        vocab_size=4
    ).to(device)
    
    # Load weights
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    print(f"✓ Loaded model from epoch {checkpoint['epoch'] + 1}")
    print(f"  Training loss: {checkpoint['loss']:.4f}")
    
    return model

# Load the model
inference_model = load_model_from_checkpoint(CHECKPOINT_PATH, device)


In [None]:
# DNA mapping for decoding
IDX_TO_BASE = {0: 'A', 1: 'C', 2: 'G', 3: 'T'}

def sequence_to_tensor(sequence, max_len=251):
    """Convert a DNA sequence string to tensor."""
    seq = sequence.strip().upper()
    indices = [DNA_VOCAB.get(base, 0) for base in seq[:max_len]]
    if len(indices) < max_len:
        indices += [0] * (max_len - len(indices))
    return torch.tensor(indices, dtype=torch.long)

def tensor_to_sequence(tensor):
    """Convert tensor back to DNA sequence string."""
    indices = tensor.cpu().numpy()
    return ''.join([IDX_TO_BASE[idx] for idx in indices])

@torch.no_grad()
def reconstruct_sequence(model, sequence, device):
    """
    Pure reconstruction: Compress → Process → Decompress
    Returns the reconstructed sequence and reconstruction error.
    """
    model.eval()
    
    # Convert to tensor
    x = sequence_to_tensor(sequence).unsqueeze(0).to(device)  # [1, N]
    x_onehot = F.one_hot(x, num_classes=4).float()  # [1, N, 4]
    
    # Forward pass through the autoencoder
    # 1. Local Encode (compress)
    z_l, s_local = model.local_encoder(x_onehot)
    
    # 2. Latent Encode (process)
    z_prime_l = model.latent_encoder(z_l)
    
    # 3. Latent Decode
    z_hat_l = model.latent_decoder(z_prime_l)
    
    # 4. Local Decode (decompress)
    x_hat = model.local_decoder(z_hat_l, s_local)
    
    # 5. Get logits and predictions
    logits = model.head(x_hat)  # [1, N, 4]
    predictions = logits.argmax(dim=-1)  # [1, N]
    
    # Calculate reconstruction error (cross-entropy per position)
    loss = F.cross_entropy(logits.view(-1, 4), x.view(-1), reduction='none')
    per_position_error = loss.view(x.shape)  # [1, N]
    
    # Get reconstructed sequence
    reconstructed = tensor_to_sequence(predictions[0])
    
    # Calculate accuracy
    accuracy = (predictions == x).float().mean().item()
    
    return {
        'original': sequence[:251],
        'reconstructed': reconstructed,
        'accuracy': accuracy,
        'mean_error': per_position_error.mean().item(),
        'per_position_error': per_position_error[0].cpu().numpy(),
        'compression_ratio': z_l.shape[1] / x.shape[1]
    }

print("✓ Reconstruction functions defined")


In [None]:
# Test reconstruction on a normal sequence from the dataset
normal_sequence = all_sequences[0]
result = reconstruct_sequence(inference_model, normal_sequence, device)

print("="*70)
print("RECONSTRUCTION TEST: Normal Sequence")
print("="*70)
print(f"\nOriginal:      {result['original'][:60]}...")
print(f"Reconstructed: {result['reconstructed'][:60]}...")
print(f"\nAccuracy:         {result['accuracy']*100:.2f}%")
print(f"Mean Error:       {result['mean_error']:.4f}")
print(f"Compression:      {result['compression_ratio']:.2f}x ({int(251 * result['compression_ratio'])} tokens)")

# Show mismatches
mismatches = sum(1 for a, b in zip(result['original'], result['reconstructed']) if a != b)
print(f"Mismatches:       {mismatches}/{len(result['original'])} positions")


In [None]:
def create_mutated_sequence(sequence, mutation_rate=0.1):
    """Create a mutated version of a sequence."""
    bases = ['A', 'C', 'G', 'T']
    seq_list = list(sequence.upper())
    n_mutations = int(len(seq_list) * mutation_rate)
    
    mutation_positions = np.random.choice(len(seq_list), n_mutations, replace=False)
    
    for pos in mutation_positions:
        original_base = seq_list[pos]
        # Pick a different base
        new_bases = [b for b in bases if b != original_base]
        seq_list[pos] = np.random.choice(new_bases)
    
    return ''.join(seq_list), mutation_positions

def create_random_sequence(length=251):
    """Create a completely random DNA sequence."""
    bases = ['A', 'C', 'G', 'T']
    return ''.join(np.random.choice(bases, length))

@torch.no_grad()
def compute_anomaly_score(model, sequence, device):
    """Compute anomaly score based on reconstruction error."""
    result = reconstruct_sequence(model, sequence, device)
    return result['mean_error'], result

print("✓ Anomaly detection functions defined")


In [None]:
# Anomaly Detection Demo: Compare normal vs mutated vs random sequences
print("="*70)
print("ANOMALY DETECTION DEMO")
print("="*70)

# 1. Normal sequences (from training data)
normal_scores = []
for i in range(20):
    score, _ = compute_anomaly_score(inference_model, all_sequences[i], device)
    normal_scores.append(score)

# 2. Mutated sequences (10% mutation rate)
mutated_scores = []
for i in range(20):
    mutated_seq, _ = create_mutated_sequence(all_sequences[i], mutation_rate=0.1)
    score, _ = compute_anomaly_score(inference_model, mutated_seq, device)
    mutated_scores.append(score)

# 3. Heavily mutated sequences (30% mutation rate)
heavily_mutated_scores = []
for i in range(20):
    mutated_seq, _ = create_mutated_sequence(all_sequences[i], mutation_rate=0.3)
    score, _ = compute_anomaly_score(inference_model, mutated_seq, device)
    heavily_mutated_scores.append(score)

# 4. Random sequences
random_scores = []
for i in range(20):
    random_seq = create_random_sequence(251)
    score, _ = compute_anomaly_score(inference_model, random_seq, device)
    random_scores.append(score)

print(f"\nAnomaly Scores (Reconstruction Error):")
print(f"  Normal sequences:         {np.mean(normal_scores):.4f} ± {np.std(normal_scores):.4f}")
print(f"  Mutated (10%):            {np.mean(mutated_scores):.4f} ± {np.std(mutated_scores):.4f}")
print(f"  Heavily mutated (30%):    {np.mean(heavily_mutated_scores):.4f} ± {np.std(heavily_mutated_scores):.4f}")
print(f"  Random sequences:         {np.mean(random_scores):.4f} ± {np.std(random_scores):.4f}")


In [None]:
# Visualize anomaly score distributions
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Box plot
ax = axes[0]
data = [normal_scores, mutated_scores, heavily_mutated_scores, random_scores]
labels = ['Normal', 'Mutated\n(10%)', 'Heavily\nMutated (30%)', 'Random']
colors_box = ['#10B981', '#F59E0B', '#EF4444', '#6366F1']

bp = ax.boxplot(data, labels=labels, patch_artist=True)
for patch, color in zip(bp['boxes'], colors_box):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)

ax.set_ylabel('Anomaly Score (Reconstruction Error)', fontsize=11)
ax.set_title('Anomaly Score Distribution by Sequence Type', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y', linestyle='--')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# Bar plot with error bars
ax = axes[1]
means = [np.mean(s) for s in data]
stds = [np.std(s) for s in data]

bars = ax.bar(labels, means, yerr=stds, capsize=5, color=colors_box, alpha=0.8, edgecolor='black', linewidth=1.2)
ax.set_ylabel('Mean Anomaly Score', fontsize=11)
ax.set_title('Mean Anomaly Scores (±1 SD)', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y', linestyle='--')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# Add value labels on bars
for bar, mean in zip(bars, means):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.05, 
            f'{mean:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold')

plt.tight_layout()
plt.savefig('../checkpoints/anomaly_detection.png', dpi=150, bbox_inches='tight', facecolor='white')
plt.show()

print("\nAnomaly detection plot saved to ../checkpoints/anomaly_detection.png")


In [None]:
# Visualize per-position reconstruction error
fig, axes = plt.subplots(2, 1, figsize=(14, 8))

# Normal sequence
normal_result = reconstruct_sequence(inference_model, all_sequences[0], device)

ax = axes[0]
ax.plot(normal_result['per_position_error'], color='#10B981', linewidth=1.5, alpha=0.8)
ax.fill_between(range(len(normal_result['per_position_error'])), 
                normal_result['per_position_error'], alpha=0.3, color='#10B981')
ax.axhline(y=normal_result['mean_error'], color='#10B981', linestyle='--', linewidth=2, label=f'Mean: {normal_result["mean_error"]:.3f}')
ax.set_title('Per-Position Reconstruction Error: Normal Sequence', fontsize=13, fontweight='bold')
ax.set_xlabel('Position (bp)', fontsize=11)
ax.set_ylabel('Error', fontsize=11)
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3, linestyle='--')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# Heavily mutated sequence
mutated_seq, mutation_pos = create_mutated_sequence(all_sequences[0], mutation_rate=0.3)
mutated_result = reconstruct_sequence(inference_model, mutated_seq, device)

ax = axes[1]
ax.plot(mutated_result['per_position_error'], color='#EF4444', linewidth=1.5, alpha=0.8)
ax.fill_between(range(len(mutated_result['per_position_error'])), 
                mutated_result['per_position_error'], alpha=0.3, color='#EF4444')
ax.axhline(y=mutated_result['mean_error'], color='#EF4444', linestyle='--', linewidth=2, label=f'Mean: {mutated_result["mean_error"]:.3f}')

# Mark mutation positions
for pos in mutation_pos:
    ax.axvline(x=pos, color='blue', alpha=0.3, linewidth=0.5)

ax.set_title('Per-Position Reconstruction Error: Mutated Sequence (30%)', fontsize=13, fontweight='bold')
ax.set_xlabel('Position (bp)', fontsize=11)
ax.set_ylabel('Error', fontsize=11)
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3, linestyle='--')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.tight_layout()
plt.savefig('../checkpoints/per_position_error.png', dpi=150, bbox_inches='tight', facecolor='white')
plt.show()

print("\nPer-position error plot saved to ../checkpoints/per_position_error.png")


## Interactive Anomaly Detection

Use this function to test any DNA sequence for anomalies:


In [None]:
def detect_anomaly(sequence, threshold=None):
    """
    Detect if a sequence is anomalous based on reconstruction error.
    
    Args:
        sequence: DNA sequence string (A, C, G, T)
        threshold: Anomaly threshold (if None, uses mean + 2*std of normal sequences)
    
    Returns:
        Dictionary with anomaly detection results
    """
    # Compute threshold from normal sequences if not provided
    if threshold is None:
        threshold = np.mean(normal_scores) + 2 * np.std(normal_scores)
    
    # Get anomaly score
    score, result = compute_anomaly_score(inference_model, sequence, device)
    
    is_anomaly = score > threshold
    
    print("="*70)
    print("ANOMALY DETECTION RESULT")
    print("="*70)
    print(f"\nSequence (first 50bp): {sequence[:50]}...")
    print(f"Sequence length:       {len(sequence)} bp")
    print(f"\nAnomaly Score:         {score:.4f}")
    print(f"Threshold:             {threshold:.4f}")
    print(f"Reconstruction Acc:    {result['accuracy']*100:.2f}%")
    print(f"\nVerdict:               {'⚠️  ANOMALOUS' if is_anomaly else '✓ NORMAL'}")
    
    return {
        'is_anomaly': is_anomaly,
        'score': score,
        'threshold': threshold,
        'accuracy': result['accuracy'],
        'result': result
    }

# Example: Test a normal sequence
print("Testing a NORMAL sequence from the dataset:")
_ = detect_anomaly(all_sequences[100])


In [None]:
# Example: Test a RANDOM (anomalous) sequence
print("\nTesting a RANDOM (anomalous) sequence:")
random_seq = create_random_sequence(251)
_ = detect_anomaly(random_seq)


In [None]:
# Try your own sequence!
# Replace this with any DNA sequence you want to test

custom_sequence = "ATGCATGCATGCATGCATGCATGCATGCATGCATGCATGCATGCATGCATGCATGCATGC" * 4 + "ATG"

print("\nTesting a CUSTOM sequence (repetitive pattern):")
_ = detect_anomaly(custom_sequence[:251])
