In [None]:
import pandas as pd
import numpy as np
import pickle
import math
import random
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import CyclicLR
from torch.utils.data import Dataset, DataLoader, Subset
from Bio import SeqIO
import os
import optuna
from torch.nn.utils.rnn import pad_sequence


nt_to_idx = {'A': 0, 'U': 1, 'G': 2, 'C': 3, 'N': 4}


In [167]:
class Config:
    vocab_size = len(nt_to_idx)
    max_len = 512
    emb_dim = 128
    num_layers = 4
    nhead = 16
    ff_dim = 1024
    dropout = 0.10055310921251336
    batch_size = 4
    epochs = 5
    lr = 0.0006706453668054184
    pad_idx = 4
    
    use_scheduler = True
    seed = 42

    def set_seed(self):
        torch.manual_seed(self.seed)
        torch.cuda.manual_seed_all(self.seed)
        np.random.seed(self.seed)
        random.seed(self.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

config = Config()
config.set_seed()
device = torch.device("cuda" if torch.cuda.is_available() else"cpu")
torch.serialization.add_safe_globals([np.core.multiarray._reconstruct])

In [168]:
class RNAStruct(Dataset):
    def __init__(self, seq_csv, labels_csv, msa_dir="MSA", mean=None, std=None, normalize=True, train_mode=False, noise_std=0.0):
        import pandas as pd
        seq_df = pd.read_csv(seq_csv)
        seq_map = {row['target_id']: row['sequence'].strip() for _, row in seq_df.iterrows()}
        self.train_mode = train_mode
        self.noise_std = noise_std
        self.msa_dir = msa_dir

        labels_df = pd.read_csv(labels_csv)
        self.data = {}
        for _, row in labels_df.iterrows():
            target_id = row['ID'].rsplit('_', 1)[0]
            idx = int(row['resid']) - 1
            if target_id not in self.data:
                self.data[target_id] = []
            self.data[target_id].append((idx, row['resname'], row['x_1'], row['y_1'], row['z_1']))

        self.samples = []
        for target_id in self.data:
            dat = sorted(self.data[target_id], key=lambda x: x[0])
            seq = seq_map[target_id]
            coords = np.array([[x[2], x[3], x[4]] for x in dat], dtype=np.float32)
            if len(seq) == len(coords) and np.isfinite(coords).all():
                seq_idx = np.array([nt_to_idx.get(nt, 4) for nt in seq], dtype=np.int64)
                base_onehot = np.eye(4)[seq_idx]
                rel_pos = np.arange(len(seq), dtype=np.float32) / (max(len(seq) - 1, 1))
                rel_pos = rel_pos[:, None]
                msa_profile = self._load_msa_profile(target_id, seq)  # (L, 4)
                features = np.concatenate([base_onehot, rel_pos, msa_profile], axis=1)  # shape (L,9)
                self.samples.append((target_id, seq_idx, coords, features))
            else:
                print(f"Bad entry removed: {target_id}")

        all_coords = np.concatenate([coords for _, _, coords, _ in self.samples], axis=0)
        self.mean = mean if mean is not None else all_coords.mean(axis=0)
        self.std = std if std is not None else all_coords.std(axis=0)
        self.normalize = normalize

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        target_id, seq_idx, coords, features = self.samples[idx]
        length = len(seq_idx)
        if self.normalize:
            mean = self.mean.numpy() if isinstance(self.mean, torch.Tensor) else self.mean
            std = self.std.numpy() if isinstance(self.std, torch.Tensor) else self.std
            norm_coords = (coords - mean) / std
            if self.train_mode and self.noise_std > 0:
                norm_coords = norm_coords + np.random.normal(0, self.noise_std, size=norm_coords.shape).astype(np.float32)
            return (
                torch.LongTensor(seq_idx),
                torch.tensor(norm_coords, dtype=torch.float32),
                length,
                target_id,
                torch.tensor(features, dtype=torch.float32)
            )
        else:
            return (
                torch.LongTensor(seq_idx),
                torch.tensor(coords, dtype=torch.float32),
                length,
                target_id,
                torch.tensor(features, dtype=torch.float32)
            )

    def get_mean_std(self,):
        return self.mean, self.std

    def _load_msa_profile(self, target_id, seq):
        msa_path = os.path.join(self.msa_dir, f"{target_id}.MSA.fasta")
        profile = np.zeros((len(seq), 4), dtype=np.float32)
        if not os.path.exists(msa_path):
            return profile
        msa = [str(record.seq) for record in SeqIO.parse(msa_path, 'fasta')]
        nt_to_col = {'A':0, 'C':1, 'G':2, 'U':3}
        for i, s in enumerate(zip(*msa)):
            total = 0
            for nt in 'ACGU':
                c = s.count(nt)
                profile[i, nt_to_col[nt]] = c
                total += c
            if total > 0:
                profile[i] /= total
        return profile

In [None]:
def compute_features_for_test(seq, target_id, msa_dir="MSA"):
    seq_idx = np.array([nt_to_idx.get(nt, 4) for nt in seq], dtype=np.int64)
    base_onehot = np.eye(4)[np.clip(seq_idx,0,3)]
    rel_pos = np.arange(len(seq), dtype=np.float32) / (max(len(seq) - 1, 1))
    rel_pos = rel_pos[:, None]
    profile = np.zeros((len(seq), 4), dtype=np.float32)
    msa_path = os.path.join(msa_dir, f"{target_id}.MSA.fasta")
    if os.path.exists(msa_path):
        from Bio import SeqIO
        msa = [str(record.seq) for record in SeqIO.parse(msa_path, 'fasta')]
        nt_to_col = {'A':0, 'C':1, 'G':2, 'U':3}
        for i, s in enumerate(zip(*msa)):
            total = 0
            for nt in 'ACGU':
                c = s.count(nt)
                profile[i, nt_to_col[nt]] = c
                total += c
            if total > 0:
                profile[i] /= total
    features = np.concatenate([base_onehot, rel_pos, profile], axis=1)
    return features

In [170]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model

    def forward(self, x):
        seq_len = x.size(1)
        device = x.device
        position = torch.arange(0, seq_len, dtype=torch.float32, device=device).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, self.d_model, 2, device=device).float() * (-math.log(10000.0) / self.d_model)
        )
        pe = torch.zeros(seq_len, self.d_model, device=device)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return x + pe.unsqueeze(0)


