# Evaluate the performance of RNA-FM-SS in prediction of splice site

In [None]:
import sys
sys.path.append('/work/home/rnasys/zhouhanwen/github/LAMAR_baselines/RNA-FM/')
from token_classification_patch import Config, RnafmForTokenClassification
from transformers import AutoConfig, AutoTokenizer, DataCollatorForTokenClassification, TrainingArguments, Trainer
from safetensors.torch import load_file, load_model
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
import os
import pandas as pd
import torch
from torch import nn
import numpy as np
import json
import tqdm

In [None]:
os.chdir('/work/home/rnasys/zhouhanwen/github/LAMAR_baselines/')

In [None]:
tokenizer_path = 'tokenizer/RNA-FM/'
model_max_length = 1026
hidden_size = 640
nlabels = 3
hidden_dropout_prob = 0
device = torch.device('cuda:0')
pretrained_state_path = 'RNA-FM/RNA-FM_pretrained.pth'
freeze = False
model_state_path = 'SpliceSitePred/saving_model/RNA-FM/bs128_lr1e-4_wr0.05_4epochs/checkpoint-20050/pytorch_model.bin'
test_set_path = 'SpliceSitePred/data/RNA-FM/gencode_test.json'

In [None]:
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, model_max_length=model_max_length)

# Config
hyperparams = Config(hidden_size, nlabels, hidden_dropout_prob)

# Model
model = RnafmForTokenClassification(pretrained_weights_location=pretrained_state_path, hyperparams=hyperparams, freeze=freeze)    
model = model.to(device)
if model_state_path.endswith('.safetensors'):
    load_model(model, filename=model_state_path, strict=True)
elif model_state_path.endswith('.bin'):
    model.load_state_dict(torch.load(model_state_path), strict=True)

In [None]:
# evaluation metrics
def compute_binary_prauc(true_label, pred_prob):
    """
    Compute PRAUC for single label classification (binary).
    Args:
        true_label(np.array): true labels of sites
        pred_prob(np.array): predicted probabilities of sites, seq len * 1
    """
    precision, recall, _ = precision_recall_curve(true_label, pred_prob)
    prauc = auc(recall, precision)
    return prauc


def compute_ovr_prauc(true_label, pred_prob):
    """
    Compute PRAUC for single label classification (multi-class).
    One vs Rest.
    true_label(np.array): true labels of sites
    pred_prob(np.array): predicted probabilities of sites, seq len * 3
    """
    n_classes = pred_prob.shape[1]
    praucs = []
    for class_idx in range(n_classes):
        prauc = compute_binary_prauc((true_label == class_idx).astype(int), pred_prob[:, class_idx])
        praucs.append(prauc)
    return praucs


def compute_metrics(true_label, pred_prob):
    """
    Compute top-K accuracy for single label (multi-class).
    One vs Rest.
    true_label(np.array): true labels of sites
    pred_prob(np.array): predicted probabilities of sites, seq len * 3
    """
    df = pd.DataFrame(pred_prob)
    df['true_label'] = true_label
    df = df[df["true_label"] != -100]
    counts = df['true_label'].value_counts().to_dict()
    topk_accuracy = [sum((df.sort_values(by=k, ascending=False)[:v])['true_label'] == k) / v for k, v in counts.items()]
    praucs = compute_ovr_prauc(df['true_label'].values, df[[0, 1, 2]].values)
    return topk_accuracy, praucs

In [None]:
with open(test_set_path) as f:
    test_set = json.load(f)

In [None]:
seqs, true_labels = [], []
for ele in tqdm.tqdm(test_set):
    seqs.append(ele['seq'])
    true_labels.extend(ele['label']) 
true_labels = np.array(true_labels)

In [None]:
softmax = nn.Softmax(dim=2)
pred_labels, pred_probs = [], []
model.eval()
with torch.no_grad():
    for seq in tqdm.tqdm(seqs):
        batch = tokenizer(seq, return_tensors='pt', padding=True)
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']

        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        model_output = model(
            input_ids = input_ids,
            attention_mask = attention_mask,
            position_ids = None,
            head_mask = None,
            inputs_embeds = None,
            labels = None,
            output_attentions = None,
            output_hidden_states = None,
            return_dict = None
        )
        batch_logits = model_output.logits
        pred_probs.extend(softmax(batch_logits)[0].tolist())
        pred_labels.extend(torch.argmax(batch_logits[0], dim=1).tolist())
pred_probs = np.array(pred_probs)

In [None]:
topk_accuracy, praucs = compute_metrics(true_labels, pred_probs)

In [None]:
topk_accuracy_mean = np.mean(topk_accuracy[1:])
prauc_mean = np.mean(praucs[1:])
result_df = pd.DataFrame({'topk_accuracy': topk_accuracy + [topk_accuracy_mean.tolist()], 'prauc': praucs + [prauc_mean.tolist()]})

In [None]:
result_df