# Promoter Prediction with MergeDNA

This notebook demonstrates how to use the MergeDNA model for **promoter prediction** - a binary classification task where we predict whether a DNA sequence contains a promoter region.

Promoters are DNA sequences located upstream of genes that initiate transcription. Identifying them computationally is crucial for understanding gene regulation.

## What we'll cover:
1. DNA tokenization and encoding
2. Loading the MergeDNA model
3. Adapting the model for classification
4. Training on synthetic promoter data
5. Evaluation and visualization


## 1. Setup and Imports


In [1]:
import sys
sys.path.insert(0, '../mergedna')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


ModuleNotFoundError: No module named 'tqdm'

## 2. DNA Tokenization

DNA sequences are composed of 4 nucleotides: **A** (Adenine), **T** (Thymine), **C** (Cytosine), **G** (Guanine).

We'll create a simple tokenizer that maps these to integer indices, plus special tokens for padding and unknown characters.


In [None]:
class DNATokenizer:
    """Simple DNA sequence tokenizer."""
    
    def __init__(self):
        # Token vocabulary
        self.vocab = {
            '<PAD>': 0,
            '<UNK>': 1,
            '<CLS>': 2,
            '<SEP>': 3,
            'A': 4,
            'T': 5,
            'C': 6,
            'G': 7,
            # IUPAC ambiguity codes
            'N': 8,   # Any nucleotide
            'R': 9,   # A or G (purine)
            'Y': 10,  # C or T (pyrimidine)
            'S': 11,  # G or C
        }
        self.inv_vocab = {v: k for k, v in self.vocab.items()}
        self.vocab_size = len(self.vocab)
    
    def encode(self, sequence: str, max_length: int = None) -> torch.Tensor:
        """Encode a DNA sequence to token indices."""
        tokens = [self.vocab.get(c.upper(), self.vocab['<UNK>']) for c in sequence]
        
        if max_length:
            if len(tokens) < max_length:
                tokens += [self.vocab['<PAD>']] * (max_length - len(tokens))
            else:
                tokens = tokens[:max_length]
        
        return torch.tensor(tokens, dtype=torch.long)
    
    def decode(self, tokens: torch.Tensor) -> str:
        """Decode token indices back to DNA sequence."""
        return ''.join(self.inv_vocab.get(t.item(), '?') for t in tokens)

# Initialize tokenizer
tokenizer = DNATokenizer()
print(f"Vocabulary size: {tokenizer.vocab_size}")
print(f"Vocabulary: {tokenizer.vocab}")

# Test encoding/decoding
test_seq = "ATCGATCG"
encoded = tokenizer.encode(test_seq)
decoded = tokenizer.decode(encoded)
print(f"\nTest: '{test_seq}' -> {encoded.tolist()} -> '{decoded}'")


## 3. Synthetic Promoter Dataset

For this example, we'll create synthetic data that mimics promoter vs non-promoter sequences.

**Promoter characteristics** (simplified):
- TATA box motif (~25-35 bp upstream): `TATAAA` or variants
- GC-rich regions near transcription start site
- Initiator element (Inr): `YYANTYY` pattern

We'll generate sequences with these patterns for positive class and random sequences for negative class.


In [None]:
def generate_random_dna(length: int) -> str:
    """Generate a random DNA sequence."""
    return ''.join(random.choices(['A', 'T', 'C', 'G'], k=length))

def generate_promoter_sequence(length: int = 256) -> str:
    """
    Generate a synthetic promoter-like sequence with characteristic motifs.
    """
    seq = list(generate_random_dna(length))
    
    # Insert TATA box (position ~200-210 from TSS, we place it around position 50-60)
    tata_variants = ['TATAAA', 'TATATA', 'TATAAG', 'TATAAT']
    tata_pos = random.randint(45, 55)
    tata_seq = random.choice(tata_variants)
    for i, nucleotide in enumerate(tata_seq):
        if tata_pos + i < length:
            seq[tata_pos + i] = nucleotide
    
    # Insert GC-rich region (CpG island-like) around position 80-120
    gc_start = random.randint(75, 85)
    gc_length = random.randint(30, 40)
    for i in range(gc_length):
        if gc_start + i < length:
            seq[gc_start + i] = random.choice(['G', 'C', 'G', 'C', 'G', 'C', 'A', 'T'])  # 75% GC
    
    # Insert initiator element around position 130-140
    inr_pos = random.randint(125, 135)
    inr_patterns = ['TCAGTT', 'CCAATT', 'TCAATT', 'CCAGTT']
    inr_seq = random.choice(inr_patterns)
    for i, nucleotide in enumerate(inr_seq):
        if inr_pos + i < length:
            seq[inr_pos + i] = nucleotide
    
    return ''.join(seq)

