# Fine-tune LAMAR to predict IRES

In [1]:
from LAMAR.sequence_classification_patch import EsmForSequenceClassification
from transformers import AutoConfig, AutoTokenizer, DataCollatorWithPadding, TrainingArguments, Trainer
from datasets import load_dataset, load_from_disk
import os
import torch
import numpy as np
from safetensors.torch import load_file, load_model
import evaluate
from sklearn.metrics import precision_recall_curve, auc

  from pandas.core import (


In [2]:
os.chdir('/picb/rnasys2/zhouhanwen/github/LAMAR/')
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

## Parameters
**The following parameters can be changed.**    
nlabels: num of predicted labels, = 1 means regression, = 2 means binary classification    
data_path: path of finetuning data  
pretrain_state_path: path of pretraining weights  
batch_size: <= 8 for single card, here we use V100 32G  
peak_lr: peak learning rate, 1e-5 ~ 1e-4 in most conditions   
total_epochs: num of finetuning epochs  
accum_steps: accumulation steps if using gradient accumulation  
output_dir: path of saving model  
logging_steps: num of training steps to log loss value  

In [3]:
tokenizer_path = 'tokenizer/single_nucleotide/'
model_max_length = 1500
model_name = 'config/config_150M.json'
token_dropout = False
positional_embedding_type = 'rotary'
hidden_size = 768
intermediate_size = 3072
num_attention_heads = 12
num_hidden_layers = 12
nlabels = 2
data_path = 'IRESPred/data/IRES_4/'
pretrain_state_path = 'pretrain/saving_model/mammalian80D_2048len1mer1sw_80M/checkpoint-250000/model.safetensors'
# pretrain_state_path = None
batch_size = 8
peak_lr = 1e-4
warmup_ratio = 0.05
total_epochs = 4
grad_clipping_norm = 1
accum_steps = 2
output_dir = 'IRESPred/saving_model/mammalian_2048/bs16_lr1e-4_wr0.05_4epochs_4'
save_epochs = 10
logging_steps = 100
fp16 = False
flash_attention = False
head_type = 'Linear'
freeze = False
kernel_sizes = [2, 3, 5]
ocs = 32

In [4]:
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, model_max_length=model_max_length)

# Config
config = AutoConfig.from_pretrained(
    model_name, vocab_size=len(tokenizer), pad_token_id=tokenizer.pad_token_id, mask_token_id=tokenizer.mask_token_id, num_labels=nlabels, 
    token_dropout=token_dropout, positional_embedding_type=positional_embedding_type, 
    hidden_size=hidden_size, intermediate_size=intermediate_size, num_attention_heads=num_attention_heads, num_hidden_layers=num_hidden_layers
)

# Training data
data = load_from_disk(data_path)

# Data Collator
data_collator = DataCollatorWithPadding(
    tokenizer=tokenizer, padding=True
)

In [5]:
# Model
model = EsmForSequenceClassification(config, head_type=head_type, freeze=freeze, kernel_sizes=kernel_sizes, ocs=ocs)
if flash_attention:
    from flash_attn_patch import EsmSelfAttentionAddFlashAttnPatch
    for i in range(config.num_hidden_layers):
        model.esm.encoder.layer[i].attention.self = EsmSelfAttentionAddFlashAttnPatch(config, position_embedding_type='rotary')
if pretrain_state_path:
    print("Loading parameters of pretraining model: {}".format(pretrain_state_path))
    if pretrain_state_path.endswith('.pt'):
        model.load_state_dict(torch.load(pretrain_state_path)['MODEL_STATE'], strict=False)
    elif pretrain_state_path.endswith('.bin'):
        model.load_state_dict(torch.load(pretrain_state_path), strict=False)
    elif pretrain_state_path.endswith('.safetensors'):
        load_model(model, filename=pretrain_state_path, strict=False)
else:
    print("No Loading parameters of pretraining model !!")

Loading parameters of pretraining model: pretrain/saving_model/mammalian80D_2048len1mer1sw_80M/checkpoint-250000/model.safetensors


