## RNA-FrameFlow

!git clone https://github.com/rish-16/rna-backbone-design.git
!cd rna-backbone-design/

```bash
env_name="rna-bb-design"


conda env list | cut -d' ' -f1 | grep -q "^${env_name}$" \
  && echo "${env_name} environment already exists" \
  || conda create -n ${env_name} python=3.10 -y


conda activate ${env_name}
```

Revised all codes to function exclusively on CPUs.

Changes done in "flow_module.py, all_atom.py, evalsuite.py" to work only on "CPU". Original Code works fine on "GPU"

## Data Preprocessing

### process_rna_pdb_files.py

In [None]:



import collections
import functools as fn
import multiprocessing as mp
import os
import time
from tqdm import tqdm
from typing import Any, Dict, Optional

import mdtraj as md
import numpy as np
import pandas as pd
import torch
from Bio import PDB

from src.data import utils
from src.data import parsers

# Fixed arguments that were previously command line parameters
CONFIG = {
    "pdb_dir": "./data/rnasolo/",
    "num_processes": 16,
    "write_dir": "./data/rnasolo_proc/",
    "skip_existing": False,
    "debug": False,
    "verbose": False
}

def process_file(
    file_path: str,
    write_dir: str,
    inter_chain_interact_dist_threshold: float = 7.0,
    skip_existing: bool = False,
    verbose: bool = False,
) -> Optional[Dict[str, Any]]:
    """Processes protein file into usable, smaller pickles.

    Args:
        file_path: Path to file to read.
        write_dir: Directory to write pickles to.
        inter_chain_interact_dist_threshold: Euclidean distance under which
            to classify a pairwise inter-chain residue-atom distance as an interaction.
        skip_existing: Whether to skip processed files.
        verbose: Whether to log everything.

    Returns:
        Saves processed protein to pickle and returns metadata.
    """
    metadata = {}
    pdb_name = os.path.basename(file_path).replace(".pdb", "")
    metadata["pdb_name"] = pdb_name

    pdb_subdir = os.path.join(write_dir, pdb_name[1:3].lower())
    os.makedirs(pdb_subdir, exist_ok=True)
    processed_path = os.path.join(pdb_subdir, f"{pdb_name}.pkl")
    metadata["processed_path"] = os.path.abspath(processed_path)
    metadata["raw_path"] = file_path
    if skip_existing and os.path.exists(metadata["processed_path"]):
        return None
    parser = PDB.PDBParser(QUIET=True)
    structure = parser.get_structure(pdb_name, file_path)

    # Extract all chains
    struct_chains = {chain.id.upper(): chain for chain in structure.get_chains()}
    metadata["num_chains"] = len(struct_chains)

    # Extract features
    all_seqs = set()
    struct_feats = []
    num_na_chains = 0
    
    na_natype, chain_dict = None, None
    
    for chain_id, chain in struct_chains.items():
        # Convert chain id into int
        chain_index = utils.chain_str_to_int(chain_id)
        chain_mol = parsers.process_chain_pdb(chain, chain_index, chain_id, verbose=verbose)
        
        if chain_mol is None:
            # Note: Indicates that neither a protein chain nor a nucleic acid chain was found
            continue
        elif chain_mol[-1]["molecule_type"] == "na":
            num_na_chains += 1
            na_natype = (
                chain_mol[-2]
                if na_natype is None
                else torch.cat((na_natype, chain_mol[-2]), dim=0)
            )

        chain_mol_constants = chain_mol[-1]["molecule_constants"]
        chain_mol_backbone_atom_name = chain_mol[-1]["molecule_backbone_atom_name"]
        chain_dict = parsers.macromolecule_outputs_to_dict(chain_mol)
        
        chain_dict = utils.parse_chain_feats_pdb(
            chain_feats=chain_dict,
            molecule_constants=chain_mol_constants,
            molecule_backbone_atom_name=chain_mol_backbone_atom_name,
        )
        all_seqs.add(tuple(chain_dict["aatype"]))
        struct_feats.append(chain_dict)
    
    if chain_dict is None:
        if verbose:
            print(f"No chains were found for PDB {file_path}. Skipping...")
        return None
    
    if len(all_seqs) == 1:
        metadata["quaternary_category"] = "homomer"
    else:
        metadata["quaternary_category"] = "heteromer"

    # Add assembly features
    seq_to_entity_id = {}
    grouped_chains = collections.defaultdict(list)
    for chain_dict in struct_feats:
        seq = tuple(chain_dict["aatype"])
        if seq not in seq_to_entity_id:
            seq_to_entity_id[seq] = len(seq_to_entity_id) + 1
        grouped_chains[seq_to_entity_id[seq]].append(chain_dict)

    new_all_chain_dict = {}
    chain_id = 1
    for entity_id, group_chain_features in grouped_chains.items():
        for sym_id, chain_dict in enumerate(group_chain_features, start=1):
            new_all_chain_dict[f"{utils.int_id_to_str_id(entity_id)}_{sym_id}"] = chain_dict
            seq_length = len(chain_dict["aatype"])
            chain_dict["asym_id"] = chain_id * np.ones(seq_length)
            chain_dict["sym_id"] = sym_id * np.ones(seq_length)
            chain_dict["entity_id"] = entity_id * np.ones(seq_length)
            chain_id += 1

    # Concatenate features
    complex_feats = utils.concat_np_features(struct_feats, add_batch_dim=False)
    if complex_feats["bb_mask"].sum() < 1.0:
        return None
    assert len(complex_feats["bb_mask"]) == len(
        complex_feats["aatype"]
    ), "Number of core atoms must match number of residues."

    # Record molecule metadata
    metadata["num_protein_chains"] = 0
    metadata["num_na_chains"] = num_na_chains

    # Process geometry features
    complex_aatype = complex_feats["aatype"]
    metadata["seq_len"] = len(complex_aatype)
    metadata["na_seq_len"] = 0 if na_natype is None else len(na_natype)
    modeled_idx = np.where((complex_aatype != 20) & (complex_aatype != 26))[0]
    na_modeled_idx = None if na_natype is None else np.where(na_natype != 26)[0]
    if np.sum((complex_aatype != 20) & (complex_aatype != 26)) == 0:
        raise utils.LengthError("No modeled residues")
    metadata["modeled_seq_len"] = np.max(modeled_idx) - np.min(modeled_idx) + 1
    metadata["modeled_protein_seq_len"] = 0
    metadata["modeled_na_seq_len"] = (
        0 if na_natype is None else np.max(na_modeled_idx) - np.min(na_modeled_idx) + 1
    )
    complex_feats["modeled_idx"] = modeled_idx
    complex_feats["na_modeled_idx"] = na_modeled_idx

    # Find inter-chain interface residues
    num_atoms_per_res = complex_feats["atom_positions"].shape[1]
    bb_pos = torch.from_numpy(complex_feats["bb_positions"]).unsqueeze(0)
    atom_pos = (
        torch.from_numpy(complex_feats["atom_positions"])
        .unsqueeze(0)
        .flatten(start_dim=1, end_dim=2)
    )
    bb_asym_id = torch.from_numpy(complex_feats["asym_id"]).unsqueeze(0)
    atom_asym_id = torch.repeat_interleave(
        bb_asym_id.unsqueeze(2), num_atoms_per_res, dim=2
    ).flatten(start_dim=1, end_dim=2)
    dist_mat = torch.cdist(bb_pos, atom_pos)
    inter_chain_mask = bb_asym_id.unsqueeze(-1) != atom_asym_id.unsqueeze(-2)
    non_h_mask = torch.ones_like(inter_chain_mask)
    interacting_res_mask = (
        inter_chain_mask & non_h_mask & (dist_mat <= inter_chain_interact_dist_threshold)
    )
    complex_feats["inter_chain_interacting_idx"] = torch.nonzero(
        interacting_res_mask.squeeze(0), as_tuple=False
    )[..., 0].unique()

    try:
        traj = md.load(file_path)
    except Exception as e:
        if verbose:
            print(f"Mdtraj failed to load file {file_path} with error {e}")
        traj = None
    try:
        pdb_ss = md.compute_dssp(traj, simplified=True) if traj is not None else None
    except Exception as e:
        if verbose:
            print(f"Mdtraj's call to DSSP failed with error {e}")
        pdb_ss = None
    try:
        pdb_rg = md.compute_rg(traj) if traj is not None else None
    except Exception as e:
        if verbose:
            print(f"Mdtraj's call to RG failed with error {e}")
        pdb_rg = None

    metadata["coil_percent"] = (
        np.sum(pdb_ss == "C") / metadata["modeled_seq_len"] if pdb_ss is not None else np.nan
    )
    metadata["helix_percent"] = (
        np.sum(pdb_ss == "H") / metadata["modeled_seq_len"] if pdb_ss is not None else np.nan
    )
    metadata["strand_percent"] = (
        np.sum(pdb_ss == "E") / metadata["modeled_seq_len"] if pdb_ss is not None else np.nan
    )

    metadata["radius_gyration"] = pdb_rg[0] if pdb_rg is not None else np.nan

    # Write features to pickles
    utils.write_pkl(processed_path, complex_feats)

    return metadata

