# Transformers in Biomedicine: From Clinical Language to Genomics

**An Interactive Notebook Based on the Lecture by Vivek Natarajan, Google Health AI**

## Section 1: Overview & Prerequisites

### Summary of the Lecture Topic

This notebook explores the transformative impact of Transformer architectures and Large Language Models (LLMs) on the field of biomedicine, based on a lecture by Vivek Natarajan. The core thesis is that sequences are a ubiquitous data modality in biomedicine—from clinical notes and electronic health records to proteins (sequences of amino acids) and genomes (sequences of nucleotides). Transformers, with their ability to model complex, long-range dependencies, represent a powerful tool for this domain.

We will delve into several key research papers discussed in the lecture, covering a spectrum of applications:
1.  **Clinical Applications:** How LLMs like Med-PaLM are being aligned for medical question answering, requiring innovations in benchmarks (`MultiMedQA`), evaluation frameworks, and alignment techniques (Instruction Prompt Tuning).
2.  **Proteomics:** How efficient Transformer architectures like the Performer can model long protein sequences, and how models like ProtNLM can annotate protein functions at a massive scale.
3.  **Genomics:** How Transformers are used in `DeepConsensus` to improve the accuracy of DNA sequencing and in `Enformer` to predict gene expression by modeling long-range interactions in the genome.

### Prerequisite Knowledge

#### Mathematical Concepts
- **Linear Algebra:** Vector and matrix operations, particularly the dot product, which is the foundation of self-attention.
- **Probability & Statistics:** Understanding of the Softmax function for converting scores into a probability distribution.
- **Calculus:** Basic understanding of gradients for model training (conceptual).
- **Low-Rank Matrix Approximation:** Conceptual understanding for the Performer model.

#### Machine Learning & Computer Science Concepts
- **Neural Networks:** Fundamentals of deep learning and backpropagation.
- **Sequence Models:** Basic knowledge of why models like RNNs struggle with long sequences (vanishing gradients).
- **The Transformer Architecture:** Self-Attention, Multi-Head Attention, Positional Encodings, Encoder-Decoder structure.
- **Large Language Models (LLMs):** The pre-training and fine-tuning paradigm.
- **Prompting Techniques:** Few-shot prompting, Chain of Thought, Self-Consistency (conceptual).
- **Parameter-Efficient Fine-Tuning (PEFT):** Specifically, the concept of Prompt Tuning.
- **Convolutional Neural Networks (CNNs):** Understanding their use in sequence modeling and the concept of a receptive field.

### Hierarchy of Topics

1.  **Mathematical Foundations:** We'll start by implementing the core mechanism of Transformers—the self-attention mechanism—from scratch.
2.  **Prerequisite Algorithms:** We'll build a basic Transformer block and review CNNs for sequence data to set the stage for later models.
3.  **Core Research: Clinical Applications (Med-PaLM):** We will explore Instruction Prompt Tuning, the technique used to align a general LLM to the medical domain.
4.  **Core Research: Proteomics (Performer & ProtNLM):** We'll implement the core idea behind the Performer's efficient attention and discuss the T5-based approach of ProtNLM.
5.  **Core Research: Genomics (Enformer):** We'll build a simplified model to demonstrate the power of Transformers over CNNs for capturing long-range genomic interactions.
6.  **Experimental Analysis:** We will reproduce toy versions of the experiments discussed, such as evaluating model answers and comparing the performance of CNNs vs. Transformers on synthetic genomic data.
7.  **Research Context & Extensions:** We'll conclude by summarizing the key takeaways and future directions outlined in the lecture.

### Learning Objectives
- **Understand** why Transformers are exceptionally well-suited for diverse biomedical data.
- **Implement** the self-attention mechanism from scratch to grasp its inner workings.
- **Grasp** the concept of Instruction Prompt Tuning for domain-specific model alignment.
- **Appreciate** the architectural innovations required to apply Transformers to long sequences in genomics and proteomics.
- **Reproduce** the core experimental findings that demonstrate the superiority of these models in their respective domains.

**Estimated Time:** 2-3 hours.

## Section 2: Mathematical Foundations

### The Heart of the Transformer: Self-Attention

