# Evaluate the performance of Fine-tune model predicting half-life of mRNAs based on 3' UTRs

In [1]:
from LaMorena.sequence_classification_patch import EsmForSequenceClassification
from transformers import AutoConfig, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
import torch
import tqdm
import numpy as np
import pandas as pd
from safetensors.torch import load_file, load_model
import os
import matplotlib.pyplot as plt

  from pandas.core import (


In [2]:
os.chdir('/picb/rnasys2/zhouhanwen/nucTran/github/')

In [3]:
# Tokenizer
tokenizer_path = 'tokenizer/single_nucleotide/'
model_max_length = 1026
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, model_max_length=model_max_length, padding_side='left')

# Config
model_name = 'config/config_150M.json'
token_dropout = False
positional_embedding_type = 'rotary'
nlabels = 1
hidden_size = 768
intermediate_size = 3072
num_attention_heads = 12
num_hidden_layers = 12
config = AutoConfig.from_pretrained(
    model_name, vocab_size=len(tokenizer), pad_token_id=tokenizer.pad_token_id, mask_token_id=tokenizer.mask_token_id, num_labels=nlabels, 
    token_dropout=token_dropout, positional_embedding_type=positional_embedding_type, 
    hidden_size=hidden_size, intermediate_size=intermediate_size, num_attention_heads=num_attention_heads, num_hidden_layers=num_hidden_layers
)

In [4]:
# Inference data
seq_df = pd.read_csv('UTR3DegPred/data/DavidErle/training/validation_set.csv')
seqs = seq_df['seq'].values.tolist()
true_labels = seq_df['label'].values.tolist()

In [5]:
# Model
device = torch.device('cuda:0')
model_state_path = 'UTR3DegPred/saving_model/DavidErle/mammalian_4096/mammalian80D_4096len1mer1sw_80M_250k_DegPred_DavidErle_2_bs8_lr5e-5_wr0.05_16epochs/checkpoint-3180/model.safetensors'
model = EsmForSequenceClassification(config, head_type='Linear', freeze=False, kernel_sizes=None, ocs=None)
model = model.to(device)
if model_state_path.endswith('.safetensors'):
    load_model(model, filename=model_state_path, strict=True)
elif model_state_path.endswith('.bin'):
    model.load_state_dict(torch.load(model_state_path), strict=True)

In [6]:
predict_labels = []
model.eval()
with torch.no_grad():
    for seq in tqdm.tqdm(seqs):
        batch = tokenizer(seq, return_tensors='pt', padding=True)
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']

        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        model_output = model(
            input_ids = input_ids,
            attention_mask = attention_mask,
            position_ids = None,
            head_mask = None,
            inputs_embeds = None,
            labels = None,
            output_attentions = None,
            output_hidden_states = None,
            return_dict = None
        )
        batch_logits = model_output.logits
        predict_labels.extend(batch_logits.tolist()[0])

100%|█████████████████████████████████████████████████████████████████████████████████| 196/196 [00:04<00:00, 48.02it/s]


In [7]:
result_df = pd.DataFrame({'predict': predict_labels, 'true': true_labels})
mse = np.mean((np.array(predict_labels) - np.array(true_labels)) ** 2)
pearson_corr_coef = result_df.corr(method='pearson').iloc[0, 1]
spearman_corr_coef = result_df.corr(method='spearman').iloc[0, 1]

In [8]:
print(mse)
print(pearson_corr_coef)
print(spearman_corr_coef)

0.16963906057076436
0.6498806108186871
0.6647705074255787
