# Chapter 6: RNN/LSTM for Protein Sequence Analysis

Welcome to protein sequence analysis! 🧪

In this notebook, we'll use Recurrent Neural Networks (RNNs) and Long Short-Term Memory (LSTM) networks to classify protein families. Proteins are the workhorses of biology, and understanding them is crucial for drug design, disease research, and biotechnology.

## 🎯 The Biological Problem

### What are Proteins?

Proteins are chains of amino acids (20 different types, represented by letters like A, G, W, etc.):

```
MKTAYIAKQRQISFVKSHFSRQLEERLGLIEV...
```

- **Length:** Typically 50-1000 amino acids (some much longer!)
- **Function:** Determined by the sequence and 3D structure
- **Families:** Proteins with similar sequences usually have similar functions

### Why Classify Protein Families?

When scientists discover a new protein sequence, they want to know:
- What does it do? (enzyme, transport, signaling, etc.)
- What family does it belong to?
- Can we predict its structure or function?

**Example families:**
- Kinases (proteins that add phosphate groups)
- G-protein coupled receptors (cell signaling)
- Immunoglobulins (antibodies)

### The Machine Learning Task

**Input:** Protein sequence
```
MGAAASIQTTVNTLSERISSKLEQEANASAQTKCDIEIGNFYIRQNHGCNLTVKNMCSAD
```

**Output:** Which family does this belong to? (multi-class classification)

**Challenge:** Sequences vary in length, and similar function can come from very different sequences!

## 📚 What You'll Learn

1. **RNNs:** Networks that can process sequences of varying lengths
2. **LSTMs:** Advanced RNNs that can remember long-range dependencies
3. **Embeddings:** How to represent amino acids as vectors
4. **Bidirectional Processing:** Reading sequences forward AND backward
5. **Handling Variable Lengths:** Padding and masking techniques

## 🔧 Why RNNs/LSTMs for Sequences?

Unlike CNNs (which look at local patterns), RNNs:
- Process one amino acid at a time
- Maintain a "memory" of what they've seen
- Can handle sequences of any length
- Capture long-range dependencies (amino acids far apart but functionally related)

Let's dive in! 🚀

---


## 1. Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from tqdm import tqdm

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Understanding Amino Acids and Protein Encoding

Proteins are made of 20 standard amino acids. Each has different properties:
- **Hydrophobic**: A, V, I, L, M, F, W, P
- **Polar**: S, T, N, Q, Y, C
- **Charged**: K, R, H (positive), D, E (negative)
- **Special**: G (flexible), P (rigid)

We'll use integer encoding where each amino acid gets a unique number.

In [None]:
# Amino acid vocabulary
AMINO_ACIDS = 'ACDEFGHIKLMNPQRSTVWY'
aa_to_idx = {aa: idx + 1 for idx, aa in enumerate(AMINO_ACIDS)}  # 0 reserved for padding
aa_to_idx['<PAD>'] = 0
idx_to_aa = {idx: aa for aa, idx in aa_to_idx.items()}

print("Amino acid encoding:")
for aa, idx in list(aa_to_idx.items())[:5]:
    print(f"  {aa}: {idx}")
print(f"  ...")
print(f"Total vocabulary size: {len(aa_to_idx)}")

## 3. Generate Synthetic Protein Sequences

We'll create three simplified protein families:
1. **Kinases**: Often have catalytic domains with specific motifs
2. **Proteases**: Have active site residues (catalytic triad)
3. **Transporters**: Hydrophobic regions for membrane spanning

In practice, you would use databases like Pfam, UniProt, or SCOP.

In [None]:
def generate_kinase_sequence(length=150):
    """
    Generate synthetic kinase-like sequence.
    Kinases often have ATP-binding motif (GXGXXG) and catalytic loop (DFG).
    """
    sequence = [np.random.choice(list(AMINO_ACIDS)) for _ in range(length)]
    
    # Insert ATP-binding motif
    motif = ['G', 'K', 'G', 'S', 'F', 'G']
    pos = np.random.randint(10, 40)
    sequence[pos:pos+len(motif)] = motif
    
    # Insert catalytic motif (DFG)
    catalytic = ['D', 'F', 'G']
    pos = np.random.randint(60, 90)
    sequence[pos:pos+len(catalytic)] = catalytic
    
    return ''.join(sequence)