def generate_non_promoter_sequence(length: int = 256) -> str:
    """
    Generate a non-promoter sequence (mostly random, avoiding promoter patterns).
    """
    seq = generate_random_dna(length)
    # Ensure no TATA box
    seq = seq.replace('TATAAA', 'GGCCGG')
    seq = seq.replace('TATATA', 'GGCCGG')
    return seq

# Generate example sequences
print("Example Promoter Sequence:")
promoter_ex = generate_promoter_sequence(256)
print(f"  ...{promoter_ex[40:70]}... (TATA region)")
print(f"  ...{promoter_ex[75:120]}... (GC-rich region)")

print("\nExample Non-Promoter Sequence:")
non_promoter_ex = generate_non_promoter_sequence(256)
print(f"  ...{non_promoter_ex[40:70]}...")
print(f"  ...{non_promoter_ex[75:120]}...")


In [None]:
class PromoterDataset(Dataset):
    """Dataset for promoter prediction task."""
    
    def __init__(self, num_samples: int = 1000, seq_length: int = 256, tokenizer: DNATokenizer = None):
        self.num_samples = num_samples
        self.seq_length = seq_length
        self.tokenizer = tokenizer or DNATokenizer()
        
        # Generate balanced dataset
        self.sequences = []
        self.labels = []
        
        for i in range(num_samples):
            if i < num_samples // 2:
                # Promoter (positive)
                seq = generate_promoter_sequence(seq_length)
                label = 1
            else:
                # Non-promoter (negative)
                seq = generate_non_promoter_sequence(seq_length)
                label = 0
            
            self.sequences.append(seq)
            self.labels.append(label)
        
        # Shuffle
        combined = list(zip(self.sequences, self.labels))
        random.shuffle(combined)
        self.sequences, self.labels = zip(*combined)
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        seq = self.sequences[idx]
        label = self.labels[idx]
        
        # Tokenize
        tokens = self.tokenizer.encode(seq, max_length=self.seq_length)
        
        return {
            'input_ids': tokens,
            'label': torch.tensor(label, dtype=torch.long)
        }

# Create datasets
SEQ_LENGTH = 256  # Must be divisible by 2^local_layers (2^2 = 4)

train_dataset = PromoterDataset(num_samples=2000, seq_length=SEQ_LENGTH, tokenizer=tokenizer)
val_dataset = PromoterDataset(num_samples=400, seq_length=SEQ_LENGTH, tokenizer=tokenizer)
test_dataset = PromoterDataset(num_samples=400, seq_length=SEQ_LENGTH, tokenizer=tokenizer)

# Create dataloaders
BATCH_SIZE = 32

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# Check a batch
batch = next(iter(train_loader))
print(f"\nBatch shape: {batch['input_ids'].shape}")
print(f"Labels shape: {batch['label'].shape}")


## 4. MergeDNA Model for Classification

We'll adapt the MergeDNA architecture for sequence classification. The key idea is to use the encoder to get compressed representations, then add a classification head on top.

The MergeDNA model uses **dynamic token merging** to compress long sequences efficiently, which is particularly useful for DNA sequences that can be thousands of base pairs long.


In [None]:
# Import MergeDNA components
from model.backbone import MergeDNAModel
from model.encoder import LocalEncoder, LatentEncoder
from mdna.utils import TransformerBlock, SinusoidalPositionalEmbedding

