# Promoter Prediction with MergeDNA

This notebook demonstrates how to use the **MergeDNA** model for promoter prediction - a binary classification task to predict whether a DNA sequence contains a promoter region.

**Promoters** are DNA sequences upstream of genes that initiate transcription. Identifying them computationally is crucial for understanding gene regulation.

## What we'll cover:
1. Loading data utilities from `mergedna.data`
2. Using the MergeDNA classifier from `mergedna.model`
3. Training on synthetic promoter data
4. Evaluation and visualization
5. Inference using `mergedna.scripts.inference`


## 1. Setup and Imports


In [None]:
import sys
import os

# Add mergedna to path
sys.path.insert(0, '../mergedna')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import random
import matplotlib.pyplot as plt

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


## 2. Load Data Utilities from MergeDNA

We use the `DNATokenizer` and `PromoterDataset` from `mergedna.data.dataloader`.


In [None]:
# Import from mergedna
from data.dataloader import (
    DNATokenizer, 
    PromoterDataset,
    generate_promoter_sequence,
    generate_non_promoter_sequence,
    generate_random_dna
)

# Initialize tokenizer
tokenizer = DNATokenizer()
print(f"Vocabulary size: {tokenizer.vocab_size}")
print(f"Vocabulary: {tokenizer.vocab}")

# Test encoding/decoding
test_seq = "ATCGATCG"
encoded = tokenizer.encode(test_seq)
decoded = tokenizer.decode(encoded)
print(f"\nTest: '{test_seq}' -> {encoded.tolist()} -> '{decoded}'")


## 3. Create Datasets

The `PromoterDataset` generates synthetic promoter and non-promoter sequences with:
- **TATA box** motifs (`TATAAA` variants)
- **GC-rich regions** (CpG island-like)
- **Initiator elements**


In [None]:
# Show example sequences
print("Example Promoter Sequence:")
promoter_ex = generate_promoter_sequence(256)
print(f"  ...{promoter_ex[40:70]}... (TATA region)")
print(f"  ...{promoter_ex[75:120]}... (GC-rich region)")

print("\nExample Non-Promoter Sequence:")
non_promoter_ex = generate_non_promoter_sequence(256)
print(f"  ...{non_promoter_ex[40:70]}...")
print(f"  ...{non_promoter_ex[75:120]}...")


In [None]:
# Create datasets
SEQ_LENGTH = 256  # Must be divisible by 2^local_layers (2^2 = 4)
BATCH_SIZE = 32

train_dataset = PromoterDataset(num_samples=2000, seq_length=SEQ_LENGTH, tokenizer=tokenizer)
val_dataset = PromoterDataset(num_samples=400, seq_length=SEQ_LENGTH, tokenizer=tokenizer)
test_dataset = PromoterDataset(num_samples=400, seq_length=SEQ_LENGTH, tokenizer=tokenizer)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# Check a batch
batch = next(iter(train_loader))
print(f"\nBatch input_ids shape: {batch['input_ids'].shape}")
print(f"Batch labels shape: {batch['label'].shape}")


## 4. Load MergeDNA Classifier

We use `MergeDNAClassifier` from `mergedna.model.backbone`. This model:
- Uses **dynamic token merging** to compress sequences
- Applies **hierarchical encoding** (local â†’ latent)
- Adds a **classification head** on top


In [None]:
# Import model from mergedna
from model.backbone import MergeDNAClassifier

# Initialize model
model = MergeDNAClassifier(
    vocab_size=tokenizer.vocab_size,
    dim=64,
    local_layers=2,
    latent_layers=2,
    heads=4,
    mlp_dim=256,
    num_classes=2,
    dropout=0.1,
    pooling='mean'
).to(device)

# Count parameters
num_params = sum(p.numel() for p in model.parameters())
num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {num_params:,}")
print(f"Trainable parameters: {num_trainable:,}")

# Test forward pass
test_input = batch['input_ids'].to(device)
test_output = model(test_input)
print(f"\nInput shape: {test_input.shape}")
print(f"Output shape: {test_output.shape}")


## 5. Training Functions


