In [23]:
import pandas as pd
import numpy as np
import torch
from ndlinear import NdLinear
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

nt_to_idx = {'A': 0, 'U': 1, 'G': 2, 'C': 3, 'N': 4}


In [24]:
class RNAStruct(Dataset):
    def __init__(self, seq_csv, labels_csv):
        seq_df = pd.read_csv(seq_csv)
        seq_map = {row['target_id']: row['sequence'].strip() for _, row in seq_df.iterrows()}

        labels_df = pd.read_csv(labels_csv)
        self.data = {}
        for _, row in labels_df.iterrows():
            target_id = row['ID'].rsplit('_', 1)[0]
            idx = int(row['resid']) - 1
            if target_id not in self.data:
                self.data[target_id] = []
            self.data[target_id].append((idx, row['resname'], row['x_1'], row['y_1'], row['z_1']))

        self.samples = []
        all_coords = []
        for target_id in self.data:
            dat = sorted(self.data[target_id], key=lambda x: x[0])
            seq = seq_map[target_id]
            coords = np.array([[x[2], x[3], x[4]] for x in dat], dtype=np.float32)
            if len(seq) == len(coords) and np.isfinite(coords).all():
                seq_idx = np.array([nt_to_idx.get(nt, 4) for nt in seq], dtype=np.int64)
                self.samples.append((target_id, seq_idx, coords))
                all_coords.append(coords)
            else:
                print(f"Bad entry removed: {target_id} (len/coords/finite)")

        all_coords = np.concatenate(all_coords, axis=0)
        self.mean = np.mean(all_coords, axis=0)
        self.std = np.std(all_coords, axis=0)

    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        target_id, seq_idx, coords = self.samples[idx]
        norm_coords = (coords - self.mean) / self.std
        return torch.LongTensor(seq_idx), torch.tensor(norm_coords), len(seq_idx), target_id


In [25]:
#collate function
def rna_collate(batch):
    seqs, coords, lengths, ids = zip(*batch)
    lengths = torch.tensor(lengths)
    seqs_padded = pad_sequence(seqs, batch_first=True, padding_value=4)
    coords_padded = pad_sequence(coords, batch_first=True, padding_value=0)
    return seqs_padded, coords_padded, lengths, ids


# Model
class RNA3DNet(nn.Module):
    def __init__(self, vocab_size=5, emb_dim=128, num_layers=6, nhead=8, ff_dim=256, dropout=0.2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, emb_dim, padding_idx=4)
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=nhead, dim_feedforward=ff_dim, dropout=dropout, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.ln = nn.LayerNorm(emb_dim)
        self.fc = nn.Linear(emb_dim, 3)
    def forward(self, seq, lengths):
        x = self.embed(seq)
        mask = (seq == 4)
        x = self.encoder(x, src_key_padding_mask=mask)
        x = self.ln(x)
        coords = self.fc(x)
        return coords

In [26]:
def validate(model, val_loader, device):
    model.eval()
    total_loss = 0.0
    total_points = 0

    loss_fn = nn.MSELoss(reduction='none')

    for seqs, coords, lengths, _ in val_loader:
        seqs, coords, lengths = seqs.to(device), coords.to(device), lengths.to(device)
        pred_coords = model(seqs, lengths)

        mask = torch.arange(seqs.size(1), device=device)[None, :] < lengths[:, None]

        mse = loss_fn(pred_coords, coords).sum(2)
        loss = (mse * mask).sum() / mask.sum()
        total_loss += loss.item() * mask.sum().item()
        total_points += mask.sum().item()
    rmsd = (total_loss / total_points) ** 0.5
    return rmsd


In [None]:
def train_model(model, dataloader, device, epochs=30, lr=1e-3, val_loader=None):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.8)
    loss_fn = nn.MSELoss(reduction='none')
    best_val_rmsd = float('inf')

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        total_points = 0
        for seqs, coords, lengths, _ in dataloader:
            seqs, coords, lengths = seqs.to(device), coords.to(device), lengths.to(device)
            optimizer.zero_grad()
            pred_coords = model(seqs, lengths)
            mask = torch.arange(seqs.size(1), device=device)[None, :] < lengths[:, None]
            mse = loss_fn(pred_coords, coords).sum(2)
            loss = (mse * mask).sum() / mask.sum()
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * mask.sum().item()
            total_points += mask.sum().item()

        epoch_loss = running_loss / total_points
        print(f"Epoch {epoch+1} Loss: {epoch_loss:.5f}")

        if val_loader is not None:
            val_rmsd = validate(model, val_loader, device)
            if val_rmsd is not None and val_rmsd < best_val_rmsd:
                best_val_rmsd = val_rmsd
                torch.save(model.state_dict(), 'best_model.pt')
                print(f"New best checkpoint (epoch {epoch+1}, RMSD {val_rmsd:.4f})")

        scheduler.step()

    
    torch.save(model.state_dict(), 'best_model.pt')
    print("Final model saved as best_model.pt")