class MergeDNAClassifier(nn.Module):
    """
    MergeDNA model adapted for sequence classification.
    Uses the encoder for feature extraction and adds a classification head.
    """
    
    def __init__(
        self,
        vocab_size: int = 12,
        dim: int = 64,
        local_layers: int = 2,
        latent_layers: int = 2,
        heads: int = 4,
        mlp_dim: int = 256,
        num_classes: int = 2,
        dropout: float = 0.1
    ):
        super().__init__()
        
        self.dim = dim
        self.local_layers = local_layers
        
        # Embedding layers
        self.embedding = nn.Embedding(vocab_size, dim)
        self.pos_emb = SinusoidalPositionalEmbedding(dim)
        
        # Encoder (from MergeDNA)
        self.local_encoder = LocalEncoder(local_layers, dim, heads, mlp_dim, dropout)
        self.latent_encoder = LatentEncoder(latent_layers, dim, heads, mlp_dim, dropout)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, num_classes)
        )
        
    def forward(self, x):
        """
        x: [Batch, Seq_Len] - token indices
        Returns: [Batch, num_classes] - logits
        """
        # Embed tokens
        x_emb = self.embedding(x)  # [B, S, D]
        x_emb = self.pos_emb(x_emb)
        
        # Encode - this reduces sequence length via token merging
        x_local = self.local_encoder(x_emb)  # [B, S/4, D] with 2 local layers
        x_latent = self.latent_encoder(x_local)  # [B, S/4, D]
        
        # Global pooling for classification
        # Mean pooling over sequence dimension
        x_pooled = x_latent.mean(dim=1)  # [B, D]
        
        # Classification
        logits = self.classifier(x_pooled)  # [B, num_classes]
        
        return logits
    
    def get_embeddings(self, x):
        """Get the latent embeddings (useful for visualization)."""
        x_emb = self.embedding(x)
        x_emb = self.pos_emb(x_emb)
        x_local = self.local_encoder(x_emb)
        x_latent = self.latent_encoder(x_local)
        return x_latent

# Initialize model
model = MergeDNAClassifier(
    vocab_size=tokenizer.vocab_size,
    dim=64,
    local_layers=2,
    latent_layers=2,
    heads=4,
    mlp_dim=256,
    num_classes=2,
    dropout=0.1
).to(device)

# Count parameters
num_params = sum(p.numel() for p in model.parameters())
num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {num_params:,}")
print(f"Trainable parameters: {num_trainable:,}")

# Test forward pass
test_input = batch['input_ids'].to(device)
test_output = model(test_input)
print(f"\nInput shape: {test_input.shape}")
print(f"Output shape: {test_output.shape}")


## 5. Training Loop

Now let's train the model on our promoter prediction task.


