# Chapter 5: CNN for DNA Sequence Classification

Welcome to your first complete biology project! 🧬

In this notebook, we'll apply CNNs (from Chapter 2) to classify DNA sequences. Specifically, we'll build a model that can identify **promoter regions** - special DNA sequences that control when genes are turned on or off.

## 🎯 The Biological Problem

### What are Promoters?

Think of DNA as an instruction manual for building proteins. But how does a cell know WHEN to read each instruction? That's where promoters come in:

- **Promoters** are DNA sequences (typically 100-1000 base pairs) located upstream of genes
- They act like "ON switches" for genes
- RNA polymerase (the enzyme that reads DNA) binds to promoters to start transcription
- Promoters often contain recognizable patterns called **motifs** (like "TATA box": TATAAA)

### Why This Matters

Identifying promoters helps us:
- Understand gene regulation (what controls gene activity)
- Predict gene expression patterns
- Design synthetic biology systems
- Diagnose diseases caused by promoter mutations

### The Machine Learning Task

**Input:** A DNA sequence (string of A, T, G, C)
```
ATGCGATATATAAAGCTAGC...
```

**Output:** Is this a promoter region? (Yes/No)

**Challenge:** Promoters don't have exact sequences - they have patterns that can vary. Perfect job for deep learning!

## 📚 What You'll Learn

1. **Data Representation:** How to convert DNA sequences (ATGC) into numbers that neural networks can process
2. **One-Hot Encoding:** The standard way to represent categorical data
3. **1D Convolutions:** Like 2D convolutions for images, but for sequences
4. **Motif Detection:** How CNNs automatically learn to recognize important DNA patterns
5. **Evaluation:** How to measure if your model is working

## 🔧 Skills You'll Practice

- Preparing biological sequence data
- Building and training a CNN for sequences
- Visualizing what the network learned
- Interpreting results in biological context

Let's get started! 🚀

---


## 1. Setup and Imports

First, let's import all necessary libraries. We'll use PyTorch as our deep learning framework.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
from tqdm import tqdm

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Data Generation

Since we're using synthetic data for educational purposes, we'll create a simple dataset:

- **Promoter sequences**: Will contain common promoter motifs like TATA box ("TATAAA")
- **Non-promoter sequences**: Random DNA sequences

In real applications, you would load data from databases like DBTSS or EPD.

In [None]:
def generate_promoter_sequence(length=200):
    """
    Generate a synthetic promoter sequence with common motifs.
    Real promoters often contain TATA box (~25-30bp upstream of TSS)
    """
    bases = ['A', 'T', 'C', 'G']
    sequence = [np.random.choice(bases) for _ in range(length)]
    
    # Insert TATA box motif at a random position
    tata_box = list("TATAAA")
    insert_pos = np.random.randint(20, length - 30)
    sequence[insert_pos:insert_pos + len(tata_box)] = tata_box
    
    # Sometimes add CAAT box
    if np.random.random() > 0.5:
        caat_box = list("CCAAT")
        insert_pos = np.random.randint(40, length - 20)
        sequence[insert_pos:insert_pos + len(caat_box)] = caat_box
    
    return ''.join(sequence)

def generate_non_promoter_sequence(length=200):
    """
    Generate a random DNA sequence without promoter motifs.
    """
    bases = ['A', 'T', 'C', 'G']
    return ''.join([np.random.choice(bases) for _ in range(length)])

# Generate dataset
n_samples = 1000
sequences = []
labels = []

for i in range(n_samples):
    if i < n_samples // 2:
        sequences.append(generate_promoter_sequence())
        labels.append(1)  # Promoter
    else:
        sequences.append(generate_non_promoter_sequence())
        labels.append(0)  # Non-promoter

print(f"Generated {len(sequences)} sequences")
print(f"Example promoter sequence: {sequences[0][:50]}...")
print(f"Example non-promoter sequence: {sequences[-1][:50]}...")

## 3. DNA Sequence Encoding

Neural networks work with numbers, not letters. We need to convert DNA sequences into numerical representations.

### One-Hot Encoding

We'll use **one-hot encoding**, where each nucleotide is represented as a 4-dimensional vector:
- A = [1, 0, 0, 0]
- T = [0, 1, 0, 0]
- C = [0, 0, 1, 0]
- G = [0, 0, 0, 1]

A sequence of length L becomes a matrix of shape (4, L), which is perfect for CNNs!

In [None]:
def one_hot_encode(sequence):
    """
    Convert DNA sequence to one-hot encoding.
    
    Args:
        sequence: DNA sequence string
    
    Returns:
        numpy array of shape (4, len(sequence))
    """
    mapping = {'A': 0, 'T': 1, 'C': 2, 'G': 3}
    seq_len = len(sequence)
    one_hot = np.zeros((4, seq_len), dtype=np.float32)
    
    for i, nucleotide in enumerate(sequence):
        if nucleotide in mapping:
            one_hot[mapping[nucleotide], i] = 1
    
    return one_hot