def process_serially(all_paths, write_dir, skip_existing=False, verbose=False):
    all_metadata = []
    for file_path in tqdm(all_paths):
        try:
            start_time = time.time()
            metadata = process_file(
                file_path, write_dir, skip_existing=skip_existing, verbose=verbose
            )
            elapsed_time = time.time() - start_time
            print(f"Finished {file_path} in {elapsed_time:2.2f}s")
            if metadata is not None:
                all_metadata.append(metadata)
        except utils.DataError as e:
            print(f"Failed {file_path}: {e}")
    return all_metadata

def process_fn(file_path, write_dir=None, skip_existing=False, verbose=False):
    try:
        start_time = time.time()
        metadata = process_file(file_path, write_dir, skip_existing=skip_existing, verbose=verbose)
        elapsed_time = time.time() - start_time
        if verbose:
            print(f"Finished {file_path} in {elapsed_time:2.2f}s")
        return metadata
    except utils.DataError as e:
        if verbose:
            print(f"Failed {file_path}: {e}")

def main():
    # Disable GPU
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = ""
    
    pdb_dir = CONFIG["pdb_dir"]
    all_file_paths = [
        os.path.join(pdb_dir, item)
        for item in os.listdir(pdb_dir)
        if ".pdb" in item
    ]
    total_num_paths = len(all_file_paths)
    write_dir = CONFIG["write_dir"]
    os.makedirs(write_dir, exist_ok=True)
    
    metadata_file_name = "rna_metadata_debug.csv" if CONFIG["debug"] else "rna_metadata.csv"
    metadata_path = os.path.join(write_dir, metadata_file_name)
    print(f"Files will be written to {write_dir}")

    # Process each PDB file
    if CONFIG["num_processes"] == 1 or CONFIG["debug"]:
        all_metadata = process_serially(
            all_file_paths, write_dir, skip_existing=CONFIG["skip_existing"], verbose=CONFIG["verbose"]
        )
    else:
        _process_fn = fn.partial(
            process_fn, write_dir=write_dir, skip_existing=CONFIG["skip_existing"], verbose=CONFIG["verbose"]
        )
        with mp.Pool(processes=CONFIG["num_processes"]) as pool:
            all_metadata = pool.map(_process_fn, all_file_paths)
        all_metadata = [x for x in all_metadata if x is not None]
    
    metadata_df = pd.DataFrame(all_metadata)
    metadata_df.to_csv(metadata_path, index=False)
    succeeded = len(all_metadata)
    print(f"Finished processing {succeeded}/{total_num_paths} files")

