# DNA Transformer - Complete Implementation
## K-mers, BPE, One-hot Encoding, and Transformer Models

This notebook demonstrates the improved implementation with:
- Robust FASTA reading
- Three encoding methods (K-mers, BPE, One-hot)
- Transformer architecture
- Complete training pipeline

## 1. Setup and Imports

In [None]:
# Installation (uncomment if needed)
# !pip install torch torchvision torchaudio
# !pip install biopython
# !pip install tokenizers
# !pip install numpy matplotlib

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from dna_transformer_improved import (
    FASTAReader,
    KmerTokenizer,
    DNABPETokenizer,
    OneHotEncoder,
    DNASequenceDataset,
    DNATransformerEncoder,
    DNATransformerOneHot,
    collate_fn_tokens,
    collate_fn_onehot,
    train_epoch,
    evaluate
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Reading FASTA Files

The improved `FASTAReader` handles both gzipped and regular FASTA files.

In [None]:
# Option 1: Read from your actual FASTA file
# Uncomment and replace with your file path
# fasta_reader = FASTAReader("your_file.fasta.gz")
# sequences_data = fasta_reader.read_sequences(limit=100)
# 
# # Display sequence information
# for i, seq in enumerate(sequences_data[:3]):
#     print(f"\nSequence {i+1}:")
#     print(f"  ID: {seq['id']}")
#     print(f"  V Number: {seq['v_number']}")
#     print(f"  Sample: {seq['sample_number']}")
#     print(f"  Length: {seq['length']:,} bases")
#     print(f"  First 50 bases: {seq['sequence'][:50]}")
# 
# # Get statistics
# stats = fasta_reader.get_sequence_stats()
# print(f"\nDataset Statistics:")
# for key, value in stats.items():
#     print(f"  {key}: {value:,.0f}")
# 
# sequences = [seq['sequence'] for seq in sequences_data]

# Option 2: Use example sequences
sequences = [
    "ATGCTAGCTAGCTAGCTAGCTAATGCTAGCGATCGATCGTAGCTAGCTGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTA",
    "GGCTACGTTACGACGTAACGTACGATCGATCGATCGTAGCTAGCTACGATCGATCGACGATCGATCGTAGCTAGCTAGCTAGC",
    "TTACTGACCTGAACCTGACCTACGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTACGATCGATCGTAGCTAGCTAGCTAGC",
    "ACGTACGTACGTACGTACGTACGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTACGATCGATCGTAGCTAGCTAGCTAGCT",
    "ATGCGGATCCGATCGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTACGATCGATCGTAGCTAGCTAGCTAGCT",
    "GGCATGCTAGCATCGATGCATGCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTACGATCGATCGTAGCTAGCTAGCTAGC",
    "TTACGATCGATCGTGCACGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTAGCTACGATCGATCGTAGCTAGCTAGCTAGCT",
    "CGATCGATCGATCGTAGCTAGCTACGATCGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTACGATCGATCGTA",
] * 10  # Replicate for more data

print(f"Loaded {len(sequences)} sequences")
print(f"Average length: {np.mean([len(s) for s in sequences]):.0f} bases")

## 3. K-mer Tokenization

K-mers split sequences into overlapping or non-overlapping subsequences of length k.

In [None]:
# Create k-mer tokenizer
k = 6
stride = 3  # 1 for overlapping, k for non-overlapping

kmer_tokenizer = KmerTokenizer(k=k, stride=stride)
kmer_tokenizer.build_vocab(sequences)

print(f"K-mer settings: k={k}, stride={stride}")
print(f"Vocabulary size: {len(kmer_tokenizer.vocab)}")
print(f"\nSpecial tokens: {kmer_tokenizer.special_tokens}")
print(f"\nSample k-mers from vocabulary:")
sample_kmers = list(kmer_tokenizer.vocab.keys())[8:18]
for kmer in sample_kmers:
    print(f"  {kmer}: {kmer_tokenizer.vocab[kmer]}")

# Test encoding/decoding
test_seq = sequences[0][:60]
encoded = kmer_tokenizer.encode(test_seq)
decoded = kmer_tokenizer.decode(encoded)

print(f"\nTest Sequence: {test_seq}")
print(f"Encoded IDs:   {encoded}")
print(f"Decoded:       {decoded}")
print(f"Match: {test_seq[:len(decoded)] == decoded}")

### K-mer Analysis

In [None]:
# Count k-mer frequencies
from collections import Counter

all_kmers = []
for seq in sequences:
    all_kmers.extend(kmer_tokenizer.sequence_to_kmers(seq))

kmer_counts = Counter(all_kmers)
top_kmers = kmer_counts.most_common(10)

print("Top 10 most frequent k-mers:")
for kmer, count in top_kmers:
    print(f"  {kmer}: {count:,} occurrences")

# Visualize k-mer distribution
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
counts = [count for _, count in top_kmers]
labels = [kmer for kmer, _ in top_kmers]
plt.bar(range(len(counts)), counts)
plt.xlabel('K-mer')
plt.ylabel('Frequency')
plt.title('Top 10 K-mer Frequencies')
plt.xticks(range(len(labels)), labels, rotation=45)

plt.subplot(1, 2, 2)
freq_distribution = list(kmer_counts.values())
plt.hist(freq_distribution, bins=50, edgecolor='black')
plt.xlabel('Frequency')
plt.ylabel('Number of K-mers')
plt.title('K-mer Frequency Distribution')
plt.yscale('log')

plt.tight_layout()
plt.show()

## 4. BPE Tokenization

BPE learns optimal subword units from the data.

In [None]:
# Create and train BPE tokenizer
bpe_tokenizer = DNABPETokenizer(vocab_size=1000, min_frequency=2)

print("Training BPE tokenizer...")
bpe_tokenizer.train(sequences)
print("Training complete!")

# Test encoding/decoding
test_seq = sequences[0][:60]
encoded_bpe = bpe_tokenizer.encode(test_seq)
decoded_bpe = bpe_tokenizer.decode(encoded_bpe)

print(f"\nTest Sequence: {test_seq}")
print(f"Encoded IDs:   {encoded_bpe}")
print(f"Decoded:       {decoded_bpe}")
print(f"Match: {test_seq == decoded_bpe}")

# Compare compression
print(f"\nCompression comparison:")
print(f"  Original length: {len(test_seq)} bases")
print(f"  K-mer tokens:    {len(encoded)} tokens")
print(f"  BPE tokens:      {len(encoded_bpe)} tokens")
print(f"  BPE compression: {len(test_seq) / len(encoded_bpe):.2f}x")

## 5. One-Hot Encoding

One-hot encoding represents each base as a 4-dimensional vector.

In [None]:
# Create one-hot encoder
onehot_encoder = OneHotEncoder(include_ambiguous=True)

# Test encoding
test_seq = "ATGCN"
encoded_onehot = onehot_encoder.encode(test_seq)

print("One-hot encoding example:")
print(f"Sequence: {test_seq}")
print(f"\nEncoded shape: {encoded_onehot.shape}")
print(f"\nEncoding (A, C, G, T):")
for i, base in enumerate(test_seq):
    print(f"  {base}: {encoded_onehot[i]}")

# Visualize one-hot encoding
plt.figure(figsize=(10, 3))
plt.imshow(encoded_onehot.T, cmap='Blues', aspect='auto')
plt.colorbar(label='Value')
plt.xlabel('Position')
plt.ylabel('Base (A, C, G, T)')
plt.title('One-Hot Encoding Visualization')
plt.yticks([0, 1, 2, 3], ['A', 'C', 'G', 'T'])
plt.xticks(range(len(test_seq)), list(test_seq))
plt.tight_layout()
plt.show()

# Test decoding
decoded_onehot = onehot_encoder.decode(encoded_onehot)
print(f"\nDecoded: {decoded_onehot}")
print(f"Match: {test_seq == decoded_onehot}")

## 6. Prepare Dataset and DataLoader

Choose your encoding method: 'kmer', 'bpe', or 'onehot'

In [None]:
# Configuration
ENCODING_TYPE = 'kmer'  # Change to 'bpe' or 'onehot' to try different methods
MAX_SEQ_LENGTH = 128
BATCH_SIZE = 16

# Choose tokenizer based on encoding type
if ENCODING_TYPE == 'kmer':
    tokenizer = kmer_tokenizer
elif ENCODING_TYPE == 'bpe':
    tokenizer = bpe_tokenizer
else:  # onehot
    tokenizer = None

# Split data
split_idx = int(0.8 * len(sequences))
train_sequences = sequences[:split_idx]
val_sequences = sequences[split_idx:]

print(f"Encoding method: {ENCODING_TYPE}")
print(f"Training sequences: {len(train_sequences)}")
print(f"Validation sequences: {len(val_sequences)}")

# Create datasets
train_dataset = DNASequenceDataset(
    train_sequences,
    tokenizer,
    max_length=MAX_SEQ_LENGTH,
    encoding_type=ENCODING_TYPE
)

val_dataset = DNASequenceDataset(
    val_sequences,
    tokenizer,
    max_length=MAX_SEQ_LENGTH,
    encoding_type=ENCODING_TYPE
)

# Create dataloaders
collate_fn = collate_fn_onehot if ENCODING_TYPE == 'onehot' else collate_fn_tokens

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn
)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

