In [58]:
import pandas as pd
import numpy as np
import torch
from ndlinear import NdLinear
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

nt_to_idx = {'A': 0, 'U': 1, 'G': 2, 'C': 3, 'N': 4}


In [59]:
class RNAStruct(Dataset):
    def __init__(self, seq_csv, labels_csv):
        seq_df = pd.read_csv(seq_csv)
        seq_map = {row['target_id']: row['sequence'].strip() for _, row in seq_df.iterrows()}

        labels_df = pd.read_csv(labels_csv)
        self.data = {}
        for _, row in labels_df.iterrows():
            target_id = row['ID'].rsplit('_', 1)[0]
            idx = int(row['resid']) - 1
            if target_id not in self.data:
                self.data[target_id] = []
            self.data[target_id].append((idx, row['resname'], row['x_1'], row['y_1'], row['z_1']))

        self.samples = []
        all_coords = []
        for target_id in self.data:
            dat = sorted(self.data[target_id], key=lambda x: x[0])
            seq = seq_map[target_id]
            coords = np.array([[x[2], x[3], x[4]] for x in dat], dtype=np.float32)
            if len(seq) == len(coords) and np.isfinite(coords).all():
                seq_idx = np.array([nt_to_idx.get(nt, 4) for nt in seq], dtype=np.int64)
                self.samples.append((target_id, seq_idx, coords))
                all_coords.append(coords)
            else:
                print(f"Bad entry removed: {target_id} (len/coords/finite)")

        all_coords = np.concatenate(all_coords, axis=0)
        self.mean = np.mean(all_coords, axis=0)
        self.std = np.std(all_coords, axis=0)

    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        target_id, seq_idx, coords = self.samples[idx]
        norm_coords = (coords - self.mean) / self.std
        return torch.LongTensor(seq_idx), torch.tensor(norm_coords), len(seq_idx), target_id


In [60]:
#collate function
def rna_collate(batch):
    seqs, coords, lengths, ids = zip(*batch)
    lengths = torch.tensor(lengths)
    seqs_padded = pad_sequence(seqs, batch_first=True, padding_value=4)
    coords_padded = pad_sequence(coords, batch_first=True, padding_value=0)
    return seqs_padded, coords_padded, lengths, ids