In [171]:
def rna_collate(batch):
    seqs, coords, lengths, target_ids, features = zip(*batch)
    lengths = torch.tensor(lengths)
    seqs_padded = pad_sequence(seqs, batch_first=True, padding_value=4)
    coords_padded = pad_sequence(coords, batch_first=True, padding_value=0)
    features_padded = pad_sequence(features, batch_first=True, padding_value=0)
    return seqs_padded, coords_padded, lengths, target_ids, features_padded

class RNA3DNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = nn.Embedding(config.vocab_size, config.emb_dim, padding_idx=config.pad_idx)
        self.extra_linear = nn.Linear(9, 32)
        self.extra_bn = nn.BatchNorm1d(32)

        self.pos_enc = PositionalEncoding(config.emb_dim + 32)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.emb_dim + 32, 
            nhead=config.nhead, 
            dim_feedforward=config.ff_dim * 2,
            dropout=config.dropout, 
            batch_first=True, 
            norm_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers + 2) # More layers

        self.ln = nn.LayerNorm(config.emb_dim + 32)

        self.fc1 = nn.Sequential(
            nn.Dropout(config.dropout),
            nn.Linear(config.emb_dim + 32, 64),
            nn.ReLU()
        )
        self.fc2 = nn.Sequential(
            nn.Dropout(config.dropout),
            nn.Linear(64, 3)
        )
    def forward(self, seq, lengths, features, noise_std=0.0):
        x = self.embed(seq)
        batch_size, seq_len, _ = features.size()
        extra = self.extra_linear(features).view(batch_size * seq_len, -1)
        extra = self.extra_bn(extra).view(batch_size, seq_len, -1)
        x = torch.cat([x, extra], dim=-1)

        if noise_std > 0 and self.training:
            x = x + torch.randn_like(x) * noise_std

        x = self.pos_enc(x)
        mask = (seq == config.pad_idx)

        x = self.encoder(x, src_key_padding_mask=mask)
        x = self.ln(x)

        x = self.fc1(x)
        coords = self.fc2(x)
        return coords