In [6]:
# Training arguments
train_args = TrainingArguments(
    disable_tqdm=False, 
    save_total_limit=1, 
    dataloader_drop_last=True, 
    per_device_train_batch_size=batch_size, 
    learning_rate=peak_lr, 
    weight_decay=0.01, 
    adam_beta1=0.9, 
    adam_beta2=0.98, 
    adam_epsilon=1e-8, 
    warmup_ratio=warmup_ratio, 
    num_train_epochs=total_epochs, 
    max_grad_norm=grad_clipping_norm, 
    gradient_accumulation_steps=accum_steps, 
    output_dir=output_dir, 
    save_strategy='steps', 
    save_steps=save_epochs, 
    logging_strategy = 'steps', 
    logging_steps=logging_steps, 
    evaluation_strategy="steps", 
    eval_steps=logging_steps, 
    fp16=fp16, 
    report_to='none'
)

In [7]:
def compute_metrics(p):
    """
    labels: true labels
    predictions: predict labels
    pred_probs: predict probabilities
    """
    # metrics
    accuracy = evaluate.load("metrics/accuracy")
    precision = evaluate.load("metrics/precision")
    recall = evaluate.load("metrics/recall")
    f1 = evaluate.load("metrics/f1")
    roc_auc = evaluate.load("metrics/roc_auc")
    
    predictions, labels = p
    pred_probs = np.exp(predictions) / np.sum(np.exp(predictions), axis=1, keepdims=True)
    predictions = np.argmax(predictions, axis=1).flatten()
    labels = np.array(labels).flatten()
    
    accuracy_v = accuracy.compute(references=labels, predictions=predictions)
    precision_v = precision.compute(references=labels, predictions=predictions, zero_division=0)
    recall_v = recall.compute(references=labels, predictions=predictions)
    f1_v = f1.compute(references=labels, predictions=predictions)
    roc_auc_v = roc_auc.compute(references=labels, prediction_scores=pred_probs[:, 1])
    precision_prauc, recall_prauc, threshold_prauc = precision_recall_curve(labels, pred_probs[:, 1])
    pr_auc_v = auc(recall_prauc, precision_prauc) 
    
    return {
        "accuracy": accuracy_v,
        "precision": precision_v,
        "recall": recall_v,
        "f1": f1_v, 
        "roc_auc": roc_auc_v, 
        "pr_auc": pr_auc_v
    }

In [8]:
# Trainer
trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=data['train'], 
    eval_dataset=data['test'], 
    compute_metrics=compute_metrics, 
    data_collator=data_collator, 
    tokenizer=tokenizer
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None)
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [9]:
# Training
trainer.train()

Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1,Roc Auc,Pr Auc
100,0.3944,0.393392,{'accuracy': 0.8473360655737705},{'precision': 0.9731182795698925},{'recall': 0.7225548902195609},{'f1': 0.8293241695303551},{'roc_auc': 0.9238585985922891},0.94002
200,0.1532,0.243532,{'accuracy': 0.9328893442622951},{'precision': 0.9421319796954315},{'recall': 0.9261477045908184},{'f1': 0.934071464519376},{'roc_auc': 0.9795577266519592},0.980482
300,0.0601,0.397364,{'accuracy': 0.9293032786885246},{'precision': 0.9},{'recall': 0.9700598802395209},{'f1': 0.9337175792507204},{'roc_auc': 0.9802873200966489},0.978967
400,0.0276,0.304286,{'accuracy': 0.9477459016393442},{'precision': 0.9563894523326572},{'recall': 0.9411177644710579},{'f1': 0.9486921529175051},{'roc_auc': 0.9850105053051791},0.985793


TrainOutput(global_step=460, training_loss=0.13813452442561316, metrics={'train_runtime': 632.5397, 'train_samples_per_second': 11.699, 'train_steps_per_second': 0.727, 'total_flos': 0.0, 'train_loss': 0.13813452442561316, 'epoch': 3.98})