# Model
class RNA3DNet(nn.Module):
    def __init__(self, vocab_size=5, emb_dim=128, hidden_dim=256, dropout=0.1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, emb_dim, padding_idx=4)
        self.lstm = nn.LSTM(emb_dim, hidden_dim, num_layers=3, batch_first=True, 
                            bidirectional=True, dropout=dropout)
        self.fc = nn.Linear(hidden_dim*2, 3)

    def forward(self, seq, lengths):
        x = self.embed(seq)
        packed = nn.utils.rnn.pack_padded_sequence(
            x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_output, _ = self.lstm(packed)
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
        coords = self.fc(output)
        return coords


In [61]:
def validate(model, val_loader, device):
    model.eval()
    total_loss = 0.0
    total_points = 0

    loss_fn = nn.MSELoss(reduction='none')

    for seqs, coords, lengths, _ in val_loader:
        seqs, coords, lengths = seqs.to(device), coords.to(device), lengths.to(device)
        pred_coords = model(seqs, lengths)

        mask = torch.arange(seqs.size(1), device=device)[None, :] < lengths[:, None]

        mse = loss_fn(pred_coords, coords).sum(2)
        loss = (mse * mask).sum() / mask.sum()
        total_loss += loss.item() * mask.sum().item()
        total_points += mask.sum().item()
    rmsd = (total_loss / total_points) ** 0.5
    return rmsd


In [None]:
# training loop
def train_model(model, dataloader, device, epochs=50, lr=3e-4, val_loader=None):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    

    total_steps = len(dataloader) * epochs
    warmup_steps = min(1000, total_steps // 10) 
    
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        return 1.0
    
    device_type = 'cuda' if device.type == 'cuda' else 'cpu'

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
    
    mse_loss = nn.MSELoss(reduction='none')
    best_val_rmsd = float('inf')
    scaler = torch.amp.GradScaler()
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        total_points = 0
        
        for seqs, coords, lengths, _ in dataloader:
            seqs, coords, lengths = seqs.to(device), coords.to(device), lengths.to(device)
            
            optimizer.zero_grad()
            
            with torch.amp.autocast(device_type=device_type):
                pred_coords = model(seqs, lengths)
                mask = torch.arange(seqs.size(1), device=device)[None, :] < lengths[:, None]
                mse = mse_loss(pred_coords, coords).sum(2)
                loss = (mse * mask).sum() / mask.sum()
            
            scaler.scale(loss).backward()
            
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            scaler.step(optimizer)
            scaler.update()
            
            running_loss += loss.item() * mask.sum().item()
            total_points += mask.sum().item()
        
            
        epoch_loss = running_loss / total_points
        print(f"Epoch {epoch+1} Loss: {epoch_loss:.5f}")
        
        if val_loader is not None:
            val_rmsd = validate(model, val_loader, device)
            print(f"Validation RMSD: {val_rmsd:.5f}")

            scheduler.step(val_rmsd)

            if val_rmsd < best_val_rmsd:
                best_val_rmsd = val_rmsd
                torch.save(model.state_dict(), 'best_model.pt')
                print(f"New best checkpoint (epoch {epoch+1}, RMSD {val_rmsd:.4f})")


In [63]:
#Submission Function
def submission(model, seq_csv, submission_csv, device, mean, std):
    seq_df = pd.read_csv(seq_csv)
    model.eval()
    submissions = []
    for idx, row in seq_df.iterrows():
        target_id, sequence = row['target_id'], row['sequence'].strip()
        seq_idx = torch.LongTensor([nt_to_idx.get(nt, 4) for nt in sequence]).unsqueeze(0).to(device)
        lengths = torch.tensor([len(sequence)]).to(device)
        with torch.no_grad():
            coords = model(seq_idx, lengths)[0][:len(sequence)].cpu().numpy()
            coords = coords * std + mean
        for i, (nt, (x, y, z)) in enumerate(zip(sequence, coords)):
            id_out = f"{target_id}_{i+1}"
            submissions.append({
                "ID": id_out,
                "resname": nt,
                "resid": i+1,
                "x_1": x, "y_1": y, "z_1": z,
                "x_2": 0.0, "y_2": 0.0, "z_2": 0.0,
                "x_3": 0.0 
            })
    sub_df = pd.DataFrame(submissions)
    sub_df.to_csv(submission_csv, index=False)
    print(f"Submission CSV written: {submission_csv}")


In [None]:
#Training 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32

train_dataset = RNAStruct("rna-folding/train_sequences_clean.csv", "rna-folding/train_labels_clean.csv")
val_dataset = RNAStruct("rna-folding/validation_sequences.csv", "rna-folding/validation_labels.csv")

mean = train_dataset.mean
std = train_dataset.std

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=rna_collate)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=rna_collate)

model = RNA3DNet().to(device)
train_model(model, train_loader, device, epochs=30, lr=1e-5, val_loader=val_loader)

model.load_state_dict(torch.load('best_model.pt'))

submission(model, "rna-folding/test_sequences.csv", "rna-folding/submission.csv", device, mean, std)


Epoch 1 Loss: 3.00612
Validation RMSD: 1.73245
New best checkpoint (epoch 1, RMSD 1.7325)
Epoch 2 Loss: 3.00450
Validation RMSD: 1.73248
Epoch 3 Loss: 3.00262
Validation RMSD: 1.73290
Epoch 4 Loss: 3.00443
Validation RMSD: 1.73329
Epoch 5 Loss: 3.00346
Validation RMSD: 1.73362
Epoch 6 Loss: 3.00231
Validation RMSD: 1.73389
Epoch 7 Loss: 3.00218
Validation RMSD: 1.73431
Epoch 8 Loss: 3.00599
Validation RMSD: 1.73497
Epoch 9 Loss: 3.00634
Validation RMSD: 1.73511
Epoch 10 Loss: 3.00527
Validation RMSD: 1.73516
