In [1]:
# 0) Trash the old symlink
!rm -rf /kaggle/working/DRfold2

# 1) Copy your code out of the read-only input into working
!cp -r /kaggle/input/final-dr/DRfold2 /kaggle/working/DRfold2

# 2) Now link in your model_hub under that real folder
!ln -s /kaggle/input/drfold2-models/model_hub /kaggle/working/DRfold2/model_hub

# 3) Verify
!ls -l /kaggle/working/DRfold

ls: cannot access '/kaggle/working/DRfold': No such file or directory


In [2]:
!export PROTENIX_DATA_ROOT_DIR=/kaggle/input/protenix-checkpoints
! mkdir /af3-dev 
! ln -s /kaggle/input/protenix-checkpoints /af3-dev/release_data
! ls /af3-dev/release_data/

components.v20240608.cif		model_v0.2.0.pt
components.v20240608.cif.rdkit_mol.pkl


In [3]:
import Bio

from copy import deepcopy

import pandas as pd
from Bio.PDB import Atom, Model, Chain, Residue, Structure, PDBParser
from Bio import SeqIO
import os, sys
import re
import numpy as np
import torch

import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
time0=time.time()
PYTHON = sys.executable
print('PYTHON',PYTHON)


DATA_KAGGLE_DIR = '/kaggle/input/stanford-rna-3d-folding'


# helper ----
class dotdict(dict):
	__setattr__ = dict.__setitem__
	__delattr__ = dict.__delitem__

	def __getattr__(self, name):
		try:
			return self[name]
		except KeyError:
			raise AttributeError(name)

# visualisation helper ----
def set_aspect_equal(ax):
	x_limits = ax.get_xlim()
	y_limits = ax.get_ylim()
	z_limits = ax.get_zlim()

	# Compute the mean of each axis
	x_middle = np.mean(x_limits)
	y_middle = np.mean(y_limits)
	z_middle = np.mean(z_limits)

	# Compute the max range across all axes
	max_range = max(x_limits[1] - x_limits[0],
					y_limits[1] - y_limits[0],
					z_limits[1] - z_limits[0]) / 2.0

	# Set the new limits to ensure equal scaling
	ax.set_xlim(x_middle - max_range, x_middle + max_range)
	ax.set_ylim(y_middle - max_range, y_middle + max_range)
	ax.set_zlim(z_middle - max_range, z_middle + max_range)




# xyz df helper --------------------
def get_truth_df(target_id):
    truth_df = LABEL_DF[LABEL_DF['target_id'] == target_id]
    truth_df = truth_df.reset_index(drop=True)
    return truth_df

def parse_output_to_df(output, seq, target_id):
    df = []
    chain_data = []
    for i, res in enumerate(seq):
        d=dict(ID = target_id,
                    resname=res,
                    resid=i+1)
        for n in range(len(output)):
            d={**d, f'x_{n+1}': round(output[n,i,0].item(),3),
                     f'y_{n+1}': round(output[n,i,1].item(),3),
                     f'z_{n+1}': round(output[n,i,2].item(),3)}
        chain_data.append(d)

    if len(chain_data)!=0:
        chain_df = pd.DataFrame(chain_data)
        df.append(chain_df)
        ##print(chain_df)
    return df

def parse_pdb_to_df(pdb_file, target_id):
    parser = PDBParser()
    structure = parser.get_structure('', pdb_file)

    df = []
    for model in structure:
        for chain in model:
            print(chain)
            chain_data = []
            for residue in chain:
                # print(residue)
                if residue.get_resname() in ['A', 'U', 'G', 'C']:
                    # Check if the residue has a C1' atom
                    if 'C1\'' in residue:
                        atom = residue['C1\'']
                        xyz = atom.get_coord()
                        resname = residue.get_resname()
                        resid = residue.get_id()[1]

                        #todo detect discontinous: resid = prev_resid+1
                        #ID	resname	resid	x_1	y_1	z_1
                        chain_data.append(dict(
                            ID = target_id+'_'+str(resid),
                            resname=resname,
                            resid=resid,
                            x_1=xyz[0],
                            y_1=xyz[1],
                            z_1=xyz[2],
                        ))
                        ##print(f"Residue {resname} {resid}, Atom: {atom.get_name()}, xyz: {xyz}")

            if len(chain_data)!=0:
                chain_df = pd.DataFrame(chain_data)
                df.append(chain_df)
                ##print(chain_df)
    return df

# usalign helper --------------------
def write_target_line(
    atom_name, atom_serial, residue_name, chain_id, residue_num, x_coord, y_coord, z_coord, occupancy=1.0, b_factor=0.0, atom_type='P'
):
    """
    Writes a single line of PDB format based on provided atom information.

    Args:
        atom_name (str): Name of the atom (e.g., "N", "CA").
        atom_serial (int): Atom serial number.
        residue_name (str): Residue name (e.g., "ALA").
        chain_id (str): Chain identifier.
        residue_num (int): Residue number.
        x_coord (float): X coordinate.
        y_coord (float): Y coordinate.
        z_coord (float): Z coordinate.
        occupancy (float, optional): Occupancy value (default: 1.0).
        b_factor (float, optional): B-factor value (default: 0.0).

    Returns:
        str: A single line of PDB string.
    """
    return f'ATOM  {atom_serial:>5d}  {atom_name:<5s} {residue_name:<3s} {residue_num:>3d}    {x_coord:>8.3f}{y_coord:>8.3f}{z_coord:>8.3f}{occupancy:>6.2f}{b_factor:>6.2f}           {atom_type}\n'

def write_xyz_to_pdb(df, pdb_file, xyz_id = 1):
    resolved_cnt = 0
    with open(pdb_file, 'w') as target_file:
        for _, row in df.iterrows():
            x_coord = row[f'x_{xyz_id}']
            y_coord = row[f'y_{xyz_id}']
            z_coord = row[f'z_{xyz_id}']

            if x_coord > -1e17 and y_coord > -1e17 and z_coord > -1e17:
                resolved_cnt += 1
                target_line = write_target_line(
                    atom_name="C1'",
                    atom_serial=int(row['resid']),
                    residue_name=row['resname'],
                    chain_id='0',
                    residue_num=int(row['resid']),
                    x_coord=x_coord,
                    y_coord=y_coord,
                    z_coord=z_coord,
                    atom_type='C',
                )
                target_file.write(target_line)
    return resolved_cnt

def parse_usalign_for_tm_score(output):
    # Extract TM-score based on length of reference structure (second)
    tm_score_match = re.findall(r'TM-score=\s+([\d.]+)', output)[1]
    if not tm_score_match:
        raise ValueError('No TM score found')
    return float(tm_score_match)

def parse_usalign_for_transform(output):
    # Locate the rotation matrix section
    matrix_lines = []
    found_matrix = False

    for line in output.splitlines():
        if "The rotation matrix to rotate Structure_1 to Structure_2" in line:
            found_matrix = True
        elif found_matrix and re.match(r'^\d+\s+[-\d.]+\s+[-\d.]+\s+[-\d.]+\s+[-\d.]+$', line):
            matrix_lines.append(line)
        elif found_matrix and not line.strip():
            break  # Stop parsing if an empty line is encountered after the matrix

    # Parse the rotation matrix values
    rotation_matrix = []
    for line in matrix_lines:
        parts = line.split()
        row_values = list(map(float, parts[1:]))  # Skip the first column (index)
        rotation_matrix.append(row_values)

    return np.array(rotation_matrix)

def call_usalign(predict_df, truth_df, verbose=1):
    truth_pdb = '~truth.pdb'
    predict_pdb = '~predict.pdb'
    write_xyz_to_pdb(predict_df, predict_pdb, xyz_id=1)
    write_xyz_to_pdb(truth_df, truth_pdb, xyz_id=1)

    command = f'{USALIGN} {predict_pdb} {truth_pdb} -atom " C1\'" -m -'
    output = os.popen(command).read()
    if verbose==1:
        print(output)
    tm_score = parse_usalign_for_tm_score(output)
    transform = parse_usalign_for_transform(output)
    return tm_score, transform

print('HELPER OK!!!')

PYTHON /usr/bin/python3
HELPER OK!!!


In [4]:
if 1!=0:
    
    
    from runner.batch_inference import get_default_runner
    from runner.inference import update_inference_configs, InferenceRunner

    from protenix.data.infer_data_pipeline import InferenceDataset

    np.random.seed(0)
    torch.random.manual_seed(0)
    torch.cuda.manual_seed_all(0)

    class DictDataset(InferenceDataset):
        def __init__(
            self,
            seq_list: list,
            dump_dir: str,
            id_list: list = None,
            use_msa: bool = False,
        ) -> None:

            self.dump_dir = dump_dir
            self.use_msa = use_msa
            if isinstance(id_list,type(None)):
                self.inputs = [{"sequences": 
                                [{"rnaSequence": 
                                  {"sequence": seq, 
                                   "count": 1}}],
                                "name": "query"} for seq in seq_list]
            else:
                self.inputs = [{"sequences": 
                                [{"rnaSequence": 
                                  {"sequence": seq, 
                                   "count": 1}}],
                                "name": i} for i, seq in zip(id_list,seq_list)]
                
    from configs.configs_base import configs as configs_base
    from configs.configs_data import data_configs
    from configs.configs_inference import inference_configs
    from protenix.config.config import parse_configs

    configs_base["use_deepspeed_evo_attention"] = (
    os.environ.get("USE_DEEPSPEED_EVO_ATTENTION", False) == "true")
    configs_base["model"]["N_cycle"] = 10 #10
    configs_base["sample_diffusion"]["N_sample"] = 5
    configs_base["sample_diffusion"]["N_step"] = 200
    inference_configs['load_checkpoint_path']='/kaggle/input/protenix-checkpoints/model_v0.2.0.pt'
    configs = {**configs_base, **{"data": data_configs}, **inference_configs}

    configs = parse_configs(
            configs=configs,
            fill_required_with_null=True,
        )
    
    runner=InferenceRunner(configs)

