# Evaluate the performance of fine-tuned model predicting IRES

In [1]:
from LAMAR.sequence_classification_patch import EsmForSequenceClassification
from transformers import AutoConfig, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
import torch
from torch import nn
import tqdm
import numpy as np
from sklearn.metrics import precision_score, accuracy_score, recall_score, f1_score, roc_auc_score, precision_recall_curve, auc
import pandas as pd
from safetensors.torch import load_file, load_model
import os

  from pandas.core import (


In [2]:
os.chdir('/picb/rnasys2/zhouhanwen/github/LAMAR/')

In [3]:
# Tokenizer
tokenizer_path = 'tokenizer/single_nucleotide/'
model_max_length = 1500
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, model_max_length=model_max_length)

# Config
model_name = 'config/config_150M.json'
token_dropout = False
positional_embedding_type = 'rotary'
nlabels = 2
hidden_size = 768
intermediate_size = 3072
num_attention_heads = 12
num_hidden_layers = 12
config = AutoConfig.from_pretrained(
    model_name, vocab_size=len(tokenizer), pad_token_id=tokenizer.pad_token_id, mask_token_id=tokenizer.mask_token_id, num_labels=nlabels, 
    token_dropout=token_dropout, positional_embedding_type=positional_embedding_type, 
    hidden_size=hidden_size, intermediate_size=intermediate_size, num_attention_heads=num_attention_heads, num_hidden_layers=num_hidden_layers
)

In [4]:
def compute_pr_auc(true_ids, probs):
    """
    Compute the PR-AUC score
    Input:
        true_ids
        logits, raw model logits
    Return:
        pr_auc, float
    """
    
    precision, recall, threshold = precision_recall_curve(true_ids, probs)
    pr_auc = auc(recall, precision)
    
    return pr_auc

In [5]:
# Inference data
seq_df = pd.read_csv('IRESPred/data/testing_set.Pos1Fold.Train1Fold.4.csv')
seqs = seq_df['seq'].values.tolist()
true_labels = seq_df['label'].values.tolist()

In [6]:
# Model
device = torch.device('cuda:0')
model_state_path = 'IRESPred/saving_model/mammalian_2048/bs16_lr1e-4_wr0.05_4epochs_4/checkpoint-460/pytorch_model.bin'
model = EsmForSequenceClassification(config, head_type='Linear', freeze=False, kernel_sizes=None, ocs=None)
model = model.to(device)
if model_state_path.endswith('.safetensors'):
    load_model(model, filename=model_state_path, strict=True)
elif model_state_path.endswith('.bin'):
    model.load_state_dict(torch.load(model_state_path), strict=True)

In [7]:
softmax = nn.Softmax(dim=1)
predict_labels, predict_probs = [], []
model.eval()
with torch.no_grad():
    for seq in tqdm.tqdm(seqs):
        batch = tokenizer(seq, return_tensors='pt')
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        model_output = model(
            input_ids = input_ids,
            attention_mask = attention_mask,
            position_ids = None,
            head_mask = None,
            inputs_embeds = None,
            labels = None,
            output_attentions = None,
            output_hidden_states = None,
            return_dict = None
        )
        batch_logits = model_output.logits
        predict_probs.extend(softmax(batch_logits)[:, 1].tolist())
        predict_labels.extend(batch_logits.argmax(dim=1).tolist())

100%|███████████████████████████████████████████████████████████████████████████████| 1952/1952 [00:36<00:00, 53.12it/s]


In [8]:
pd.crosstab(true_labels, predict_labels)

col_0,0,1
row_0,Unnamed: 1_level_1,Unnamed: 2_level_1
0,908,42
1,63,939


In [9]:
print('%.3f' % precision_score(true_labels, predict_labels))
print('%.3f' % recall_score(true_labels, predict_labels))
print('%.3f' % f1_score(true_labels, predict_labels))
print('%.3f' % roc_auc_score(true_labels, predict_probs))
print('%.3f' % compute_pr_auc(true_labels, predict_probs))

0.957
0.937
0.947
0.985
0.986
