**Developer** : Ryan Mehra

**Competition** : [Stanford RNA 3D Folding](https://www.kaggle.com/competitions/stanford-rna-3d-folding/)

**Approach**

v1-mvp: Build end-to-end CNN as experimental start

v2-mvp: Compare physics-based protocol FARFAR2 with Kaggle's deep‑learning model (RibonanzaNet2)

# I. Libraries

In [1]:
import pandas as pd
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
import pickle
from tqdm import tqdm

# II. Config

In [2]:
#set seed for everything
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

In [3]:
config = {
    "seed": 0,
    "cutoff_date": "2020-01-01",
    "test_cutoff_date": "2022-05-01",
    "max_len": 384,
    "batch_size": 10,
    "learning_rate": 1e-4,
    "weight_decay": 0.0,
    "mixed_precision": "bf16",
    "model_config_path": "../working/configs/pairwise.yaml",  # Adjust path as needed
    "epochs": 10,
    "cos_epoch": 5,
    "loss_power_scale": 1.0,
    "max_cycles": 1,
    "grad_clip": 0.1,
    "gradient_accumulation_steps": 1,
    "d_clamp": 30,
    "max_len_filter": 9999999,
    "min_len_filter": 10, 
    "structural_violation_epoch": 50,
    "balance_weight": False,
}

# III. Data Prepration

In [4]:
# Load data

train_sequences=pd.read_csv("/kaggle/input/stanford-rna-3d-folding/train_sequences.csv")
train_labels=pd.read_csv("/kaggle/input/stanford-rna-3d-folding/train_labels.csv")

validation_sequences=pd.read_csv("/kaggle/input/stanford-rna-3d-folding/validation_sequences.csv")
validation_labels=pd.read_csv("/kaggle/input/stanford-rna-3d-folding/validation_labels.csv")

test_sequences=pd.read_csv("/kaggle/input/stanford-rna-3d-folding/test_sequences.csv")

In [5]:
train_labels["pdb_id"] = train_labels["ID"].apply(lambda x: x.split("_")[0]+'_'+x.split("_")[1])
validation_labels["pdb_id"] = validation_labels["ID"].apply(lambda x: x.split("_")[0])

In [6]:
train_sequences.shape, train_labels.shape, validation_sequences.shape, validation_labels.shape, test_sequences.shape

((844, 5), (137095, 7), (12, 5), (2515, 124), (12, 5))

In [7]:
train_sequences.head(1)

Unnamed: 0,target_id,sequence,temporal_cutoff,description,all_sequences
0,1SCL_A,GGGUGCUCAGUACGAGAGGAACCGCACCC,1995-01-26,"THE SARCIN-RICIN LOOP, A MODULAR RNA",>1SCL_1|Chain A|RNA SARCIN-RICIN LOOP|Rattus n...


In [8]:
train_labels.head(1)

Unnamed: 0,ID,resname,resid,x_1,y_1,z_1,pdb_id
0,1SCL_A_1,G,1,13.76,-25.974001,0.102,1SCL_A


In [9]:
validation_sequences.head(1)

Unnamed: 0,target_id,sequence,temporal_cutoff,description,all_sequences
0,R1107,GGGGGCCACAGCAGAAGCGUUCACGUCGCAGCCCCUGUCAGCCAUU...,2022-05-28,CPEB3 ribozyme\nHuman\nhuman CPEB3 HDV-like ri...,>7QR4_1|Chain A|U1 small nuclear ribonucleopro...


In [10]:
## Validation Labels has many coordinates, we do not have this in the training set, for the first run we will ignore the rest and just pick the first XYZ set
validation_labels.head(1)

Unnamed: 0,ID,resname,resid,x_1,y_1,z_1,x_2,y_2,z_2,x_3,...,x_38,y_38,z_38,x_39,y_39,z_39,x_40,y_40,z_40,pdb_id
0,R1107_1,G,1,-5.499,8.52,8.605,-1e+18,-1e+18,-1e+18,-1e+18,...,-1e+18,-1e+18,-1e+18,-1e+18,-1e+18,-1e+18,-1e+18,-1e+18,-1e+18,R1107


In [11]:
test_sequences.head(1)

Unnamed: 0,target_id,sequence,temporal_cutoff,description,all_sequences
0,R1107,GGGGGCCACAGCAGAAGCGUUCACGUCGCAGCCCCUGUCAGCCAUU...,2022-05-28,CPEB3 ribozyme\nHuman\nhuman CPEB3 HDV-like ri...,>7QR4_1|Chain A|U1 small nuclear ribonucleopro...


In [11]:
_tmp = pd.DataFrame()
_tmp['temporal_cutoff'] = pd.to_datetime(train_sequences['temporal_cutoff'])

year_counts = (
    _tmp
    .groupby(_tmp['temporal_cutoff'].dt.year)
    .size()
    .rename('count')
)
print(year_counts)

temporal_cutoff
1995     6
1996     4
1997    17
1998    20
1999    13
2000    15
2001    24
2002    17
2003    30
2004    19
2005    30
2006    44
2007    30
2008    39
2009     9
2010    26
2011    17
2012    17
2013    21
2014    77
2015    26
2016    24
2017    25
2018    19
2019    33
2020    47
2021    28
2022    47
2023    40
2024    80
Name: count, dtype: int64


In [12]:
## Build all coordinates as list per target_id

all_xyz_coord_trng = []
all_xyz_coord_val = []

for pdb_id in tqdm(train_sequences['target_id']):
    df = train_labels[train_labels["pdb_id"] == pdb_id]
    xyz = df[['x_1','y_1','z_1']].to_numpy().astype('float32')

    # 1) Build a mask array, initialized to False
    mask = np.zeros_like(xyz, dtype=bool)
    finite_mask = np.isfinite(xyz)

    # 2) Only compare where values are finite, write into mask
    np.less(xyz, -1e17, out=mask, where=finite_mask)

    # 3) Assign NaN to all positions flagged by mask
    xyz[mask] = np.nan

    all_xyz_coord_trng.append(xyz)


for pdb_id in tqdm(validation_sequences['target_id']):
    df = validation_labels[validation_labels["pdb_id"] == pdb_id]
    xyz = df[['x_1','y_1','z_1']].to_numpy().astype('float32')

    # 1) Build a mask array, initialized to False
    mask = np.zeros_like(xyz, dtype=bool)
    finite_mask = np.isfinite(xyz)

    # 2) Only compare where values are finite, write into mask
    np.less(xyz, -1e17, out=mask, where=finite_mask)

    # 3) Assign NaN to all positions flagged by mask
    xyz[mask] = np.nan

    all_xyz_coord_val.append(xyz)

100%|██████████| 844/844 [00:08<00:00, 102.79it/s]
100%|██████████| 12/12 [00:00<00:00, 1070.77it/s]


In [13]:
len(all_xyz_coord_trng), len(all_xyz_coord_val)

(844, 12)

In [14]:
"""
Filter and process data
	•	finds and prints the maximum coordinate-sequence length.
	•	keeps only those RNAs whose coordinate arrays have
    	1.	≤ 50% missing values,
    	2.	length within your configured min(10), max(9999) bounds.
	•	It then filters your sequence labels and coordinate data down to that clean subset.
"""

#### Process is required for only Training Data, expected to have clean Validaton Data

# initialize stats
lengths = [len(xyz) for xyz in all_xyz_coord_trng]
max_len = max(lengths)
min_len = min(lengths)
total = len(all_xyz_coord_trng)

# build filter mask
filter_mask = []
for xyz in all_xyz_coord_trng:
    frac_nan = np.isnan(xyz).mean()
    seq_len = len(xyz)
    keep = (
        (frac_nan <= 0.5) and
        (seq_len < config['max_len_filter']) and
        (seq_len > config['min_len_filter'])
    )
    filter_mask.append(keep)

filter_mask = np.array(filter_mask)
kept_indices = np.nonzero(filter_mask)[0]
dropped = total - len(kept_indices)

# apply filter
train_sequences = train_sequences.loc[kept_indices].reset_index(drop=True)
all_xyz_coord_trng = [all_xyz_coord_trng[i] for i in kept_indices]

# print stats
print(f"Total sequences initially : {total}")
print(f" Kept                    : {len(kept_indices)}")
print(f" Dropped                 : {dropped}")
print(f"Shortest sequence length : {min_len}")
print(f"Longest sequence length  : {max_len}")

Total sequences initially : 844
 Kept                    : 765
 Dropped                 : 79
Shortest sequence length : 3
Longest sequence length  : 4298


In [15]:
#pack data into a dictionary

training_data={
      "sequence":train_sequences['sequence'].to_list(),
      "temporal_cutoff": train_sequences['temporal_cutoff'].to_list(),
      "description": train_sequences['description'].to_list(),
      "all_sequences": train_sequences['all_sequences'].to_list(),
      "xyz": all_xyz_coord_trng
}

validation_data={
      "sequence":validation_sequences['sequence'].to_list(),
      "temporal_cutoff": validation_sequences['temporal_cutoff'].to_list(),
      "description": validation_sequences['description'].to_list(),
      "all_sequences": validation_sequences['all_sequences'].to_list(),
      "xyz": all_xyz_coord_val
}

In [16]:
next(iter(training_data['sequence'])), next(iter(training_data['temporal_cutoff'])), next(iter(training_data['description'])), next(iter(training_data['all_sequences'])), next(iter(training_data['xyz']))

('GGGUGCUCAGUACGAGAGGAACCGCACCC',
 '1995-01-26',
 'THE SARCIN-RICIN LOOP, A MODULAR RNA',
 '>1SCL_1|Chain A|RNA SARCIN-RICIN LOOP|Rattus norvegicus (10116)\nGGGUGCUCAGUACGAGAGGAACCGCACCC\n',
 array([[ 13.76 , -25.974,   0.102],
        [  9.31 , -29.638,   2.669],
        [  5.529, -27.813,   5.878],
        [  2.678, -24.901,   9.793],
        [  1.827, -20.136,  11.793],
        [  2.04 , -14.908,  11.771],
        [  1.107, -11.513,   7.517],
        [  2.991,  -6.406,   4.783],
        [  0.896,  -1.193,   7.608],
        [  0.228,   2.646,   9.128],
        [  4.329,   2.718,   4.804],
        [  5.165,   4.792,  -0.914],
        [  2.61 ,   9.495,  -2.308],
        [  1.174,  13.829,   0.201],
        [  1.58 ,  20.115,   3.76 ],
        [ -1.575,  16.928,   5.897],
        [ -6.051,  14.762,   5.224],
        [ -5.554,  10.415,   4.309],
        [ -3.107,   6.405,   2.12 ],
        [ -1.41 ,   3.335,  -2.655],
        [  1.866,  -0.716,  -4.333],
        [  3.655,  -4.444,  -2.4

In [17]:
next(iter(validation_data['sequence'])), next(iter(validation_data['temporal_cutoff'])), next(iter(validation_data['description'])), next(iter(validation_data['all_sequences'])), next(iter(validation_data['xyz']))

('GGGGGCCACAGCAGAAGCGUUCACGUCGCAGCCCCUGUCAGCCAUUGCACUCCGGCUGCGAAUUCUGCU',
 '2022-05-28',
 'CPEB3 ribozyme\nHuman\nhuman CPEB3 HDV-like ribozyme',
 '>7QR4_1|Chain A|U1 small nuclear ribonucleoprotein A|Homo sapiens (9606)\nRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKM\n>7QR4_2|Chain B|RNA CPEB3 ribozyme|Homo sapiens (9606)\nGGGGGCCACAGCAGAAGCGUUCACGUCGCAGCCCCUGUCAGCCAUUGCACUCCGGCUGCGAAUUCUGCU',
 array([[ -5.499,   8.52 ,   8.605],
        [ -5.826,  10.453,  14.01 ],
        [ -5.849,  14.768,  17.585],
        [ -5.784,  19.985,  18.666],
        [ -5.755,  25.533,  17.133],
        [ -6.227,  30.093,  13.965],
        [ -9.016,  37.03 ,  11.306],
        [ -9.026,  31.554,   8.725],
        [-13.912,  30.908,   8.347],
        [-22.273,  33.251,   7.105],
        [-25.752,  28.854,   8.548],
        [-28.567,  25.027,   6.709],
        [-30.613,  22.207,   2.6  ],
        [-30.474,  20.334,  -2.326],
        [-27.767,  19.594,  -7.189],
  

# IV. Training Data Prepration

In [18]:
## No need to split from the training set, as we have validaton set 
# all_index = np.arange(len(data['sequence']))
# cutoff_date = pd.Timestamp(config['cutoff_date'])
# test_cutoff_date = pd.Timestamp(config['test_cutoff_date'])
# train_index = [i for i, d in enumerate(data['temporal_cutoff']) if pd.Timestamp(d) <= cutoff_date]
# test_index = [i for i, d in enumerate(data['temporal_cutoff']) if pd.Timestamp(d) > cutoff_date and pd.Timestamp(d) <= test_cutoff_date]

In [19]:
# print(f"Train size: {len(train_index)}")
# print(f"Test size: {len(test_index)}")

print(f"Train size: {len(training_data['sequence'])}")
print(f"Validation size: {len(validation_data['sequence'])}")

Train size: 765
Validation size: 12


**Pytorch Dataset**

In [12]:
from torch.utils.data import Dataset, DataLoader
from ast import literal_eval

def get_ct(bp,s):
    ct_matrix=np.zeros((len(s),len(s)))
    for b in bp:
        ct_matrix[b[0]-1,b[1]-1]=1
    return ct_matrix


class RNA3D_Dataset(torch.utils.data.Dataset):
    def __init__(self, data: dict, config: dict):
        """
        data: dict of lists, keys include:
              'sequence' (list of str), 'xyz' (list of Nx3 arrays), etc.
        config: dict with at least 'max_len' key
        """
        self.data   = data
        self.config = config

        # build token map for known nucleotides
        self.tokens = {nt: i for i, nt in enumerate('ACGU')}
        # assign an ID for unknown tokens
        self.UNK_ID = len(self.tokens)

    def __len__(self):
        return len(self.data['sequence'])
    
    def __getitem__(self, idx):
        # --- sequence to IDs, unknown → UNK_ID ---
        seq_str = self.data['sequence'][idx]
        seq_ids = [ self.tokens.get(nt, self.UNK_ID) for nt in seq_str ]
        sequence = torch.tensor(seq_ids, dtype=torch.long)

        # --- xyz list → tensor ---
        xyz_arr = np.array(self.data['xyz'][idx], dtype=np.float32)
        xyz     = torch.tensor(xyz_arr,   dtype=torch.float32)

        # --- optional random crop if too long ---
        max_len = self.config['max_len']
        if len(sequence) > max_len:
            start = np.random.randint(0, len(sequence) - max_len + 1)
            end   = start + max_len
            sequence = sequence[start:end]
            xyz       = xyz[start:end]
        
        return {
            'sequence': sequence,
            'xyz':       xyz
        }
        
# class RNA3D_Dataset(Dataset):
#     # def __init__(self,indices,data):
#     def __init__(self,data):
#         # self.indices=indices
#         self.data=data
#         self.tokens={nt:i for i,nt in enumerate('ACGU')}

#     def __len__(self):
#         return len(self.data['sequence'])
    
#     def __getitem__(self, idx):

#         # idx=self.indices[idx]
#         sequence=[self.tokens[nt] for nt in (self.data['sequence'][idx])]
#         sequence=np.array(sequence)
#         sequence=torch.tensor(sequence)

#         #get C1' xyz
#         xyz=self.data['xyz'][idx]
#         xyz=torch.tensor(np.array(xyz))


#         if len(sequence)>config['max_len']:
#             crop_start=np.random.randint(len(sequence)-config['max_len'])
#             crop_end=crop_start+config['max_len']

#             sequence=sequence[crop_start:crop_end]
#             xyz=xyz[crop_start:crop_end]
        

#         return {'sequence':sequence,
#                 'xyz':xyz}

In [25]:
# train_dataset=RNA3D_Dataset(train_index,data)
# val_dataset=RNA3D_Dataset(test_index,data)

train_dataset=RNA3D_Dataset(training_data, config)
val_dataset=RNA3D_Dataset(validation_data, config)

In [26]:
train_dataset.__getitem__(0), val_dataset.__getitem__(0)

({'sequence': tensor([2, 2, 2, 3, 2, 1, 3, 1, 0, 2, 3, 0, 1, 2, 0, 2, 0, 2, 2, 0, 0, 1, 1, 2,
          1, 0, 1, 1, 1]),
  'xyz': tensor([[ 13.7600, -25.9740,   0.1020],
          [  9.3100, -29.6380,   2.6690],
          [  5.5290, -27.8130,   5.8780],
          [  2.6780, -24.9010,   9.7930],
          [  1.8270, -20.1360,  11.7930],
          [  2.0400, -14.9080,  11.7710],
          [  1.1070, -11.5130,   7.5170],
          [  2.9910,  -6.4060,   4.7830],
          [  0.8960,  -1.1930,   7.6080],
          [  0.2280,   2.6460,   9.1280],
          [  4.3290,   2.7180,   4.8040],
          [  5.1650,   4.7920,  -0.9140],
          [  2.6100,   9.4950,  -2.3080],
          [  1.1740,  13.8290,   0.2010],
          [  1.5800,  20.1150,   3.7600],
          [ -1.5750,  16.9280,   5.8970],
          [ -6.0510,  14.7620,   5.2240],
          [ -5.5540,  10.4150,   4.3090],
          [ -3.1070,   6.4050,   2.1200],
          [ -1.4100,   3.3350,  -2.6550],
          [  1.8660,  -0.7160,  

In [27]:
import plotly.graph_objects as go
import numpy as np



# Example: Generate an Nx3 matrix
xyz = train_dataset[200]['xyz']  # Replace this with your actual Nx3 data
N = len(xyz)


for _ in range(2): #plot twice because it doesnt show up on first try for some reason
    # Extract columns
    x, y, z = xyz[:, 0], xyz[:, 1], xyz[:, 2]
    
    # Create the 3D scatter plot
    fig = go.Figure(data=[go.Scatter3d(
        x=x, y=y, z=z,
        mode='markers+lines', #'markers',
        marker=dict(
            size=5,
            color=z,  # Coloring based on z-value
            colorscale='Viridis',  # Choose a colorscale
            opacity=0.8
        )
    )])
    
    # Customize layout
    fig.update_layout(
        scene=dict(
            xaxis_title="X",
            yaxis_title="Y",
            zaxis_title="Z"
        ),
        title="3D Scatter Plot"
    )

fig.show()

In [28]:
### Do this later post defining batch sizes 

# ## Create dataloader instances 

# train_loader=DataLoader(train_dataset,batch_size=1,shuffle=True)
# val_loader=DataLoader(val_dataset,batch_size=1,shuffle=False)

In [29]:
#! pip install einops

# V. Create Custom Model Instance

We will add a linear layer to predict xyz of C1' atoms on the base /kaggle/input/ribonanzanet2d-final 



In [13]:
import sys

sys.path.append("/kaggle/input/ribonanzanet2d-final")

from Network import *
import yaml



class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)
        self.entries=entries

    def print(self):
        print(self.entries)

def load_config_from_yaml(file_path):
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)
    return Config(**config)



class finetuned_RibonanzaNet(RibonanzaNet):
    def __init__(self, config, pretrained=False):
        config.dropout=0.1
        super(finetuned_RibonanzaNet, self).__init__(config)
        if pretrained:
            self.load_state_dict(torch.load("/kaggle/input/ribonanzanet-weights/RibonanzaNet.pt",map_location='cpu'))
        # self.ct_predictor=nn.Sequential(nn.Linear(64,256),
        #                                 nn.ReLU(),
        #                                 nn.Linear(256,64),
        #                                 nn.ReLU(),
        #                                 nn.Linear(64,1)) 
        self.dropout=nn.Dropout(0.0)
        self.xyz_predictor=nn.Linear(256,3)


    
    def forward(self,src):
        
        #with torch.no_grad():
        sequence_features, pairwise_features=self.get_embeddings(src, torch.ones_like(src).long().to(src.device))


        xyz=self.xyz_predictor(sequence_features)

        return xyz

In [31]:
## Available GPUs 
print("GPUs available:", torch.cuda.device_count())

GPUs available: 1


In [15]:
from pprint import pprint
cfg = load_config_from_yaml("/kaggle/input/ribonanzanet2d-final/configs/pairwise.yaml")

## Update the batch size to new value
_batch_size= 5

cfg.batch_size = _batch_size
cfg.entries['batch_size'] = _batch_size

## Update the GPUs to multiple if multiple available 
if torch.cuda.device_count() > 1:
    cfg.gpu_id = "0,1"
    cfg.entries['gpu_id'] = "0,1"
    
pprint(vars(cfg))

{'batch_size': 5,
 'bpp_file_folder': '../../input/bpp_files/',
 'dropout': 0.05,
 'entries': {'batch_size': 5,
             'bpp_file_folder': '../../input/bpp_files/',
             'dropout': 0.05,
             'epochs': 40,
             'fold': 0,
             'gpu_id': '0',
             'gradient_accumulation_steps': 2,
             'input_dir': '../../input/',
             'k': 9,
             'learning_rate': 0.001,
             'nclass': 2,
             'nfolds': 6,
             'nhead': 8,
             'ninp': 256,
             'nlayers': 9,
             'ntoken': 5,
             'optimizer': 'ranger',
             'pairwise_dimension': 64,
             'test_batch_size': 8,
             'use_bpp': False,
             'use_grad_checkpoint': True,
             'use_triangular_attention': False,
             'weight_decay': 0.0001},
 'epochs': 40,
 'fold': 0,
 'gpu_id': '0',
 'gradient_accumulation_steps': 2,
 'input_dir': '../../input/',
 'k': 9,
 'learning_rate': 0.001,
 'nclas

In [21]:
## Create dataloader instances 

# train_loader=DataLoader(train_dataset,batch_size=1,shuffle=True)
# val_loader=DataLoader(val_dataset,batch_size=1,shuffle=False)

import torch
from torch.nn.utils.rnn import pad_sequence

def pad_collate(batch):
    # batch is a list of dicts, e.g. {'sequence': Tensor[L], 'xyz': Tensor[L,3], …}
    seqs = [torch.tensor(item['sequence']) for item in batch]
    xyzs = [torch.tensor(item['xyz'], dtype=torch.float32) for item in batch]

    # pad to the max length in this batch
    seqs_padded = pad_sequence(seqs, batch_first=True, padding_value=0)       # or pad_token
    xyzs_padded = pad_sequence(xyzs, batch_first=True, padding_value=float('nan'))

    # collect any other fields you need, e.g. labels
    # labels = torch.stack([item['label'] for item in batch], 0)

    return {
        'sequence': seqs_padded,
        'xyz':       xyzs_padded,
        # 'label':    labels,
    }



In [None]:
# then in your DataLoader:
train_loader = DataLoader(
    train_dataset,
    batch_size=cfg.batch_size,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
    collate_fn=pad_collate
)

val_loader = DataLoader(
    val_dataset,
    batch_size=cfg.batch_size,
    shuffle=False,
    num_workers=0,
    pin_memory=True,
    collate_fn=pad_collate
)

In [34]:
# model=finetuned_RibonanzaNet(load_config_from_yaml("/kaggle/input/ribonanzanet2d-final/configs/pairwise.yaml"),pretrained=True).cuda()

# instantiate on CPU first
model = finetuned_RibonanzaNet(cfg, pretrained=True)

# wrap in DataParallel (uses all available GPUs by default)
model = torch.nn.DataParallel(model)

# then move to CUDA
model = model.cuda()

# after wrapping in DataParallel
# print("Model sees config:", model.module.cfg.batch_size, model.module.cfg.gpu_id)

print("GPUs visible:", torch.cuda.device_count())

print("DataParallel device IDs:", model.device_ids)
print("First parameter on device:", next(model.parameters()).device)

constructing 9 ConvTransformerEncoderLayers



You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



GPUs visible: 1
DataParallel device IDs: [0]
First parameter on device: cuda:0


**Define Loss Function**

we will use dRMSD loss on the predicted xyz. the loss function is invariant to translations, rotations, and reflections. because dRMSD is invariant to reflections, it cannot distinguish chiral structures, so there may be better loss functions

In [35]:
def calculate_distance_matrix(X,Y,epsilon=1e-4):
    return (torch.square(X[:,None]-Y[None,:])+epsilon).sum(-1).sqrt()


def dRMSD(pred_x,
          pred_y,
          gt_x,
          gt_y,
          epsilon=1e-4,Z=10,d_clamp=None):
    pred_dm=calculate_distance_matrix(pred_x,pred_y)
    gt_dm=calculate_distance_matrix(gt_x,gt_y)



    mask=~torch.isnan(gt_dm)
    mask[torch.eye(mask.shape[0]).bool()]=False

    if d_clamp is not None:
        rmsd=(torch.square(pred_dm[mask]-gt_dm[mask])+epsilon).clip(0,d_clamp**2)
    else:
        rmsd=torch.square(pred_dm[mask]-gt_dm[mask])+epsilon

    return rmsd.sqrt().mean()/Z

def local_dRMSD(pred_x,
          pred_y,
          gt_x,
          gt_y,
          epsilon=1e-4,Z=10,d_clamp=30):
    pred_dm=calculate_distance_matrix(pred_x,pred_y)
    gt_dm=calculate_distance_matrix(gt_x,gt_y)



    mask=(~torch.isnan(gt_dm))*(gt_dm<d_clamp)
    mask[torch.eye(mask.shape[0]).bool()]=False



    rmsd=torch.square(pred_dm[mask]-gt_dm[mask])+epsilon
    # rmsd=(torch.square(pred_dm[mask]-gt_dm[mask])+epsilon).sqrt()/Z
    #rmsd=torch.abs(pred_dm[mask]-gt_dm[mask])/Z
    return rmsd.sqrt().mean()/Z

def dRMAE(pred_x,
          pred_y,
          gt_x,
          gt_y,
          epsilon=1e-4,Z=10,d_clamp=None):
    pred_dm=calculate_distance_matrix(pred_x,pred_y)
    gt_dm=calculate_distance_matrix(gt_x,gt_y)



    mask=~torch.isnan(gt_dm)
    mask[torch.eye(mask.shape[0]).bool()]=False

    rmsd=torch.abs(pred_dm[mask]-gt_dm[mask])

    return rmsd.mean()/Z

import torch

def align_svd_mae(input, target, Z=10):
    """
    Align input (Nx3) to target (Nx3) via Procrustes (SVD) in float32,
    then compute MAE / Z.
    """
    assert input.shape == target.shape, "Input and target must match"

    # 1) Mask out NaNs
    mask = ~torch.isnan(target.sum(-1))
    inp = input[mask].float()   # cast to float32
    tgt = target[mask].float()  # cast to float32

    # 2) Compute and remove centroids
    c_inp = inp.mean(dim=0, keepdim=True)
    c_tgt = tgt.mean(dim=0, keepdim=True)
    inp_c = inp - c_inp
    tgt_c = tgt - c_tgt

    # 3) Covariance matrix
    cov = inp_c.t() @ tgt_c

    # 4) SVD in float32
    #    Detach so no gradients flow through the SVD
    with torch.no_grad():
        U, S, Vt = torch.svd(cov)
        R = Vt @ U.t()
        # fix potential reflection
        if torch.det(R) < 0:
            Vt[-1, :] *= -1
            R = Vt @ U.t()

    # 5) Rotate back and re-add centroid
    #    (R is already float32, inp_c is float32)
    aligned = inp_c @ R.t() + c_tgt

    # 6) MAE loss (float32)
    loss = torch.abs(aligned - tgt).mean() / Z

    return loss
    
# def align_svd_mae(input, target, Z=10):
#     """
#     Aligns the input (Nx3) to target (Nx3) using SVD-based Procrustes alignment
#     and computes RMSD loss.
    
#     Args:
#         input (torch.Tensor): Nx3 tensor representing the input points.
#         target (torch.Tensor): Nx3 tensor representing the target points.
    
#     Returns:
#         aligned_input (torch.Tensor): Nx3 aligned input.
#         rmsd_loss (torch.Tensor): RMSD loss.
#     """
#     assert input.shape == target.shape, "Input and target must have the same shape"

#     #mask 
#     mask=~torch.isnan(target.sum(-1))

#     input=input[mask]
#     target=target[mask]
    
#     # Compute centroids
#     centroid_input = input.mean(dim=0, keepdim=True)
#     centroid_target = target.mean(dim=0, keepdim=True)

#     # Center the points
#     input_centered = input - centroid_input.detach()
#     target_centered = target - centroid_target

#     # Compute covariance matrix
#     cov_matrix = input_centered.T @ target_centered

#     # SVD to find optimal rotation
#     U, S, Vt = torch.svd(cov_matrix)

#     # Compute rotation matrix
#     R = Vt @ U.T

#     # Ensure a proper rotation (det(R) = 1, no reflection)
#     if torch.det(R) < 0:
#         Vt[-1, :] *= -1
#         R = Vt @ U.T

#     # Rotate input
#     aligned_input = (input_centered @ R.T.detach()) + centroid_target.detach()

#     # # Compute RMSD loss
#     # rmsd_loss = torch.sqrt(((aligned_input - target) ** 2).mean())

#     # rmsd_loss = torch.sqrt(((aligned_input - target) ** 2).mean())
    
#     # return aligned_input, rmsd_loss
#     return torch.abs(aligned_input-target).mean()/Z

**Training Loop**

In [36]:
# from tqdm import tqdm
# from torch.amp import GradScaler
# # from torch.cuda.amp import autocast, GradScaler

# epochs=50
# cos_epoch=35


# best_loss=np.inf
# optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.0, lr=0.0001) #no weight decay following AF

# batch_size=_batch_size

# #for cycle in range(2):

# criterion=torch.nn.BCEWithLogitsLoss(reduction='none')

# scaler = GradScaler()

# schedule=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(epochs-cos_epoch)*len(train_loader)//batch_size)

# best_val_loss=99999999999
# for epoch in range(epochs):
#     model.train()
#     tbar=tqdm(train_loader)
#     total_loss=0
#     oom=0
#     for idx, batch in enumerate(tbar):
#         #try:

#         sequence=batch['sequence'].cuda()
#         gt_xyz=batch['xyz'].cuda().squeeze()

#         #with torch.autocast(device_type='cuda', dtype=torch.float16):
#         pred_xyz=model(sequence).squeeze()
        
#         loss=dRMAE(pred_xyz,pred_xyz,gt_xyz,gt_xyz) + align_svd_mae(pred_xyz, gt_xyz)
#              #local_dRMSD(pred_xyz,pred_xyz,gt_xyz,gt_xyz)

#         if loss!=loss:
#             stop

        
#         (loss/batch_size).backward()

#         if (idx+1)%batch_size==0 or idx+1 == len(tbar):

#             torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
#             optimizer.step()
#             optimizer.zero_grad()
#             # scaler.scale(loss/batch_size).backward()
#             # scaler.unscale_(optimizer)
#             # torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
#             # scaler.step(optimizer)
#             # scaler.update()

            
#             if (epoch+1)>cos_epoch:
#                 schedule.step()
#         #schedule.step()
#         total_loss+=loss.item()
        
#         tbar.update(1)
#         tbar.set_description(f"Epoch {epoch + 1} Loss: {total_loss/(idx+1)} OOMs: {oom}")



#         # except Exception:
#         #     #print(Exception)
#         #     oom+=1
#     tbar=tqdm(val_loader)
#     model.eval()
#     val_preds=[]
#     val_loss=0
#     for idx, batch in enumerate(tbar):
#         sequence=batch['sequence'].cuda()
#         gt_xyz=batch['xyz'].cuda().squeeze()

#         with torch.no_grad():
#             pred_xyz=model(sequence).squeeze()
#             loss=dRMAE(pred_xyz,pred_xyz,gt_xyz,gt_xyz)
            
#         val_loss+=loss.item()
#         val_preds.append([gt_xyz.cpu().numpy(),pred_xyz.cpu().numpy()])
#     val_loss=val_loss/len(tbar)
#     print(f"val loss: {val_loss}")
    
    
    
#     if val_loss<best_val_loss:
#         best_val_loss=val_loss
#         best_preds=val_preds
#         torch.save(model.state_dict(),'RibonanzaNet-3D_RM.pt')

#     # 1.053595052265986 train loss after epoch 0
# torch.save(model.state_dict(),'RibonanzaNet-3D-final_RM.pt')

In [37]:
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=cfg.learning_rate,
    weight_decay=cfg.weight_decay
)


epochs    = 50
cos_epoch = 35
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=(epochs - cos_epoch) * len(train_loader) // cfg.batch_size
)
scaler = GradScaler()

# ---- TRAIN & VALIDATION LOOP ----
best_val_loss = float('inf')

for epoch in range(1, epochs + 1):
    # TRAINING
    model.train()
    optimizer.zero_grad(set_to_none=True)
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", unit="batch")
    running_loss = 0.0

    for idx, batch in enumerate(train_bar, start=1):
        seq = batch['sequence'].cuda(non_blocking=True)
        gt  = batch['xyz'].cuda(non_blocking=True).squeeze()

        # 1) compute dRMAE in fp16
        with autocast():
            pred = model(seq).squeeze()
            dR_loss = dRMAE(pred, pred, gt, gt) #+ align_svd_mae(pred, gt)

        # 2) compute alignment loss in fp32
        with autocast(enabled=False):
            rot_loss = align_svd_mae(pred, gt)  # SVD runs in fp32

        loss = dR_loss + rot_loss

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)

        running_loss += loss.item()
        if idx % 10 == 0:
            train_bar.set_postfix(loss=running_loss / idx)

    # LR SCHEDULER STEP
    if epoch > cos_epoch:
        scheduler.step()

    # VALIDATION
    model.eval()
    val_loss = 0.0
    val_bar = tqdm(val_loader, desc="Validation", unit="batch")
    with torch.no_grad():
        for batch in val_bar:
            seq = batch['sequence'].cuda(non_blocking=True)
            gt  = batch['xyz'].cuda(non_blocking=True).squeeze()
            pred = model(seq).squeeze()
            vloss = dRMAE(pred, pred, gt, gt)
            val_loss += vloss.item()

    val_loss /= len(val_loader)
    print(f"Epoch {epoch} Validation Loss: {val_loss:.4f}")

    # SAVE BEST MODEL
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'RibonanzaNet-best-rm.pt')
        print(f"  ✨ Saved new best model (val_loss={val_loss:.4f})")

# FINAL SAVE
torch.save(model.state_dict(), 'RibonanzaNet-final-rm.pt')
print("Training complete. Final model saved.")


`torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).


`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.


torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.


`torch.cuda.amp.autocast(args...)` is depreca

Epoch 1 Validation Loss: 14.0704
  ✨ Saved new best model (val_loss=14.0704)


Epoch 2/50: 100%|██████████| 153/153 [07:53<00:00,  3.10s/batch, loss=21.6]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 2 Validation Loss: 14.1055


Epoch 3/50: 100%|██████████| 153/153 [08:18<00:00,  3.26s/batch, loss=21.2]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 3 Validation Loss: 14.1357


Epoch 4/50: 100%|██████████| 153/153 [08:06<00:00,  3.18s/batch, loss=21.1]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 4 Validation Loss: 14.1465


Epoch 5/50: 100%|██████████| 153/153 [08:09<00:00,  3.20s/batch, loss=19.6]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 5 Validation Loss: 14.1737


Epoch 6/50: 100%|██████████| 153/153 [08:00<00:00,  3.14s/batch, loss=20.2]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 6 Validation Loss: 13.9068
  ✨ Saved new best model (val_loss=13.9068)


Epoch 7/50: 100%|██████████| 153/153 [08:18<00:00,  3.26s/batch, loss=21.4]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 7 Validation Loss: 14.2414


Epoch 8/50: 100%|██████████| 153/153 [07:43<00:00,  3.03s/batch, loss=20.2]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 8 Validation Loss: 13.9289


Epoch 9/50: 100%|██████████| 153/153 [08:00<00:00,  3.14s/batch, loss=20]  
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 9 Validation Loss: 14.1557


Epoch 10/50: 100%|██████████| 153/153 [08:28<00:00,  3.32s/batch, loss=20.7]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 10 Validation Loss: 14.1236


Epoch 11/50: 100%|██████████| 153/153 [07:47<00:00,  3.06s/batch, loss=19.5]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 11 Validation Loss: 14.1344


Epoch 12/50: 100%|██████████| 153/153 [08:20<00:00,  3.27s/batch, loss=19.5]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 12 Validation Loss: 14.0620


Epoch 13/50: 100%|██████████| 153/153 [08:21<00:00,  3.28s/batch, loss=19.9]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 13 Validation Loss: 13.9490


Epoch 14/50: 100%|██████████| 153/153 [07:49<00:00,  3.07s/batch, loss=19.2]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 14 Validation Loss: 14.0978


Epoch 15/50: 100%|██████████| 153/153 [08:11<00:00,  3.21s/batch, loss=19.6]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 15 Validation Loss: 13.8969
  ✨ Saved new best model (val_loss=13.8969)


Epoch 16/50: 100%|██████████| 153/153 [07:57<00:00,  3.12s/batch, loss=19.6]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 16 Validation Loss: 14.1233


Epoch 17/50: 100%|██████████| 153/153 [08:07<00:00,  3.19s/batch, loss=18.9]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 17 Validation Loss: 14.1271


Epoch 18/50: 100%|██████████| 153/153 [08:14<00:00,  3.23s/batch, loss=19.8]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 18 Validation Loss: 14.1336


Epoch 19/50: 100%|██████████| 153/153 [07:55<00:00,  3.10s/batch, loss=20]  
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 19 Validation Loss: 13.9990


Epoch 20/50: 100%|██████████| 153/153 [08:14<00:00,  3.23s/batch, loss=19.5]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 20 Validation Loss: 13.9723


Epoch 21/50: 100%|██████████| 153/153 [08:03<00:00,  3.16s/batch, loss=20.5]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 21 Validation Loss: 13.8945
  ✨ Saved new best model (val_loss=13.8945)


Epoch 22/50: 100%|██████████| 153/153 [08:01<00:00,  3.15s/batch, loss=19.7]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 22 Validation Loss: 14.1691


Epoch 23/50: 100%|██████████| 153/153 [08:00<00:00,  3.14s/batch, loss=19.2]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 23 Validation Loss: 13.8842
  ✨ Saved new best model (val_loss=13.8842)


Epoch 24/50: 100%|██████████| 153/153 [08:21<00:00,  3.28s/batch, loss=18.9]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 24 Validation Loss: 14.1495


Epoch 25/50: 100%|██████████| 153/153 [08:00<00:00,  3.14s/batch, loss=18.9]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 25 Validation Loss: 14.1758


Epoch 26/50: 100%|██████████| 153/153 [07:51<00:00,  3.08s/batch, loss=18.9]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 26 Validation Loss: 10.2989
  ✨ Saved new best model (val_loss=10.2989)


Epoch 27/50: 100%|██████████| 153/153 [08:40<00:00,  3.40s/batch, loss=19.1]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 27 Validation Loss: 8.9348
  ✨ Saved new best model (val_loss=8.9348)


Epoch 28/50: 100%|██████████| 153/153 [08:06<00:00,  3.18s/batch, loss=19.8]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 28 Validation Loss: 9.8874


Epoch 29/50: 100%|██████████| 153/153 [08:01<00:00,  3.14s/batch, loss=18.5]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 29 Validation Loss: 9.1549


Epoch 30/50: 100%|██████████| 153/153 [08:24<00:00,  3.30s/batch, loss=18.6]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 30 Validation Loss: 9.2669


Epoch 31/50: 100%|██████████| 153/153 [08:10<00:00,  3.20s/batch, loss=18.6]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 31 Validation Loss: 8.4199
  ✨ Saved new best model (val_loss=8.4199)


Epoch 32/50: 100%|██████████| 153/153 [07:25<00:00,  2.91s/batch, loss=19.3]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 32 Validation Loss: 7.7713
  ✨ Saved new best model (val_loss=7.7713)


Epoch 33/50: 100%|██████████| 153/153 [07:57<00:00,  3.12s/batch, loss=17.9]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 33 Validation Loss: 8.9476


Epoch 34/50: 100%|██████████| 153/153 [08:06<00:00,  3.18s/batch, loss=17.7]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 34 Validation Loss: 7.9089


Epoch 35/50: 100%|██████████| 153/153 [07:47<00:00,  3.05s/batch, loss=17.8]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 35 Validation Loss: 10.9635


Epoch 36/50: 100%|██████████| 153/153 [08:03<00:00,  3.16s/batch, loss=17.9]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 36 Validation Loss: 9.1817


Epoch 37/50: 100%|██████████| 153/153 [08:26<00:00,  3.31s/batch, loss=18.3]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 37 Validation Loss: 7.9338


Epoch 38/50: 100%|██████████| 153/153 [08:14<00:00,  3.23s/batch, loss=18.1]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 38 Validation Loss: 8.1174


Epoch 39/50: 100%|██████████| 153/153 [08:09<00:00,  3.20s/batch, loss=18]  
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 39 Validation Loss: 8.6104


Epoch 40/50: 100%|██████████| 153/153 [07:40<00:00,  3.01s/batch, loss=18.2]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 40 Validation Loss: 9.0082


Epoch 41/50: 100%|██████████| 153/153 [07:34<00:00,  2.97s/batch, loss=18.7]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 41 Validation Loss: 8.7924


Epoch 42/50: 100%|██████████| 153/153 [08:01<00:00,  3.15s/batch, loss=18.4]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 42 Validation Loss: 8.7342


Epoch 43/50: 100%|██████████| 153/153 [08:22<00:00,  3.29s/batch, loss=17]  
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 43 Validation Loss: 9.3102


Epoch 44/50: 100%|██████████| 153/153 [08:14<00:00,  3.24s/batch, loss=18.4]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 44 Validation Loss: 8.4229


Epoch 45/50: 100%|██████████| 153/153 [08:38<00:00,  3.39s/batch, loss=18.4]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 45 Validation Loss: 8.9334


Epoch 46/50: 100%|██████████| 153/153 [07:51<00:00,  3.08s/batch, loss=18.2]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 46 Validation Loss: 9.7374


Epoch 47/50: 100%|██████████| 153/153 [08:22<00:00,  3.28s/batch, loss=17.8]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 47 Validation Loss: 9.0753


Epoch 48/50: 100%|██████████| 153/153 [07:55<00:00,  3.11s/batch, loss=18.2]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 48 Validation Loss: 8.2836


Epoch 49/50: 100%|██████████| 153/153 [07:57<00:00,  3.12s/batch, loss=18.3]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]


Epoch 49 Validation Loss: 9.0319


Epoch 50/50: 100%|██████████| 153/153 [07:58<00:00,  3.13s/batch, loss=16.8]
Validation: 100%|██████████| 3/3 [00:03<00:00,  1.09s/batch]

Epoch 50 Validation Loss: 9.2478
Training complete. Final model saved.





In [None]:
# import torch, gc

# # 1) Delete any large objects you no longer need
# del model
# del optimizer
# del train_loader, val_loader
# # (also delete any large tensors you’re still holding onto)

# # 2) Force Python to collect garbage
# gc.collect()

# # 3) Ask CUDA to release its cached memory
# torch.cuda.empty_cache()

# VI. Submission

In [16]:
## Load model

import torch
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm

# 1) Reconstruct model & load best checkpoint
model = finetuned_RibonanzaNet(cfg, pretrained=False)
model = torch.nn.DataParallel(model).cuda()
state = torch.load('/kaggle/working/RibonanzaNet-best-rm.pt', map_location='cuda:0')
model.load_state_dict(state)

constructing 9 ConvTransformerEncoderLayers


  state = torch.load('/kaggle/working/RibonanzaNet-best-rm.pt', map_location='cuda:0')


<All keys matched successfully>

In [23]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

# --- assume `test_sequence` is your DataFrame,
#     `config`, `cfg`, `model`, `pad_collate`, & `RNA3D_Dataset` are already in scope

# 1) Build a dict of lists for the Dataset, with dummy xyz
test_data = {
    'sequence':      test_sequences['sequence'].tolist(),
    'xyz':           [np.zeros((config['max_len'], 3), dtype=np.float32)]
                       * len(test_sequences),   # dummy
}
# (we ignore temporal_cutoff / description / all_sequences here)

# 2) Instantiate the Dataset + Loader
test_ds = RNA3D_Dataset(test_data, config)
test_loader = DataLoader(
    test_ds,
    batch_size=cfg.batch_size,
    shuffle=False,
    num_workers=0,
    pin_memory=True,
    collate_fn=pad_collate
)

# 3) Inference
model.eval()
all_preds = []  # will be a list of [L_padded, 3] arrays
with torch.no_grad():
    for batch in tqdm(test_loader, desc="Predicting"):
        seq   = batch['sequence'].cuda(non_blocking=True)
        preds = model(seq).cpu().numpy()   # shape (B, L_batch, 3)
        # append each RNA in the batch separately
        for p in preds:
            all_preds.append(p)

# now all_preds[i] is the padded-prediction for test i
# length may vary per-batch, but you'll slice to true L below

# 4) Build submission rows
rows = []
for i, row in test_sequences.iterrows():
    tid     = row['target_id']
    seq_str = row['sequence']
    L       = len(seq_str)
    coords  = all_preds[i][:L]   # slice off the padding → shape [L,3]

    for j, (x,y,z) in enumerate(coords, start=1):
        base = {
            'ID':      f"{tid}_{j}",
            'resname': seq_str[j-1],
            'resid':   j
        }
        # replicate each coordinate 5×
        for k in range(1, 6):
            base[f'x_{k}'] = x
            base[f'y_{k}'] = y
            base[f'z_{k}'] = z
        rows.append(base)

submission_df = pd.DataFrame(rows)
print("Final submission shape:", submission_df.shape)
submission_df.to_csv("submission.csv", index=False)

  seqs = [torch.tensor(item['sequence']) for item in batch]
  xyzs = [torch.tensor(item['xyz'], dtype=torch.float32) for item in batch]
Predicting: 100%|██████████| 3/3 [00:03<00:00,  1.10s/it]

Final submission shape: (2179, 18)