train scheduler 16.0
inference scheduler 16.0
Diffusion Module has 16.0


In [5]:
# Corrected Cell 5: proteinx function definition
# This function uses 'runner' and 'configs' that should be initialized globally in Cell 4.
# It also relies on 'DictDataset', 'update_inference_configs', and 'parse_output_to_df'
# being defined/available in the global scope (presumably from Cell 3 and Cell 4).

def proteinx(sequence_string: str, target_id_string: str) -> pd.DataFrame:
    """
    Predicts RNA structure using Protenix for a single sequence.
    Args:
        sequence_string (str): The RNA sequence.
        target_id_string (str): The target ID for this sequence.
    Returns:
        pd.DataFrame: A DataFrame with predicted coordinates from Protenix (e.g., x_1..x_5).
                      Returns an empty DataFrame on failure or a fallback DataFrame.
    """
    print(f"    Protenix: Predicting for {target_id_string}, Length: {len(sequence_string)}")
    try:
        # 1. Create dataset for the single sequence.
        # Ensure DictDataset class is defined (it is in Cell 4 of your notebook).
        # Ensure /kaggle/working/protenix_temp_out is a writable path.
        os.makedirs('/kaggle/working/protenix_temp_out', exist_ok=True)
        dataset_protenix = DictDataset(
            seq_list=[sequence_string],
            dump_dir='/kaggle/working/protenix_temp_out',
            id_list=[target_id_string],
            use_msa=False
        )

        if not dataset_protenix.inputs:
             print(f"    Protenix: Failed to create dataset for {target_id_string}")
             return pd.DataFrame() # Return empty DataFrame on failure

        data, atom_array, data_error_message = dataset_protenix[0]

        if data_error_message:
            print(f"    Protenix: Data error for {target_id_string} - {data_error_message}")
            return pd.DataFrame() # Return empty DataFrame on data error

        # 2. Update configs and predict (using global 'configs' and 'runner' from Cell 4)
        # Ensure update_inference_configs is available.
        new_prediction_configs = update_inference_configs(configs, data["N_token"].item())
        runner.update_model_configs(new_prediction_configs)
        prediction_result = runner.predict(data)

        # Protenix specific: atom_to_tokatom_idx == 12 corresponds to C1' atom.
        # prediction_result['coordinate'] shape: (N_sample, N_residue_tokens, 3_xyz)
        # N_sample is 5 for submission (configured in Cell 4).
        predicted_coords_tensor = prediction_result['coordinate'][:, data['input_feature_dict']['atom_to_tokatom_idx'] == 12]

        # 3. Parse to DataFrame
        # Ensure parse_output_to_df is defined (it is in Cell 3).
        # It returns a list of DataFrames; for a single target, it's a list with one DataFrame.
        df_list = parse_output_to_df(predicted_coords_tensor, sequence_string, target_id_string)
        
        if df_list and not df_list[0].empty:
            result_df = df_list[0]
            print(f"    Protenix: Successfully predicted for {target_id_string}. DataFrame shape: {result_df.shape}")
            return result_df
        else:
            print(f"    Protenix: Failed to parse output to DataFrame or got empty DataFrame for {target_id_string}")
            # Fallback to return a DataFrame with zeros, matching expected columns for 5 models
            num_models = 5 # As configured for Protenix submission
            fallback_data = []
            for res_idx, res_char in enumerate(sequence_string):
                row_data = {'ID': target_id_string, 'resname': res_char, 'resid': res_idx + 1}
                for model_idx in range(1, num_models + 1):
                    row_data[f'x_{model_idx}'] = 0.0
                    row_data[f'y_{model_idx}'] = 0.0
                    row_data[f'z_{model_idx}'] = 0.0
                fallback_data.append(row_data)
            return pd.DataFrame(fallback_data)

    except Exception as e:
        print(f"    Protenix: Critical error during prediction for {target_id_string}: {str(e)}")
        # Fallback DataFrame structure
        num_models = 5 
        fallback_data = []
        for res_idx, res_char in enumerate(sequence_string):
            row_data = {'ID': target_id_string, 'resname': res_char, 'resid': res_idx + 1}
            for model_idx in range(1, num_models + 1):
                row_data[f'x_{model_idx}'] = 0.0
                row_data[f'y_{model_idx}'] = 0.0
                row_data[f'z_{model_idx}'] = 0.0
            fallback_data.append(row_data)
        return pd.DataFrame(fallback_data)

In [6]:
import numpy as np

def protx_to_submission_array(df: pd.DataFrame, L_full: int):
    # df has columns: ID, resname, resid, x_1,y_1,z_1, … x_5,y_5,z_5
    # we want shape (L_full, 3 atoms, 3 coords)
    # fill P = sample_i coords, C4′ & N = zeros
    coords = np.zeros((L_full, 3, 3), dtype=np.float32)
    for sample_idx in range(5):
        prefix = f'{["x","y","z"][0]}_{sample_idx+1}'
    # better: for each sample, pick C1′ → slot 0 
    for i, row in df.iterrows():
        resid = int(row.resid) - 1
        for s in range(5):
            coords[resid, 0, :] = [
                row[f'x_{s+1}'],
                row[f'y_{s+1}'],
                row[f'z_{s+1}']
            ]
    return coords


