# Sub 8: N_CYCLE=4 + Multi-Seed + All Templates + RibonanzaNet2

**Hypothesis:** N_CYCLE=4 is RNAPro's default (we've been using 10, which is 2.5x slower).
The default was chosen by NVIDIA as the best quality/speed tradeoff. Running at N_CYCLE=4
frees GPU time for multi-seed diversity (3 seeds x 5 samples = 15 candidates).

Strategy:
1. Run TBM to generate 5 templates
2. Run RNAPro with N_CYCLE=4, template_idx=4, seeds=42,101,202, N_SAMPLE=5
3. 2.5x faster per target enables processing more targets at full quality
4. 15 candidates per target, best 5 selected by ranking_score

In [None]:
import os
IS_SCORING_RUN = os.environ.get('KAGGLE_IS_COMPETITION_RERUN')
print('Scoring run:', IS_SCORING_RUN)

## Section A: Setup RNAPro

In [None]:
!cp -r /kaggle/input/rnapro-src/RNAPro /kaggle/working/
!cp /kaggle/input/rnapro-src/rnapro-private-best-500m.ckpt /kaggle/working/

In [None]:
%cd /kaggle/working/RNAPro
!pip install -e . --no-deps
# Install critical RNAPro dependencies not in Kaggle base image
!pip install biotite==1.4.0 rdkit-pypi
%cd /kaggle/working

In [None]:
!pip install --no-index /kaggle/input/biopython-cp312/biopython-1.86-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl

## Section B: Improved TBM (6feb_v1 + validation data)

In [None]:
import pandas as pd
import numpy as np
import random
import time
import warnings
import os, sys

warnings.filterwarnings('ignore')

DATA_PATH = '/kaggle/input/stanford-rna-3d-folding-2/'
train_seqs = pd.read_csv(DATA_PATH + 'train_sequences.csv')
test_seqs = pd.read_csv(DATA_PATH + 'test_sequences.csv')
train_labels = pd.read_csv(DATA_PATH + 'train_labels.csv')

# Load and merge validation data
try:
    validation_seqs = pd.read_csv(DATA_PATH + 'validation_sequences.csv')
    validation_labels = pd.read_csv(DATA_PATH + 'validation_labels.csv')
    combined_seqs = pd.concat([train_seqs, validation_seqs], ignore_index=True)
    print(f"Combined: {len(combined_seqs)} template sequences (train+validation)")
except FileNotFoundError:
    combined_seqs = train_seqs
    validation_labels = None
    print("Validation not found, using train only")

sys.path.append(os.path.join(DATA_PATH, "extra"))

try:
    import typing as _typing
    import builtins as _builtins
    _builtins.Dict  = getattr(_typing, "Dict")
    _builtins.Tuple = getattr(_typing, "Tuple")
    _builtins.List  = getattr(_typing, "List")
    from parse_fasta_py import parse_fasta as _parse_fasta_raw
    def parse_fasta(fasta_content: str):
        d = _parse_fasta_raw(fasta_content)
        out = {}
        for k, v in d.items():
            out[k] = v[0] if isinstance(v, tuple) else v
        return out
except Exception:
    def parse_fasta(fasta_content: str):
        out = {}
        cur = None
        seq_parts = []
        for line in str(fasta_content).splitlines():
            line = line.strip()
            if not line: continue
            if line.startswith(">"):
                if cur is not None:
                    out[cur] = "".join(seq_parts)
                cur = line[1:].split()[0]
                seq_parts = []
            else:
                seq_parts.append(line.replace(" ", ""))
        if cur is not None:
            out[cur] = "".join(seq_parts)
        return out

def parse_stoichiometry(stoich: str):
    if pd.isna(stoich) or str(stoich).strip() == "":
        return []
    out = []
    for part in str(stoich).split(';'):
        ch, cnt = part.split(':')
        out.append((ch.strip(), int(cnt)))
    return out

def get_chain_segments(row):
    seq = row['sequence']
    stoich = row.get('stoichiometry', '')
    all_seq = row.get('all_sequences', '')
    if pd.isna(stoich) or pd.isna(all_seq) or str(stoich).strip()=="" or str(all_seq).strip()=="":
        return [(0, len(seq))]
    try:
        chain_dict = parse_fasta(all_seq)
        order = parse_stoichiometry(stoich)
        segs, pos = [], 0
        for ch, cnt in order:
            base = chain_dict.get(ch)
            if base is None: return [(0, len(seq))]
            for _ in range(cnt):
                L = len(base)
                segs.append((pos, pos + L))
                pos += L
        if pos != len(seq): return [(0, len(seq))]
        return segs
    except Exception:
        return [(0, len(seq))]

