In [22]:
import pandas as pd
import numpy as np
import torch
from ndlinear import NdLinear
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

nt_to_idx = {'A': 0, 'U': 1, 'G': 2, 'C': 3, 'N': 4}


In [None]:
class RNAStruct(Dataset):
    def __init__(self, seq_csv, labels_csv, mean=None, std=None, normalize=True):
        seq_df = pd.read_csv(seq_csv)
        seq_map = {row['target_id']: row['sequence'].strip() for _, row in seq_df.iterrows()}

        labels_df = pd.read_csv(labels_csv)
        self.data = {}
        for _, row in labels_df.iterrows():
            target_id = row['ID'].rsplit('_', 1)[0]
            idx = int(row['resid']) - 1
            if target_id not in self.data:
                self.data[target_id] = []
            self.data[target_id].append((idx, row['resname'], row['x_1'], row['y_1'], row['z_1']))

        self.samples = []
        for target_id in self.data:
            dat = sorted(self.data[target_id], key=lambda x: x[0])
            seq = seq_map[target_id]
            coords = np.array([[x[2], x[3], x[4]] for x in dat], dtype=np.float32)
            if len(seq) == len(coords) and np.isfinite(coords).all():
                seq_idx = np.array([nt_to_idx.get(nt, 4) for nt in seq], dtype=np.int64)
                self.samples.append((target_id, seq_idx, coords))
            else:
                print(f"Bad entry removed: {target_id} (len/coords/finite)")

        all_coords = np.concatenate([coords for _, _, coords in self.samples], axis=0)
        self.mean = mean if mean is not None else all_coords.mean(axis=0)
        self.std = std if std is not None else all_coords.std(axis=0)
        self.normalize = normalize

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        target_id, seq_idx, coords = self.samples[idx]
        length = len(seq_idx)
        if self.normalize:
            norm_coords = (coords - self.mean) / self.std
            return (
                torch.LongTensor(seq_idx),
                torch.tensor(norm_coords, dtype=torch.float32),
                length,
                target_id
            )
        else:
            return (
                torch.LongTensor(seq_idx),
                torch.tensor(coords, dtype=torch.float32),
                length,
                target_id
            )



In [24]:
#collate function
def rna_collate(batch):
    seqs, coords, lengths, target_ids = zip(*batch)
    lengths = torch.tensor(lengths)
    seqs_padded = pad_sequence(seqs, batch_first=True, padding_value=4)
    coords_padded = pad_sequence(coords, batch_first=True, padding_value=0)
    return seqs_padded, coords_padded, lengths, target_ids


# Model
class RNA3DNet(nn.Module):
    def __init__(self, vocab_size=5, emb_dim=128, num_layers=6, nhead=8, ff_dim=256, dropout=0.2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, emb_dim, padding_idx=4)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=emb_dim, nhead=nhead, dim_feedforward=ff_dim,
            dropout=dropout, batch_first=True, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.ln = nn.LayerNorm(emb_dim)
        self.fc = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(emb_dim, 3)
        )

    def forward(self, seq, lengths):
        x = self.embed(seq)
        mask = (seq == 4) 
        x = self.encoder(x, src_key_padding_mask=mask)
        x = self.ln(x)
        coords = self.fc(x)
        return coords


In [25]:
def validate(model, dataloader, device):
    model.eval()
    total_pts, total_rmsd = 0, 0
    with torch.no_grad():
        for seqs, coords, lengths, _ in dataloader:
            seqs, coords, lengths = seqs.to(device), coords.to(device), lengths.to(device)
            pred_coords = model(seqs, lengths)
            mask = torch.arange(seqs.size(1), device=device)[None, :] < lengths[:, None]
            diff = ((pred_coords - coords) ** 2).sum(2).sqrt()
            rmsd = (diff * mask).sum() / mask.sum()
            total_rmsd += rmsd.item() * mask.sum().item()
            total_pts += mask.sum().item()
    mean_rmsd = total_rmsd / total_pts
    print(f"Validation RMSD: {mean_rmsd:.4f}")
    model.train()
    return mean_rmsd