In [7]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Combined DRFold2 inference, selection, and PDB generation with write_frame_coor_to_pdb,
followed by C1' extraction for submission.
- Uses length-dependent model selection.
- Energy scores medium-length predictions with PotentialFold.
- Builds PDBs using write_frame_coor_to_pdb.
- Extracts C1' coordinates from these PDBs for submission.
"""
import os
import sys
import gc
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from datetime import datetime
import time
import re
import pickle
import tempfile
import shutil
import json
import subprocess
import traceback
from Bio.Seq import Seq
import hashlib
# OPENMM: Import OpenMM libraries
try:
    from openmm import app
    from openmm import unit
    from openmm import LangevinIntegrator # Or LangevinMiddleIntegrator
    from pdbfixer import PDBFixer
    OPENMM_AVAILABLE = True
    print("OpenMM imported successfully.")
except ImportError as e_openmm:
    OPENMM_AVAILABLE = False
    print(f"WARNING: OpenMM could not be imported. Refinement step will be skipped. Error: {e_openmm}")
    print("If you intend to use OpenMM, please install it (e.g., pip install openmm pdbfixer)")
# --- USER CONFIGURATION ---
# Paths
USALIGN_SOURCE_PATH = '/kaggle/input/usalign/USalign'
USALIGN_WORKING_PATH = '/kaggle/working/USalign'
DEFAULT_MODEL_CFG_FOLDER_NAME="cfg_97"
DRFOLD2_ROOT = "/kaggle/working/DRfold2"
INPUT_MODEL_HUB = "/kaggle/input/drfold2-models/model_hub"
TARGET_MODEL_HUB = os.path.join(DRFOLD2_ROOT, "model_hub")
DATA_DIR = '/kaggle/input/stanford-rna-3d-folding'
OUT_DIR_PDB = '/kaggle/working/predictions_custom_pdbs' # For saving top PDBs from custom function
SUBMISSION_OUT_DIR = '/kaggle/working/'

BASE_NPY_INPUT_PATH = "/kaggle/input/refined-dr/DRfold2/cfg_97/base.npy"
OTHER_NPY_INPUT_PATH_FOR_NETWORK = "/kaggle/input/base-npy1/base.npy" # For DRFold2 network's "other_coor"
# OPENMM: Directory for refined PDBs
OUT_DIR_PDB_REFINED = '/kaggle/working/predictions_refined_pdbs'

original_system_path = list(sys.path)

# OPENMM: Configuration
OPENMM_REFINE = True # Set to False to skip OpenMM refinement globally
OPENMM_FORCE_FIELD_RNA = 'amber14-all.xml'
OPENMM_FORCE_FIELD_WATER = 'amber14/tip3pfb.xml'
OPENMM_FORCE_FIELD_IONS="implicit/gbn2.xml"
OPENMM_MINIMIZATION_STEPS = 100 # Number of minimization steps, 0 to skip
OPENMM_REPORTER_INTERVAL = 10

# Path to the specific other2.npy for PotentialFold/PDB writer (C1' etc.)
# IMPORTANT: Ensure this path is correct and the file exists.
PF_OTHER2_NPY_EXPLICIT_PATH = "/kaggle/input/refined-dr/DRfold2/PotentialFold/lib/other2.npy"
# OPENMM: Create directory for refined PDBs
os.makedirs(OUT_DIR_PDB_REFINED, exist_ok=True)

# Model selection
MAX_SUBMISSION_MODELS = 5
MAX_DRFOLD_PREDICTION_LENGTH = 480
MODELS_LONG_NUMERIC_IDS = [    
    {'id': 0, 'cfg': 'cfg_97'},    # Model 0 from cfg_97
    {'id': 1, 'cfg': 'cfg_97'},    # Model 1 from cfg_97
    {'id': 2, 'cfg': 'cfg_97'},
    {'id': 8, 'cfg': 'cfg_97'},
    {'id': 9, 'cfg': 'cfg_97'}
] # Example: 5 specific models


MODELS_COMBINED_CFG_AWARE =  [
   {'id': 0, 'cfg': 'cfg_95'}, 
   {'id': 2, 'cfg': 'cfg_95'},    # Model 0 from cfg_97# Model 0 from cfg_97
  {'id': 3, 'cfg': 'cfg_95'},    # Model 0 from cfg_97
    {'id': 4, 'cfg': 'cfg_95'},
   {'id': 6, 'cfg': 'cfg_95'},    # Model 0 from cfg_97
    {'id': 8, 'cfg': 'cfg_95'},    # Model 1 from cfg_97
    {'id': 9, 'cfg': 'cfg_95'},    # Model 1 from cfg_97
    {'id': 10, 'cfg': 'cfg_95'},
    {'id': 13, 'cfg': 'cfg_95'},
    {'id': 17, 'cfg': 'cfg_95'},
    {'id': 18, 'cfg': 'cfg_95'},    # Model 1 from cfg_97
    {'id': 2, 'cfg': 'cfg_96'},
    {'id': 7, 'cfg': 'cfg_96'},
    {'id': 13, 'cfg': 'cfg_96'},
    {'id': 14, 'cfg': 'cfg_96'},
    {'id': 17, 'cfg': 'cfg_96'},
    {'id': 18, 'cfg': 'cfg_96'},
    {'id': 19, 'cfg': 'cfg_96'},
    {'id': 2, 'cfg': 'cfg_97'},
    {'id': 16, 'cfg': 'cfg_97'},
    {'id': 0, 'cfg': 'cfg_97'},
    {'id': 8, 'cfg': 'cfg_97'},
    {'id': 15, 'cfg': 'cfg_97'},
    {'id': 12, 'cfg': 'cfg_97'},
    {'id': 6, 'cfg': 'cfg_97'},
    {'id': 10, 'cfg': 'cfg_97'},
    {'id': 13, 'cfg': 'cfg_97'},
    {'id': 1, 'cfg': 'cfg_99'},
    {'id': 0, 'cfg': 'cfg_99'},
    {'id': 2, 'cfg': 'cfg_99'},
    {'id': 3, 'cfg': 'cfg_99'},
    {'id': 4, 'cfg': 'cfg_99'},
    {'id': 7, 'cfg': 'cfg_99'},
    {'id': 9, 'cfg': 'cfg_99'},
    {'id': 10, 'cfg': 'cfg_99'},
    {'id': 11, 'cfg': 'cfg_99'},
    {'id': 13, 'cfg': 'cfg_99'},
    {'id': 15, 'cfg': 'cfg_99'},
    {'id': 16, 'cfg': 'cfg_99'},
    {'id': 18, 'cfg': 'cfg_99'},
]
# --- END OF USER CONFIGURATION ---

# 0) Ensure non-reentrant checkpoint by monkey-patch
if 'torch.utils.checkpoint' in sys.modules:
    cp_module = sys.modules['torch.utils.checkpoint']
    if hasattr(cp_module, 'checkpoint') and cp_module.checkpoint.__name__ != 'checkpoint_no_reentrant_combined':
        _orig_checkpoint_combined = cp_module.checkpoint
        def checkpoint_no_reentrant_combined(fn, *args, **kwargs):
            if 'use_reentrant' not in kwargs:
                kwargs['use_reentrant'] = False
            return _orig_checkpoint_combined(fn, *args, **kwargs)
        cp_module.checkpoint = checkpoint_no_reentrant_combined
        print("Applied non-reentrant checkpoint patch.")
else:
    print("Note: torch.utils.checkpoint not yet imported; patch might not be active if DRFold2 imports it later.")

# 1) Select device
DEVICE_STR = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = torch.device(DEVICE_STR)
print(f"Using device: {DEVICE}")
if len(sys.argv) == 1:
    sys.argv.append(DEVICE_STR)
elif len(sys.argv) > 1:
    sys.argv[1] = DEVICE_STR
print(f"Modified sys.argv for DRFold2 compatibility: {sys.argv}")

# 2) Path and Directory Setup
os.makedirs(OUT_DIR_PDB, exist_ok=True)
os.makedirs(SUBMISSION_OUT_DIR, exist_ok=True)
print(f"Output directory for selected custom PDBs: {OUT_DIR_PDB}")
print(f"Output directory for submission file: {SUBMISSION_OUT_DIR}")

if not os.path.exists(DRFOLD2_ROOT):
    print(f"CRITICAL ERROR: DRfold2 root directory does not exist at {DRFOLD2_ROOT}. Ensure it's copied or cloned.")
    sys.exit(1)
if not os.path.exists(INPUT_MODEL_HUB):
    print(f"Warning: Input model hub {INPUT_MODEL_HUB} not found. Model loading will fail.")

if os.path.islink(TARGET_MODEL_HUB):
    os.unlink(TARGET_MODEL_HUB)
elif os.path.exists(TARGET_MODEL_HUB):
    print(f"Note: {TARGET_MODEL_HUB} exists and is not a symlink. Removing it for fresh symlink.")
    if os.path.isdir(TARGET_MODEL_HUB): shutil.rmtree(TARGET_MODEL_HUB)
    else: os.remove(TARGET_MODEL_HUB)

if not os.path.exists(TARGET_MODEL_HUB) and os.path.exists(INPUT_MODEL_HUB):
    try:
        os.symlink(INPUT_MODEL_HUB, TARGET_MODEL_HUB, target_is_directory=True)
        print(f"Symlink created: {TARGET_MODEL_HUB} -> {INPUT_MODEL_HUB}")
    except Exception as e:
        print(f"Error creating symlink for model_hub: {e}")
elif not os.path.exists(INPUT_MODEL_HUB):
    print(f"Could not create symlink: source {INPUT_MODEL_HUB} does not exist.")
else:
    print(f"Note: {TARGET_MODEL_HUB} already exists or other issue preventing symlink.")

paths_to_add_globally = [
    os.path.join(DRFOLD2_ROOT, DEFAULT_MODEL_CFG_FOLDER_NAME), # For DRFOLD2_ROOT/cfg_97/RNALM2/*
    DRFOLD2_ROOT, # Base for DRFold2 structure
    os.path.join(DRFOLD2_ROOT, "PotentialFold") # For PotentialFold modules
]
current_sys_path = list(original_system_path) # Start fresh for these additions
for p_add in reversed(paths_to_add_globally): # Prepend in order
    if os.path.exists(p_add) and p_add not in current_sys_path:
        current_sys_path.insert(0, p_add)
sys.path = current_sys_path # Apply the changes
print(f"Sys.path for global imports initialized: {sys.path[:7]}...") # For debugging

# --- Helper Functions ---

def make_drfold_data_dummy_msa(seq_str: str, base_coor_template_np, other_coor_template_np,
                                parse_func, get_base_func_network, device_to_use): # Renamed get_base_func for clarity
    aa_numerical = parse_func(seq_str)
    msa_np_dummy = np.stack([aa_numerical, aa_numerical], axis=0)
    msa_tensor_dummy = torch.tensor(msa_np_dummy, dtype=torch.long, device=device_to_use)
    msa_one_hot = F.one_hot(msa_tensor_dummy, num_classes=6).float()
    base_template_np = get_base_func_network(seq_str, base_coor_template_np) # This Get_Base is for network input
    other_template_np = get_base_func_network(seq_str, other_coor_template_np) # This Get_Base is for network input
    base_tensor = torch.from_numpy(base_template_np).float().to(device_to_use)
    other_tensor = torch.from_numpy(other_template_np).float().to(device_to_use)
    idx_tensor = torch.arange(1, len(seq_str) + 1, dtype=torch.long).to(device_to_use)
    return msa_one_hot, base_tensor, other_tensor, idx_tensor

try:
    from PotentialFold import a2b
    print("Successfully imported a2b from PotentialFold.")
except ImportError as e_a2b:
    print(f"CRITICAL: Could not import 'a2b' from PotentialFold. Error: {e_a2b}")
    sys.exit(1)

# Get_base function as defined by you, for write_frame_coor_to_pdb
# This version expects NumPy arrays as input for basenpy_standard and returns a PyTorch tensor
def Get_base_for_pdb_writer(seq, basenpy_standard_np):
    base_num = basenpy_standard_np.shape[1]
    basenpy = np.zeros([len(seq),base_num,3])
    seqnpy = np.array(list(seq))
    # Ensure basenpy_standard_np has enough entries if it's indexed 0-3
    if basenpy_standard_np.shape[0] < 4:
        raise ValueError(f"basenpy_standard_np (shape {basenpy_standard_np.shape}) for PDB writer's Get_base is too small. Needs at least 4 entries for A,G,C,U.")

    basenpy[seqnpy=='A']=basenpy_standard_np[0]
    basenpy[seqnpy=='a']=basenpy_standard_np[0]
    basenpy[seqnpy=='G']=basenpy_standard_np[1]
    basenpy[seqnpy=='g']=basenpy_standard_np[1]
    basenpy[seqnpy=='C']=basenpy_standard_np[2]
    basenpy[seqnpy=='c']=basenpy_standard_np[2]
    basenpy[seqnpy=='U']=basenpy_standard_np[3]
    basenpy[seqnpy=='u']=basenpy_standard_np[3]
    basenpy[seqnpy=='T']=basenpy_standard_np[3] # Thymine as Uracil
    basenpy[seqnpy=='t']=basenpy_standard_np[3]
    return torch.from_numpy(basenpy).double() # Returns a tensor

def write_frame_coor_to_pdb(coor, seq, pdbfile,
                            BASE_COOR_NP_ARG, OTHER_COOR_NP_ARG): # These are expected to be NumPy arrays
    tx_np = coor.astype(np.float64) # coor from DRFold2 is (L,3,3) NumPy array

    # Get_base_for_pdb_writer returns tensors
    basex_tensor = Get_base_for_pdb_writer(seq, BASE_COOR_NP_ARG).to(DEVICE)
    otherx_tensor = Get_base_for_pdb_writer(seq, OTHER_COOR_NP_ARG).to(DEVICE)

    tx_tensor = torch.from_numpy(tx_np).to(DEVICE)

    L= len(seq)
    x = torch.rand([L, 21], device=DEVICE, dtype=torch.double) # Match dtype for rama
    x[:, 18:] = tx_tensor.mean(dim=1) # tx_tensor is already on DEVICE

    biasq = torch.mean(tx_tensor, dim=1, keepdim=True)
    q = tx_tensor - biasq
    m = torch.einsum('bnz,bny->bzy', basex_tensor, q).reshape([L, -1]) # basex_tensor is on DEVICE
    x[:, :9] = x[:, 9:18] = m
    rama = x.double() # x is already double

    xyz = a2b.quat2b(basex_tensor, rama.view(L, 21)[:, 9:]).float().cpu().data.numpy()
    other_xyz = a2b.quat2b(otherx_tensor, rama.view(L, 21)[:, 9:]).float().cpu().data.numpy()

    atom_name_list = [' P  ', " C4'", ' N1 ']
    last_name = ['P', 'C', 'N']
    other_atom_name = [" O5'", " C5'", " C3'", " O3'", " C1'"]
    other_last_name = ['O', 'C', 'C', 'O', 'C']
    res_map_1to3_custom = {'A':'ADE', 'C':'CYT', 'G':'GUA', 'U':'URA', 'T':'THY'}

    lines = [f'REMARK Generated by custom write_frame_coor_to_pdb for {seq}']
    pdb_format_line = "ATOM  %5d %-4s %3s %1s%4d    %8.3f%8.3f%8.3f%6.2f%6.2f          %2s  \n"
    count = 1
    for i in range(L):
        current_atom_names_for_res = list(atom_name_list)
        res_char_upper = seq[i].upper()
        res_name_3letter = res_map_1to3_custom.get(res_char_upper, 'UNK')

        if res_char_upper in ['A', 'G']: current_atom_names_for_res[2] = ' N9 '
        elif res_char_upper in ['C', 'U', 'T']: current_atom_names_for_res[2] = ' N1 '

        if xyz.shape[0] > i and xyz.shape[1] == 3:
            for j in range(xyz.shape[1]):
                atom_n = current_atom_names_for_res[j]
                element_sym = last_name[j]
                if "C4'" in atom_n: element_sym = 'C' # More specific for C4'
                elif "N" in atom_n and "'" not in atom_n : element_sym = 'N' # Avoid C4'
                elif "P" in atom_n: element_sym = 'P'


                if np.any(np.isnan(xyz[i][j])): continue
                lines.append(pdb_format_line % (
                    count, atom_n, res_name_3letter, 'A', i + 1,
                    xyz[i][j][0], xyz[i][j][1], xyz[i][j][2],
                    1.00, 0.00, element_sym.rjust(2)))
                count += 1
        
        if other_xyz.shape[0] > i and other_xyz.shape[1] == 5:
            for j in range(other_xyz.shape[1]):
                atom_n = other_atom_name[j]
                element_sym = other_last_name[j]
                if "C" in atom_n and "'" in atom_n: element_sym = 'C' # C5', C3', C1'
                elif "O" in atom_n and "'" in atom_n: element_sym = 'O' # O5', O3'

                if np.any(np.isnan(other_xyz[i][j])): continue
                lines.append(pdb_format_line % (
                    count, atom_n, res_name_3letter, 'A', i + 1,
                    other_xyz[i][j][0], other_xyz[i][j][1], other_xyz[i][j][2],
                    1.00, 0.00, element_sym.rjust(2)))
                count += 1
    lines.append("END\n")
    with open(pdbfile, 'w') as f: f.write("".join(lines))
    print(f"Successfully wrote PDB with C1' using write_frame_coor_to_pdb: {pdbfile}")

def extract_c1_prime_coords_from_pdb(pdb_file_path, expected_sequence_length):
    c1_coords_map = {}
    if not os.path.exists(pdb_file_path):
        print(f"PDB file not found for C1' extraction: {pdb_file_path}")
        return np.full((expected_sequence_length, 3), np.nan, dtype=np.float32)
    try:
        with open(pdb_file_path, 'r') as f:
            for line in f:
                if line.startswith("ATOM"):
                    atom_name = line[12:16].strip()
                    if atom_name == "C1'" or atom_name == "C1*":
                        try:
                            res_id = int(line[22:26].strip())
                            x = float(line[30:38]); y = float(line[38:46]); z = float(line[46:54])
                            if res_id not in c1_coords_map: c1_coords_map[res_id] = [x, y, z]
                        except ValueError: pass
    except Exception as e:
        print(f"Error reading PDB {pdb_file_path} for C1' extraction: {e}")
        return np.full((expected_sequence_length, 3), np.nan, dtype=np.float32)
    final_coords_np = np.full((expected_sequence_length, 3), np.nan, dtype=np.float32)
    for res_id, coords in c1_coords_map.items():
        array_idx = res_id - 1
        if 0 <= array_idx < expected_sequence_length: final_coords_np[array_idx] = coords
    return final_coords_np

# --- Main Script Components ---
try:
    from RNALM2.Model import RNA2nd
    print("DRFold2 core modules imported successfully.")
    from Selection import Structure as PFStructure
    print("PotentialFold selection modules imported.")
except ImportError as e:
    print(f"CRITICAL ERROR importing DRFold2/PotentialFold modules: {e}"); traceback.print_exc(); sys.exit(1)

try:
    model_input_base_coor_path = os.path.join(DRFOLD2_ROOT, DEFAULT_MODEL_CFG_FOLDER_NAME, "base.npy")
    if not os.path.exists(model_input_base_coor_path):
        if os.path.exists(BASE_NPY_INPUT_PATH):
            os.makedirs(os.path.dirname(model_input_base_coor_path), exist_ok=True)
            shutil.copy(BASE_NPY_INPUT_PATH, model_input_base_coor_path)
        else: raise FileNotFoundError(f"DRFold2 network base.npy source missing: {BASE_NPY_INPUT_PATH}")
    BASE_COOR_FOR_MODEL = np.load(model_input_base_coor_path)
    print(f"Loaded BASE_COOR_FOR_MODEL (for network) from {model_input_base_coor_path}")

    if os.path.exists(OTHER_NPY_INPUT_PATH_FOR_NETWORK):
        OTHER_COOR_FOR_MODEL = np.load(OTHER_NPY_INPUT_PATH_FOR_NETWORK)
        print(f"Loaded OTHER_COOR_FOR_MODEL (for network) from {OTHER_NPY_INPUT_PATH_FOR_NETWORK}")
    else:
        print(f"Warning: OTHER_NPY_INPUT_PATH_FOR_NETWORK '{OTHER_NPY_INPUT_PATH_FOR_NETWORK}' not found. Using BASE_COOR_FOR_MODEL for network's other_coor.")
        OTHER_COOR_FOR_MODEL = BASE_COOR_FOR_MODEL

    POTENTIALFOLD_LIB_DIR = os.path.join(DRFOLD2_ROOT, "PotentialFold", "lib")
    os.makedirs(POTENTIALFOLD_LIB_DIR, exist_ok=True)
    
    PF_BASE_NPY_PATH = os.path.join(POTENTIALFOLD_LIB_DIR, "base.npy")
    if not os.path.exists(PF_BASE_NPY_PATH):
        if os.path.exists(BASE_NPY_INPUT_PATH):
            shutil.copy(BASE_NPY_INPUT_PATH, PF_BASE_NPY_PATH)
        else: raise FileNotFoundError(f"PotentialFold/lib/base.npy source missing: {BASE_NPY_INPUT_PATH}")
    PF_BASE_COOR_NP = np.load(PF_BASE_NPY_PATH)
    print(f"Loaded PF_BASE_COOR_NP (for PDB writer) from {PF_BASE_NPY_PATH}")

    # Load the explicit other2.npy for PDB writer
    if os.path.exists(PF_OTHER2_NPY_EXPLICIT_PATH):
        PF_OTHER_COOR_NP = np.load(PF_OTHER2_NPY_EXPLICIT_PATH)
        print(f"Loaded PF_OTHER_COOR_NP (for PDB writer) from explicit path: {PF_OTHER2_NPY_EXPLICIT_PATH}")
    else:
        print(f"CRITICAL ERROR: PF_OTHER2_NPY_EXPLICIT_PATH '{PF_OTHER2_NPY_EXPLICIT_PATH}' not found. This is required for write_frame_coor_to_pdb.")
        # Fallback or error handling if it's truly optional and write_frame_coor_to_pdb can cope
        # For now, let's make it critical based on its importance for C1'
        # PF_OTHER_COOR_NP = PF_BASE_COOR_NP # Unsafe fallback
        # print(f"Warning: Using PF_BASE_COOR_NP as fallback for PDB writer's OTHER_COOR. This is likely incorrect.")
        sys.exit(1) # Making it critical

except Exception as e_load_base:
    print(f"CRITICAL ERROR Loading base/other coordinates: {e_load_base}"); traceback.print_exc(); sys.exit(1)

print("Loading RNALM model (for DRFold2) ONCE...")
rnalm_global_model = None
try:
    rnalm_path = os.path.join(TARGET_MODEL_HUB, 'RCLM', 'epoch_67000')
    if not os.path.exists(rnalm_path):
        alt_rnalm_path = os.path.join(TARGET_MODEL_HUB, 'epoch_67000')
        if os.path.exists(alt_rnalm_path): rnalm_path = alt_rnalm_path
        else: raise FileNotFoundError(f"RNALM model checkpoint not found: {rnalm_path} or {alt_rnalm_path}")
    rnalm_global_model = RNA2nd(dict(s_in_dim=5, z_in_dim=2, s_dim=512, z_dim=128, N_elayers=18))
    use_weights_only_rnalm = (torch.__version__ >= '1.8') and ('weights_only' in torch.load.__code__.co_varnames)
    state_rnalm = torch.load(rnalm_path, map_location=torch.device('cpu'), weights_only=use_weights_only_rnalm) if use_weights_only_rnalm else torch.load(rnalm_path, map_location=torch.device('cpu'))
    rnalm_global_model.load_state_dict(state_rnalm, strict=False)
    rnalm_global_model.to(DEVICE).eval()
    print("RNALM loaded globally and moved to device.")
except Exception as e_rnalm:
    print(f"CRITICAL ERROR Loading RNALM: {e_rnalm}"); traceback.print_exc(); sys.exit(1)

def load_sequences_for_submission(data_directory):
    # ... (same as before)
    fname = 'test_sequences.csv'
    seq_file_path = os.path.join(data_directory, fname)
    if not os.path.exists(seq_file_path):
        print(f"CRITICAL ERROR: Sequence file {seq_file_path} not found.")
        sys.exit(1)
    df = pd.read_csv(seq_file_path)
    id_col_name = None
    if 'target_id' in df.columns: id_col_name = 'target_id'
    elif 'ID' in df.columns: id_col_name = 'ID'
    elif 'id' in df.columns: id_col_name = 'id'
    else: raise ValueError("Cannot find target ID column.")
    if id_col_name != 'target_id': df['target_id'] = df[id_col_name]
    target_ids = df['target_id'].tolist()
    print(f"Loaded {len(df)} sequences from {fname}.")
    return df, target_ids


MSA_DIR_PATH = "/kaggle/input/your_competition_msa_data_directory"
def load_msa_for_target(target_id, msa_directory):
    # ... (same as before)
    msa_file_path = os.path.join(msa_directory, f"{target_id}.a3m")
    if os.path.exists(msa_file_path):
        print(f" MSA file {msa_file_path} found, but MSA parsing not fully implemented.")
        return []
    else:
        return []


import torch
import torch.nn.functional as F # Assuming F is from here
import numpy as np # If compute_apos uses it
import os

# Define these functions exactly as in your PreMSA
def _static_compute_pos(maxL=2000):
    a = torch.arange(maxL)
    b = (a[None,:]-a[:,None]).clamp(-32,32)
    return F.one_hot(b+32,65).float()

def _static_compute_apos(maxL=2000):
    d_range = torch.arange(maxL)
    m_bits = 14
    return (((d_range[:,None] & (1 << torch.arange(m_bits)))) > 0).float()

SAVE_DIR = os.path.join(DRFOLD2_ROOT, "model_assets_precomputed") # Choose a persistent path
os.makedirs(SAVE_DIR, exist_ok=True)
MAX_L_FOR_STATIC_TENSORS = 2000 # Or your chosen max length

pos_tensor = _static_compute_pos(maxL=MAX_L_FOR_STATIC_TENSORS)
apos_tensor = _static_compute_apos(maxL=MAX_L_FOR_STATIC_TENSORS)

torch.save(pos_tensor, os.path.join(SAVE_DIR, f"static_premsa_pos_L{MAX_L_FOR_STATIC_TENSORS}.pt"))
torch.save(apos_tensor, os.path.join(SAVE_DIR, f"static_premsa_apos_L{MAX_L_FOR_STATIC_TENSORS}.pt"))
print(f"Saved static PreMSA tensors to {SAVE_DIR}")




overall_script_start_time = time.time()
sequences_df, target_ids_list = load_sequences_for_submission(DATA_DIR)
MODELS_LONG = [0,1, 2, 8, 9]
MODELS_MEDIUM = list(range(20)) # Or your preferred list for medium, e.g., [0,1] for quick test
TEMP_PROCESSING_DIR = tempfile.mkdtemp(dir="/kaggle/working/", prefix="drfold_proc_")
print(f"Temporary processing directory: {TEMP_PROCESSING_DIR}")

all_target_c1_prime_coords = {tgt_id: [None] * MAX_SUBMISSION_MODELS for tgt_id in target_ids_list}
SCORING_CONFIG_PATH = os.path.join(DRFOLD2_ROOT, "cfg_for_selection.json")
if not os.path.exists(SCORING_CONFIG_PATH):
    print(f"WARNING: PotentialFold scoring config file not found: {SCORING_CONFIG_PATH}.")

RUNTIME_CACHE_BASE_DIR = os.path.join(DRFOLD2_ROOT, "runtime_cache_v2") # v2 to avoid old cache
MDDDM_CACHE_DIR = os.path.join(RUNTIME_CACHE_BASE_DIR, "mdddm_outputs") # For make_drfold_data_dummy_msa
PREMSA_INTERNAL_CACHE_DIR = os.path.join(RUNTIME_CACHE_BASE_DIR, "premsa_internals") # For PreMSA
# --- ADD THIS LINE ---
os.makedirs(MDDDM_CACHE_DIR, exist_ok=True)
# You might also want to ensure PREMSA_INTERNAL_CACHE_DIR is created if it's used for saving
os.makedirs(PREMSA_INTERNAL_CACHE_DIR, exist_ok=True) # If you also save to this dir

MAX_LEN_FOR_CACHING = MAX_DRFOLD_PREDICTION_LENGTH

for target_idx, row_data in sequences_df.iterrows():
    current_target_id = target_ids_list[target_idx]
    full_target_sequence = row_data.sequence.strip().upper()
    L_full_seq = len(full_target_sequence)
    target_loop_start_time = time.time()
    print(f"\n{'='*80}\nProcessing Target: {current_target_id} (L: {L_full_seq}) ({target_idx+1}/{len(sequences_df)})\n{'='*80}")
    
    # Corrected section:
    model_ids_for_this_target_info = [] 
    if L_full_seq > MAX_DRFOLD_PREDICTION_LENGTH:
        print("proteinx started")
        # Call your corrected Protenix prediction function
        # It should return a DataFrame with columns like x_1, y_1, z_1, ..., x_5, y_5, z_5
        protenix_df_pred = proteinx(full_target_sequence, current_target_id)
        protenix_c1_for_submission_list = [np.full((L_full_seq, 3), np.nan, dtype=np.float32) for _ in range(MAX_SUBMISSION_MODELS)]
        num_protenix_models_processed = 0

        if protenix_df_pred.empty:
            print(f"    Protenix returned no prediction or an empty DataFrame for {current_target_id}.")
        else:
            # Process up to 5 models from the Protenix DataFrame
            for model_idx in range(MAX_SUBMISSION_MODELS): # 0 to MAX_SUBMISSION_MODELS-1
                protenix_col_suffix = model_idx + 1 

                x_col = f'x_{protenix_col_suffix}'
                y_col = f'y_{protenix_col_suffix}'
                z_col = f'z_{protenix_col_suffix}'

                if not (x_col in protenix_df_pred.columns and \
                        y_col in protenix_df_pred.columns and \
                        z_col in protenix_df_pred.columns):
                    if model_idx == 0:
                        print(f"      Protenix output for {current_target_id} missing coordinate columns for its model {protenix_col_suffix} (e.g., {x_col}).")
                    break 

                try:
                    current_model_c1_coords_direct = np.full((L_full_seq, 3), np.nan, dtype=np.float32)
                    
                    # Check if DataFrame has enough rows
                    if len(protenix_df_pred[x_col]) >= L_full_seq:
                        current_model_c1_coords_direct[:, 0] = protenix_df_pred[x_col].values[:L_full_seq]
                        current_model_c1_coords_direct[:, 1] = protenix_df_pred[y_col].values[:L_full_seq]
                        current_model_c1_coords_direct[:, 2] = protenix_df_pred[z_col].values[:L_full_seq]
                    else:
                        print(f"      Warning: Protenix DataFrame for model {protenix_col_suffix} has {len(protenix_df_pred[x_col])} rows, less than sequence length {L_full_seq}. Padding with NaNs.")
                        valid_len = len(protenix_df_pred[x_col])
                        current_model_c1_coords_direct[:valid_len, 0] = protenix_df_pred[x_col].values
                        current_model_c1_coords_direct[:valid_len, 1] = protenix_df_pred[y_col].values
                        current_model_c1_coords_direct[:valid_len, 2] = protenix_df_pred[z_col].values
                    num_protenix_models_processed += 1
                    protenix_c1_for_submission_list[model_idx] = current_model_c1_coords_direct
                except Exception as e_extract:
                    print(f"      Error extracting coordinates for Protenix model {protenix_col_suffix} for {current_target_id}: {e_extract}")
        
            all_target_c1_prime_coords[current_target_id] = protenix_c1_for_submission_list

            if num_protenix_models_processed > 0:
                print(f"  Added {num_protenix_models_processed} model(s) from Protenix to scoring pool for {current_target_id}.")
            else:
                print(f"  No models from Protenix were added for {current_target_id} due to missing data in its DataFrame output.")

        print(f"  Finished processing target {current_target_id} (Protenix path).")
        continue # Crucial: Skip DRFold2 models for this target
    else:
        model_ids_for_this_target_info = MODELS_COMBINED_CFG_AWARE
        print(f"  Target {current_target_id} (L={L_full_seq}): Using combined model set from various CFGs.")

    msa_aligned_sequences_list = load_msa_for_target(current_target_id, MSA_DIR_PATH)
    # ... (DRFold2 prediction loop - same as before, but use get_base_for_network)
    raw_predictions_for_this_target = []
    for model_info_item in model_ids_for_this_target_info: # NEW LOOP
        current_model_numeric_id = model_info_item['id']
        current_model_cfg_folder = model_info_item['cfg']
        # --- DYNAMIC MODULE LOADING for cfg-specific modules ---
        # Save the sys.path state *before* this iteration's dynamic changes
        # This should be the sys.path established after global initializations.
        sys_path_before_dynamic_load = list(sys.path)
        
        current_cfg_module_path = os.path.join(DRFOLD2_ROOT, current_model_cfg_folder)
        
        # Temporarily prioritize the current CFG path
        # Create a new list for sys.path for this iteration
        temp_sys_path_iter = list(original_system_path) # Start from pristine
        
        # Add paths in reverse order of priority (highest priority added last to be at index 0)
        paths_for_this_iter_ordered = [
            current_cfg_module_path,                             # [0] cfg_xx specific modules
            DRFOLD2_ROOT,                                        # [1] For finding cfg_xx/RNALM2 etc.
            os.path.join(DRFOLD2_ROOT, "PotentialFold"),          # [2] For PotentialFold if needed by dynamic modules
        ]
        
        for p_iter in reversed(paths_for_this_iter_ordered):
            if os.path.exists(p_iter) and p_iter not in temp_sys_path_iter:
                temp_sys_path_iter.insert(0, p_iter)
        
        # Add back any other paths from original_system_path that were not DRFold2 related
        for p_orig in original_system_path:
            if DRFOLD2_ROOT not in p_orig and p_orig not in temp_sys_path_iter:
                temp_sys_path_iter.append(p_orig) # Or insert at appropriate place if order matters

        sys.path = temp_sys_path_iter
        # print(f"    Temp sys.path for {current_model_cfg_folder}: {sys.path[:7]}...")


        import importlib

        # Modules that are inside each cfg_XX folder (or might depend on them)
        # and need to be reloaded/imported from the current_cfg_module_path context.
        # RNALM2.Model is included here to ensure the EvoMSA2XYZ gets its cfg_xx/RNALM2/Model.py
        modules_to_clear_and_import = ['data', 'basic', 'IPA', 'EvoMSA', 'EvoPair', 'Structure', 'Evoformer', 'RNALM2.Model', 'EvoMSA2XYZ']
        
        for module_name_str in modules_to_clear_and_import:
            if module_name_str in sys.modules:
                del sys.modules[module_name_str]
            if '.' in module_name_str: # For 'RNALM2.Model', ensure 'RNALM2' is also cleared if it was a package entry
                pkg_name = module_name_str.split('.')[0]
                if pkg_name in sys.modules:
                    del sys.modules[pkg_name]
        
        # ... (model_checkpoint_path, drfold_network initialization, etc.)
        try:
            import EvoMSA2XYZ as current_evo_module 
            from EvoMSA2XYZ import MSA2XYZ # If you only import MSA2XYZ
            
            print(f"[Main Script] ID of rnalm_global_model: {id(rnalm_global_model)}")
            
            # --- This is the crucial step ---
            # Check if the module has an 'RNAlm' attribute (it should, from its own definition)
            if hasattr(current_evo_module, 'RNAlm'):
                current_evo_module.RNAlm = rnalm_global_model # Replace it
                print(f"[Main Script] ID of RNAlm in '{current_evo_module.__name__}' AFTER patch: {id(current_evo_module.RNAlm)}")
            else:
                print(f"[Main Script] WARNING: Module '{current_evo_module.__name__}' does not have an 'RNAlm' attribute to patch.")
            import data as drfold_data_module          # From DRFOLD2_ROOT/current_cfg_folder/data.py
            # The EvoMSA2XYZ script will do 'from RNALM2 import Model', which will pick up
            # DRFOLD2_ROOT/current_cfg_folder/RNALM2/Model.py
            
            parse_seq_rna = drfold_data_module.parse_seq
            get_base_for_network = drfold_data_module.Get_base
            # --- END DYNAMIC MODULE LOADING ---

            model_checkpoint_path = os.path.join(TARGET_MODEL_HUB, current_model_cfg_folder, f"model_{current_model_numeric_id}")

            if not os.path.exists(model_checkpoint_path):
                print(f"    Checkpoint for model ID {current_model_numeric_id} from CFG {current_model_cfg_folder} not found at {model_checkpoint_path}. Skipping.")
                sys.path = sys_path_before_dynamic_load # Restore sys.path
                continue
            
            print(f"  Predicting with DRFold2 Model ID: {current_model_numeric_id} from CFG: {current_model_cfg_folder}")
            drfold_network = None
            # ... (network loading, RNALM assignment) ...
            drfold_network = MSA2XYZ(seq_dim=6, msa_dim=7, N_ensemble=1, N_cycle=8, m_dim=64, s_dim=64, z_dim=64)
            use_weights_only_model = (torch.__version__ >= '1.8') and ('weights_only' in torch.load.__code__.co_varnames)
            state_dict_model = torch.load(model_checkpoint_path, map_location='cpu', weights_only=use_weights_only_model) if use_weights_only_model else torch.load(model_checkpoint_path, map_location='cpu')
            drfold_network.load_state_dict(state_dict_model, strict=False)

            drfold_network.to(DEVICE).eval()
            # Handle sequence cropping for DRFold2 input
            crop_offset_val = 0 # Initialize here
            sequence_for_prediction = full_target_sequence # Initialize here
            L_pred_seq = L_full_seq # Initialize here

            if L_full_seq > MAX_DRFOLD_PREDICTION_LENGTH: 
                np.random.seed(L_full_seq + current_model_numeric_id) 
                crop_offset_val = np.random.randint(0, L_full_seq - MAX_DRFOLD_PREDICTION_LENGTH + 1) 
                sequence_for_prediction = full_target_sequence[crop_offset_val : crop_offset_val + MAX_DRFOLD_PREDICTION_LENGTH] 
            
            msa_input_tensor, base_input_tensor, other_input_tensor, idx_input_tensor = None, None, None, None
            is_eligible_for_mdddm_cache = (L_full_seq <= MAX_LEN_FOR_CACHING) # Only cache if original was short (uncropped)

            if is_eligible_for_mdddm_cache:
#             # Key based on the full (uncropped) sequence as sequence_for_prediction is the full sequence here
              seq_key_mdddm = hashlib.md5(full_target_sequence.encode()).hexdigest()
              cache_filename_mdddm = os.path.join(MDDDM_CACHE_DIR, f"mdddm_L{L_full_seq}_{seq_key_mdddm}.pt")

              if os.path.exists(cache_filename_mdddm):
                  try:
                      # print(f"  [Cache Load MDDDM] For {current_target_id}")
                      msa_input_tensor, base_input_tensor, other_input_tensor, idx_input_tensor = \
                          torch.load(cache_filename_mdddm, map_location=DEVICE)
                      print("cache used msa")
                  except Exception as e_load:
                      print(f"    [Cache Error MDDDM] Load failed: {e_load}. Recomputing.")
                      msa_input_tensor = None 
        
            if msa_input_tensor is None: # Needs computation
                if is_eligible_for_mdddm_cache: print(f"  [Cache Compute MDDDM] For {current_target_id}")
            
                msa_input_tensor, base_input_tensor, other_input_tensor, idx_input_tensor = \
                  make_drfold_data_dummy_msa(sequence_for_prediction, # Use the actual segment
                                             BASE_COOR_FOR_MODEL, OTHER_COOR_FOR_MODEL,
                                             parse_seq_rna, get_base_for_network, DEVICE)
            
                if is_eligible_for_mdddm_cache: # Save only if it was for a short, uncropped sequence
                  try:
                      torch.save((msa_input_tensor, base_input_tensor, other_input_tensor, idx_input_tensor), 
                                 cache_filename_mdddm)
                      print(f"    [Cache Save MDDDM] Saved for {current_target_id}")
                  except Exception as e_save:
                      print(f"    [Cache Error MDDDM] Save failed: {e_save}") 
                
            # THIS IS WHERE drfold_output_dict IS DEFINED:
            with torch.no_grad():
                drfold_output_dict = drfold_network.pred(msa_input_tensor, idx_input_tensor, None, base_input_tensor, list(sequence_for_prediction))

            # Convert outputs to NumPy and desired types
            for key_to_convert in ['coor', 'dist_p', 'dist_c', 'dist_n', 'plddt']:
                if key_to_convert in drfold_output_dict and isinstance(drfold_output_dict[key_to_convert], torch.Tensor):
                    drfold_output_dict[key_to_convert] = drfold_output_dict[key_to_convert].cpu().numpy()
            
            if 'coor' in drfold_output_dict: drfold_output_dict['coor'] = drfold_output_dict['coor'].astype(np.float32)
            # ... (other dtype conversions) ...
            
            # Now drfold_output_dict is defined and can be saved:
            temp_pred_output_path = os.path.join(TEMP_PROCESSING_DIR, f"{current_target_id}_cfg{current_model_cfg_folder}_m{current_model_numeric_id}_pred_data.pkl") 
            with open(temp_pred_output_path, 'wb') as f_pkl: pickle.dump(drfold_output_dict, f_pkl) 
            
            raw_predictions_for_this_target.append({
                'model_numeric_id': current_model_numeric_id, 
                'model_cfg': current_model_cfg_folder,       
                'pred_data_path': temp_pred_output_path, 
                'seq_predicted_segment': sequence_for_prediction, 
                'seq_full': full_target_sequence, 
                'seq_len_full': L_full_seq, 
                'seq_len_predicted': L_pred_seq, 
                'crop_offset': crop_offset_val 
            })
            del drfold_output_dict


            print(f"    Prediction successful for model {current_model_numeric_id} from CFG {current_model_cfg_folder}.") 
        
        except Exception as e_pred: 
            print(f"    ERROR during DRFold2 prediction for {current_target_id}, model {current_model_numeric_id} (CFG {current_model_cfg_folder}): {e_pred}"); traceback.print_exc() 
        finally:
            if 'drfold_network' in locals() and drfold_network is not None: del drfold_network # Corrected variable name
            if 'msa_input_tensor' in locals(): del msa_input_tensor, base_input_tensor, other_input_tensor, idx_input_tensor
            # drfold_output_dict is defined within the try block, so no need to check for it in finally for deletion here.
            # It will go out of scope or be handled by gc.collect() if the block exits.
            sys.path = sys_path_before_dynamic_load # Restore previous sys.path

            gc.collect(); torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    if not raw_predictions_for_this_target: print(f" No successful DRFold2 predictions for {current_target_id}. Skipping."); continue

    # Scoring logic
    selected_raw_model_predictions_info = []
    if 0 < L_full_seq <= 480 and os.path.exists(SCORING_CONFIG_PATH):
        print(f" Scoring {len(raw_predictions_for_this_target)} raw predictions for {current_target_id} (L={L_full_seq}) using PotentialFold...")
        scored_predictions_temp = []
# ... (inside scoring logic) ...
        for pred_info_item in raw_predictions_for_this_target: #
            model_num_id_for_score = pred_info_item['model_numeric_id'] # Use new key
            model_cfg_for_score = pred_info_item['model_cfg']         # Use new key
            pred_data_path = pred_info_item['pred_data_path'] #
            # ... (load pickled data) ...
            try:
                with open(pred_data_path, 'rb') as f_pkl: out_d = pickle.load(f_pkl)
            except Exception as e_load_pkl_score:
                print(f" ERROR: Could not load pred_data for scoring from {pred_data_path}: {e_load_pkl_score}. Skipping.")
                scored_predictions_temp.append({'energy': float('inf'), 'pred_info': pred_info_item}); continue
            
            seq_for_score = pred_info_item['seq_predicted_segment']            
            # Update filenames for PotentialFold temporary files
            temp_ret_file = os.path.join(TEMP_PROCESSING_DIR, f"{current_target_id}_cfg{model_cfg_for_score}_m{model_num_id_for_score}.ret") #
            # ... (write .ret and .fasta files) ...
            with open(temp_ret_file, 'wb') as f_p_ret: pickle.dump(out_d, f_p_ret) #
            temp_fasta_file = os.path.join(TEMP_PROCESSING_DIR, f"{current_target_id}_cfg{model_cfg_for_score}_m{model_num_id_for_score}.fasta") #
            with open(temp_fasta_file, 'w') as f_p_fasta: f_p_fasta.write(f">{current_target_id}_cfg{model_cfg_for_score}_m{model_num_id_for_score}\n{seq_for_score}\n") #

            try: #
                potential_fold_scorer = PFStructure(fastafile=temp_fasta_file, geofiles=[temp_ret_file], foldconfig=SCORING_CONFIG_PATH, #
                                                    saveprefix=os.path.join(TEMP_PROCESSING_DIR, f"score_dummy_{current_target_id}_cfg{model_cfg_for_score}_m{model_num_id_for_score}")) #
                # ... (rest of scoring logic) ...
                if 'coor' in out_d:
                    # --- FIX 1: Move tensor to DEVICE ---
                    coords_for_score_tensor_cpu = torch.from_numpy(out_d['coor'].astype(np.float64)) # Stays on CPU
                    energy_score = potential_fold_scorer.energy_from_coor(coords_for_score_tensor_cpu).item()
                    scored_predictions_temp.append({'energy': energy_score, 'pred_info': pred_info_item})
                else: scored_predictions_temp.append({'energy': float('inf'), 'pred_info': pred_info_item})
            except Exception as e_score:
                print(f" Error scoring {current_target_id}, model {current_model_id} with PotentialFold: {e_score}"); #traceback.print_exc()
                scored_predictions_temp.append({'energy': float('inf'), 'pred_info': pred_info_item})
            finally:
                if 'out_d' in locals(): del out_d
                if 'potential_fold_scorer' in locals(): del potential_fold_scorer
                gc.collect()
        scored_predictions_temp.sort(key=lambda x: x['energy'])
        selected_raw_model_predictions_info = [item['pred_info'] for item in scored_predictions_temp[:MAX_SUBMISSION_MODELS]]
        print(f" Selected top {len(selected_raw_model_predictions_info)} models for {current_target_id} based on energy.")
    else: # Fallback if not scoring
        # ... (same fallback logic as before) ...
        if not (0 < L_full_seq <= 480): print(f" Target {current_target_id} (L={L_full_seq}) is outside PotFold scoring range (1-480). Using first predictions.")
        elif not os.path.exists(SCORING_CONFIG_PATH): print(f" Scoring config not found. Using first predictions for {current_target_id}.")
        selected_raw_model_predictions_info = raw_predictions_for_this_target[:MAX_SUBMISSION_MODELS]
        print(f" Selected first {len(selected_raw_model_predictions_info)} models for {current_target_id} (no/failed energy scoring).")


    c1_coords_for_target_submission = [np.full((L_full_seq, 3), np.nan, dtype=np.float32) for _ in range(MAX_SUBMISSION_MODELS)]
    for i, current_pred_info in enumerate(selected_raw_model_predictions_info):
        if i >= MAX_SUBMISSION_MODELS: break
        model_num_id_for_pdb = current_pred_info['model_numeric_id'] # Use new key
        model_cfg_for_pdb = current_pred_info['model_cfg']         # Use new key

        print(f"  Generating PDB and extracting C1' for model {i+1}/{len(selected_raw_model_predictions_info)} (orig_CFG: {model_cfg_for_pdb}, orig_ID: {model_num_id_for_pdb}) for {current_target_id}...") #
        pred_data_path = current_pred_info['pred_data_path']
        try:
            with open(pred_data_path, 'rb') as f_pkl: loaded_out_dict = pickle.load(f_pkl)
        except Exception as e_load_pkl:
            print(f" ERROR: Could not load pred_data from {pred_data_path}: {e_load_pkl}. Skipping."); continue
        
        raw_coor_L33 = loaded_out_dict['coor'] # This is (L_pred, 3, 3) from DRFold2 network
        seq_segment_predicted = current_pred_info['seq_predicted_segment']

        
        unrefined_pdb_path = os.path.join(TEMP_PROCESSING_DIR, f"{current_target_id}_cfg{model_cfg_for_pdb}_m{model_num_id_for_pdb}_unrefined_C1.pdb") #
        try:
                write_frame_coor_to_pdb(
                    raw_coor_L33,
                    seq_segment_predicted,
                    unrefined_pdb_path,
                    PF_BASE_COOR_NP,
                    PF_OTHER_COOR_NP
                )
                if not (os.path.exists(unrefined_pdb_path) and os.path.getsize(unrefined_pdb_path) > 0):
                    print(f" Unrefined PDB file {unrefined_pdb_path} was not created or is empty. Skipping further processing for this model.")
                    del loaded_out_dict; gc.collect(); continue
        except Exception as e_custom_pdb:
                print(f" ERROR during custom PDB generation for model {current_pred_info['model_id']}: {e_custom_pdb}"); traceback.print_exc()
                del loaded_out_dict; gc.collect(); continue
        
            # OPENMM: Refine the PDB
        refined_pdb_path_temp = os.path.join(TEMP_PROCESSING_DIR, f"{current_target_id}_cfg{model_cfg_for_pdb}_m{model_num_id_for_pdb}_refined_C1.pdb") #
        descriptive_model_id_for_log = f"cfg{model_cfg_for_pdb}_m{model_num_id_for_pdb}"
        """"pdb_path_for_coord_extraction = refine_pdb_with_openmm( #
                                        unrefined_pdb_path, #
                                        refined_pdb_path_temp, #
                                        current_target_id, #
                                        descriptive_model_id_for_log # Use descriptive ID for OpenMM logs
                                    )"""

        extracted_coords_segment = extract_c1_prime_coords_from_pdb(unrefined_pdb_path, len(seq_segment_predicted))
        if extracted_coords_segment is not None and not np.all(np.isnan(extracted_coords_segment)):
            # ... (same coordinate padding logic as before) ...
            current_full_length_coords = np.full((L_full_seq, 3), np.nan, dtype=np.float32)
            offset = current_pred_info['crop_offset']; len_pred = current_pred_info['seq_len_predicted']
            if extracted_coords_segment.shape[0] == len_pred:
                current_full_length_coords[offset : offset + len_pred] = extracted_coords_segment
            else:
                print(f" Warning: Length mismatch for C1' extracted coords. Expected {len_pred}, got {extracted_coords_segment.shape[0]}. Padding.")
                valid_len_extracted = min(extracted_coords_segment.shape[0], len_pred)
                current_full_length_coords[offset : offset + valid_len_extracted] = extracted_coords_segment[:valid_len_extracted]
            c1_coords_for_target_submission[i] = current_full_length_coords
            # FIX: Use the correct keys for model identification
            print(f" Successfully generated PDB and extracted C1' for model (CFG: {current_pred_info['model_cfg']}, ID: {current_pred_info['model_numeric_id']}).")
            """"if i == 0 and os.path.exists(pdb_path_for_coord_extraction): # Save top 1 PDB
                # FIX for commented code as well if you plan to use it:
                # final_pdb_savename = os.path.join(OUT_DIR_PDB, f"{current_target_id}_top1_custom_C1_cfg{current_pred_info['model_cfg']}_m{current_pred_info['model_numeric_id']}.pdb")
                final_pdb_savename = os.path.join(OUT_DIR_PDB, f"{current_target_id}_top1_custom_C1_m{current_pred_info['model_numeric_id']}.pdb") # Or keep it simpler if preferred
                shutil.copy(pdb_path_for_coord_extraction, final_pdb_savename)
                print(f" Saved top 1 custom C1' PDB to: {final_pdb_savename}")"""
        else:
            # FIX: Use the correct keys here as well
            print(f" Failed to extract valid C1' coordinates for model (CFG: {current_pred_info['model_cfg']}, ID: {current_pred_info['model_numeric_id']}).")
        del loaded_out_dict; gc.collect()

    all_target_c1_prime_coords[current_target_id] = c1_coords_for_target_submission
    print(f" Finished processing for target {current_target_id}. Time: {time.time() - target_loop_start_time:.1f}s")
    print(f"  Cleaning up temporary files for target {current_target_id}...")
    files_in_temp_dir = os.listdir(TEMP_PROCESSING_DIR)
    cleaned_count = 0
    for temp_filename in files_in_temp_dir:
        # Check if the filename contains the current target ID and common patterns
        # Adjust the pattern if your temporary filenames are more complex
        if f"{current_target_id}_cfg" in temp_filename or \
           f"score_dummy_{current_target_id}_cfg" in temp_filename:
            file_to_remove = os.path.join(TEMP_PROCESSING_DIR, temp_filename)
            try:
                if os.path.isfile(file_to_remove): # Ensure it's a file
                    os.remove(file_to_remove)
                    cleaned_count += 1
            except Exception as e_clean_file:
                print(f"    Warning: Could not remove temporary file {file_to_remove}: {e_clean_file}")
    print(f"  Cleaned up {cleaned_count} temporary files for {current_target_id}.")

# ... (Submission file assembly - same as before, using all_target_c1_prime_coords) ...
print("\nAssembling submission file...")
sample_sub_path = os.path.join(DATA_DIR, 'sample_submission.csv')
if not os.path.exists(sample_sub_path):
    print(f"CRITICAL ERROR: sample_submission.csv not found at {sample_sub_path}")
    sys.exit(1)
sample_sub_df = pd.read_csv(sample_sub_path)
submission_data_dict = {col: [] for col in sample_sub_df.columns}
map_target_id_from_sub = {sub_id_str: sub_id_str.rsplit('_',1)[0] for sub_id_str in sample_sub_df['ID']}
map_res_idx_from_sub = {sub_id_str: int(sub_id_str.rsplit('_',1)[1])-1 for sub_id_str in sample_sub_df['ID']}

for _, sample_row in sample_sub_df.iterrows():
    submission_id_str = sample_row['ID']
    submission_data_dict['ID'].append(submission_id_str)
    submission_data_dict['resname'].append(sample_row['resname'])
    submission_data_dict['resid'].append(sample_row['resid'])

    target_id_for_coords = map_target_id_from_sub[submission_id_str]
    residue_idx_for_coords = map_res_idx_from_sub[submission_id_str]
    coords_list_for_this_target = all_target_c1_prime_coords.get(target_id_for_coords, [None] * MAX_SUBMISSION_MODELS)

    for model_slot_idx in range(MAX_SUBMISSION_MODELS):
        axis_x_col = f'x_{model_slot_idx+1}'; axis_y_col = f'y_{model_slot_idx+1}'; axis_z_col = f'z_{model_slot_idx+1}'
        x_val, y_val, z_val = np.nan, np.nan, np.nan
        if model_slot_idx < len(coords_list_for_this_target) and coords_list_for_this_target[model_slot_idx] is not None:
            c1_coords_for_model_slot = coords_list_for_this_target[model_slot_idx]
            if 0 <= residue_idx_for_coords < c1_coords_for_model_slot.shape[0]:
                if not np.any(np.isnan(c1_coords_for_model_slot[residue_idx_for_coords])):
                    x_val = c1_coords_for_model_slot[residue_idx_for_coords, 0]
                    y_val = c1_coords_for_model_slot[residue_idx_for_coords, 1]
                    z_val = c1_coords_for_model_slot[residue_idx_for_coords, 2]
        submission_data_dict[axis_x_col].append(x_val)
        submission_data_dict[axis_y_col].append(y_val)
        submission_data_dict[axis_z_col].append(z_val)

final_submission_df = pd.DataFrame(submission_data_dict)
coord_cols_to_fill_with_zero = []
for c_idx in range(MAX_SUBMISSION_MODELS):
    for axis in ('x', 'y', 'z'): coord_cols_to_fill_with_zero.append(f"{axis}_{c_idx+1}")
for col_name_fill in coord_cols_to_fill_with_zero:
    if col_name_fill in final_submission_df.columns:
        final_submission_df[col_name_fill] = final_submission_df[col_name_fill].fillna(0.0)
    else: print(f"Warning: Expected coord column '{col_name_fill}' not found for NaN replacement.")

try:
    shutil.rmtree(TEMP_PROCESSING_DIR)
    print(f"Cleaned up temporary processing directory: {TEMP_PROCESSING_DIR}")
except Exception as e_cleanup:
    print(f"Error cleaning up temp directory {TEMP_PROCESSING_DIR}: {e_cleanup}")

submission_file_path = os.path.join(SUBMISSION_OUT_DIR, 'submission.csv')
final_submission_df.to_csv(submission_file_path, index=False)
print(f"\nSuccessfully created submission file: {submission_file_path}")
print(f"\nTotal script execution time: {(time.time() - overall_script_start_time)/60:.2f} minutes.")
print("Combined script finished with custom PDB generation for C1'.")

OpenMM imported successfully.
Applied non-reentrant checkpoint patch.
Using device: cuda
Modified sys.argv for DRFold2 compatibility: ['/usr/local/lib/python3.11/dist-packages/colab_kernel_launcher.py', 'cuda', '/tmp/tmp9xf5lus8.json', '--HistoryManager.hist_file=:memory:']
Output directory for selected custom PDBs: /kaggle/working/predictions_custom_pdbs
Output directory for submission file: /kaggle/working/
Symlink created: /kaggle/working/DRfold2/model_hub -> /kaggle/input/drfold2-models/model_hub
Sys.path for global imports initialized: ['/kaggle/working/DRfold2/cfg_97', '/kaggle/working/DRfold2', '/kaggle/working/DRfold2/PotentialFold', '/kaggle/working', '/kaggle/lib/kagglegym', '/kaggle/lib', '/kaggle/input/stanford-rna-3d-folding']...
Successfully imported a2b from PotentialFold.
DRFold2 core modules imported successfully.
PotentialFold selection modules imported.
Loaded BASE_COOR_FOR_MODEL (for network) from /kaggle/working/DRfold2/cfg_97/base.npy
Loaded OTHER_COOR_FOR_MODEL (

  with torch.cuda.amp.autocast(enabled=False):
  with torch.cuda.amp.autocast(enabled=False):
  with torch.cuda.amp.autocast(enabled=False):


    Protenix: Successfully predicted for R1138. DataFrame shape: (720, 18)
  Added 5 model(s) from Protenix to scoring pool for R1138.
  Finished processing target R1138 (Protenix path).

Processing Target: R1149 (L: 124) (9/12)
  Target R1149 (L=124): Using combined model set from various CFGs.
will do checkpoint
will do checkpoint
will do checkpoint
will do checkpoint
will do checkpoint
will do checkpoint
will do checkpoint
will do checkpoint
will do checkpoint
will do checkpoint
will do checkpoint
will do checkpoint
will do checkpoint
will do checkpoint
will do checkpoint
will do checkpoint
will do checkpoint
will do checkpoint
[Main Script] ID of rnalm_global_model: 134200793409552
[Main Script] ID of RNAlm in 'EvoMSA2XYZ' AFTER patch: 134200793409552
  Predicting with DRFold2 Model ID: 0 from CFG: cfg_95
will do checkpoint
will do checkpoint
will do checkpoint
will do checkpoint
will do checkpoint
will do checkpoint
will do checkpoint
will do checkpoint
will do checkpoint
will do 