def build_segments_map(df):
    seg_map, stoich_map = {}, {}
    for _, r in df.iterrows():
        tid = r['target_id']
        seg_map[tid] = get_chain_segments(r)
        stoich_map[tid] = str(r.get('stoichiometry', '') if not pd.isna(r.get('stoichiometry', '')) else '')
    return seg_map, stoich_map

train_segs_map, train_stoich_map = build_segments_map(train_seqs)
test_segs_map, test_stoich_map = build_segments_map(test_seqs)

def process_labels(labels_df):
    coords_dict = {}
    prefixes = labels_df['ID'].str.rsplit('_', n=1).str[0]
    for id_prefix, group in labels_df.groupby(prefixes):
        coords_dict[id_prefix] = group.sort_values('resid')[['x_1', 'y_1', 'z_1']].values
    return coords_dict

train_coords_dict = process_labels(train_labels)
if validation_labels is not None:
    valid_coords_dict = process_labels(validation_labels)
    train_coords_dict.update(valid_coords_dict)
print(f"Total template structures: {len(train_coords_dict)}")
print(f"Test targets: {len(test_seqs)}")

In [None]:
from Bio.Align import PairwiseAligner

aligner = PairwiseAligner()
aligner.mode = 'global'
aligner.match_score = 2
aligner.mismatch_score = -1.6
aligner.open_gap_score   = -8
aligner.extend_gap_score = -0.4
aligner.query_left_open_gap_score  = -8
aligner.query_left_extend_gap_score = -0.4
aligner.query_right_open_gap_score = -8
aligner.query_right_extend_gap_score = -0.4
aligner.target_left_open_gap_score = -8
aligner.target_left_extend_gap_score = -0.4
aligner.target_right_open_gap_score = -8
aligner.target_right_extend_gap_score = -0.4

def find_similar_sequences(query_seq, train_seqs_df, train_coords_dict, top_n=5):
    similar_seqs = []
    for _, row in train_seqs_df.iterrows():
        target_id, train_seq = row['target_id'], row['sequence']
        if target_id not in train_coords_dict: continue
        if abs(len(train_seq) - len(query_seq)) / max(len(train_seq), len(query_seq)) > 0.3: continue
        raw_score = aligner.score(query_seq, train_seq)
        normalized_score = raw_score / (2 * min(len(query_seq), len(train_seq)))
        similar_seqs.append((target_id, train_seq, normalized_score, train_coords_dict[target_id]))
    similar_seqs.sort(key=lambda x: x[2], reverse=True)
    return similar_seqs[:top_n]

def adapt_template_to_query(query_seq, template_seq, template_coords):
    alignment = next(iter(aligner.align(query_seq, template_seq)))
    new_coords = np.full((len(query_seq), 3), np.nan)
    for (q_start, q_end), (t_start, t_end) in zip(*alignment.aligned):
        t_chunk = template_coords[t_start:t_end]
        if len(t_chunk) == (q_end - q_start):
            new_coords[q_start:q_end] = t_chunk
    for i in range(len(new_coords)):
        if np.isnan(new_coords[i, 0]):
            prev_v = next((j for j in range(i-1, -1, -1) if not np.isnan(new_coords[j, 0])), -1)
            next_v = next((j for j in range(i+1, len(new_coords)) if not np.isnan(new_coords[j, 0])), -1)
            if prev_v >= 0 and next_v >= 0:
                w = (i - prev_v) / (next_v - prev_v)
                new_coords[i] = (1-w)*new_coords[prev_v] + w*new_coords[next_v]
            elif prev_v >= 0: new_coords[i] = new_coords[prev_v] + [3, 0, 0]
            elif next_v >= 0: new_coords[i] = new_coords[next_v] + [3, 0, 0]
            else: new_coords[i] = [i*3, 0, 0]
    return np.nan_to_num(new_coords)