# Test the encoding
test_seq = "ATCG"
encoded = one_hot_encode(test_seq)
print(f"Sequence: {test_seq}")
print(f"Encoded shape: {encoded.shape}")
print(f"Encoded matrix:\n{encoded}")

## 4. PyTorch Dataset Class

We create a custom Dataset class to handle our sequences efficiently. This is PyTorch's standard way of organizing data.

In [None]:
class DNASequenceDataset(Dataset):
    """
    PyTorch Dataset for DNA sequences.
    """
    def __init__(self, sequences, labels):
        self.sequences = sequences
        self.labels = labels
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        label = self.labels[idx]
        
        # Encode sequence
        encoded = one_hot_encode(sequence)
        
        return torch.tensor(encoded, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

# Create dataset
dataset = DNASequenceDataset(sequences, labels)

# Split into train, validation, and test sets (70%, 15%, 15%)
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size]
)

print(f"Train size: {len(train_dataset)}")
print(f"Validation size: {len(val_dataset)}")
print(f"Test size: {len(test_dataset)}")

## 5. Create DataLoaders

DataLoaders handle batching and shuffling of our data during training.

In [None]:
batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"Number of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(val_loader)}")
print(f"Number of test batches: {len(test_loader)}")

## 6. Build the CNN Model

Now for the exciting part - building our CNN!

### Architecture Components:

1. **Conv1d Layers**: Detect local patterns (motifs) in sequences
   - First layer: Detects simple motifs (3-6 nucleotides)
   - Second layer: Detects combinations of motifs

2. **MaxPooling**: Reduces dimensionality and provides position invariance

3. **Fully Connected Layers**: Combine detected features for final classification

4. **Dropout**: Prevents overfitting by randomly dropping neurons during training

In [None]:
class DNASequenceCNN(nn.Module):
    """
    CNN for DNA sequence classification.
    """
    def __init__(self, sequence_length=200, num_classes=2):
        super(DNASequenceCNN, self).__init__()
        
        # First convolutional block
        # Input: (batch, 4, 200) - 4 channels (A,T,C,G), sequence length 200
        # Kernel size 6: looks at 6 nucleotides at a time (typical motif length)
        self.conv1 = nn.Conv1d(in_channels=4, out_channels=32, kernel_size=6, padding=2)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
        
        # Second convolutional block
        # Detects more complex patterns by combining first-layer features
        self.conv2 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=6, padding=2)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        
        # Third convolutional block
        self.conv3 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=6, padding=2)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2)
        
        # Calculate the size after convolutions and pooling
        # Original: 200 -> after pool1: 100 -> after pool2: 50 -> after pool3: 25
        self.fc_input_size = 128 * 24  # 24 because of padding adjustments
        
        # Fully connected layers
        self.fc1 = nn.Linear(self.fc_input_size, 256)
        self.relu4 = nn.ReLU()
        self.dropout = nn.Dropout(0.5)  # Randomly drop 50% of neurons during training
        self.fc2 = nn.Linear(256, num_classes)
    
    def forward(self, x):
        # Forward pass through the network
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.pool3(self.relu3(self.conv3(x)))
        
        # Flatten for fully connected layers
        x = x.view(x.size(0), -1)
        
        x = self.dropout(self.relu4(self.fc1(x)))
        x = self.fc2(x)
        
        return x

# Create model and move to GPU if available
model = DNASequenceCNN().to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(model)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 7. Training Setup

### Loss Function: Cross-Entropy Loss
Suitable for classification tasks. It measures how well our predictions match the true labels.

### Optimizer: Adam
An adaptive learning rate optimizer that works well for most problems. It's like a smart way to update model weights.

In [None]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Learning rate scheduler: reduces learning rate when validation loss plateaus
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

## 8. Training Loop

The training loop:
1. **Forward pass**: Feed data through the network
2. **Calculate loss**: How wrong are our predictions?
3. **Backward pass**: Calculate gradients (how to adjust weights)
4. **Update weights**: Make the model better

We'll track both training and validation metrics to ensure the model generalizes well.

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    """
    Train for one epoch.
    """
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for sequences, labels in tqdm(loader, desc="Training"):
        sequences, labels = sequences.to(device), labels.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(sequences)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        # Statistics
        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    return total_loss / len(loader), 100 * correct / total

def validate(model, loader, criterion, device):
    """
    Validate the model.
    """
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for sequences, labels in loader:
            sequences, labels = sequences.to(device), labels.to(device)
            
            outputs = model(sequences)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    return total_loss / len(loader), 100 * correct / total

In [None]:
# Training
num_epochs = 20
train_losses = []
val_losses = []
train_accs = []
val_accs = []

print("Starting training...\n")

for epoch in range(num_epochs):
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    # Update learning rate
    scheduler.step(val_loss)
    
    # Store metrics
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)
    
    print(f"Epoch {epoch+1}/{num_epochs}:")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%\n")

print("Training completed!")