In [172]:
def compute_tm_score(pred_coords, true_coords, Lref=None):
    assert pred_coords.shape == true_coords.shape
    L = pred_coords.shape[0]
    Lref = Lref if Lref is not None else L
    dists = np.linalg.norm(pred_coords - true_coords, axis=1)
    
    # d0 depends on Lref
    if Lref >= 30:
        d0 = 1.24 * (Lref - 15) ** (1/3) - 1.8
    elif Lref >= 24:
        d0 = 0.7
    elif Lref >= 20:
        d0 = 0.6
    elif Lref >= 16:
        d0 = 0.5
    elif Lref >= 12:
        d0 = 0.4
    else:
        d0 = 0.3

    score = (1 / Lref) * np.sum(1 / (1 + (dists / d0) ** 2))
    return score


In [173]:
def kabsch_align(P, Q):
    C = np.dot(P.T, Q)
    V, S, Wt = np.linalg.svd(C)
    d = (np.linalg.det(V) * np.linalg.det(Wt)) < 0.0
    if d:
        V[:, -1] = -V[:, -1]
    U = np.dot(V, Wt)
    return np.dot(P, U)



def validate(model, dataloader, device):
    model.eval()
    total_pts, total_rmsd, total_tm = 0, 0, 0
    with torch.no_grad():
        for seqs, coords, lengths, _, features  in dataloader:
            seqs = seqs.to(device)
            coords = coords.to(device)
            lengths = lengths.to(device)
            features = features.to(device)
            pred_coords = model(seqs, lengths, features=features)
            for i in range(seqs.size(0)):
                L = lengths[i].item()
                pred_i = pred_coords[i, :L].cpu().numpy()
                true_i = coords[i, :L].cpu().numpy()
                pred_aligned = kabsch_align(pred_i - pred_i.mean(axis=0), true_i - true_i.mean(axis=0))

                dists = np.linalg.norm(pred_aligned - true_i, axis=1)
                rmsd = np.sqrt((dists ** 2).mean())
                tm = compute_tm_score(pred_aligned, true_i, Lref=L)

                total_rmsd += rmsd * L
                total_tm += tm * L
                total_pts += L


    mean_rmsd = total_rmsd / total_pts
    mean_tm = total_tm / total_pts
    model.train()
    return mean_rmsd, mean_tm


In [None]:
# training loop
def train_model(model, dataloader, device, val_loader=None):
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    scheduler = CyclicLR(optimizer, base_lr=1e-5, max_lr=1e-3, step_size_up=100, mode='triangular') if config.use_scheduler else None

    best_val = float('inf')
    loss_fn = nn.SmoothL1Loss(reduction='none')
    epoch_losses = []
    val_scores = []

    global_step = 0

    for epoch in range(config.epochs):
        model.train()
        running_loss = 0.0
        total_points = 0
        
        for seqs, coords, lengths, _, features in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
            seqs = seqs.to(device)
            coords = coords.to(device)
            lengths = lengths.to(device)
            features = features.to(device)
            optimizer.zero_grad()
            pred_coords = model(seqs, lengths, noise_std=0.2,features=features)
            per_nt_loss = loss_fn(pred_coords, coords).sum(-1)
            mask = (seqs != config.pad_idx)
            loss = (per_nt_loss * mask).sum() / mask.sum().clamp(min=1)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            running_loss += loss.item() * mask.sum().item()
            total_points += mask.sum().item()
            global_step += 1
        epoch_loss = running_loss / total_points
        epoch_losses.append(epoch_loss)

        # Print loss after each epoch
        print(f"\nEpoch {epoch+1} Loss: {epoch_loss:.5f}")

        if val_loader is not None:
            val_rmsd, _ = validate(model, val_loader, device)
            val_scores.append(val_rmsd)  
    
            print(f"Validation RMSD: {val_rmsd:.5f}")

            if scheduler is not None:
                scheduler.step()
            if val_rmsd < best_val:
                best_val = val_rmsd
                torch.save(model.state_dict(), 'setnet.pt')
                val_rmsd, val_tm = validate(model, val_loader, device)
                print(f"New BEST model saved at Epoch {epoch+1} (RMSD: {val_rmsd:.4f})")
                print(f"\nFinal TM-score on validation set: {val_tm:.4f}")
                print(f"Final RMSD on validation set: {val_rmsd:.4f}")
                 
    # plot losses and validation scores
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(epoch_losses, label='Training Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training Loss Over Epochs')
    plt.legend()

    if val_loader is not None:
        plt.subplot(1, 2, 2)
        plt.plot(val_scores, label='Validation RMSD', color='r')
        plt.xlabel('Epochs')
        plt.ylabel('RMSD')
        plt.title('Validation RMSD Over Epochs')
        plt.legend()

    plt.tight_layout()
    plt.show()