In [None]:
def train_epoch(model, loader, optimizer, criterion, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch in loader:
        input_ids = batch['input_ids'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        
        logits = model(input_ids)
        loss = criterion(logits, labels)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        preds = logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    
    return total_loss / len(loader), correct / total

def evaluate(model, loader, criterion, device):
    """Evaluate the model."""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            labels = batch['label'].to(device)
            
            logits = model(input_ids)
            loss = criterion(logits, labels)
            
            total_loss += loss.item()
            preds = logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return total_loss / len(loader), correct / total, all_preds, all_labels


In [None]:
# Training configuration
NUM_EPOCHS = 15
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': []
}

best_val_acc = 0.0

print("Starting training...")
print("-" * 50)

for epoch in range(NUM_EPOCHS):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc, _, _ = evaluate(model, val_loader, criterion, device)
    
    scheduler.step()
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        # Save best model
        torch.save(model.state_dict(), 'best_promoter_model.pt')
    
    print(f"Epoch {epoch+1:2d}/{NUM_EPOCHS} | "
          f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

print("-" * 50)
print(f"Best validation accuracy: {best_val_acc:.4f}")


## 6. Training Visualization


In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss plot
axes[0].plot(history['train_loss'], label='Train Loss', color='#2ecc71', linewidth=2)
axes[0].plot(history['val_loss'], label='Val Loss', color='#e74c3c', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training & Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy plot
axes[1].plot(history['train_acc'], label='Train Acc', color='#2ecc71', linewidth=2)
axes[1].plot(history['val_acc'], label='Val Acc', color='#e74c3c', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training & Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


## 7. Test Set Evaluation


In [None]:
# Load best model and evaluate on test set
model.load_state_dict(torch.load('best_promoter_model.pt'))

test_loss, test_acc, test_preds, test_labels = evaluate(model, test_loader, criterion, device)

print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")

# Confusion matrix
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

cm = confusion_matrix(test_labels, test_preds)

fig, ax = plt.subplots(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
            xticklabels=['Non-Promoter', 'Promoter'],
            yticklabels=['Non-Promoter', 'Promoter'])
ax.set_xlabel('Predicted')
ax.set_ylabel('Actual')
ax.set_title('Confusion Matrix')
plt.tight_layout()
plt.show()

# Classification report
print("\nClassification Report:")
print(classification_report(test_labels, test_preds, 
                          target_names=['Non-Promoter', 'Promoter']))


## 8. Inference Example

Let's test the model on some new sequences to see how it performs.


In [None]:
def predict_promoter(model, sequence: str, tokenizer: DNATokenizer, device, seq_length: int = 256):
    """
    Predict whether a DNA sequence is a promoter.
    
    Returns:
        dict with prediction label, probabilities, and confidence
    """
    model.eval()
    
    # Tokenize
    tokens = tokenizer.encode(sequence, max_length=seq_length).unsqueeze(0).to(device)
    
    # Predict
    with torch.no_grad():
        logits = model(tokens)
        probs = F.softmax(logits, dim=-1)
        pred = logits.argmax(dim=-1).item()
    
    return {
        'sequence_preview': sequence[:50] + '...' if len(sequence) > 50 else sequence,
        'prediction': 'Promoter' if pred == 1 else 'Non-Promoter',
        'confidence': probs[0, pred].item(),
        'probabilities': {
            'non_promoter': probs[0, 0].item(),
            'promoter': probs[0, 1].item()
        }
    }

# Test on new sequences
print("=" * 60)
print("INFERENCE EXAMPLES")
print("=" * 60)

# Generate a new promoter sequence
new_promoter = generate_promoter_sequence(256)
result = predict_promoter(model, new_promoter, tokenizer, device)
print(f"\n[Generated Promoter Sequence]")
print(f"  Preview: {result['sequence_preview']}")
print(f"  Prediction: {result['prediction']}")
print(f"  Confidence: {result['confidence']:.4f}")

# Generate a new non-promoter sequence
new_non_promoter = generate_non_promoter_sequence(256)
result = predict_promoter(model, new_non_promoter, tokenizer, device)
print(f"\n[Generated Non-Promoter Sequence]")
print(f"  Preview: {result['sequence_preview']}")
print(f"  Prediction: {result['prediction']}")
print(f"  Confidence: {result['confidence']:.4f}")

# Test with a sequence containing clear TATA box
tata_sequence = "GCGCGCGCGCATATATATATATAAAAAAAATCAGTTGCGCGCGCGCGCGC" + generate_random_dna(206)
result = predict_promoter(model, tata_sequence, tokenizer, device)
print(f"\n[Sequence with TATA-like motif]")
print(f"  Preview: {result['sequence_preview']}")
print(f"  Prediction: {result['prediction']}")
print(f"  Confidence: {result['confidence']:.4f}")


## 9. Embedding Visualization

Let's visualize how the model separates promoter and non-promoter sequences in the latent space using t-SNE.


In [None]:
from sklearn.manifold import TSNE

# Get embeddings for test set
model.eval()
all_embeddings = []
all_labels = []

with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(device)
        labels = batch['label']
        
        # Get latent embeddings and pool
        embeddings = model.get_embeddings(input_ids)  # [B, S/4, D]
        pooled = embeddings.mean(dim=1)  # [B, D]
        
        all_embeddings.append(pooled.cpu().numpy())
        all_labels.extend(labels.numpy())

all_embeddings = np.vstack(all_embeddings)
all_labels = np.array(all_labels)

# Apply t-SNE
print("Computing t-SNE projection...")
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
embeddings_2d = tsne.fit_transform(all_embeddings)

# Plot
fig, ax = plt.subplots(figsize=(8, 6))

colors = ['#3498db', '#e74c3c']
labels_text = ['Non-Promoter', 'Promoter']

for i in range(2):
    mask = all_labels == i
    ax.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1], 
               c=colors[i], label=labels_text[i], alpha=0.6, s=30)

ax.set_xlabel('t-SNE Component 1')
ax.set_ylabel('t-SNE Component 2')
ax.set_title('t-SNE Visualization of Learned Embeddings')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"\nEmbedding shape: {all_embeddings.shape}")


## Summary

In this notebook, we demonstrated:

1. **DNA Tokenization**: Converting nucleotide sequences (A, T, C, G) into token indices
2. **Synthetic Data Generation**: Creating promoter-like sequences with characteristic motifs (TATA box, GC-rich regions, initiator elements)
3. **MergeDNA for Classification**: Adapting the MergeDNA encoder architecture for binary classification by adding a pooling layer and classification head
4. **Training Pipeline**: Standard PyTorch training loop with validation monitoring
5. **Evaluation**: Confusion matrix and classification metrics on held-out test set
6. **Embedding Visualization**: t-SNE projection showing learned representations

### Key Features of MergeDNA for Genomics:

- **Dynamic Token Merging**: Efficiently compresses long DNA sequences while preserving important features
- **Hierarchical Encoding**: Local encoder captures short-range patterns (motifs), latent encoder captures long-range dependencies
- **Scalability**: The merging mechanism allows processing of very long sequences (thousands of base pairs)

### Next Steps:

- Use real promoter datasets (e.g., EPDnew, ENCODE)
- Experiment with different sequence lengths and model sizes
- Add attention visualization to identify important motifs
- Fine-tune on organism-specific promoter data