## 9. Visualize Training Progress

Let's plot the training and validation metrics to understand how well our model learned.

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Plot losses
ax1.plot(train_losses, label='Train Loss', marker='o')
ax1.plot(val_losses, label='Validation Loss', marker='s')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot accuracies
ax2.plot(train_accs, label='Train Accuracy', marker='o')
ax2.plot(val_accs, label='Validation Accuracy', marker='s')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Key observations:")
print("- Training loss should decrease over time")
print("- If validation loss increases while training loss decreases, we have overfitting")
print("- Similar train and val accuracy indicates good generalization")

## 10. Evaluate on Test Set

Finally, we evaluate on the test set that the model has never seen during training.

In [None]:
def evaluate_model(model, loader, device):
    """
    Comprehensive evaluation of the model.
    """
    model.eval()
    all_predictions = []
    all_labels = []
    all_probabilities = []
    
    with torch.no_grad():
        for sequences, labels in loader:
            sequences, labels = sequences.to(device), labels.to(device)
            
            outputs = model(sequences)
            probabilities = torch.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)
            
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())
    
    return np.array(all_predictions), np.array(all_labels), np.array(all_probabilities)

# Evaluate
test_predictions, test_labels, test_probabilities = evaluate_model(model, test_loader, device)

# Print classification report
print("Classification Report:")
print(classification_report(test_labels, test_predictions, target_names=['Non-Promoter', 'Promoter']))

# Calculate test accuracy
test_accuracy = 100 * np.sum(test_predictions == test_labels) / len(test_labels)
print(f"\nTest Accuracy: {test_accuracy:.2f}%")

## 11. Confusion Matrix

A confusion matrix shows where our model makes mistakes:
- **True Positives (TP)**: Correctly identified promoters
- **True Negatives (TN)**: Correctly identified non-promoters
- **False Positives (FP)**: Non-promoters incorrectly labeled as promoters
- **False Negatives (FN)**: Promoters incorrectly labeled as non-promoters

In [None]:
# Confusion matrix
cm = confusion_matrix(test_labels, test_predictions)

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Non-Promoter', 'Promoter'],
            yticklabels=['Non-Promoter', 'Promoter'])
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title('Confusion Matrix')
plt.show()

## 12. ROC Curve

The ROC (Receiver Operating Characteristic) curve shows the trade-off between sensitivity and specificity. A perfect classifier has AUC = 1.0.

In [None]:
# ROC curve
fpr, tpr, thresholds = roc_curve(test_labels, test_probabilities[:, 1])
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random classifier')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.grid(True, alpha=0.3)
plt.show()

print(f"AUC Score: {roc_auc:.4f}")

## 13. Visualize Learned Filters

Let's examine what patterns (motifs) the first convolutional layer has learned to detect. These should resemble biological motifs!

In [None]:
# Get first layer filters
first_layer_weights = model.conv1.weight.data.cpu().numpy()

# Plot first 8 filters
fig, axes = plt.subplots(2, 4, figsize=(16, 6))
axes = axes.ravel()

nucleotides = ['A', 'T', 'C', 'G']

for i in range(8):
    filter_weights = first_layer_weights[i]  # Shape: (4, kernel_size)
    
    # Create visualization
    im = axes[i].imshow(filter_weights, cmap='RdBu_r', aspect='auto')
    axes[i].set_yticks(range(4))
    axes[i].set_yticklabels(nucleotides)
    axes[i].set_xlabel('Position in motif')
    axes[i].set_title(f'Filter {i+1}')
    plt.colorbar(im, ax=axes[i])

plt.tight_layout()
plt.suptitle('Learned Convolutional Filters (Motif Detectors)', y=1.02, fontsize=14)
plt.show()

print("\nInterpretation:")
print("- Red: Strong positive weights (filter activates for these nucleotides)")
print("- Blue: Strong negative weights (filter suppresses for these nucleotides)")
print("- These patterns should resemble known promoter motifs!")

## Summary and Key Takeaways

In this notebook, we:

1. ✅ **Encoded DNA sequences** using one-hot encoding for neural network input
2. ✅ **Built a CNN architecture** specifically designed for sequence classification
3. ✅ **Trained the model** with proper train/validation/test splits
4. ✅ **Evaluated performance** using multiple metrics (accuracy, confusion matrix, ROC curve)
5. ✅ **Visualized learned patterns** to understand what the model detects

### Why This Approach Works:

- **CNNs excel at pattern recognition**: Perfect for finding motifs in sequences
- **Translation invariance**: Detects motifs regardless of their position
- **Parameter efficiency**: Shared weights across the sequence reduce overfitting

### Next Steps:

- Use real promoter databases (e.g., EPD, DBTSS)
- Experiment with different architectures
- Try attention mechanisms to identify important regions
- Apply to other sequence classification tasks (splice sites, transcription factor binding)

### Real-World Applications:

- Gene annotation
- Regulatory element prediction
- Variant effect prediction
- Drug target identification