def generate_protease_sequence(length=150):
    """
    Generate synthetic protease-like sequence.
    Proteases have catalytic triad (Ser-His-Asp or Cys-His-Asn).
    """
    sequence = [np.random.choice(list(AMINO_ACIDS)) for _ in range(length)]
    
    # Insert catalytic triad residues at different positions
    positions = [np.random.randint(20, 50), 
                 np.random.randint(60, 90), 
                 np.random.randint(100, 130)]
    sequence[positions[0]] = 'S'  # Serine
    sequence[positions[1]] = 'H'  # Histidine
    sequence[positions[2]] = 'D'  # Aspartate
    
    return ''.join(sequence)

def generate_transporter_sequence(length=150):
    """
    Generate synthetic transporter-like sequence.
    Transporters have hydrophobic transmembrane regions.
    """
    hydrophobic_aa = 'AVILMFW'
    sequence = []
    
    # Create alternating hydrophobic and mixed regions
    pos = 0
    while pos < length:
        if pos % 40 < 20:  # Transmembrane region
            sequence.append(np.random.choice(list(hydrophobic_aa)))
        else:  # Cytoplasmic/extracellular region
            sequence.append(np.random.choice(list(AMINO_ACIDS)))
        pos += 1
    
    return ''.join(sequence)

# Generate dataset
n_samples_per_class = 400
sequences = []
labels = []
label_names = ['Kinase', 'Protease', 'Transporter']

for i in range(n_samples_per_class):
    sequences.append(generate_kinase_sequence(length=np.random.randint(120, 180)))
    labels.append(0)
    
    sequences.append(generate_protease_sequence(length=np.random.randint(120, 180)))
    labels.append(1)
    
    sequences.append(generate_transporter_sequence(length=np.random.randint(120, 180)))
    labels.append(2)

print(f"Generated {len(sequences)} protein sequences")
print(f"\nClass distribution:")
for i, name in enumerate(label_names):
    count = labels.count(i)
    print(f"  {name}: {count}")

print(f"\nExample sequences:")
for i in range(3):
    print(f"  {label_names[i]}: {sequences[i * n_samples_per_class][:50]}...")

## 4. Sequence Encoding and Padding

RNNs can handle variable-length sequences, but batching requires padding:
- All sequences in a batch must have the same length
- We pad shorter sequences with zeros
- We'll use PyTorch's `pack_padded_sequence` to handle this efficiently

In [None]:
def encode_sequence(sequence, aa_to_idx):
    """
    Convert amino acid sequence to integers.
    """
    return [aa_to_idx.get(aa, 0) for aa in sequence]

def pad_sequences(sequences, max_length=None):
    """
    Pad sequences to the same length.
    """
    if max_length is None:
        max_length = max(len(seq) for seq in sequences)
    
    padded = []
    lengths = []
    
    for seq in sequences:
        length = len(seq)
        lengths.append(length)
        padded.append(seq + [0] * (max_length - length))
    
    return padded, lengths

# Test encoding
test_seq = "ACDEFG"
encoded = encode_sequence(test_seq, aa_to_idx)
print(f"Sequence: {test_seq}")
print(f"Encoded: {encoded}")

## 5. PyTorch Dataset for Variable-Length Sequences

In [None]:
class ProteinSequenceDataset(Dataset):
    """
    PyTorch Dataset for protein sequences.
    """
    def __init__(self, sequences, labels, aa_to_idx):
        self.sequences = sequences
        self.labels = labels
        self.aa_to_idx = aa_to_idx
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        label = self.labels[idx]
        
        # Encode sequence
        encoded = encode_sequence(sequence, self.aa_to_idx)
        length = len(encoded)
        
        return torch.tensor(encoded, dtype=torch.long), length, torch.tensor(label, dtype=torch.long)

def collate_fn(batch):
    """
    Custom collate function to handle variable-length sequences.
    """
    # Sort batch by sequence length (required for pack_padded_sequence)
    batch.sort(key=lambda x: x[1], reverse=True)
    
    sequences, lengths, labels = zip(*batch)
    
    # Pad sequences
    sequences_padded = nn.utils.rnn.pad_sequence(sequences, batch_first=True, padding_value=0)
    
    return sequences_padded, torch.tensor(lengths), torch.stack(labels)

