In [None]:
# @title Imports

import os
import h5py
import pandas as pd
import torch as tr
from tqdm.notebook import tqdm

print("Installing packages...")
from transformers import AutoTokenizer, AutoModel
!pip install multimolecule > /dev/null 2>&1
from multimolecule import RnaTokenizer

!git clone --quiet https://github.com/sinc-lab/rna-llm-folding > null
os.chdir("rna-llm-folding/")

print("Done!")

In [None]:
# @title Select dataset and RNA-LLM

Dataset = "PDB-RNA" #@param ["ArchiveII_kfold", "ArchiveII_famfold", "bpRNA", "bpRNA-new", "PDB-RNA"]
dataset_name = Dataset
RNA_LLM = "RNAErnie" #@param ["RNAErnie", "RNA-FM", "RNABERT", "RNA-MSM", "RiNALMo", "ERNIE-RNA"]
model_name = RNA_LLM.lower().replace("-", "")

llm_path = f"multimolecule/{model_name}"

print(f"loading {RNA_LLM}")
tokenizer = AutoTokenizer.from_pretrained(llm_path,
                                          trust_remote_code=True,
                                          cls_token=None, eos_token=None)
model = AutoModel.from_pretrained(llm_path, trust_remote_code=True)


In [None]:
# @title Load dataset and generate embeddings

# load dataset
dataset_name_base = dataset_name.split("_")[0]
df = pd.read_csv(f"data/{dataset_name_base}.csv", index_col="id")

max_len = 512
if RNA_LLM=="RNABERT":
    max_len = 440

df["len"] = df.sequence.str.len()
df = df[df.len<max_len]

# generate embeddings
embeddings = {}
for k in tqdm(range(len(df))):
    id = df.iloc[k].name
    sequence = df.iloc[k].sequence
    with tr.no_grad():
        input = tokenizer(sequence, return_tensors="pt")
        output = model(**input)["last_hidden_state"][0, 1:, :]
    embeddings[id] = output

emb_file = f'data/embeddings/{model_name}_{dataset_name_base}.h5'
with h5py.File(emb_file, 'w') as hdf:
    print("Writing embedding file", emb_file)
    for key, value in embeddings.items():
        hdf.create_dataset(key, data=value)

In [None]:
# @title Run scripts for train and test

cmd = f"python scripts/run_{dataset_name.lower()}.py --emb {model_name}_{dataset_name_base}"
print(cmd)
!{cmd}