In [None]:
# libraries
import torch
from transformers import XLNetConfig, XLNetForTokenClassification, TrainingArguments, EarlyStoppingCallback
from utils import RegressionTrainerFive, RiboDatasetExp1, RiboDatasetExp2, RiboDatasetExp1_2, GWSDatasetFromPandas, collate_fn, compute_metrics  # custom dataset and trainer
import pytorch_lightning as pl
import wandb

In [None]:
# model parameters
annot_thresh = 0.3
longZerosThresh_val = 20
percNansThresh_val = 0.05
d_model_val = 256
n_layers_val = 3
n_heads_val = 8
dropout_val = 0.1
lr_val = 1e-4
batch_size_val = 1
loss_fun_name = '5L' # 5L
seed_val = 42
noise_flag = False
experiment_type = 'set1_2' # set1, set2, set1_2
plabel_quartile_exp1 = 0.25 # (0.25, 0.5, 0.75, 1)
plabel_quartile_exp2 = 0.25 # (0.25, 0.5, 0.75, 1)
impden = 'impute' # impute, impden, or same (only impute the set1, impute and denoise the set 1, keep it same)

In [None]:
# reproducibility
pl.seed_everything(seed_val)

plabel_exp1_quartile_dict = {0.25: 0.06655636, 0.5: 0.07077431, 0.75: 0.07649534, 1: 0.15930015}
plabel_exp2_quartile_dict = {0.25: 0.06782594, 0.5: 0.07330298, 0.75: 0.08001823, 1: 0.19505197}

# pseudolabeling threshold
if experiment_type == 'set1':
    plabel_exp1_thresh = plabel_exp1_quartile_dict[plabel_quartile_exp1]
    plabel_quartile_exp2 = 'None'
elif experiment_type == 'set2' or experiment_type == 'set1_2':
    plabel_quartile_exp1 = plabel_quartile_exp2
    plabel_exp1_thresh = plabel_exp1_quartile_dict[plabel_quartile_exp1]
    plabel_exp2_thresh = plabel_exp2_quartile_dict[plabel_quartile_exp2]

# dataset paths 
data_folder = '/net/lts2gdk0/mnt/scratch/lts2/nallapar/rb-prof/data/Jan_2024/Lina/processed/'

# model name and output folder path
model_name = 'PLabelXLNetDHConds DS: DeprNA [' + str(n_layers_val) + ', ' + str(d_model_val) + ', ' + str(n_heads_val) + '] FT: [PEL] BS: ' + str(batch_size_val) + ' Loss: ' + str(loss_fun_name) + ' Data Conds: [NZ: ' + str(longZerosThresh_val) + ', PNTh: ' + str(percNansThresh_val) + ', AnnotThresh: ' + str(annot_thresh) + '] ' + 'Seed: ' + str(seed_val) + '[Exp: ' + str(experiment_type) + ', PLQ1: ' + str(plabel_quartile_exp1) + ', PLQ2: ' + str(plabel_quartile_exp2) + ', ImpDen: ' + str(impden) + ']' + ' Noisy: ' + str(noise_flag)
output_loc = "saved_models/" + model_name

# set wandb name to model_name
wandb.init(project="XLNet-DH", name=model_name)

In [None]:
# generate dataset
if experiment_type == 'set1':
    train_dataset, test_dataset = RiboDatasetExp1(threshold = annot_thresh, longZerosThresh = longZerosThresh_val, percNansThresh = percNansThresh_val, plabel_thresh = plabel_exp1_thresh, plabel_quartile = plabel_quartile_exp1, impden = impden)
elif experiment_type == 'set2':
    train_dataset, test_dataset = RiboDatasetExp2(threshold = annot_thresh, longZerosThresh = longZerosThresh_val, percNansThresh = percNansThresh_val, plabel_thresh1 = plabel_exp1_thresh, plabel_quartile1 = plabel_quartile_exp1, impden = impden,  plabel_thresh2 = plabel_exp2_thresh, plabel_quartile2 = plabel_quartile_exp2)
elif experiment_type == 'set1_2':
    train_dataset, test_dataset = RiboDatasetExp1_2(threshold = annot_thresh, longZerosThresh = longZerosThresh_val, percNansThresh = percNansThresh_val, plabel_thresh1 = plabel_exp1_thresh, plabel_quartile1 = plabel_quartile_exp1, impden = impden,  plabel_thresh2 = plabel_exp2_thresh, plabel_quartile2 = plabel_quartile_exp2)

# convert pandas dataframes into torch datasets
train_dataset = GWSDatasetFromPandas(train_dataset, 'train', noise_flag)
test_dataset = GWSDatasetFromPandas(test_dataset, 'test', noise_flag)
print("samples in train dataset: ", len(train_dataset))
print("samples in test dataset: ", len(test_dataset)) 

In [None]:
# load xlnet to train from scratch
config = XLNetConfig(vocab_size=71, pad_token_id=70, d_model = d_model_val, n_layer = n_layers_val, n_head = n_heads_val, d_inner = d_model_val, num_labels = 1, dropout=dropout_val) # 6 conds, 64 codons, 1 for padding
model = XLNetForTokenClassification(config)

# modify the output layer
# model.classifier is a linear layer followed by a softmax layer
model.classifier = torch.nn.Linear(d_model_val, 2, bias=True)

In [None]:
# xlnet training arguments
training_args = TrainingArguments(
    output_dir = output_loc,
    learning_rate = lr_val,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = batch_size_val, # training batch size = per_device_train_batch_size * gradient_accumulation_steps
    per_device_eval_batch_size = 1,
    eval_accumulation_steps = 4, 
    num_train_epochs = 100,
    weight_decay = 0.01,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    load_best_model_at_end = True,
    push_to_hub = False,
    dataloader_pin_memory = True,
    save_total_limit = 5,
    dataloader_num_workers = 4,
    include_inputs_for_metrics = True
)

# initialize trainer
if loss_fun_name == '5L': # (MAE+PCC) on Final, CTRL, (MAE) on DD
    trainer = RegressionTrainerFive(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        data_collator=collate_fn,
        compute_metrics=compute_metrics,
        callbacks = [EarlyStoppingCallback(early_stopping_patience=10)]
    )

In [None]:
# train model
trainer.train()

# save best model
trainer.save_model(output_loc + "/best_model")

In [None]:
# evaluate model
trainer.evaluate()