# Fine-tune RNA-FM to predict the splice sites of pre-mRNAs

In [None]:
import sys
sys.path.append('/work/home/rnasys/zhouhanwen/github/LAMAR_baselines/RNA-FM/')
from token_classification_patch import Config, RnafmForTokenClassification
from transformers import AutoConfig, AutoTokenizer, DataCollatorForTokenClassification, TrainingArguments, Trainer
from datasets import load_dataset, load_from_disk
from safetensors.torch import load_file, load_model
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
from scipy.special import softmax
import os
import pandas as pd
import torch
import evaluate
import numpy as np

In [None]:
os.chdir('/work/home/rnasys/zhouhanwen/github/LAMAR_baselines/')

In [None]:
tokenizer_path = 'tokenizer/RNA-FM/'
model_max_length = 1026
hidden_size = 640
num_labels = 3
hidden_dropout_prob = 0
data_path = 'SpliceSitePred/data/RNA-FM/ss_single_nucleotide/'
freeze = False
batch_size = 4
peak_lr = 1e-4
warmup_ratio = 0.05
total_epochs = 4
grad_clipping_norm = 1
accum_steps = 2
output_dir = 'SpliceSitePred/saving_model/RNA-FM/bs128_lr1e-4_wr0.05_4epochs'
# output_dir = 'SpliceSitePred/saving_model/RNA-FM/test'
save_epochs = 1000
logging_steps = 100
fp16 = False
flash_attention = False

In [None]:
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, model_max_length=model_max_length)

# Config
hyperparams = Config(hidden_size=hidden_size, num_labels=num_labels, hidden_dropout_prob=hidden_dropout_prob)

# Training data
data = load_from_disk(data_path)

# Data Collator
data_collator = DataCollatorForTokenClassification(
    tokenizer=tokenizer, padding="max_length", max_length=tokenizer.model_max_length
)

# Model
pretrained_state_path = 'RNA-FM/RNA-FM_pretrained.pth'
model = RnafmForTokenClassification(pretrained_weights_location=pretrained_state_path, hyperparams=hyperparams, freeze=freeze)

In [None]:
# Training arguments
train_args = TrainingArguments(
    disable_tqdm=False, 
    save_total_limit=1, 
    dataloader_drop_last=True, 
    per_device_train_batch_size=batch_size, 
    per_device_eval_batch_size=16, 
    learning_rate=peak_lr, 
    weight_decay=0.01, 
    adam_beta1=0.9, 
    adam_beta2=0.98, 
    adam_epsilon=1e-8, 
    warmup_ratio=warmup_ratio, 
    num_train_epochs=total_epochs, 
    max_grad_norm=grad_clipping_norm, 
    gradient_accumulation_steps=accum_steps, 
    output_dir=output_dir, 
    evaluation_strategy="steps",
    eval_steps=logging_steps, 
    save_strategy='steps', 
    save_steps=save_epochs, 
    logging_strategy = 'steps', 
    logging_steps=logging_steps, 
    fp16=fp16, 
    report_to="none", 
    save_safetensors=False
)

In [None]:
# evaluation metrics
def compute_binary_pr_auc(reference, predict_logits):
    precision, recall, _ = precision_recall_curve(reference, predict_logits)
    return auc(recall, precision)


def compute_ovr_pr_auc(reference, predict_logits, average=None, ignore_idx=[]):
    n_classes = predict_logits.shape[1]
    pr_aucs = []
    for class_idx in range(n_classes):
        if class_idx not in ignore_idx:
            pr_auc = compute_binary_pr_auc((reference == class_idx).astype(int), predict_logits[:, class_idx])
            pr_aucs.append(pr_auc)
    if average == "macro":
        return np.mean(pr_aucs)
    elif average == "weighted":
        class_counts = np.bincount(reference)
        weighted_pr_aucs = np.array(pr_aucs) * class_counts / len(reference)
        return np.sum(weighted_pr_aucs)
    else:
        return pr_aucs


def compute_ovo_pr_auc(reference, predict_logits, average=None):
    # OvO is not directly supported by precision_recall_curve
    raise NotImplementedError("OvO PR AUC computation is not implemented yet.")


def pr_auc_score(reference, predict_logits, multi_class=None, average=None):
    if multi_class == "ovr":
        pr_auc = compute_ovr_pr_auc(reference, predict_logits, average=average)
    elif multi_class == "ovo":
        pr_auc = compute_ovo_pr_auc(reference, predict_logits, average=average)
    else:
        pr_auc = compute_binary_pr_auc(reference, predict_logits)
    return pr_auc


def compute_metrics(p):
    ignore_label = -100
    logits, labels = p
    softpred = softmax(logits, axis=2)
    pred_label = np.argmax(softpred, axis=2).astype(np.int8)
    logits = softpred.reshape((softpred.shape[0] * softpred.shape[1], -1))
    table = pd.DataFrame(logits)
    table["pred"] = np.array(pred_label).flatten()
    table["true"] = np.array(labels).flatten()
    table = table[table["true"] != ignore_label]
    # print("finish flatten")
    result = {}
    counts = table.true.value_counts().to_dict()
    result["topk"] = {
        "topk": {k: sum((table.sort_values(by=k, ascending=False)[:v]).true == k) / v for k, v in counts.items()}
    }
    scores = table.loc[:, table.columns[~table.columns.isin(["pred", "true"])]].values
    result["roc_auc"] = list(
        roc_auc_score(
            table["true"],
            scores,
            multi_class="ovr",
            average=None
        )
    )
    result["pr_auc"] = list(
        pr_auc_score(
            table["true"],
            scores,
            multi_class="ovr",
            average=None
        )
    )
    return result

In [None]:
# Trainer
trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=data['train'], 
    eval_dataset=data['test'],
    data_collator=data_collator, 
    tokenizer=tokenizer, 
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()