In [175]:
def submission(model, seq_csv, submission_csv, device, mean, std, num_predictions=5):
    seq_df = pd.read_csv(seq_csv)
    model.train()
    rows = []
    mean = mean.cpu().numpy() if isinstance(mean, torch.Tensor) else mean
    std = std.cpu().numpy() if isinstance(std, torch.Tensor) else std
    for _, row in seq_df.iterrows():
        target_id, seq = row['target_id'], row['sequence'].strip()
        seq_tensor = torch.LongTensor([nt_to_idx.get(nt, config.pad_idx) for nt in seq]).unsqueeze(0).to(device)
        length = torch.tensor([len(seq)], device=device)

        all_coords = []
        features = compute_features_for_test(seq, target_id, msa_dir="rna-folding/MSA") # or your correct dir
        features = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(device)  # [1, L, 9]

        for _ in range(num_predictions):
            with torch.no_grad():
                coords = model(seq_tensor, length, features, noise_std=0.1)[0][:length.item()].cpu().numpy()
                coords = coords * std + mean
                all_coords.append(coords)

        for i, nt in enumerate(seq):
            row_data = {
        "ID": f"{target_id}_{i+1}",
        "resname": nt,
        "resid": i+1,
            }
            for j, coords_pred in enumerate(all_coords, start=1):
                row_data[f"x_{j}"] = coords_pred[i][0]
                row_data[f"y_{j}"] = coords_pred[i][1]
                row_data[f"z_{j}"] = coords_pred[i][2]
            rows.append(row_data)


    pd.DataFrame(rows).to_csv(submission_csv, index=False)
    print(f"Submission CSV written: {submission_csv}")

In [None]:
# Training Loop

full_dataset = RNAStruct(
    "rna-folding/train_sequences.cutoff_filtered.csv",
    "rna-folding/train_labels.cutoff_filtered.csv"
)

all_ids = [target_id for target_id, _, _, _ in full_dataset.samples]
id_to_idx = {t: i for i, (t, _, _, _) in enumerate(full_dataset.samples)}
train_ids, val_ids = train_test_split(all_ids, test_size=0.15, random_state=config.seed)

train_indices = [id_to_idx[t] for t in train_ids if t in id_to_idx]
val_indices   = [id_to_idx[t] for t in val_ids if t in id_to_idx]

valid_coords = []

for i in train_indices:
    target_id, seq_idx, coords, features = full_dataset.samples[i]
    if coords.shape[0] == len(seq_idx) and np.all(np.isfinite(coords)):
        centered = coords - coords.mean(axis=0)
        valid_coords.append(centered)


# Combine and compute
all_coords = np.concatenate(valid_coords, axis=0)
mean = torch.tensor(all_coords.mean(axis=0), dtype=torch.float32)
std = torch.tensor(all_coords.std(axis=0), dtype=torch.float32)