def adaptive_rna_constraints(coordinates, target_id, confidence=1.0, passes=2):
    coords = coordinates.copy()
    segments = test_segs_map.get(target_id, [(0, len(coords))])
    strength = 0.80 * (1.0 - min(confidence, 0.98))
    strength = max(strength, 0.02)
    for _ in range(passes):
        for (s, e) in segments:
            X = coords[s:e]
            L = e - s
            if L < 3: continue
            # Bond i,i+1 to ~5.95A
            d = X[1:] - X[:-1]
            dist = np.linalg.norm(d, axis=1) + 1e-5
            scale = (5.95 - dist) / dist
            adj = (d * scale[:, None]) * (0.22 * strength)
            X[:-1] -= adj; X[1:] += adj
            # Soft i,i+2 to ~10.2A
            d2 = X[2:] - X[:-2]
            dist2 = np.linalg.norm(d2, axis=1) + 1e-6
            scale2 = (10.2 - dist2) / dist2
            adj2 = (d2 * scale2[:, None]) * (0.10 * strength)
            X[:-2] -= adj2; X[2:] += adj2
            # Laplacian smoothing
            lap = 0.5 * (X[:-2] + X[2:]) - X[1:-1]
            X[1:-1] += (0.06 * strength) * lap
            # Self-avoidance
            if L >= 25:
                k = min(L, 160) if L > 220 else L
                idx = np.linspace(0, L-1, k).astype(int) if k < L else np.arange(L)
                P = X[idx]
                diff = P[:, None, :] - P[None, :, :]
                distm = np.linalg.norm(diff, axis=2) + 1e-6
                sep = np.abs(idx[:, None] - idx[None, :])
                mask = (sep > 2) & (distm < 3.3)
                if np.any(mask):
                    force = (3.3 - distm) / distm
                    vec = (diff * force[:, :, None] * mask[:, :, None]).sum(axis=1)
                    X[idx] += (0.015 * strength) * vec
            coords[s:e] = X
    return coords

In [None]:
def _rotmat(axis, ang):
    axis = np.asarray(axis, float)
    axis = axis / (np.linalg.norm(axis) + 1e-12)
    x, y, z = axis
    c, s = np.cos(ang), np.sin(ang)
    C = 1.0 - c
    return np.array([[c+x*x*C, x*y*C-z*s, x*z*C+y*s],
                     [y*x*C+z*s, c+y*y*C, y*z*C-x*s],
                     [z*x*C-y*s, z*y*C+x*s, c+z*z*C]], dtype=float)

def apply_hinge(coords, seg, rng, max_angle_deg=25):
    s, e = seg
    L = e - s
    if L < 30: return coords
    pivot = s + int(rng.integers(10, L - 10))
    axis = rng.normal(size=3)
    ang = np.deg2rad(float(rng.uniform(-max_angle_deg, max_angle_deg)))
    R = _rotmat(axis, ang)
    X = coords.copy()
    p0 = X[pivot].copy()
    X[pivot+1:e] = (X[pivot+1:e] - p0) @ R.T + p0
    return X

def jitter_chains(coords, segments, rng, max_angle_deg=12, max_trans=1.5):
    X = coords.copy()
    gc = X.mean(axis=0, keepdims=True)
    for (s, e) in segments:
        axis = rng.normal(size=3)
        ang = np.deg2rad(float(rng.uniform(-max_angle_deg, max_angle_deg)))
        R = _rotmat(axis, ang)
        shift = rng.normal(size=3)
        shift = shift / (np.linalg.norm(shift) + 1e-10) * float(rng.uniform(0.0, max_trans))
        c = X[s:e].mean(axis=0, keepdims=True)
        X[s:e] = (X[s:e] - c) @ R.T + c + shift
    X -= X.mean(axis=0, keepdims=True) - gc
    return X

def smooth_wiggle(coords, segments, rng, amp=0.8):
    X = coords.copy()
    for (s, e) in segments:
        L = e - s
        if L < 20: continue
        ctrl_x = np.linspace(0, L-1, 6)
        ctrl_disp = rng.normal(0, amp, size=(6, 3))
        t = np.arange(L)
        disp = np.vstack([np.interp(t, ctrl_x, ctrl_disp[:, k]) for k in range(3)]).T
        X[s:e] += disp
    return X

def predict_rna_structures(row, train_seqs_df, train_coords_dict, n_predictions=5):
    tid = row['target_id']
    seq = row['sequence']
    segments = test_segs_map.get(tid, [(0, len(seq))])
    cands = find_similar_sequences(query_seq=seq, train_seqs_df=train_seqs_df,
                                   train_coords_dict=train_coords_dict, top_n=40)
    predictions, used = [], set()
    for i in range(n_predictions):
        seed = (abs(hash(tid)) + i * 10005) % (2**32)
        rng = np.random.default_rng(seed)
        if not cands:
            coords = np.zeros((len(seq), 3), dtype=float)
            for (s, e) in segments:
                for j in range(s+1, e): coords[j] = coords[j-1] + [5.95, 0, 0]
            predictions.append(coords); continue
        if i == 0:
            t_id, t_seq, sim, t_coords = cands[0]
        else:
            K = min(12, len(cands))
            sims = np.array([cands[k][2] for k in range(K)], float)
            w = np.exp((sims - sims.max()) / 0.10)
            for k in range(K):
                if cands[k][0] in used: w[k] *= 0.10
            w /= (w.sum() + 1e-10)
            k = int(rng.choice(np.arange(K), p=w))
            t_id, t_seq, sim, t_coords = cands[k]
        used.add(t_id)
        adapted = adapt_template_to_query(query_seq=seq, template_seq=t_seq, template_coords=t_coords)
        if i == 0: X = adapted
        elif i == 1: X = adapted + rng.normal(0, max(0.01, (0.40-sim)*0.06), adapted.shape)
        elif i == 2:
            longest = max(segments, key=lambda se: se[1]-se[0])
            X = apply_hinge(adapted, longest, rng, max_angle_deg=22)
        elif i == 3: X = jitter_chains(adapted, segments, rng, max_angle_deg=10, max_trans=1.0)
        else: X = smooth_wiggle(adapted, segments, rng, amp=0.8)
        predictions.append(adaptive_rna_constraints(X, tid, confidence=sim, passes=2))
    return predictions

