In [None]:
# libraries
import torch
from transformers import XLNetConfig, XLNetForTokenClassification, TrainingArguments, EarlyStoppingCallback
from xlnet_plabel_utils import RegressionTrainerFour, RiboDatasetExp1, RiboDatasetExp2, GWSDatasetFromPandas, collate_fn, compute_metrics  # custom dataset and trainer
import pytorch_lightning as pl
import wandb

In [None]:
# model parameters
annot_thresh = 0.3
longZerosThresh_val = 20
percNansThresh_val = 0.05
d_model_val = 512
n_layers_val = 6
n_heads_val = 4
dropout_val = 0.1
lr_val = 1e-4
batch_size_val = 2
loss_fun_name = '4L' # 4L
seed_val = 1
experiment_type = 'exp1' # exp1, exp2

In [None]:
# reproducibility
pl.seed_everything(seed_val)

# dataset paths 
data_folder = '/net/lts2gdk0/mnt/scratch/lts2/nallapar/rb-prof/data/Jan_2024/Lina/processed/'

# model name and output folder path
model_name = 'XLNet-PLabelDH ' + ' Exp: ' + experiment_type + ' [NL: ' + str(n_layers_val) + ', NH: ' + str(n_heads_val) + ', D: ' + str(d_model_val) + ', LR: ' + str(lr_val) + ', BS: ' + str(batch_size_val) + ', LF: ' + loss_fun_name + ', Dr: ' + str(dropout_val) + ', S: ' + str(seed_val) + ']'
output_loc = "saved_models/" + model_name

# set wandb name to model_name
wandb.init(project="Riboclette", name=model_name)

In [None]:
# generate dataset
if experiment_type == 'exp1': # impute all train genes
    train_dataset, val_dataset, test_dataset = RiboDatasetExp1(threshold = annot_thresh, longZerosThresh = longZerosThresh_val, percNansThresh = percNansThresh_val)
elif experiment_type == 'exp2': # impute train genes + extra mouse genome genes
    train_dataset, val_dataset, test_dataset = RiboDatasetExp2(threshold = annot_thresh, longZerosThresh = longZerosThresh_val, percNansThresh = percNansThresh_val)

# convert pandas dataframes into torch datasets
train_dataset = GWSDatasetFromPandas(train_dataset)
val_dataset = GWSDatasetFromPandas(val_dataset)
test_dataset = GWSDatasetFromPandas(test_dataset)

print("samples in train dataset: ", len(train_dataset))
print("samples in val dataset: ", len(val_dataset))
print("samples in test dataset: ", len(test_dataset))

In [None]:
# load xlnet to train from scratch
config = XLNetConfig(vocab_size=71, pad_token_id=70, d_model = d_model_val, n_layer = n_layers_val, n_head = n_heads_val, d_inner = d_model_val, num_labels = 1, dropout=dropout_val) # 6 conds, 64 codons, 1 for padding
model = XLNetForTokenClassification(config)

# modify the output layer
# model.classifier is a linear layer followed by a softmax layer
model.classifier = torch.nn.Linear(d_model_val, 2, bias=True)

In [None]:
# xlnet training arguments
training_args = TrainingArguments(
    output_dir = output_loc,
    learning_rate = lr_val,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = batch_size_val, # training batch size = per_device_train_batch_size * gradient_accumulation_steps
    per_device_eval_batch_size = 1,
    eval_accumulation_steps = 4, 
    num_train_epochs = 100,
    weight_decay = 0.01,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    load_best_model_at_end = True,
    push_to_hub = False,
    dataloader_pin_memory = True,
    save_total_limit = 5,
    dataloader_num_workers = 4,
    include_inputs_for_metrics = True
)

# initialize trainer
trainer = RegressionTrainerFour(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=10)]
)

In [None]:
# train model
trainer.train()

# save best model
trainer.save_model(output_loc + "/best_model")

In [None]:
# evaluate model
trainer.evaluate(eval_dataset=test_dataset)

In [None]:
# model = model.from_pretrained(output_loc + "/best_model")
# # evaluate model
# trainer.evaluate(eval_dataset=test_dataset)

In [None]:
''' Exp1
S1:
S2:
S3:
S4:
S42:
'''