# Evaluate the performance of RNAErnie-SS in prediction of splice site

In [1]:
import sys
sys.path.append('/work/home/rnasys/zhouhanwen/github/LAMAR_baselines/RNAErnie')
from tokenization_rnaernie import RNAErnieTokenizer
from transformers import AutoConfig, AutoTokenizer, DataCollatorForTokenClassification, TrainingArguments, Trainer, AutoModelForTokenClassification
from safetensors.torch import load_file, load_model
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
import os
import pandas as pd
import torch
from torch import nn
import numpy as np
import json
import tqdm

In [2]:
os.chdir('/work/home/rnasys/zhouhanwen/github/LAMAR_baselines/')

In [3]:
tokenizer_path = 'tokenizer/RNAErnie/'
model_name = 'RNAErnie/config.json'
model_max_length = 1026
nlabels = 3
model_state_path = 'SpliceSitePred/saving_model/RNAErnie/bs256_lr1e-04_wr0.05_4epochs/checkpoint-10000/model.safetensors'
test_set_path = 'SpliceSitePred/data/gencode_test.json'

In [4]:
# Tokenizer
tokenizer = RNAErnieTokenizer.from_pretrained('tokenizer/RNAErnie/', model_max_length=model_max_length)

# Config
config = AutoConfig.from_pretrained(
    model_name, vocab_size=len(tokenizer), pad_token_id=tokenizer.pad_token_id, mask_token_id=tokenizer.mask_token_id, num_labels=nlabels, 
    classifier_dropout=0
)

# Model
device = torch.device('cuda:0')
model = AutoModelForTokenClassification.from_config(config)
model = model.to(device)
print("Loading parameters of fine-tuning model: {}".format(model_state_path))
load_model(model, filename=model_state_path, strict=False)

Loading parameters of fine-tuning model: SpliceSitePred/saving_model/RNAErnie/bs256_lr1e-04_wr0.05_4epochs/checkpoint-10000/model.safetensors


(set(), [])

In [5]:
# evaluation metrics
def compute_binary_prauc(true_label, pred_prob):
    """
    Compute PRAUC for single label classification (binary).
    Args:
        true_label(np.array): true labels of sites
        pred_prob(np.array): predicted probabilities of sites, seq len * 1
    """
    precision, recall, _ = precision_recall_curve(true_label, pred_prob)
    prauc = auc(recall, precision)
    return prauc


def compute_ovr_prauc(true_label, pred_prob):
    """
    Compute PRAUC for single label classification (multi-class).
    One vs Rest.
    true_label(np.array): true labels of sites
    pred_prob(np.array): predicted probabilities of sites, seq len * 3
    """
    n_classes = pred_prob.shape[1]
    praucs = []
    for class_idx in range(n_classes):
        prauc = compute_binary_prauc((true_label == class_idx).astype(int), pred_prob[:, class_idx])
        praucs.append(prauc)
    return praucs


def compute_metrics(true_label, pred_prob):
    """
    Compute top-K accuracy for single label (multi-class).
    One vs Rest.
    true_label(np.array): true labels of sites
    pred_prob(np.array): predicted probabilities of sites, seq len * 3
    """
    df = pd.DataFrame(pred_prob)
    df['true_label'] = true_label
    df = df[df["true_label"] != -100]
    counts = df['true_label'].value_counts().to_dict()
    topk_accuracy = [sum((df.sort_values(by=k, ascending=False)[:v])['true_label'] == k) / v for k, v in counts.items()]
    praucs = compute_ovr_prauc(df['true_label'].values, df[[0, 1, 2]].values)
    return topk_accuracy, praucs

In [6]:
with open(test_set_path) as f:
    test_set = json.load(f)

In [7]:
seqs, true_labels = [], []
for ele in tqdm.tqdm(test_set):
    seqs.append(ele['seq'])
    true_labels.extend(ele['label']) 
true_labels = np.array(true_labels)

100%|██████████| 102905/102905 [00:00<00:00, 177514.53it/s]


In [8]:
softmax = nn.Softmax(dim=2)
pred_labels, pred_probs = [], []
model.eval()
with torch.no_grad():
    for seq in tqdm.tqdm(seqs):
        batch = tokenizer(seq, return_tensors='pt', padding=True)
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']

        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        model_output = model(
            input_ids = input_ids,
            attention_mask = attention_mask,
            position_ids = None,
            head_mask = None,
            inputs_embeds = None,
            labels = None,
            output_attentions = None,
            output_hidden_states = None,
            return_dict = None
        )
        batch_logits = model_output.logits
        pred_probs.extend(softmax(batch_logits)[0].tolist())
        pred_labels.extend(torch.argmax(batch_logits[0], dim=1).tolist())
pred_probs = np.array(pred_probs)

100%|██████████| 102905/102905 [25:24<00:00, 67.52it/s] 


In [9]:
topk_accuracy, praucs = compute_metrics(true_labels, pred_probs)

In [10]:
topk_accuracy_mean = np.mean(topk_accuracy[1:])
prauc_mean = np.mean(praucs[1:])
result_df = pd.DataFrame({'topk_accuracy': topk_accuracy + [topk_accuracy_mean.tolist()], 'prauc': praucs + [prauc_mean.tolist()]})

In [11]:
result_df

Unnamed: 0,topk_accuracy,prauc
0,0.99996,1.0
1,0.884317,0.941074
2,0.896496,0.951452
3,0.890406,0.946263


In [12]:
result_df.to_csv('SpliceSitePred/data/RNAErnie/result_bs256_lr1e-04_wr0.05_4epochs.csv', index=False)