In [None]:
# Generate TBM predictions
all_predictions = []
start_time = time.time()
for idx, row in test_seqs.iterrows():
    if idx % 10 == 0: print(f"TBM: {idx}/{len(test_seqs)} | {time.time()-start_time:.1f}s")
    tid, seq = row['target_id'], row['sequence']
    preds = predict_rna_structures(row, combined_seqs, train_coords_dict)
    for j in range(len(seq)):
        res = {'ID': f"{tid}_{j+1}", 'resname': seq[j], 'resid': j+1}
        for i in range(5):
            res[f'x_{i+1}'], res[f'y_{i+1}'], res[f'z_{i+1}'] = preds[i][j]
        all_predictions.append(res)

sub_tbm = pd.DataFrame(all_predictions)
cols = ['ID', 'resname', 'resid'] + [f'{c}_{i}' for i in range(1,6) for c in ['x','y','z']]
coord_cols = [c for c in cols if c.startswith(('x_','y_','z_'))]
sub_tbm[coord_cols] = sub_tbm[coord_cols].clip(-999.999, 9999.999)
sub_tbm[cols].to_csv('/kaggle/working/submission_tbm.csv', index=False)
print(f"TBM done in {time.time()-start_time:.1f}s, shape: {sub_tbm.shape}")

## Section C: Convert TBM templates to .pt for RNAPro

In [None]:
%cd /kaggle/working/RNAPro
!python preprocess/convert_templates_to_pt_files.py --input_csv /kaggle/working/submission_tbm.csv --output_name templates.pt
%cd /kaggle/working

## Section D: CCD cache setup

In [None]:
DIST = "/kaggle/working/RNAPro/release_data/ccd_cache/"
os.makedirs(DIST, exist_ok=True)
!cp /kaggle/input/rnapro-ccd-cache/ccd_cache/components.cif $DIST
!cp /kaggle/input/rnapro-ccd-cache/ccd_cache/components.cif.rdkit_mol.pkl $DIST
print("CCD cache ready")

## Section E: RNAPro Inference

In [None]:
# Prepare sequences CSV (full test set for scoring run, head(5) for dev)
import pandas as pd
df = pd.read_csv("/kaggle/input/stanford-rna-3d-folding-2/test_sequences.csv")
if not IS_SCORING_RUN:
    df = df.head(5)
df.to_csv('/kaggle/working/sample_sequences.csv', index=False)
print(f"Prepared {len(df)} sequences for RNAPro inference")

In [None]:
%%writefile /kaggle/working/RNAPro/runner/inference.py
import os
import shutil
import logging
import traceback
import warnings
import argparse
from contextlib import nullcontext
from os.path import join as opjoin
from typing import Any, Mapping
import glob as glob_mod

import json
import torch
import pandas as pd
import numpy as np
from biotite.structure.io import pdbx

from configs.configs_base import configs as configs_base
from configs.configs_data import data_configs
from configs.configs_inference import inference_configs
from runner.dumper import DataDumper

from rnapro.config import parse_sys_args
from rnapro.config.config import ConfigManager, ArgumentNotSet
from rnapro.data.infer_data_pipeline import get_inference_dataloader
from rnapro.model.RNAPro import RNAPro
from rnapro.utils.distributed import DIST_WRAPPER
from rnapro.utils.seed import seed_everything
from rnapro.utils.torch_utils import to_device

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.WARNING)
logging.getLogger("rnapro.data").setLevel(logging.WARNING)
logging.getLogger("rnapro").setLevel(logging.WARNING)