In [28]:
#Submission Function
def submission(model, seq_csv, submission_csv, device, mean, std):
    seq_df = pd.read_csv(seq_csv)
    model.eval()
    submissions = []
    for idx, row in seq_df.iterrows():
        target_id, sequence = row['target_id'], row['sequence'].strip()
        seq_idx = torch.LongTensor([nt_to_idx.get(nt, 4) for nt in sequence]).unsqueeze(0).to(device)
        lengths = torch.tensor([len(sequence)]).to(device)
        with torch.no_grad():
            coords = model(seq_idx, lengths)[0][:len(sequence)].cpu().numpy()
            coords = coords * std + mean
        for i, (nt, (x, y, z)) in enumerate(zip(sequence, coords)):
            id_out = f"{target_id}_{i+1}"
            submissions.append({
                "ID": id_out,
                "resname": nt,
                "resid": i+1,
                "x_1": x, "y_1": y, "z_1": z,
                "x_2": 0.0, "y_2": 0.0, "z_2": 0.0,
                "x_3": 0.0 
            })
    sub_df = pd.DataFrame(submissions)
    sub_df.to_csv(submission_csv, index=False)
    print(f"Submission CSV written: {submission_csv}")


In [None]:
#Training 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32

train_dataset = RNAStruct("rna-folding/train_sequences_clean.csv", "rna-folding/train_labels_clean.csv")
val_dataset = RNAStruct("rna-folding/validation_sequences.csv", "rna-folding/validation_labels.csv")

mean = train_dataset.mean
std = train_dataset.std

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=rna_collate)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=rna_collate)

model = RNA3DNet().to(device)
train_model(model, train_loader, device, epochs=30, lr=1e-5, val_loader=val_loader)

model.load_state_dict(torch.load('best_model.pt'))

# Generate predictions
submission(model, "rna-folding/test_sequences.csv", "rna-folding/submission.csv", device, mean, std)


Epoch 1 Loss: 3.67628
New best checkpoint (epoch 1, RMSD 1.7766)
Epoch 2 Loss: 3.22439
New best checkpoint (epoch 2, RMSD 1.7514)
Epoch 3 Loss: 3.20590
New best checkpoint (epoch 3, RMSD 1.7487)
Epoch 4 Loss: 3.19225
New best checkpoint (epoch 4, RMSD 1.7398)
Epoch 5 Loss: 3.17232
Epoch 6 Loss: 3.15293
New best checkpoint (epoch 6, RMSD 1.7386)
Epoch 7 Loss: 3.10487
New best checkpoint (epoch 7, RMSD 1.7353)
Epoch 8 Loss: 3.10308
New best checkpoint (epoch 8, RMSD 1.7350)
Epoch 9 Loss: 3.11389
Epoch 10 Loss: 3.08420
New best checkpoint (epoch 10, RMSD 1.7345)
Epoch 11 Loss: 3.09478
Epoch 12 Loss: 3.06061
New best checkpoint (epoch 12, RMSD 1.7328)
Epoch 13 Loss: 3.06469
Epoch 14 Loss: 3.06521
Epoch 15 Loss: 3.05548
Epoch 16 Loss: 3.08788
Epoch 17 Loss: 3.04885
Epoch 18 Loss: 3.04073
Epoch 19 Loss: 3.02710
Epoch 20 Loss: 3.01367
Epoch 21 Loss: 3.01015
Epoch 22 Loss: 3.01390
Epoch 23 Loss: 3.00518
Epoch 24 Loss: 2.99706
Epoch 25 Loss: 2.98457
Epoch 26 Loss: 3.00606
Epoch 27 Loss: 2.99619

  output = torch._nested_tensor_from_mask(
