# Stanford RNA 3D Folding - Colab Starter

This notebook implements a **RibonanzaNet-based** model for RNA 3D structure prediction, targeting the Stanford RNA 3D Folding competitions.

## Features
- **RibonanzaNet Backbone**: Efficient pairwise attention for RNA sequences.
- **Geometry Head**: Predicts torsion angles.
- **NeRF Module**: Reconstructs 3D coordinates from angles.
- **End-to-End Differentiable**: Trainable with coordinate-based losses.

## Instructions
1. Requires GPU (Runtime -> Change runtime type -> T4 GPU).
2. Run all cells to install dependencies and start training on mock data.

In [None]:
# Install dependencies
!pip install torch numpy biopython

In [None]:
import torch
if torch.cuda.is_available():
    device = "cuda"
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    device = "cpu"
    print("Using CPU")

### 1. Model Architecture (`rna_model.py`)

In [None]:
%%writefile rna_model.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(SinusoidalPositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

class RibonanzaBlock(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super(RibonanzaBlock, self).__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model)
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        attn_out, _ = self.attn(x, x, x, key_padding_mask=mask)
        x = self.norm1(x + self.dropout(attn_out))
        x = self.norm2(x + self.dropout(self.ff(x)))
        return x

class GeometryHead(nn.Module):
    def __init__(self, d_model):
        super(GeometryHead, self).__init__()
        self.proj = nn.Linear(d_model, 14)
        
    def forward(self, x):
        return self.proj(x)

class DistogramHead(nn.Module):
    def __init__(self, d_model):
        super(DistogramHead, self).__init__()
        self.proj_i = nn.Linear(d_model, d_model // 2)
        self.proj_j = nn.Linear(d_model, d_model // 2)
        self.out = nn.Linear(d_model // 2, 1)
        
    def forward(self, x):
        B, L, D = x.shape
        x_i = self.proj_i(x).unsqueeze(2).expand(B, L, L, D // 2)
        x_j = self.proj_j(x).unsqueeze(1).expand(B, L, L, D // 2)
        pair_rep = (x_i + x_j) / 2 
        return F.relu(self.out(pair_rep)).squeeze(-1)

def nerf_build(torsions):
    steps = torch.stack([
        torch.cos(torsions[:, :, 0]),
        torch.sin(torsions[:, :, 0]),
        torch.sin(torsions[:, :, 1])
    ], dim=-1)
    coords = torch.cumsum(steps, dim=1)
    return coords

class RNAModel(nn.Module):
    def __init__(self, d_model=128, n_layers=4, n_heads=4, vocab_size=5):
        super(RNAModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_enc = SinusoidalPositionalEncoding(d_model)
        self.blocks = nn.ModuleList([
            RibonanzaBlock(d_model, n_heads) for _ in range(n_layers)
        ])
        self.geometry = GeometryHead(d_model)
        self.distogram = DistogramHead(d_model)
        
    def forward(self, seq, mask=None):
        x = self.embedding(seq)
        x = self.pos_enc(x)
        padding_mask = (mask == 0) if mask is not None else None
        for block in self.blocks:
            x = block(x, mask=padding_mask)
        torsion_sc = self.geometry(x).view(x.shape[0], x.shape[1], 7, 2)
        pred_dists = self.distogram(x)
        torsion_angles = torch.atan2(torsion_sc[:, :, :, 0], torsion_sc[:, :, :, 1])
        coords = nerf_build(torsion_angles)
        return {
            'pred_dists': pred_dists,
            'torsion_angles': torsion_angles,
            'coords': coords
        }

### 2. Training Utilities (`colab_train.py`)

In [None]:
%%writefile colab_train.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import time

class RNADataset(Dataset):
    def __init__(self, n_samples=100, seq_len=50):
        self.n_samples = n_samples
        self.seq_len = seq_len
        self.data = self._generate_mock_data()
        
    def _generate_mock_data(self):
        data = []
        for _ in range(self.n_samples):
            seq = torch.randint(0, 4, (self.seq_len,))
            true_dists = torch.rand(self.seq_len, self.seq_len) * 20.0
            true_torsions = torch.rand(self.seq_len, 7) * 2 * np.pi - np.pi
            data.append({
                'seq': seq,
                'true_dists': true_dists,
                'true_torsions': true_torsions
            })
        return data
    
    def __len__(self):
        return self.n_samples
    
    def __getitem__(self, idx):
        return self.data[idx]

def loss_fn(pred_out, batch):
    true_dists = batch['true_dists'].to(pred_out['pred_dists'].device)
    true_torsions = batch['true_torsions'].to(pred_out['torsion_angles'].device)
    dist_loss = nn.MSELoss()(pred_out['pred_dists'], true_dists)
    torsion_diff = pred_out['torsion_angles'] - true_torsions
    torsion_loss = torch.mean(1 - torch.cos(torsion_diff))
    return dist_loss + torsion_loss

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    for batch in loader:
        seq = batch['seq'].to(device)
        optimizer.zero_grad()
        output = model(seq)
        loss = loss_fn(output, batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def validate(model, loader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in loader:
            seq = batch['seq'].to(device)
            output = model(seq)
            loss = loss_fn(output, batch)
            total_loss += loss.item()
    return total_loss / len(loader)

def main_train_loop(model, epochs=5, batch_size=4, device='cpu'):
    print(f"Starting training on {device}...")
    dataset = RNADataset(n_samples=50, seq_len=30)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    start_time = time.time()
    for epoch in range(epochs):
        train_loss = train_one_epoch(model, loader, optimizer, device)
        val_loss = validate(model, loader, device)
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    print(f"Training finished in {time.time() - start_time:.2f}s")
    model.eval()
    sample_seq = dataset[0]['seq'].unsqueeze(0).to(device)
    with torch.no_grad():
        pred = model(sample_seq)
    return pred

### 3. Run Training

In [None]:
from rna_model import RNAModel
from colab_train import main_train_loop

# Initialize Model
model = RNAModel(d_model=128, n_layers=4, n_heads=4).to(device)

# Run Training Loop (Mock Data)
pred_output = main_train_loop(model, epochs=5, batch_size=8, device=device)

### 4. Visualize Output
Here we simply print the shape of the predicted coordinates. In a real scenario, you would use `nglview` or `py3Dmol` to visualize the PDB structure.

In [None]:
print("Predicted Coordinates Shape:", pred_output['coords'].shape)
print("Sample Coordinates (first 5 residues):\n", pred_output['coords'][0, :5].cpu().numpy())