In [None]:
# libraries
import numpy as np
import pandas as pd 
import torch
from transformers import XLNetConfig, XLNetForTokenClassification, TrainingArguments, EarlyStoppingCallback
import random
from torch.nn.utils.rnn import pad_sequence
from torchmetrics.functional import pearson_corrcoef
from torchmetrics import Metric
from ipynb.fs.full.utils_dh import RegressionTrainerFive, RiboDatasetGWS, GWSDatasetFromPandas, collate_fn, compute_metrics  # custom dataset and trainer

In [None]:
# model parameters
annot_thresh = 0.3
longZerosThresh_val = 20
percNansThresh_val = 0.05
d_model_val = 256
n_layers_val = 3
n_heads_val = 8
dropout_val = 0.1
lr_val = 1e-4
batch_size_val = 1
loss_fun_name = '5L' # 5L

In [None]:
# reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# dataset paths 
data_folder = '/net/lts2gdk0/mnt/scratch/lts2/nallapar/rb-prof/data/Jan_2024/Lina/processed/'

# model name and output folder path
model_name = 'XLNetDH DS: DeprNA [' + str(n_layers_val) + ', ' + str(d_model_val) + ', ' + str(n_heads_val) + '] FT: [PEL] BS: ' + str(batch_size_val) + ' Loss: ' + str(loss_fun_name) + ' Data Conds: [NZ: ' + str(longZerosThresh_val) + ', PNTh: ' + str(percNansThresh_val) + ', AnnotThresh: ' + str(annot_thresh) + ']'
output_loc = "saved_models/" + model_name

In [None]:
# generate dataset
ds = 'ALL' # uses all the 6 conditions + liver
train_dataset, test_dataset = RiboDatasetGWS(data_folder, ds, threshold = annot_thresh, longZerosThresh = longZerosThresh_val, percNansThresh = percNansThresh_val)

# convert pandas dataframes into torch datasets
train_dataset = GWSDatasetFromPandas(train_dataset)
test_dataset = GWSDatasetFromPandas(test_dataset)
print("samples in train dataset: ", len(train_dataset))
print("samples in test dataset: ", len(test_dataset))

In [None]:
# load xlnet to train from scratch
config = XLNetConfig(vocab_size=385, pad_token_id=384, d_model = d_model_val, n_layer = n_layers_val, n_head = n_heads_val, d_inner = d_model_val, num_labels = 1, dropout=dropout_val) # 64*6 tokens + 1 for padding
model = XLNetForTokenClassification(config)

# modify the output layer
model.classifier = torch.nn.Linear(d_model_val, 2, bias=True)

In [None]:
# xlnet training arguments
training_args = TrainingArguments(
    output_dir = output_loc,
    learning_rate = lr_val,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = batch_size_val, # training batch size = per_device_train_batch_size * gradient_accumulation_steps
    per_device_eval_batch_size = 1,
    eval_accumulation_steps = 4, 
    num_train_epochs = 100,
    weight_decay = 0.01,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    load_best_model_at_end = True,
    push_to_hub = False,
    dataloader_pin_memory = True,
    save_total_limit = 5,
    dataloader_num_workers = 4,
    include_inputs_for_metrics = True
)

# initialize trainer
if loss_fun_name == '5L': # (MAE+PCC) on Final, CTRL, (MAE) on DD
    trainer = RegressionTrainerFive(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        data_collator=collate_fn,
        compute_metrics=compute_metrics,
        callbacks = [EarlyStoppingCallback(early_stopping_patience=20)]
    )

In [None]:
# train model
trainer.train()

# save best model
trainer.save_model(output_loc + "/best_model")

In [None]:
# evaluate model
trainer.evaluate()