# Fine-tune La Morena to predict half-lives of mRNAs based on 3' UTRs

In [None]:
from LaMorena.sequence_classification_patch import EsmForSequenceClassification
from transformers import AutoConfig, AutoTokenizer, DataCollatorWithPadding, TrainingArguments, Trainer
from datasets import load_dataset, load_from_disk
import os
import torch
import numpy as np
from safetensors.torch import load_file, load_model
import evaluate
import pandas as pd

In [None]:
os.chdir('/picb/rnasys2/zhouhanwen/nucTran/github/')
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

## Parameters
**The following parameters can be changed.**    
nlabels: num of predicted labels, = 1 means regression, = 2 means binary classification    
data_path: path of finetuning data  
pretrain_state_path: path of pretraining weights  
batch_size: <= 8 for single card, here we use V100 32G  
peak_lr: peak learning rate, 1e-5 ~ 1e-4 in most conditions   
total_epochs: num of finetuning epochs  
accum_steps: accumulation steps if using gradient accumulation  
output_dir: path of saving model  
logging_steps: num of training steps to log loss value  

In [None]:
tokenizer_path = 'tokenizer/single_nucleotide/'
model_max_length = 200
model_name = 'config/config_150M.json'
token_dropout = False
positional_embedding_type = 'rotary'
hidden_size = 768
intermediate_size = 3072
num_attention_heads = 12
num_hidden_layers = 12
nlabels = 1
data_path = 'UTR3DegPred/data/DavidErle/training/deg_2'
pretrain_state_path = 'pretrain/saving_model/mammalian80D_4096len1mer1sw_80M/checkpoint-250000/model.safetensors'
# pretrain_state_path = None
batch_size = 8
peak_lr = 5e-5
warmup_ratio = 0.05
total_epochs = 16
grad_clipping_norm = 1
accum_steps = 1
output_dir = 'DegPred/saving_model/DavidErle/test/mammalian80D_4096len1mer1sw_80M_250k_DegPred_DavidErle_2_bs8_lr5e-5_wr0.05_16epochs'
save_epochs = 10
logging_steps = 100
fp16 = False
flash_attention = False
head_type = 'Linear'
freeze = False
kernel_sizes = [2, 3, 5]
ocs = 32

In [None]:
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, model_max_length=model_max_length)

# Config
config = AutoConfig.from_pretrained(
    model_name, vocab_size=len(tokenizer), pad_token_id=tokenizer.pad_token_id, mask_token_id=tokenizer.mask_token_id, num_labels=nlabels, 
    token_dropout=token_dropout, positional_embedding_type=positional_embedding_type, 
    hidden_size=hidden_size, intermediate_size=intermediate_size, num_attention_heads=num_attention_heads, num_hidden_layers=num_hidden_layers
)

# Training data
data = load_from_disk(data_path)

# Data Collator
data_collator = DataCollatorWithPadding(
    tokenizer=tokenizer, padding=True
)

In [None]:
# Model
model = EsmForSequenceClassification(config, head_type=head_type, freeze=freeze, kernel_sizes=kernel_sizes, ocs=ocs)
if flash_attention:
    from flash_attn_patch import EsmSelfAttentionAddFlashAttnPatch
    for i in range(config.num_hidden_layers):
        model.esm.encoder.layer[i].attention.self = EsmSelfAttentionAddFlashAttnPatch(config, position_embedding_type='rotary')
if pretrain_state_path:
    print("Loading parameters of pretraining model: {}".format(pretrain_state_path))
    if pretrain_state_path.endswith('.bin'):
        model.load_state_dict(torch.load(pretrain_state_path), strict=False)
    elif pretrain_state_path.endswith('.safetensors'):
        load_model(model, filename=pretrain_state_path, strict=False)
else:
    print("No Loading parameters of pretraining model !!")

In [None]:
# Training arguments
train_args = TrainingArguments(
    disable_tqdm=False, 
    save_total_limit=1, 
    dataloader_drop_last=True, 
    per_device_train_batch_size=batch_size, 
    per_device_eval_batch_size=1, 
    learning_rate=peak_lr, 
    weight_decay=0.01, 
    adam_beta1=0.9, 
    adam_beta2=0.98, 
    adam_epsilon=1e-8, 
    warmup_ratio=warmup_ratio, 
    num_train_epochs=total_epochs, 
    max_grad_norm=grad_clipping_norm, 
    gradient_accumulation_steps=accum_steps, 
    output_dir=output_dir, 
    evaluation_strategy="steps",
    eval_steps=logging_steps, 
    save_strategy='steps', 
    save_steps=save_epochs, 
    logging_strategy = 'steps', 
    logging_steps=logging_steps, 
    fp16=fp16, 
    report_to="none"
)

In [None]:
# Metrics
def compute_metrics(p):
    """
    labels: true labels
    predictions: predict labels
    """
    predictions, labels = p
    predictions = predictions.squeeze()
    mse = np.mean((predictions - labels) ** 2)
    df = pd.DataFrame({'pred': predictions, 'label': labels})
    corr_coef_pearson = df.corr(method='pearson').iloc[0, 1]
    corr_coef_spearman = df.corr(method='spearman').iloc[0, 1]
    
    return {
        "mse": mse,
        "corr_coef_pearson": corr_coef_pearson, 
        "corr_coef_spearman": corr_coef_spearman
    }

In [None]:
# Trainer
trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=data['train'], 
    eval_dataset=data['test'], 
    compute_metrics=compute_metrics, 
    data_collator=data_collator, 
    tokenizer=tokenizer
)

In [None]:
# Training
trainer.train()