# Create dataset
dataset = ProteinSequenceDataset(sequences, labels, aa_to_idx)

# Split dataset
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size]
)

print(f"Dataset splits:")
print(f"  Train: {len(train_dataset)}")
print(f"  Validation: {len(val_dataset)}")
print(f"  Test: {len(test_dataset)}")

## 6. Create DataLoaders

In [None]:
batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

print(f"Number of batches:")
print(f"  Train: {len(train_loader)}")
print(f"  Validation: {len(val_loader)}")
print(f"  Test: {len(test_loader)}")

## 7. Build Bidirectional LSTM Model

### Why Bidirectional?

Bidirectional LSTMs process the sequence in both directions:
- **Forward LSTM**: Reads sequence left-to-right (N-terminus to C-terminus)
- **Backward LSTM**: Reads sequence right-to-left
- **Combined**: Captures context from both directions

This is powerful for proteins because functional regions can depend on context from both sides!

### LSTM Components:

1. **Embedding Layer**: Converts amino acid indices to dense vectors
2. **LSTM Layers**: Process the sequence while maintaining memory
3. **Dropout**: Prevents overfitting
4. **Fully Connected**: Maps LSTM output to class probabilities

In [None]:
class ProteinLSTM(nn.Module):
    """
    Bidirectional LSTM for protein sequence classification.
    """
    def __init__(self, vocab_size, embedding_dim=64, hidden_dim=128, 
                 num_layers=2, num_classes=3, dropout=0.3):
        super(ProteinLSTM, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Embedding layer: converts amino acid indices to dense vectors
        # Each amino acid gets a learnable embedding
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embedding_dim,
            padding_idx=0  # Padding token
        )
        
        # Bidirectional LSTM
        # bidirectional=True means we process sequence in both directions
        self.lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0
        )
        
        # Dropout for regularization
        self.dropout = nn.Dropout(dropout)
        
        # Fully connected layer
        # *2 because bidirectional (forward + backward)
        self.fc = nn.Linear(hidden_dim * 2, num_classes)
    
    def forward(self, sequences, lengths):
        # Embed sequences
        embedded = self.embedding(sequences)  # (batch, seq_len, embedding_dim)
        
        # Pack padded sequences for efficient processing
        # This tells PyTorch to ignore padding during computation
        packed = pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=True)
        
        # Pass through LSTM
        packed_output, (hidden, cell) = self.lstm(packed)
        
        # Unpack sequences
        output, _ = pad_packed_sequence(packed_output, batch_first=True)
        
        # Use the final hidden state from both directions
        # hidden shape: (num_layers * 2, batch, hidden_dim)
        # We take the last layer's forward and backward hidden states
        forward_hidden = hidden[-2, :, :]
        backward_hidden = hidden[-1, :, :]
        
        # Concatenate forward and backward
        combined = torch.cat((forward_hidden, backward_hidden), dim=1)
        
        # Apply dropout and fully connected layer
        out = self.dropout(combined)
        out = self.fc(out)
        
        return out

# Create model
vocab_size = len(aa_to_idx)
model = ProteinLSTM(vocab_size=vocab_size, num_classes=3).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(model)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 8. Training Setup

In [None]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

## 9. Training Functions

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    """
    Train for one epoch.
    """
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for sequences, lengths, labels in tqdm(loader, desc="Training"):
        sequences = sequences.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(sequences, lengths)
        loss = criterion(outputs, labels)
        
        loss.backward()
        
        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    return total_loss / len(loader), 100 * correct / total

def validate(model, loader, criterion, device):
    """
    Validate the model.
    """
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for sequences, lengths, labels in loader:
            sequences = sequences.to(device)
            labels = labels.to(device)
            
            outputs = model(sequences, lengths)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    return total_loss / len(loader), 100 * correct / total

## 10. Train the Model

Note: We use **gradient clipping** to prevent exploding gradients, which is common in RNNs.

In [None]:
num_epochs = 25
train_losses = []
val_losses = []
train_accs = []
val_accs = []

print("Starting training...\n")

for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    scheduler.step(val_loss)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)
    
    print(f"Epoch {epoch+1}/{num_epochs}:")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%\n")

print("Training completed!")