# Save
torch.save({'mean': mean, 'std': std}, 'mean_std.pt')
print("Mean:", mean)
print("Std:", std)


train_dataset = RNAStruct(
    "rna-folding/train_sequences.cutoff_filtered.csv",
    "rna-folding/train_labels.cutoff_filtered.csv",
    mean=mean,
    std=std
)
val_dataset = RNAStruct(
    "rna-folding/train_sequences.cutoff_filtered.csv",
    "rna-folding/train_labels.cutoff_filtered.csv",
    mean=mean,
    std=std
)

train_dataset = Subset(train_dataset, train_indices)
val_dataset = Subset(val_dataset, val_indices)

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, collate_fn=rna_collate)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, collate_fn=rna_collate)

model = RNA3DNet().to(device)
train_model(model, train_loader, device, val_loader=val_loader)

stats = torch.load('mean_std.pt')
mean, std = stats['mean'], stats['std']

mean = mean.numpy() if isinstance(mean, torch.Tensor) else mean
std = std.numpy() if isinstance(std, torch.Tensor) else std

print("🚨 Final sanity check:")
print("Mean:", mean)
print("Std:", std)


submission(
    model,
    "rna-folding/test_sequences.csv",
    "rna-folding/submission.csv",
    device,
    mean,
    std
)


Bad entry removed: 1ZDI_S
Bad entry removed: 7MSF_S
Bad entry removed: 5MSF_S
Bad entry removed: 1FOQ_A
Bad entry removed: 1E7K_D
Bad entry removed: 1IBM_Y
Bad entry removed: 1JWC_A
Bad entry removed: 1LS2_B
Bad entry removed: 1OSW_A
Bad entry removed: 1P6V_B
Bad entry removed: 1P6V_D
Bad entry removed: 1QZC_C
Bad entry removed: 1R2W_C
Bad entry removed: 1QZC_B
Bad entry removed: 1QZA_B
Bad entry removed: 1QZB_B
Bad entry removed: 1RY1_E
Bad entry removed: 1Y1Y_P
Bad entry removed: 1ZC8_Z
Bad entry removed: 1ZC8_G
Bad entry removed: 1ZC8_J
Bad entry removed: 1ZC8_F
Bad entry removed: 1ZC8_I
Bad entry removed: 1ZC8_H
Bad entry removed: 1X18_D
Bad entry removed: 1X18_A
Bad entry removed: 1X18_B
Bad entry removed: 1ZN1_B
Bad entry removed: 1YSH_F
Bad entry removed: 2A64_A
Bad entry removed: 2BQ5_S
Bad entry removed: 2BS0_R
Bad entry removed: 2BS0_S
Bad entry removed: 2B2E_R
Bad entry removed: 2B2E_S
Bad entry removed: 2AGN_C
Bad entry removed: 2AGN_A
Bad entry removed: 2IZN_S
Bad entry re

Epoch 1: 100%|██████████| 110/110 [00:37<00:00,  2.92it/s]



Epoch 1 Loss: 7.08808
Validation RMSD: 6.61341
New BEST model saved at Epoch 1 (RMSD: 6.6134)

Final TM-score on validation set: 0.6842
Final RMSD on validation set: 6.6134


Epoch 2: 100%|██████████| 110/110 [00:38<00:00,  2.89it/s]



Epoch 2 Loss: 6.77096
Validation RMSD: 6.61113
New BEST model saved at Epoch 2 (RMSD: 6.6111)

Final TM-score on validation set: 0.6841
Final RMSD on validation set: 6.6111


Epoch 3: 100%|██████████| 110/110 [00:38<00:00,  2.84it/s]



Epoch 3 Loss: 6.70010
Validation RMSD: 6.60866
New BEST model saved at Epoch 3 (RMSD: 6.6087)

Final TM-score on validation set: 0.6842
Final RMSD on validation set: 6.6087


Epoch 4:  10%|█         | 11/110 [00:06<00:57,  1.73it/s]


KeyboardInterrupt: 