# Run the processing
if __name__ == "__main__":
    main()

Files will be written to ./data/rnasolo_proc/


## Train Model

### train_se3_flows.py

In [None]:
import os
import GPUtil
import torch

import hydra
from omegaconf import DictConfig, OmegaConf

from pytorch_lightning import Trainer
from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

from src.data.pdb_na_datamodule_base import PDBNABaseDataModule
from src.models.flow_module import FlowModule
import src.utils as eu
import wandb

log = eu.get_pylogger(__name__)
torch.set_float32_matmul_precision('high')

class Experiment:
    def __init__(self, *, cfg: DictConfig):
        self._cfg = cfg
        self._data_cfg = cfg.data_cfg
        self._exp_cfg = cfg.experiment
        self._model = FlowModule(self._cfg)
        self._datamodule = PDBNABaseDataModule(data_cfg=self._data_cfg)
 
    def train(self):
        callbacks = []
        
        if self._exp_cfg.debug:
            log.info("Debug mode.")
            logger = None
            self._exp_cfg.num_devices = 1
            self._data_cfg.loader.num_workers = 0
        else:
            logger = WandbLogger(**self._exp_cfg.wandb,)
            
            # Checkpoint directory
            ckpt_dir = self._exp_cfg.checkpointer.dirpath
            os.makedirs(ckpt_dir, exist_ok=True)
            log.info(f"Checkpoints saved to {ckpt_dir}")
            
            # Model checkpoints
            callbacks.append(ModelCheckpoint(**self._exp_cfg.checkpointer))
            
            # Save config
            cfg_path = os.path.join(ckpt_dir, 'config.yaml')
            with open(cfg_path, 'w') as f:
                OmegaConf.save(config=self._cfg, f=f.name)
            cfg_dict = OmegaConf.to_container(self._cfg, resolve=True)
            flat_cfg = dict(eu.flatten_dict(cfg_dict))
            if isinstance(logger.experiment.config, wandb.sdk.wandb_config.Config):
                logger.experiment.config.update(flat_cfg)

        devices = GPUtil.getAvailable(order='memory', limit = 8)[:self._exp_cfg.num_devices]
        log.info(f"Using devices: {devices}")
        
        trainer = Trainer(
            **self._exp_cfg.trainer,
            callbacks=callbacks,
            logger=logger,
            use_distributed_sampler=False,
            enable_progress_bar=True,
            enable_model_summary=True,
            devices=devices,
        )

        trainer.fit(
            model=self._model,
            datamodule=self._datamodule,
            ckpt_path=self._exp_cfg.warm_start
        )