The core innovation of the Transformer is the self-attention mechanism. It allows the model to weigh the importance of different words (or tokens, amino acids, nucleotides) in the input sequence when processing a specific word. It computes a representation of a token by relating it to all other tokens in the sequence.

The formula is:
$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$

Where:
- $Q$ (Query), $K$ (Key), and $V$ (Value) are matrices derived from the input embeddings.
- $d_k$ is the dimension of the key vectors, used for scaling.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import ipywidgets as widgets
from IPython.display import display, clear_output

# Set seed for reproducibility
torch.manual_seed(42)

def educational_self_attention(x, d_model, d_k):
    """
    Clear implementation of self-attention for understanding.
    - Based directly on the mathematical formula.
    - Extensive comments explaining each step.
    - No black-box libraries for the core logic.
    
    Args:
        x (Tensor): Input tensor of shape (seq_len, d_model).
        d_model (int): The embedding dimension.
        d_k (int): The dimension for Key/Query vectors.
    """
    seq_len = x.shape[0]
    
    # 1. Create linear projections for Query, Key, and Value
    # In a real model, these would be learned weight matrices (nn.Linear)
    W_q = torch.randn(d_model, d_k)
    W_k = torch.randn(d_model, d_k)
    W_v = torch.randn(d_model, d_k) # d_v is often same as d_k
    
    # 2. Project the input into Query, Key, and Value spaces
    Q = x @ W_q  # (seq_len, d_k)
    K = x @ W_k  # (seq_len, d_k)
    V = x @ W_v  # (seq_len, d_k)
    
    # 3. Calculate attention scores: Q * K^T
    # This measures the similarity between each query and all keys.
    scores = Q @ K.T  # (seq_len, seq_len)
    
    # 4. Scale the scores to stabilize gradients
    scaled_scores = scores / np.sqrt(d_k)
    
    # 5. Apply softmax to get attention weights (probabilities)
    # This makes the weights for each token sum to 1.
    attention_weights = F.softmax(scaled_scores, dim=-1)
    
    # 6. Apply attention weights to the Value vectors
    # This creates a weighted sum of values, where the weights are the attention scores.
    output = attention_weights @ V  # (seq_len, d_k)
    
    return output, attention_weights

# --- Example Usage ---
seq_len = 5 # e.g., 5 amino acids in a protein sequence
d_model = 32 # embedding dimension for each amino acid
d_k = 16 # dimension of Q, K, V projections

# Create a dummy input sequence (e.g., embeddings for "A C G T C")
input_sequence = torch.randn(seq_len, d_model)

output, weights = educational_self_attention(input_sequence, d_model, d_k)

print("Input Shape:", input_sequence.shape)
print("Output Shape:", output.shape)
print("Attention Weights Shape:", weights.shape)

def interactive_attention_explorer():
    """
    Interactive widget to visualize attention weights.
    """
    plt.figure(figsize=(6, 6))
    sns.heatmap(weights.detach().numpy(), annot=True, cmap='viridis', xticklabels=range(seq_len), yticklabels=range(seq_len))
    plt.title("Self-Attention Weights")
    plt.xlabel("Key Positions")
    plt.ylabel("Query Positions")
    plt.show()

interactive_attention_explorer()
print("\nEach row shows how much the model 'attends' to every other token when processing the token at that row's index.")

## Section 3: Prerequisite Algorithms

### Prerequisite 1: A Basic Transformer Block

A full Transformer is built by stacking several blocks. A single block typically contains:
1.  **Multi-Head Attention:** An evolution of self-attention where the attention mechanism is run multiple times in parallel with different, learned linear projections. This allows the model to jointly attend to information from different representation subspaces.
2.  **Add & Norm (Layer Normalization):** A residual connection followed by layer normalization, which helps stabilize training.
3.  **Feed-Forward Network:** A simple position-wise fully connected network applied to each position separately and identically.
4.  **Add & Norm:** Another residual connection and layer normalization.