In [None]:
# training loop
def train_model(model, dataloader, device, epochs=100, lr=1e-3, val_loader=None, scheduler=None):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss(reduction='none')
    best_val = float('inf')

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        total_points = 0
        for seqs, coords, lengths, _ in dataloader:
            seqs, coords, lengths = seqs.to(device), coords.to(device), lengths.to(device)
            optimizer.zero_grad()
            pred_coords = model(seqs, lengths)
            mask = torch.arange(seqs.size(1), device=device)[None, :] < lengths[:, None]
            mse = loss_fn(pred_coords, coords).sum(2)
            loss = (mse * mask).sum() / mask.sum()
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * mask.sum().item()
            total_points += mask.sum().item()
        if scheduler is not None:
            scheduler.step()
        print(f"Epoch {epoch+1} Loss: {running_loss/total_points:.5f}")
        if val_loader is not None:
            val_rmsd = validate(model, val_loader, device)
            if val_rmsd < best_val:
                best_val = val_rmsd
                torch.save(model.state_dict(), 'best_model.pt')
                print(f"Epoch {epoch+1}")



In [27]:
#Submission Function
def submission(model, seq_csv, submission_csv, device, mean, std):
    seq_df = pd.read_csv(seq_csv)
    model.eval()
    submissions = []
    for idx, row in seq_df.iterrows():
        target_id, sequence = row['target_id'], row['sequence'].strip()
        seq_idx = torch.LongTensor([nt_to_idx.get(nt, 4) for nt in sequence]).unsqueeze(0).to(device)
        lengths = torch.tensor([len(sequence)]).to(device)
        with torch.no_grad():
            coords = model(seq_idx, lengths)[0][:len(sequence)].cpu().numpy()
            coords = coords * std + mean
        for i, (nt, (x, y, z)) in enumerate(zip(sequence, coords)):
            id_out = f"{target_id}_{i+1}"
            submissions.append({
                "ID": id_out,
                "resname": nt,
                "resid": i+1,
                "x_1": x, "y_1": y, "z_1": z,
                "x_2": 0.0, "y_2": 0.0, "z_2": 0.0,
                "x_3": 0.0
            })
    sub_df = pd.DataFrame(submissions)
    sub_df.to_csv(submission_csv, index=False)
    print(f"Submission CSV written: {submission_csv}")



In [28]:
#Training 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32

train_dataset = RNAStruct("rna-folding/train_sequences_clean.csv", "rna-folding/train_labels_clean.csv")
val_dataset = RNAStruct("rna-folding/validation_sequences.csv", "rna-folding/validation_labels.csv")

mean = train_dataset.mean
std = train_dataset.std

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=rna_collate)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=rna_collate)

model = RNA3DNet().to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2) #CosineAnnealingLR(optimizer, T_max=100)

train_model(model, train_loader, device, epochs=30, lr=1e-5, val_loader=val_loader)

model.load_state_dict(torch.load('best_model.pt'))

submission(model, "rna-folding/test_sequences.csv", "rna-folding/submission.csv", device, mean, std)




Epoch 1 Loss: 4.29305
Validation RMSD: 0.7969
Epoch 1 (RMSD: 0.7969)
Epoch 2 Loss: 3.70254
Validation RMSD: 0.5413
Epoch 2 (RMSD: 0.5413)
Epoch 3 Loss: 3.51538
Validation RMSD: 0.4722
Epoch 3 (RMSD: 0.4722)
Epoch 4 Loss: 3.43061
Validation RMSD: 0.4680
Epoch 4 (RMSD: 0.4680)
Epoch 5 Loss: 3.41824
Validation RMSD: 0.4176
Epoch 5 (RMSD: 0.4176)
Epoch 6 Loss: 3.40111
Validation RMSD: 0.4111
Epoch 6 (RMSD: 0.4111)
Epoch 7 Loss: 3.37685
Validation RMSD: 0.3853
Epoch 7 (RMSD: 0.3853)
Epoch 8 Loss: 3.35641
Validation RMSD: 0.3971
Epoch 9 Loss: 3.34937
Validation RMSD: 0.4122
Epoch 10 Loss: 3.33621
Validation RMSD: 0.4468
Epoch 11 Loss: 3.34375
Validation RMSD: 0.4081
Epoch 12 Loss: 3.31447
Validation RMSD: 0.3344
Epoch 12 (RMSD: 0.3344)
Epoch 13 Loss: 3.30657
Validation RMSD: 0.3692
Epoch 14 Loss: 3.30667
Validation RMSD: 0.3760
Epoch 15 Loss: 3.29721
Validation RMSD: 0.3636
Epoch 16 Loss: 3.28375
Validation RMSD: 0.3337
Epoch 16 (RMSD: 0.3337)
Epoch 17 Loss: 3.30761
Validation RMSD: 0.3526
E