from hydra.core.config_store import ConfigStore
from hydra import initialize, compose

def main():
    # Initialize hydra
    with initialize(version_base=None, config_path="./configs"):
        # Compose the configuration
        cfg = compose(config_name="config")
        
        if cfg.experiment.warm_start is not None and cfg.experiment.warm_start_cfg_override:
            # Loads warm start config.
            warm_start_cfg_path = os.path.join(os.path.dirname(cfg.experiment.warm_start), 'config.yaml')
            warm_start_cfg = OmegaConf.load(warm_start_cfg_path)

            # Warm start config may not have latest fields in the base config.
            # Add these fields to the warm start config.
            OmegaConf.set_struct(cfg.model, False)
            OmegaConf.set_struct(warm_start_cfg.model, False)
            cfg.model = OmegaConf.merge(cfg.model, warm_start_cfg.model)
            OmegaConf.set_struct(cfg.model, True)
            log.info(f'Loaded warm start config from {warm_start_cfg_path}')

        exp = Experiment(cfg=cfg)
        exp.train()

if __name__ == "__main__":
    main()


[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

The training code executes successfully after Weights & Biases (wandb) authorization, but its runtime is significantly long, depending on the GPU used.

After training, the final saved checkpoint can be found at ckpt/se3-fm/rna-frameflow/last.ckpt

## Inference 

### Inference_se3_flows.py

In [1]:
import os
import time
import numpy as np
import torch
from pytorch_lightning import Trainer
from omegaconf import DictConfig, OmegaConf
from hydra import initialize, compose

import src.utils as eu
from src.models.flow_module import FlowModule
from src.data.pdb_na_dataset_base import LengthDataset
from src.analysis.evalsuite import EvalSuite

torch.set_float32_matmul_precision('high')
log = eu.get_pylogger(__name__)

class Sampler:
    def __init__(self, cfg: DictConfig):
        """Initialize sampler.

        Args:
            cfg: inference config.
        """
        ckpt_path = cfg.inference.ckpt_path
        ckpt_dir = os.path.dirname(ckpt_path)
        ckpt_cfg = OmegaConf.load(os.path.join(ckpt_dir, 'config.yaml'))

        # Set-up config.
        OmegaConf.set_struct(cfg, False)
        OmegaConf.set_struct(ckpt_cfg, False)
        cfg = OmegaConf.merge(cfg, ckpt_cfg)
        cfg.experiment.checkpointer.dirpath = './'

        self._cfg = cfg
        self._infer_cfg = cfg.inference
        self._samples_cfg = self._infer_cfg.samples
        self._rng = np.random.default_rng(self._infer_cfg.seed)

        # Set-up directories to write results to
        self._ckpt_name = '/'.join(ckpt_path.replace('.ckpt', '').split('/')[-3:])
        self._output_dir = os.path.join(
            self._infer_cfg.output_dir,
            self._infer_cfg.name,
        )
        os.makedirs(self._output_dir, exist_ok=True)
        log.info(f'Saving results to {self._output_dir}')
        config_path = os.path.join(self._output_dir, 'config.yaml')
        with open(config_path, 'w') as f:
            OmegaConf.save(config=self._cfg, f=f)
        log.info(f'Saving inference config to {config_path}')

        # Read checkpoint and initialize module.
        self._flow_module = FlowModule.load_from_checkpoint(checkpoint_path=ckpt_path)
        
        self._flow_module.eval()
        self._flow_module._infer_cfg = self._infer_cfg
        self._flow_module._samples_cfg = self._samples_cfg
        self._flow_module._output_dir = self._output_dir

    def run_sampling(self):
        log.info("Running on CPU")
        
        eval_dataset = LengthDataset(self._samples_cfg)
        dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=1, shuffle=False, drop_last=False)
        
        trainer = Trainer(
            accelerator="cpu"
        )

        start_time = time.time()
        trainer.predict(self._flow_module, dataloaders=dataloader)
        elapsed_time = time.time() - start_time
        log.info(f'Finished in {elapsed_time:.2f}s')
        log.info(f'Generated samples are stored here: {self._cfg.inference.output_dir}/{self._cfg.inference.name}/')