## 7. Create Transformer Model

In [None]:
# Model hyperparameters
D_MODEL = 128
NHEAD = 4
NUM_LAYERS = 2
DIM_FEEDFORWARD = 512
DROPOUT = 0.1

# Create model based on encoding type
if ENCODING_TYPE == 'onehot':
    model = DNATransformerOneHot(
        input_dim=4,
        d_model=D_MODEL,
        nhead=NHEAD,
        num_layers=NUM_LAYERS,
        dim_feedforward=DIM_FEEDFORWARD,
        dropout=DROPOUT,
        num_classes=4
    )
    vocab_size = 4
else:
    vocab_size = len(tokenizer.vocab)
    model = DNATransformerEncoder(
        vocab_size=vocab_size,
        d_model=D_MODEL,
        nhead=NHEAD,
        num_layers=NUM_LAYERS,
        dim_feedforward=DIM_FEEDFORWARD,
        dropout=DROPOUT,
        max_seq_length=MAX_SEQ_LENGTH
    )

model = model.to(device)

# Count parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model: {model.__class__.__name__}")
print(f"Vocabulary size: {vocab_size:,}")
print(f"Model parameters: {num_params:,}")
print(f"\nModel architecture:")
print(model)

## 8. Train the Model

In [None]:
# Training setup
NUM_EPOCHS = 10
LEARNING_RATE = 1e-3

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index=-100)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=2, verbose=True
)