In [None]:
def train_epoch(model, loader, optimizer, criterion, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch in loader:
        input_ids = batch['input_ids'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        
        logits = model(input_ids)
        loss = criterion(logits, labels)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        preds = logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    
    return total_loss / len(loader), correct / total

def evaluate(model, loader, criterion, device):
    """Evaluate the model."""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            labels = batch['label'].to(device)
            
            logits = model(input_ids)
            loss = criterion(logits, labels)
            
            total_loss += loss.item()
            preds = logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return total_loss / len(loader), correct / total, all_preds, all_labels


## 6. Training Loop


In [None]:
# Training configuration
NUM_EPOCHS = 15
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': []
}

best_val_acc = 0.0
CHECKPOINT_PATH = 'best_promoter_model.pt'

print("Starting training...")
print("-" * 60)

for epoch in range(NUM_EPOCHS):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc, _, _ = evaluate(model, val_loader, criterion, device)
    
    scheduler.step()
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        # Save best model
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
        }, CHECKPOINT_PATH)
    
    print(f"Epoch {epoch+1:2d}/{NUM_EPOCHS} | "
          f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

print("-" * 60)
print(f"Best validation accuracy: {best_val_acc:.4f}")


## 7. Training Visualization


In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss plot
axes[0].plot(history['train_loss'], label='Train Loss', color='#2ecc71', linewidth=2)
axes[0].plot(history['val_loss'], label='Val Loss', color='#e74c3c', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training & Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy plot
axes[1].plot(history['train_acc'], label='Train Acc', color='#2ecc71', linewidth=2)
axes[1].plot(history['val_acc'], label='Val Acc', color='#e74c3c', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training & Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


## 8. Test Set Evaluation


In [None]:
# Load best model and evaluate on test set
checkpoint = torch.load(CHECKPOINT_PATH)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded model from epoch {checkpoint['epoch']+1} with val_acc: {checkpoint['val_acc']:.4f}")

test_loss, test_acc, test_preds, test_labels = evaluate(model, test_loader, criterion, device)

print(f"\nTest Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")


In [None]:
# Confusion matrix
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

cm = confusion_matrix(test_labels, test_preds)

fig, ax = plt.subplots(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
            xticklabels=['Non-Promoter', 'Promoter'],
            yticklabels=['Non-Promoter', 'Promoter'])
ax.set_xlabel('Predicted')
ax.set_ylabel('Actual')
ax.set_title('Confusion Matrix')
plt.tight_layout()
plt.show()

# Classification report
print("\nClassification Report:")
print(classification_report(test_labels, test_preds, 
                          target_names=['Non-Promoter', 'Promoter']))


## 9. Inference with MergeDNA Utilities

Use the inference utilities from `mergedna.scripts.inference` for easy predictions.


In [None]:
# Import inference utilities
from scripts.inference import predict_promoter, batch_predict, get_embeddings

print("=" * 60)
print("INFERENCE EXAMPLES")
print("=" * 60)

# Test on new sequences
new_promoter = generate_promoter_sequence(256)
result = predict_promoter(model, new_promoter, tokenizer, device)
print(f"\n[Generated Promoter Sequence]")
print(f"  Preview: {result['sequence_preview']}")
print(f"  Prediction: {result['prediction']}")
print(f"  Confidence: {result['confidence']:.4f}")

new_non_promoter = generate_non_promoter_sequence(256)
result = predict_promoter(model, new_non_promoter, tokenizer, device)
print(f"\n[Generated Non-Promoter Sequence]")
print(f"  Preview: {result['sequence_preview']}")
print(f"  Prediction: {result['prediction']}")
print(f"  Confidence: {result['confidence']:.4f}")

# Test with explicit TATA box
tata_seq = "GCGCGCGCGCATATATATATATAAAAAAAATCAGTTGCGCGCGCGCGCGC" + generate_random_dna(206)
result = predict_promoter(model, tata_seq, tokenizer, device)
print(f"\n[Sequence with TATA-like motif]")
print(f"  Preview: {result['sequence_preview']}")
print(f"  Prediction: {result['prediction']}")
print(f"  Confidence: {result['confidence']:.4f}")


In [None]:
# Batch prediction example
test_sequences = [
    generate_promoter_sequence(256),
    generate_promoter_sequence(256),
    generate_non_promoter_sequence(256),
    generate_non_promoter_sequence(256),
]

results = batch_predict(model, test_sequences, tokenizer, device)

print("Batch Predictions:")
for i, res in enumerate(results):
    print(f"  Seq {i+1}: {res['prediction']:12s} (confidence: {res['confidence']:.4f})")