In [None]:
class EducationalTransformerBlock(nn.Module):
    """
    Clear implementation of a Transformer block for understanding.
    - Based directly on the "Attention is All You Need" paper.
    - Uses PyTorch's optimized layers but shows how they connect.
    """
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(EducationalTransformerBlock, self).__init__()
        
        # Multi-Head Attention Layer
        self.multi_head_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, dropout=dropout, batch_first=True)
        
        # Feed-Forward Network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        
        # Layer Normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Dropout for regularization
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, src):
        # 1. Multi-Head Attention followed by Add & Norm
        attn_output, _ = self.multi_head_attn(src, src, src) # Q, K, V are all the same 'src'
        src = src + self.dropout1(attn_output) # Residual connection
        src = self.norm1(src) # Layer Normalization
        
        # 2. Feed-Forward Network followed by Add & Norm
        ffn_output = self.ffn(src)
        src = src + self.dropout2(ffn_output) # Residual connection
        src = self.norm2(src) # Layer Normalization
        
        return src

# --- Example Usage ---
batch_size = 1
seq_len = 10 # 10 tokens
d_model = 512 # Model dimension
num_heads = 8 # Number of attention heads
d_ff = 2048 # Dimension of the feed-forward layer

transformer_block = EducationalTransformerBlock(d_model, num_heads, d_ff)

# Input needs to be (batch_size, seq_len, d_model)
dummy_input = torch.randn(batch_size, seq_len, d_model)
output = transformer_block(dummy_input)

print("Input Shape:", dummy_input.shape)
print("Output Shape:", output.shape)

### Prerequisite 2: CNNs for Sequence Data

Before Transformers, CNNs were often used for sequence tasks. They use kernels (filters) that slide across the sequence, capturing local patterns. Their primary limitation, especially relevant for the `Enformer` paper, is their **fixed and local receptive field**. A deep stack of CNN layers is required to see tokens that are far apart, making it difficult to model long-range dependencies efficiently.

In [None]:
def visualize_receptive_field():
    """Visualize the limited receptive field of a CNN vs a Transformer."""
    
    # A CNN with a kernel size of 3 only sees its immediate neighbors.
    cnn_receptive_field = np.zeros((10, 10))
    for i in range(10):
        cnn_receptive_field[i, max(0, i-1):min(10, i+2)] = 1
        
    # A Transformer's self-attention can see every other token from the start.
    transformer_receptive_field = np.ones((10, 10))
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    sns.heatmap(cnn_receptive_field, ax=ax1, cbar=False, cmap='Reds', linewidths=.5)
    ax1.set_title("CNN Receptive Field (1 Layer, Kernel=3)")
    ax1.set_xlabel("Input Sequence")
    ax1.set_ylabel("Output Neuron")
    
    sns.heatmap(transformer_receptive_field, ax=ax2, cbar=False, cmap='Blues', linewidths=.5)
    ax2.set_title("Transformer Receptive Field (1 Layer)")
    ax2.set_xlabel("Input Sequence")
    ax2.set_ylabel("Output Neuron")
    
    plt.suptitle("Receptive Field Comparison")
    plt.show()

visualize_receptive_field()

## Section 4: Core Research Content

### 4.1 Med-PaLM: Aligning LLMs with Instruction Prompt Tuning

A key challenge highlighted in the lecture was that general-purpose LLMs like FLAN-PaLM, despite encoding significant medical knowledge, are not directly usable in clinical settings due to issues like hallucination and a lack of caution. The solution presented was **Instruction Prompt Tuning**, a parameter-efficient method.

**How it works:**
1.  The large base LLM (e.g., FLAN-PaLM) is **frozen**. Its billions of parameters are not updated.
2.  A small set of new, learnable vectors (the "soft prompt" or "prompt embedding") are added.
3.  During training, only these few prompt vectors are updated using a small, high-quality dataset of expert-written examples (e.g., medical questions with ideal answers).
4.  At inference time, these learned prompt vectors are prepended to the actual user input, guiding the frozen LLM to generate responses in the desired style (e.g., safe, informative, cautious).

This is highly efficient in terms of both data and computation compared to full fine-tuning.

In [None]:
class EducationalPromptTunedModel(nn.Module):
    """
    A simplified model to demonstrate the concept of prompt tuning.
    """
    def __init__(self, frozen_llm, prompt_len=10, d_model=512):
        super().__init__()
        self.frozen_llm = frozen_llm
        self.prompt_len = prompt_len
        
        # The ONLY trainable part of the model
        self.soft_prompt = nn.Parameter(torch.randn(1, prompt_len, d_model))
        
        # Freeze the base LLM
        for param in self.frozen_llm.parameters():
            param.requires_grad = False

    def forward(self, input_embeddings):
        # Prepend the learned soft prompt to the input embeddings
        batch_size = input_embeddings.shape[0]
        # Repeat the prompt for each item in the batch
        prompt = self.soft_prompt.expand(batch_size, -1, -1)
        
        # Concatenate the soft prompt with the actual input
        combined_input = torch.cat([prompt, input_embeddings], dim=1)
        
        # Pass the combined input through the frozen LLM
        output = self.frozen_llm(combined_input)
        return output