def parse_configs(configs, arg_str=None, fill_required_with_null=False):
    manager = ConfigManager(configs, fill_required_with_null=fill_required_with_null)
    parser = argparse.ArgumentParser()
    parser.add_argument("--max_len", type=int, default=10000, required=False)
    parser.add_argument("--n_template_combos", type=int, default=1, required=False)
    for key, (dtype, default_value, allow_none, required) in manager.config_infos.items():
        parser.add_argument("--" + key, type=str, default=ArgumentNotSet(), required=required)
    parsed_args = parser.parse_args(arg_str.split()) if arg_str else parser.parse_args()
    merged_configs = manager.merge_configs(vars(parsed_args))
    max_len = parsed_args.max_len
    merged_configs.max_len = max_len
    merged_configs.n_template_combos = parsed_args.n_template_combos
    return merged_configs


class dotdict(dict):
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
    def __getattr__(self, name):
        try: return self[name]
        except KeyError: raise AttributeError(name)


class InferenceRunner(object):
    def __init__(self, configs):
        self.configs = configs
        self.init_env()
        self.init_basics()
        self.init_model()
        self.load_checkpoint()
        self.init_dumper(
            need_atom_confidence=configs.need_atom_confidence,
            sorted_by_ranking_score=configs.sorted_by_ranking_score,
        )

    def init_env(self):
        self.print(
            f"Distributed environment: world size: {DIST_WRAPPER.world_size}, "
            + f"global rank: {DIST_WRAPPER.rank}, local rank: {DIST_WRAPPER.local_rank}"
        )
        self.use_cuda = torch.cuda.device_count() > 0
        if self.use_cuda:
            self.device = torch.device("cuda:{}".format(DIST_WRAPPER.local_rank))
            os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
            all_gpu_ids = ",".join(str(x) for x in range(torch.cuda.device_count()))
            devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
            torch.cuda.set_device(self.device)
        else:
            self.device = torch.device("cpu")
        if self.configs.use_deepspeed_evo_attention:
            env = os.getenv("CUTLASS_PATH", None)
            self.print(f"env: {env}")
            assert env is not None

    def init_basics(self):
        self.dump_dir = self.configs.dump_dir
        self.error_dir = opjoin(self.dump_dir, "ERR")
        os.makedirs(self.dump_dir, exist_ok=True)
        os.makedirs(self.error_dir, exist_ok=True)

    def init_model(self):
        self.model = RNAPro(self.configs).to(self.device)
        num_params = sum(p.numel() for p in self.model.parameters())
        self.print(f"Total number of parameters: {num_params:,}")

    def load_checkpoint(self):
        checkpoint_path = self.configs.load_checkpoint_path
        if not os.path.exists(checkpoint_path):
            raise Exception(f"Checkpoint not found: {checkpoint_path}")
        self.print(f"Loading from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, self.device)
        sample_key = list(checkpoint["model"].keys())[0]
        if sample_key.startswith("module."):
            checkpoint["model"] = {
                k[len("module."):]: v for k, v in checkpoint["model"].items()
            }
        self.model.load_state_dict(state_dict=checkpoint["model"], strict=True)
        self.model.eval()

    def init_dumper(self, need_atom_confidence=False, sorted_by_ranking_score=True):
        self.dumper = DataDumper(
            base_dir=self.dump_dir,
            need_atom_confidence=need_atom_confidence,
            sorted_by_ranking_score=sorted_by_ranking_score,
        )

    @torch.no_grad()
    def predict(self, data):
        eval_precision = {
            "fp32": torch.float32, "bf16": torch.bfloat16, "fp16": torch.float16,
        }[self.configs.dtype]
        enable_amp = (
            torch.autocast(device_type="cuda", dtype=eval_precision)
            if torch.cuda.is_available() else nullcontext()
        )
        data = to_device(data, self.device)
        with enable_amp:
            prediction, _, _ = self.model(
                input_feature_dict=data["input_feature_dict"],
                label_full_dict=None, label_dict=None, mode="inference",
            )
        return prediction

    def print(self, msg):
        if DIST_WRAPPER.rank == 0: print(msg)

    def update_model_configs(self, new_configs):
        self.model.configs = new_configs


def update_inference_configs(configs, N_token):
    if N_token > 3840:
        configs.skip_amp.confidence_head = False
        configs.skip_amp.sample_diffusion = False
    elif N_token > 2560:
        configs.skip_amp.confidence_head = False
        configs.skip_amp.sample_diffusion = True
    else:
        configs.skip_amp.confidence_head = True
        configs.skip_amp.sample_diffusion = True
    return configs


def infer_predict(runner, configs):
    try:
        dataloader = get_inference_dataloader(configs=configs)
    except Exception as e:
        error_message = f"{e}:\n{traceback.format_exc()}"
        logger.info(error_message)
        with open(opjoin(runner.error_dir, "error.txt"), "a") as f:
            f.write(error_message)
        return
    num_data = len(dataloader.dataset)
    for seed in configs.seeds:
        seed_everything(seed=seed, deterministic=configs.deterministic)
        for batch in dataloader:
            try:
                data, atom_array, data_error_message = batch[0]
                sample_name = data["sample_name"]
                if len(data_error_message) > 0:
                    with open(opjoin(runner.error_dir, f"{sample_name}.txt"), "a") as f:
                        f.write(data_error_message)
                    continue
                new_configs = update_inference_configs(configs, data["N_token"].item())
                runner.update_model_configs(new_configs)
                prediction = runner.predict(data)
                runner.dumper.dump(
                    dataset_name="", pdb_id=sample_name, seed=seed,
                    pred_dict=prediction, atom_array=atom_array,
                    entity_poly_type=data["entity_poly_type"],
                )
                torch.cuda.empty_cache()
            except Exception as e:
                error_message = f"{data['sample_name']} {e}:\n{traceback.format_exc()}"
                logger.info(error_message)
                with open(opjoin(runner.error_dir, f"{sample_name}.txt"), "a") as f:
                    f.write(error_message)
                if hasattr(torch.cuda, "empty_cache"): torch.cuda.empty_cache()


def make_dummy_solution(valid_df):
    solution = dotdict()
    for i, row in valid_df.iterrows():
        solution[row.target_id] = dotdict(
            target_id=row.target_id, sequence=row.sequence, coord=[],
        )
    return solution


def solution_to_submit_df(solution):
    submit_df = []
    for k, s in solution.items():
        df = coord_to_df(s.sequence, s.coord, s.target_id)
        submit_df.append(df)
    return pd.concat(submit_df)


def coord_to_df(sequence, coord, target_id):
    L = len(sequence)
    df = pd.DataFrame()
    df["ID"] = [f"{target_id}_{i+1}" for i in range(L)]
    df["resname"] = [s for s in sequence]
    df["resid"] = [i+1 for i in range(L)]
    for j in range(len(coord)):
        df[f"x_{j+1}"] = coord[j][:, 0]
        df[f"y_{j+1}"] = coord[j][:, 1]
        df[f"z_{j+1}"] = coord[j][:, 2]
    return df


def create_input_json(sequence, target_id):
    return [{
        "sequences": [{"rnaSequence": {"sequence": sequence, "count": 1}}],
        "name": target_id,
    }]


def extract_c1_coordinates(cif_file_path):
    try:
        with open(cif_file_path, "r") as f:
            cif_data = pdbx.CIFFile.read(f)
        atom_array = pdbx.get_structure(cif_data, model=1)
        atom_names_clean = np.char.strip(atom_array.atom_name.astype(str))
        mask_c1 = atom_names_clean == "C1'"
        c1_atoms = atom_array[mask_c1]
        if len(c1_atoms) == 0:
            print(f"Warning: No C1' atoms found in {cif_file_path}")
            return None
        sort_indices = np.argsort(c1_atoms.res_id)
        return c1_atoms[sort_indices].coord
    except Exception as e:
        print(f"Error extracting C1' from {cif_file_path}: {e}")
        return None


def process_sequence(sequence, target_id, temp_dir):
    input_json = create_input_json(sequence, target_id)
    os.makedirs(temp_dir, exist_ok=True)
    input_json_path = os.path.join(temp_dir, f"{target_id}_input.json")
    with open(input_json_path, "w") as f:
        json.dump(input_json, f, indent=4)


def run_ptx(target_id, sequence, configs, solution, template_idx, runner):
    temp_dir = f"./{configs.dump_dir}/input"
    output_dir = f"./{configs.dump_dir}/output"
    os.makedirs(temp_dir, exist_ok=True)
    os.makedirs(output_dir, exist_ok=True)
    process_sequence(sequence=sequence, target_id=target_id, temp_dir=temp_dir)
    configs.input_json_path = os.path.join(temp_dir, f"{target_id}_input.json")
    configs.template_idx = int(template_idx)
    infer_predict(runner, configs)

    # Collect ALL sample CIF files from all seeds (sorted by ranking score)
    base_dir = f"{configs.dump_dir}/{target_id}"
    collected = 0
    for seed_dir in sorted(glob_mod.glob(f"{base_dir}/seed_*")):
        pred_dir = f"{seed_dir}/predictions"
        cif_files = sorted(glob_mod.glob(f"{pred_dir}/{target_id}_sample_*.cif"))
        for cif_file in cif_files:
            coord = extract_c1_coordinates(cif_file)
            if coord is None:
                coord = np.zeros((len(sequence), 3), dtype=np.float32)
            elif coord.shape[0] < len(sequence):
                pad = np.zeros((len(sequence) - coord.shape[0], 3), dtype=np.float32)
                coord = np.concatenate([coord, pad], axis=0)
            elif coord.shape[0] > len(sequence):
                coord = coord[:len(sequence)]
            solution[target_id].coord.append(coord)
            collected += 1
    print(f"    Collected {collected} structures for {target_id} (template_idx={template_idx})")

    # Clean up CIF files to save disk space
    if os.path.exists(base_dir):
        shutil.rmtree(base_dir, ignore_errors=True)


def run():
    LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s"
    logging.basicConfig(format=LOG_FORMAT, level=logging.WARNING, datefmt="%Y-%m-%d %H:%M:%S", filemode="w")
    logging.getLogger("rnapro.data").setLevel(logging.WARNING)
    logging.getLogger("rnapro").setLevel(logging.WARNING)
    configs_base["use_deepspeed_evo_attention"] = (
        os.environ.get("USE_DEEPSPEED_EVO_ATTENTION", False) == "true"
    )
    configs = {**configs_base, **{"data": data_configs}, **inference_configs}
    configs = parse_configs(configs=configs, arg_str=parse_sys_args(), fill_required_with_null=True)

    # Parse seeds - handle comma-separated string from shell
    seeds = configs.seeds
    if isinstance(seeds, str):
        if ',' in seeds:
            seeds = [int(s.strip()) for s in seeds.split(',') if s.strip().isdigit()]
        else:
            seeds = [int(seeds)]
    elif isinstance(seeds, (int, float)):
        seeds = [int(seeds)]
    elif not isinstance(seeds, (list, tuple)):
        seeds = [42]
    configs.seeds = seeds

    valid_df = pd.read_csv(configs.sequences_csv)
    print(f"\n -> Loaded {len(valid_df)} sequence(s), seeds={configs.seeds}")

    n_template_combos = configs.n_template_combos
    print(f" -> Using {n_template_combos} template combination(s)")

    print('\n -> Building model and loading checkpoint')
    runner = InferenceRunner(configs)
    print('\n -> Done, starting inference...')

    solution = make_dummy_solution(valid_df)
    for idx, row in valid_df.iterrows():
        seq_len = len(row.sequence)
        print(f"\n -> Sequence {row.target_id}: len={seq_len}")
        if seq_len > configs.max_len:
            print(f'Sequence too long ({seq_len} > {configs.max_len}), skipping')
            for _ in range(5):
                solution[row.target_id].coord.append(
                    np.zeros((seq_len, 3), dtype=np.float32)
                )
            continue
        try:
            for template_idx in [4]:  # Use ALL 5 templates at once
                print(f'  template_combo={template_idx}')
                run_ptx(
                    target_id=row.target_id, sequence=row.sequence,
                    configs=configs, solution=solution,
                    template_idx=template_idx, runner=runner,
                )
        except Exception as e:
            print(f"Error processing {row.target_id}: {e}")
            traceback.print_exc()

        # Cap at 5 predictions (best are first since CIFs are sorted by ranking score)
        coords = solution[row.target_id].coord
        if len(coords) == 0:
            print(f"  WARNING: No predictions for {row.target_id}, using zeros")
            for _ in range(5):
                coords.append(np.zeros((seq_len, 3), dtype=np.float32))
        while len(coords) < 5:
            coords.append(coords[-1].copy())
        solution[row.target_id].coord = coords[:5]
        print(f"  Final: {len(solution[row.target_id].coord)} predictions")

    print('\n\n -> Inference done! Saving to submission.csv')
    submit_df = solution_to_submit_df(solution)
    submit_df = submit_df.fillna(0.0)
    submit_df.to_csv("./submission.csv", index=False)


if __name__ == "__main__":
    run()

In [None]:
%%writefile /kaggle/working/RNAPro/rnapro_inference_kaggle.sh

export LAYERNORM_TYPE=torch

# === SUB 8: N_CYCLE=4 + multi-seed + all templates + RibonanzaNet2 ===
# N_CYCLE=4 is RNAPro default, 2.5x faster than N_CYCLE=10
# 3 seeds x 5 samples = 15 candidates, best 5 selected by ranking_score
SEED=42,101,202
N_SAMPLE=5
N_STEP=200
N_CYCLE=4
N_TEMPLATE_COMBOS=1

# Paths
DUMP_DIR="../output"
CHECKPOINT_PATH="../rnapro-private-best-500m.ckpt"

# Template/MSA settings
TEMPLATE_DATA="./release_data/kaggle/templates.pt"
RNA_MSA_DIR="/kaggle/input/stanford-rna-3d-folding-2/MSA/"
SEQUENCES_CSV="/kaggle/working/sample_sequences.csv"

# RibonanzaNet2
RIBONANZA_PATH="/kaggle/input/ribonanzanet2/pytorch/alpha/1/"
if [ -d "$RIBONANZA_PATH" ]; then
    echo "RibonanzaNet2 found"
    RIBONANZA_ARGS="--model.use_RibonanzaNet2 true --model.ribonanza_net_path ${RIBONANZA_PATH}"
else
    echo "WARNING: RibonanzaNet2 NOT found"
    RIBONANZA_ARGS="--model.use_RibonanzaNet2 false"
fi

MODEL_NAME="rnapro_base"
mkdir -p "${DUMP_DIR}"

python3 runner/inference.py \
    --model_name "${MODEL_NAME}" \
    --seeds ${SEED} \
    --dump_dir "${DUMP_DIR}" \
    --load_checkpoint_path "${CHECKPOINT_PATH}" \
    --use_msa true \
    --use_template "ca_precomputed" \
    --model.use_template "ca_precomputed" \
    --model.template_embedder.n_blocks 2 \
    ${RIBONANZA_ARGS} \
    --template_data "${TEMPLATE_DATA}" \
    --rna_msa_dir "${RNA_MSA_DIR}" \
    --model.N_cycle ${N_CYCLE} \
    --sample_diffusion.N_sample ${N_SAMPLE} \
    --sample_diffusion.N_step ${N_STEP} \
    --load_strict true \
    --num_workers 0 \
    --triangle_attention "torch" \
    --triangle_multiplicative "torch" \
    --sequences_csv "${SEQUENCES_CSV}" \
    --max_len 1000 \
    --n_template_combos ${N_TEMPLATE_COMBOS}

In [None]:
%cd /kaggle/working/RNAPro
!bash ./rnapro_inference_kaggle.sh
!mv submission.csv /kaggle/working/submission_rnapro.csv
%cd /kaggle/working

## Section F: Merge RNAPro + TBM

In [None]:
import pandas as pd
import numpy as np
import os

df_tbm = pd.read_csv("/kaggle/working/submission_tbm.csv")

# Check if RNAPro produced output
rnapro_path = "/kaggle/working/submission_rnapro.csv"
if os.path.exists(rnapro_path):
    df_rnapro = pd.read_csv(rnapro_path)
    df_seqs = pd.read_csv("/kaggle/input/stanford-rna-3d-folding-2/test_sequences.csv")

    # Identify long targets (>1000 nt) where RNAPro skips (outputs zeros)
    long_targets = df_seqs[df_seqs['sequence'].str.len() > 1000]['target_id'].values
    print(f"Long targets to replace with TBM (len > 1000): {len(long_targets)}")

    # Build mask for rows belonging to long targets
    mask_long = df_rnapro['ID'].apply(lambda x: any(str(x).startswith(t + "_") for t in long_targets))

    if mask_long.sum() > 0:
        print(f"Replacing {mask_long.sum()} residues with TBM predictions...")
        df_rnapro_idx = df_rnapro.set_index('ID')
        df_tbm_idx = df_tbm.set_index('ID')
        ids_to_update = df_rnapro_idx[mask_long.values].index
        valid_ids = [i for i in ids_to_update if i in df_tbm_idx.index]
        df_rnapro_idx.loc[valid_ids] = df_tbm_idx.loc[valid_ids]
        df_final = df_rnapro_idx.reset_index()
    else:
        print("No long targets to replace.")
        df_final = df_rnapro
else:
    print("WARNING: RNAPro inference failed. Using TBM-only submission.")
    df_final = df_tbm

df_final.to_csv("/kaggle/working/submission.csv", index=False)
print(f"Final submission shape: {df_final.shape}")
df_final.head()

## Section G: Validation (dev runs only)

In [None]:
if not IS_SCORING_RUN:
    print("Dev run - checking first 5 sequences only")
    sub = pd.read_csv('/kaggle/working/submission.csv')
    print(f"Submission shape: {sub.shape}")
    print(sub.head(10))
    print("\nNon-zero coordinate check:")
    coord_cols = [c for c in sub.columns if c.startswith(('x_', 'y_', 'z_'))]
    for col in coord_cols[:3]:
        nz = (sub[col] != 0).sum()
        print(f"  {col}: {nz}/{len(sub)} non-zero")