## 11. Visualize Training Progress

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Losses
ax1.plot(train_losses, label='Train Loss', marker='o')
ax1.plot(val_losses, label='Validation Loss', marker='s')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracies
ax2.plot(train_accs, label='Train Accuracy', marker='o')
ax2.plot(val_accs, label='Validation Accuracy', marker='s')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 12. Evaluate on Test Set

In [None]:
def evaluate_model(model, loader, device):
    """
    Comprehensive model evaluation.
    """
    model.eval()
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for sequences, lengths, labels in loader:
            sequences = sequences.to(device)
            labels = labels.to(device)
            
            outputs = model(sequences, lengths)
            _, predicted = torch.max(outputs, 1)
            
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return np.array(all_predictions), np.array(all_labels)

# Evaluate
test_predictions, test_labels = evaluate_model(model, test_loader, device)

# Classification report
print("Classification Report:")
print(classification_report(
    test_labels, test_predictions, 
    target_names=['Kinase', 'Protease', 'Transporter']
))

test_accuracy = 100 * np.sum(test_predictions == test_labels) / len(test_labels)
print(f"\nTest Accuracy: {test_accuracy:.2f}%")

## 13. Confusion Matrix

In [None]:
cm = confusion_matrix(test_labels, test_predictions)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Kinase', 'Protease', 'Transporter'],
            yticklabels=['Kinase', 'Protease', 'Transporter'])
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title('Confusion Matrix - Protein Family Classification')
plt.show()

# Per-class accuracy
print("\nPer-class accuracy:")
for i, name in enumerate(['Kinase', 'Protease', 'Transporter']):
    class_correct = cm[i, i]
    class_total = cm[i, :].sum()
    print(f"  {name}: {100 * class_correct / class_total:.2f}%")

## 14. Understanding LSTM vs Simple RNN

Let's create a simple RNN for comparison to understand why LSTMs are better.

In [None]:
class SimpleRNN(nn.Module):
    """
    Simple RNN for comparison with LSTM.
    """
    def __init__(self, vocab_size, embedding_dim=64, hidden_dim=128, num_classes=3):
        super(SimpleRNN, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.rnn = nn.RNN(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            batch_first=True,
            bidirectional=True
        )
        self.fc = nn.Linear(hidden_dim * 2, num_classes)
    
    def forward(self, sequences, lengths):
        embedded = self.embedding(sequences)
        packed = pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=True)
        _, hidden = self.rnn(packed)
        
        forward_hidden = hidden[-2, :, :]
        backward_hidden = hidden[-1, :, :]
        combined = torch.cat((forward_hidden, backward_hidden), dim=1)
        
        return self.fc(combined)

print("\nKey differences between RNN and LSTM:")
print("\n1. RNN:")
print("   - Simple recurrent connections")
print("   - Suffers from vanishing/exploding gradients")
print("   - Poor at learning long-term dependencies")
print("\n2. LSTM:")
print("   - Has memory cells and gates (input, forget, output)")
print("   - Better gradient flow through time")
print("   - Can learn long-range dependencies")
print("   - More parameters but more powerful")

## Summary and Key Takeaways

In this notebook, we:

1. ✅ **Encoded protein sequences** using integer encoding and embeddings
2. ✅ **Built a bidirectional LSTM** to capture context from both directions
3. ✅ **Handled variable-length sequences** efficiently with padding and packing
4. ✅ **Trained the model** with proper regularization (dropout, gradient clipping)
5. ✅ **Evaluated multi-class classification** performance

### Why LSTMs for Protein Sequences?

- **Sequential nature**: Proteins have sequential dependencies
- **Variable length**: Proteins vary greatly in length (50-30,000 amino acids)
- **Long-range interactions**: Functional sites can be far apart in sequence
- **Bidirectional context**: N-terminal and C-terminal regions both matter

### Advanced Techniques (Next Steps):

- **Attention mechanisms**: Focus on important regions
- **Transfer learning**: Use pre-trained protein language models (ESM, ProtTrans)
- **Multi-task learning**: Predict multiple properties simultaneously
- **GRU**: Alternative to LSTM with fewer parameters

### Real-World Applications:

- Protein function prediction
- Subcellular localization
- Post-translational modification sites
- Protein-protein interaction prediction
- Drug target identification