# --- Example Usage ---
# 1. Create a "frozen" LLM (a simple Transformer block in our case)
frozen_llm = EducationalTransformerBlock(d_model=512, num_heads=8, d_ff=2048)

# 2. Create our prompt-tuned model
med_palm_concept_model = EducationalPromptTunedModel(frozen_llm, prompt_len=20, d_model=512)

# 3. Check which parameters are trainable
total_params = 0
trainable_params = 0
for name, param in med_palm_concept_model.named_parameters():
    total_params += param.numel()
    if param.requires_grad:
        print(f"Trainable parameter: {name} with size {param.shape}")
        trainable_params += param.numel()

print(f"\nTotal Parameters: {total_params}")
print(f"Trainable Parameters: {trainable_params}")
print(f"Percentage of Trainable Params: {100 * trainable_params / total_params:.4f}%")

# This demonstrates the extreme parameter efficiency of the method.

### 4.2 Performer: Efficient Attention with Low-Rank Approximation

For long biological sequences like proteins, the quadratic complexity of the $QK^T$ matrix multiplication in self-attention is a bottleneck. The Performer paper introduces a way to approximate the softmax attention kernel using a low-rank decomposition. 

Instead of computing the full $(N \times N)$ attention matrix, it approximates it by mapping $Q$ and $K$ to a lower-dimensional randomized feature space. This changes the computation from $O(N^2 d)$ to $O(N r d)$, where $N$ is sequence length, $d$ is the feature dimension, and $r$ is the much smaller dimension of the random features ($r \ll N$). This makes the computation linear in sequence length.

The core idea relies on the fact that the softmax attention kernel can be expressed as: 
$$ \text{Attention}(q, K, v)_i = E_{\omega \sim D}[\phi(q_i)^T\phi(K)^T v] $$
where $\phi$ is a feature map based on random projections.

In [None]:
def educational_performer_attention(Q, K, V, num_random_features):
    """
    Educational implementation of the Performer's core idea.
    - Approximates softmax attention with random feature maps.
    - Avoids explicit computation of the (N x N) attention matrix.
    """
    seq_len, d_k = Q.shape
    
    # 1. Create a random projection matrix
    # In the actual paper, this is drawn from a specific distribution.
    random_projection_matrix = torch.randn(d_k, num_random_features)
    
    # 2. Define the feature map phi(x). Here we use a simplified version.
    # The paper uses exp(xW) with cos and sin features.
    def feature_map(x):
        # Project to the random feature space
        projected = x @ random_projection_matrix
        # A non-linearity to approximate the kernel
        return F.relu(projected) / np.sqrt(num_random_features)

    # 3. Map Q and K to the lower-dimensional space
    Q_prime = feature_map(Q) # (seq_len, num_random_features)
    K_prime = feature_map(K) # (seq_len, num_random_features)
    
    # 4. Compute the attention output without forming the QK^T matrix
    # The key insight is to change the order of matrix multiplication
    # Standard: (Q @ K.T) @ V
    # Performer: Q @ (K.T @ V)
    # This is possible because (A B) C = A (B C)
    kv_product = K_prime.T @ V # (num_random_features, d_k)
    output = Q_prime @ kv_product # (seq_len, d_k)
    
    return output

# --- Comparison ---
seq_len = 1024 # A longer sequence where the difference matters
d_model = 64
d_k = 32
num_random_features = 128 # r << seq_len

input_long_seq = torch.randn(seq_len, d_model)
W_q = torch.randn(d_model, d_k)
W_k = torch.randn(d_model, d_k)
W_v = torch.randn(d_model, d_k)
Q = input_long_seq @ W_q
K = input_long_seq @ W_k
V = input_long_seq @ W_v

