# Fine-tune LAMAR to predict the splice sites of pre-mRNAs

In [None]:
from LaMorena.modeling_nucESM2 import EsmForTokenClassification
from transformers import AutoConfig, AutoTokenizer, DataCollatorForTokenClassification, TrainingArguments, Trainer
from datasets import load_dataset, load_from_disk
from safetensors.torch import load_file, load_model
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
from scipy.special import softmax
import os
import pandas as pd
import torch
import evaluate
import numpy as np

In [None]:
os.chdir('/picb/rnasys2/zhouhanwen/nucTran/github/')

## Parameters
**The following parameters can be changed.**      
data_path: path of finetuning data  
pretrain_state_path: path of pretraining weights  
batch_size: <= 8 for single card, here we use V100 32G  
peak_lr: peak learning rate, 1e-5 ~ 1e-4 in most conditions   
total_epochs: num of finetuning epochs  
accum_steps: accumulation steps if using gradient accumulation  
output_dir: path of saving model  
logging_steps: num of training steps to log loss value  

In [None]:
tokenizer_path = 'tokenizer/single_nucleotide/'
model_max_length = 1026
model_name = 'config/config_150M.json'
token_dropout = False
positional_embedding_type = 'rotary'
hidden_size = 768
intermediate_size = 3072
num_attention_heads = 12
num_hidden_layers = 12
nlabels = 3
data_path = 'SpliceSitePrediction/data/SpliceAI/training/ss_single_nucleotide/'
pretrain_state_path = 'pretrain/saving_model/mammalian80D_4096len1mer1sw_80M/checkpoint-250000/model.safetensors'
# pretrain_state_path = None
batch_size = 8
peak_lr = 5e-5
warmup_ratio = 0.05
total_epochs = 1
grad_clipping_norm = 1
accum_steps = 1
output_dir = 'SpliceSitePrediction/saving_model/test/test01'
save_epochs = 1000
logging_steps = 100
fp16 = False
flash_attention = False

In [None]:
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, model_max_length=model_max_length)

# Config
config = AutoConfig.from_pretrained(
    model_name, vocab_size=len(tokenizer), pad_token_id=tokenizer.pad_token_id, mask_token_id=tokenizer.mask_token_id, num_labels=nlabels, 
    token_dropout=token_dropout, positional_embedding_type=positional_embedding_type, 
    hidden_size=hidden_size, intermediate_size=intermediate_size, num_attention_heads=num_attention_heads, num_hidden_layers=num_hidden_layers
)

# Training data
data = load_from_disk(data_path)

# Data Collator
data_collator = DataCollatorForTokenClassification(
    tokenizer=tokenizer, padding="max_length", max_length=tokenizer.model_max_length
)

# Model
model = EsmForTokenClassification(config)
if pretrain_state_path:
    print("Loading parameters of pretraining model: {}".format(pretrain_state_path))
    if pretrain_state_path.endswith('.bin'):
        model.load_state_dict(torch.load(pretrain_state_path), strict=False)
    elif pretrain_state_path.endswith('.safetensors'):
        load_model(model, filename=pretrain_state_path, strict=False)
else:
    print("No Loading parameters of pretraining model !!")

In [None]:
# Training arguments
train_args = TrainingArguments(
    disable_tqdm=False, 
    save_total_limit=1, 
    dataloader_drop_last=True, 
    per_device_train_batch_size=batch_size, 
    per_device_eval_batch_size=16, 
    learning_rate=peak_lr, 
    weight_decay=0.01, 
    adam_beta1=0.9, 
    adam_beta2=0.98, 
    adam_epsilon=1e-8, 
    warmup_ratio=warmup_ratio, 
    num_train_epochs=total_epochs, 
    max_grad_norm=grad_clipping_norm, 
    gradient_accumulation_steps=accum_steps, 
    output_dir=output_dir, 
    evaluation_strategy="steps",
    eval_steps=logging_steps, 
    save_strategy='steps', 
    save_steps=save_epochs, 
    logging_strategy = 'steps', 
    logging_steps=logging_steps, 
    fp16=fp16, 
    report_to="none"
)

In [None]:
# evaluation metrics
def compute_binary_pr_auc(reference, predict_logits):
    precision, recall, _ = precision_recall_curve(reference, predict_logits)
    return auc(recall, precision)


def compute_ovr_pr_auc(reference, predict_logits, average=None, ignore_idx=[]):
    n_classes = predict_logits.shape[1]
    pr_aucs = []
    for class_idx in range(n_classes):
        if class_idx not in ignore_idx:
            pr_auc = compute_binary_pr_auc((reference == class_idx).astype(int), predict_logits[:, class_idx])
            pr_aucs.append(pr_auc)
    if average == "macro":
        return np.mean(pr_aucs)
    elif average == "weighted":
        class_counts = np.bincount(reference)
        weighted_pr_aucs = np.array(pr_aucs) * class_counts / len(reference)
        return np.sum(weighted_pr_aucs)
    else:
        return pr_aucs


def compute_ovo_pr_auc(reference, predict_logits, average=None):
    # OvO is not directly supported by precision_recall_curve
    raise NotImplementedError("OvO PR AUC computation is not implemented yet.")


def pr_auc_score(reference, predict_logits, multi_class=None, average=None):
    if multi_class == "ovr":
        pr_auc = compute_ovr_pr_auc(reference, predict_logits, average=average)
    elif multi_class == "ovo":
        pr_auc = compute_ovo_pr_auc(reference, predict_logits, average=average)
    else:
        pr_auc = compute_binary_pr_auc(reference, predict_logits)
    return pr_auc


def compute_metrics(p):
    ignore_label = -100
    logits, labels = p
    softpred = softmax(logits, axis=2)
    pred_label = np.argmax(softpred, axis=2).astype(np.int8)
    logits = softpred.reshape((softpred.shape[0] * softpred.shape[1], -1))
    table = pd.DataFrame(logits)
    table["pred"] = np.array(pred_label).flatten()
    table["true"] = np.array(labels).flatten()
    table = table[table["true"] != ignore_label]
    # print("finish flatten")
    result = {}
    counts = table.true.value_counts().to_dict()
    result["topk"] = {"topk":{
        k:sum(
            (table.sort_values(by=k, ascending=False)[:v]).true==k
            )/v 
        for k,v in counts.items()
        }}
    scores = table.loc[:, table.columns[~table.columns.isin(["pred", "true"])]].values
    result["roc_auc"] = list(
        roc_auc_score(
            table["true"],
            scores,
            multi_class="ovr",
            average=None
        )
    )
    result["pr_auc"] = list(
        pr_auc_score(
            table["true"],
            scores,
            multi_class="ovr",
            average=None
        )
    )
    return result

In [None]:
# Trainer
trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=data['train'], 
    eval_dataset=data['test'],
    data_collator=data_collator, 
    tokenizer=tokenizer, 
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()