def run():
    # Initialize hydra
    with initialize(version_base=None, config_path="./camera_ready_ckpts"):
        # Compose the configuration
        cfg = compose(config_name="inference")
        
        # Read model checkpoint and run inference
        if cfg.inference.run_inference:
            log.info('Starting inference on CPU')
            sampler = Sampler(cfg)
            sampler.run_sampling()

        # Run optional eval
        if cfg.inference.evalsuite.run_eval:
            print("Starting EvalSuite on generated backbones ...")
            print(f"Sample directory: {cfg.inference.output_dir}/{cfg.inference.name}/")

            rna_bb_samples_dir = f"{cfg.inference.output_dir}/{cfg.inference.name}"
            saving_dir = cfg.inference.evalsuite.eval_save_dir
            
            # init evaluation module
            evalsuite = EvalSuite(
                        save_dir=saving_dir,
                        paths=cfg.inference.evalsuite.paths,
                        constants=cfg.inference.evalsuite.constants,
                        gpu_id1=None,  # No GPU for inverse-folding model
                        gpu_id2=None,  # No GPU for forward-folding model
                    )
            
            # run self-consistency pipeline
            metric_dict = evalsuite.perform_eval(
                                    rna_bb_samples_dir,
                                    flatten_dir=True
                                )

            # print out global self-consistency metrics
            metrics_fp = os.path.join(saving_dir, "final_metrics.pt")
            metric_dict = evalsuite.load_from_metric_dict(metrics_fp)
            evalsuite.print_metrics(metric_dict) # print eval metrics

if __name__ == '__main__':
    run()

  from .autonotebook import tqdm as notebook_tqdm
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
c:\Users\nikhi\anaconda3\envs\rna-bb-design\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
c:\Users\nikhi\anaconda3\envs\rna-bb-design\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `Data

#### Generating sequences with the following lengths: [40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150]
Predicting DataLoader 0:   4%|▍         | 1/24 [00:07<02:45,  0.14it/s]



Predicting DataLoader 0: 100%|██████████| 24/24 [11:16<00:00,  0.04it/s]
Starting EvalSuite on generated backbones ...
Sample directory: cam_ready_rna_bb_samples/generated_samples/
Created savedir: evalsuite_metrics
Instantiating gRNAde v0.3
    Using device: cpu
    Creating RNA graph featurizer for max_num_conformers=1
    Initialising GNN encoder-decoder model
    Loading model checkpoint: src/tools/grnade_api/checkpoints/gRNAde_ARv1_1state_all.h5
Finished initialising gRNAde v0.3

Flattening directory ...
Loaded metadata CSV with 343 filtered cluster samples ...


  0%|          | 0/12 [00:00<?, ?it/s]

: 

Results will be saved in the "results" folder.  However, evaluation requires metrics calculation, which necessitates instantiating the gRNAde model.  Since the gRNAde model functions on Linux, this code will run without issues in a Linux environment.