print(f"Sequence Length (N): {seq_len}")
print(f"Random Features (r): {num_random_features}")
print("\n--- Theoretical Complexity ---")
print(f"Standard Attention: O(N^2 * d) = O({seq_len**2 * d_k}) = O({seq_len**2 * d_k:,})")
print(f"Performer Attention: O(N * r * d) = O({seq_len * num_random_features * d_k}) = O({seq_len * num_random_features * d_k:,})")

performer_output = educational_performer_attention(Q, K, V, num_random_features)
print("\nPerformer Output Shape:", performer_output.shape)

### 4.3 Enformer: Predicting Gene Expression with Transformers

The Enformer model was designed to predict gene expression from DNA sequences, a task that requires modeling extremely long-range interactions (up to 100,000 base pairs away). As we saw earlier, CNNs are ill-suited for this. 

The Enformer architecture combines the strengths of both models:
1.  **CNN Stem:** A few layers of CNNs at the beginning to learn local patterns and downsample the very long input sequence, making it computationally tractable for the Transformer layers.
2.  **Transformer Body:** A stack of Transformer blocks that can then model the global, long-range interactions between the features extracted by the CNN stem.

This hybrid approach allows the model to capture the influence of distant *enhancer* regions on gene *promoters*.

In [None]:
class EducationalEnformer(nn.Module):
    """
    A simplified, educational version of the Enformer architecture.
    """
    def __init__(self, d_model=128, n_heads=4, n_layers=2, d_ff=512):
        super().__init__()
        # In the real paper, input is one-hot encoded DNA (4 channels)
        # Here, we'll assume a single channel for simplicity.
        
        # 1. CNN Stem to downsample and extract local features
        self.cnn_stem = nn.Sequential(
            # Conv -> Pool -> Conv -> Pool
            nn.Conv1d(in_channels=1, out_channels=d_model//2, kernel_size=15, padding='same'),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2),
            nn.Conv1d(in_channels=d_model//2, out_channels=d_model, kernel_size=15, padding='same'),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2),
        )
        
        # 2. Transformer Body to model long-range interactions
        transformer_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=n_heads, 
            dim_feedforward=d_ff,
            batch_first=True
        )
        self.transformer_body = nn.TransformerEncoder(transformer_layer, num_layers=n_layers)
        
        # 3. Final head to predict gene expression
        self.prediction_head = nn.Linear(d_model, 1)

    def forward(self, dna_sequence):
        # Input shape: (batch_size, channels, seq_len)
        
        # Pass through CNN stem
        x = self.cnn_stem(dna_sequence)
        
        # Reshape for Transformer: (batch_size, seq_len, d_model)
        x = x.permute(0, 2, 1)
        
        # Pass through Transformer body
        x = self.transformer_body(x)
        
        # Use the representation of the central part of the sequence for prediction
        # (A common technique in genomics models)
        center_idx = x.shape[1] // 2
        gene_representation = x[:, center_idx, :]
        
        # Final prediction
        prediction = self.prediction_head(gene_representation)
        return prediction

# --- Example Usage ---
batch_size = 4
dna_length = 4096 # A much shorter sequence than the real paper (200k)

model = EducationalEnformer()
# Input needs shape (batch, channels, length)
dummy_dna = torch.randn(batch_size, 1, dna_length)
output = model(dummy_dna)

print("Input DNA sequence shape:", dummy_dna.shape)
print("Final gene expression prediction shape:", output.shape)

## Section 5: Experimental Analysis

### 5.1 Reproducing the Med-PaLM Evaluation Concept

The Med-PaLM paper emphasized that standard metrics like accuracy are insufficient for medical applications. They developed a human evaluation framework where clinicians and lay users rated model responses on multiple axes. We can simulate this process interactively.

Below, we compare a hypothetical "unaligned" FLAN-PaLM response with an "aligned" Med-PaLM response. Use the sliders to rate each response according to the criteria.

In [None]:
def interactive_medical_qa_evaluation():
    """
    Interactive widgets to simulate the evaluation of medical QA responses.
    """
    question = "**User Question:** I have a persistent cough and occasional shortness of breath. What could it be?"
    
    flan_palm_answer = ("**FLAN-PaLM (Unaligned) Style Answer:**\n" 
                        "You have bronchitis. Take cough syrup.")
    
    med_palm_answer = ("**Med-PaLM (Aligned) Style Answer:**\n" 
                       "A persistent cough and shortness of breath can have several potential causes, ranging from common respiratory infections to more serious conditions. It is not possible to provide a diagnosis without a full medical evaluation. Common causes include bronchitis, asthma, or even acid reflux. However, it's very important to see a healthcare professional to rule out other possibilities. They can perform a physical examination and may recommend further tests if needed. You should consult a doctor for an accurate diagnosis and treatment plan.")

    print(question)
    print("-"*50)
    print(flan_palm_answer)
    print("\nRate the FLAN-PaLM Answer:")
    widgets.IntSlider(description='Factual?', min=1, max=5, value=2)
    widgets.IntSlider(description='Helpful?', min=1, max=5, value=2)
    widgets.IntSlider(description='Safe?', min=1, max=5, value=1)

    # Create widgets for FLAN-PaLM ratings
    style = {'description_width': 'initial'}
    fp_factuality = widgets.IntSlider(description='Factuality (1-5)', min=1, max=5, value=2, style=style)
    fp_helpfulness = widgets.IntSlider(description='Helpfulness (1-5)', min=1, max=5, value=2, style=style)
    fp_safety = widgets.IntSlider(description='Potential for Harm (1=High, 5=Low)', min=1, max=5, value=1, style=style)
    
    # Create widgets for Med-PaLM ratings
    mp_factuality = widgets.IntSlider(description='Factuality (1-5)', min=1, max=5, value=5, style=style)
    mp_helpfulness = widgets.IntSlider(description='Helpfulness (1-5)', min=1, max=5, value=5, style=style)
    mp_safety = widgets.IntSlider(description='Potential for Harm (1=High, 5=Low)', min=1, max=5, value=5, style=style)

    # Display everything
    display(fp_factuality, fp_helpfulness, fp_safety)
    
    print("\n" + "-"*50)
    print(med_palm_answer)
    print("\nRate the Med-PaLM Answer:")
    display(mp_factuality, mp_helpfulness, mp_safety)
    
    print("\nThis simulation highlights the qualitative gap that alignment techniques like Instruction Prompt Tuning aim to close.")

interactive_medical_qa_evaluation()

### 5.2 Enformer Experiment: Comparing CNN vs. Transformer on a Toy Genomic Task

We'll create a synthetic dataset to demonstrate the core finding of the Enformer paper: Transformers can model long-range dependencies that CNNs cannot.

**Task:** Predict a binary output (gene expression: ON/OFF).
**Rule:** The output is ON (1) if and only if a specific sequence pattern ("promoter") is present at the center of the sequence AND another specific pattern ("enhancer") is present far away. Otherwise, the output is OFF (0).

A CNN with a small receptive field will struggle to see both patterns simultaneously, while a Transformer should succeed.

In [None]:
def create_synthetic_genome_data(num_samples, seq_len, promoter, enhancer):
    """Creates a synthetic dataset for the Enformer experiment."""
    X = torch.randint(0, 4, (num_samples, seq_len)) # 4 bases: A, C, G, T
    y = torch.zeros(num_samples, 1)
    
    center = seq_len // 2
    enhancer_pos = 50
    
    for i in range(num_samples):
        has_promoter = (np.random.rand() > 0.5)
        has_enhancer = (np.random.rand() > 0.5)
        
        if has_promoter:
            X[i, center:center+len(promoter)] = torch.tensor(promoter)
        if has_enhancer:
            X[i, enhancer_pos:enhancer_pos+len(enhancer)] = torch.tensor(enhancer)
            
        if has_promoter and has_enhancer:
            y[i] = 1.0
    
    # Convert to one-hot encoding for models
    X_onehot = F.one_hot(X, num_classes=4).float().permute(0, 2, 1)
    return X_onehot, y

# Simple CNN model for baseline comparison
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(4, 16, 7, padding='same')
        self.pool = nn.MaxPool1d(4)
        self.conv2 = nn.Conv1d(16, 32, 7, padding='same')
        self.fc = nn.Linear(32 * (2048//16), 1) # Length is hardcoded after pooling

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        return self.fc(x)

# --- Training Loop ---
def train_model(model, data, labels, epochs=20):
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    # Evaluate
    with torch.no_grad():
        preds = torch.sigmoid(model(data)) > 0.5
        accuracy = (preds.float() == labels).float().mean().item()
    return accuracy

print("Running toy Enformer experiment... This may take a minute.")
# Create data
seq_len = 2048
promoter = [0, 1, 2, 3] # ACGT
enhancer = [3, 2, 1, 0] # TGCA
X, y = create_synthetic_genome_data(512, seq_len, promoter, enhancer)

# Train CNN
cnn_model = SimpleCNN()
cnn_acc = train_model(cnn_model, X, y, epochs=50)

# Train Transformer (simplified Enformer)
enformer_toy_model = EducationalEnformer(d_model=32, n_heads=4, n_layers=1, d_ff=128)
# The CNN stem in EducationalEnformer takes 1 channel, so we average the one-hot
X_enformer = X.mean(axis=1, keepdim=True)
enformer_acc = train_model(enformer_toy_model, X_enformer, y, epochs=50)

print(f"\n--- Results ---")
print(f"Simple CNN Accuracy: {cnn_acc*100:.2f}% (Should be near random chance ~50-75%)")
print(f"Simplified Enformer Accuracy: {enformer_acc*100:.2f}% (Should be much higher)")
print("\nConclusion: The Transformer's global attention allows it to learn the long-range dependency, while the CNN fails.")

## Section 6: Research Context & Extensions

### Research Contribution in Context

The lecture by Vivek Natarajan positions these works as part of a larger trend: the convergence of AI methodologies around the Transformer architecture. 

- **Med-PaLM** builds on the foundation of general-purpose LLMs (PaLM, FLAN-PaLM) and adapts them to a specialized, high-stakes domain. Its contribution is less about architectural novelty and more about the crucial aspects of **data, evaluation, and safe alignment**.
- **Performer and Enformer** are examples of architectural innovation driven by the specific constraints of biological data. They tackle the problem of **long sequences and long-range dependencies**, pushing the boundaries of what Transformers can efficiently process.
- **ProtNLM and DeepConsensus** showcase the direct, practical application of established Transformer models (T5, standard encoder) to solve high-impact problems in protein annotation and genomics, demonstrating the versatility of the architecture.

Together, they illustrate that applying Transformers to biomedicine requires a holistic approach, innovating not just on models, but also on data curation, evaluation frameworks, and loss functions tailored to the domain.

### Current Research Directions Mentioned
The lecture concluded by highlighting several key areas for future research:

- **Multimodality:** The ultimate goal is to build foundational models that can process the full spectrum of biomedical data—text, genomics, proteomics, medical imaging—within a single, unified framework.
- **Data Scarcity and Privacy:** Medical datasets are often small and siloed due to privacy regulations. Techniques like **federated learning and evaluation** will be critical to train and validate models without centralizing sensitive data.
- **Improved Uncertainty & Reliability:** For clinical use, models must be able to reliably communicate when they are uncertain. Research into better methods for uncertainty quantification and the ability to **defer to an expert** is paramount.
- **Retrieval Augmented Models:** Enhancing LLMs with the ability to retrieve information from authoritative, up-to-date sources (like medical textbooks or recent research papers) to reduce hallucination and provide citable evidence for their answers.
- **Generalist vs. Specialist Models:** An ongoing debate is whether large, general-purpose LLMs fine-tuned for medicine will outperform smaller, specialist models trained from scratch on domain-specific data. The answer likely involves a combination of both approaches.

### Practical Applications Discussed
The research presented is not merely academic; it points towards tangible real-world impact:

- **Clinical Workflow Automation:** Near-term applications of models like Med-PaLM will likely focus on augmenting physicians by automating tasks like generating clinical note summaries, drafting insurance letters, or converting complex medical jargon into patient-friendly language.
- **Accelerating Scientific Discovery:** Tools like ProtNLM can annotate millions of uncharacterized proteins, creating a massive, searchable database that could accelerate research in areas like drug discovery. Enformer-like models can prioritize genetic variants for experimental validation, saving time and resources.
- **Rapid Diagnostics:** The use of DeepConsensus in a record-setting rapid genome sequencing case demonstrates the potential for AI to have a direct impact on patient outcomes by enabling faster, more accurate diagnoses for genetic conditions.