# Training loop
train_losses = []
val_losses = []
best_val_loss = float('inf')

print(f"Training for {NUM_EPOCHS} epochs...")
print("=" * 70)

for epoch in range(1, NUM_EPOCHS + 1):
    print(f"\nEpoch {epoch}/{NUM_EPOCHS}")
    print("-" * 70)
    
    # Train
    train_loss = train_epoch(
        model, train_loader, optimizer, criterion, device, ENCODING_TYPE
    )
    train_losses.append(train_loss)
    
    # Validate
    val_loss = evaluate(
        model, val_loader, criterion, device, ENCODING_TYPE
    )
    val_losses.append(val_loss)
    
    # Learning rate scheduling
    scheduler.step(val_loss)
    
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss:   {val_loss:.4f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
        }, f'best_model_{ENCODING_TYPE}.pt')
        print("âœ“ Saved best model")

print("\n" + "=" * 70)
print(f"Training complete! Best validation loss: {best_val_loss:.4f}")

## 9. Visualize Training Results

In [None]:
# Plot training curves
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss', marker='o', linewidth=2)
plt.plot(val_losses, label='Val Loss', marker='s', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title(f'Training Curves - {ENCODING_TYPE.upper()} Encoding')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
epochs = range(1, len(train_losses) + 1)
overfitting = np.array(val_losses) - np.array(train_losses)
plt.plot(epochs, overfitting, marker='o', linewidth=2, color='red')
plt.axhline(y=0, color='black', linestyle='--', alpha=0.3)
plt.xlabel('Epoch')
plt.ylabel('Validation - Training Loss')
plt.title('Overfitting Metric')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'training_curves_{ENCODING_TYPE}.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"Plot saved: training_curves_{ENCODING_TYPE}.png")

## 10. Test Model Predictions

In [None]:
# Load best model
checkpoint = torch.load(f'best_model_{ENCODING_TYPE}.pt')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Test on a few sequences
test_sequences = val_sequences[:3]

print("Testing model predictions:")
print("=" * 70)

for i, test_seq in enumerate(test_sequences):
    test_seq = test_seq[:100]  # Use first 100 bases
    
    print(f"\nTest Sequence {i+1}:")
    print("-" * 70)
    
    if ENCODING_TYPE == 'onehot':
        # One-hot encoding
        encoder = OneHotEncoder()
        test_input = encoder.encode_tensor(test_seq).unsqueeze(0).to(device)
        
        with torch.no_grad():
            output = model(test_input)
        
        predicted = torch.argmax(output, dim=-1)[0]
        bases = ['A', 'C', 'G', 'T']
        predicted_seq = ''.join([bases[i] for i in predicted.cpu().numpy()])
    else:
        # Token-based encoding
        test_input = torch.tensor(tokenizer.encode(test_seq)).unsqueeze(0).to(device)
        
        with torch.no_grad():
            output = model(test_input)
        
        predicted = torch.argmax(output, dim=-1)[0]
        predicted_seq = tokenizer.decode(predicted.cpu().tolist())
    
    # Calculate accuracy
    min_len = min(len(test_seq), len(predicted_seq))
    matches = sum(1 for j in range(min_len) if test_seq[j] == predicted_seq[j])
    accuracy = matches / min_len
    
    print(f"Original:  {test_seq[:50]}...")
    print(f"Predicted: {predicted_seq[:50]}...")
    print(f"Accuracy:  {accuracy:.2%} ({matches}/{min_len} bases correct)")

## 11. Save Models and Tokenizers

In [None]:
# Save tokenizer/encoder
if ENCODING_TYPE == 'kmer':
    kmer_tokenizer.save_vocab(f'kmer_vocab_k{k}.json')
    print(f"Saved: kmer_vocab_k{k}.json")
elif ENCODING_TYPE == 'bpe':
    bpe_tokenizer.save(f'bpe_tokenizer.json')
    print(f"Saved: bpe_tokenizer.json")

print(f"Model checkpoint: best_model_{ENCODING_TYPE}.pt")
print("\nAll files saved successfully!")

## 12. Summary and Comparison

### Encoding Method Comparison

| Method | Pros | Cons | Best For |
|--------|------|------|----------|
| **K-mer** | Fast, interpretable, good for motifs | Fixed vocabulary, may miss long-range patterns | Genomic sequences, motif discovery |
| **BPE** | Learns optimal units, flexible vocabulary | Requires training, less interpretable | Large diverse datasets, compression |
| **One-hot** | Simple, preserves base information | High memory, no compression | Short sequences, base-level predictions |

### Next Steps

1. **Experiment with different k values** (3, 6, 9) for k-mer encoding
2. **Try different stride values** to balance sequence length and information
3. **Increase model size** (d_model, num_layers) for better performance
4. **Train on larger datasets** for improved generalization
5. **Fine-tune for specific tasks** (e.g., promoter prediction, variant calling)
