In [None]:
# libraries
import torch
from transformers import XLNetConfig, XLNetForTokenClassification, EarlyStoppingCallback
from utils import RegressionTrainer, RiboDatasetGWSDepr, GWSDatasetFromPandas, CorrCoef, compute_metrics, collate_fn  # custom dataset and trainer
from transformers import TrainingArguments
import pytorch_lightning as pl

In [None]:
seed = 1
tot_epochs = 100
n_layers_val = 5
batch_size_val = 4
lr_val = 1e-3
dropout_val = 0.1
d_model_val = 128
n_heads_val = 4
loss_fun_name = 'MAE_PCC'

In [None]:
# reproducibility
pl.seed_everything(seed)

In [None]:
# GWS dataset
train_dataset, val_dataset, test_dataset = RiboDatasetGWSDepr()

# convert to torch dataset
train_dataset = GWSDatasetFromPandas(train_dataset)
val_dataset = GWSDatasetFromPandas(val_dataset)
test_dataset = GWSDatasetFromPandas(test_dataset)

print("samples in train dataset: ", len(train_dataset))
print("samples in val dataset: ", len(val_dataset))
print("samples in test dataset: ", len(test_dataset))

In [None]:
# load xlnet to train from scratch
model_name = 'XLNet-CSH ' + '[NL: ' + str(n_layers_val) + ', NH: ' + str(n_heads_val) + ', D: ' + str(d_model_val) + ', LR: ' + str(lr_val) + ', BS: ' + str(batch_size_val) + ', LF: ' + loss_fun_name + ', Dr: ' + str(dropout_val) + ', S: ' + str(seed) + ']'
output_loc = "saved_models/" + model_name

xlnet_config = XLNetConfig(vocab_size=71, pad_token_id=70, d_model = d_model_val, n_layer = n_layers_val, n_head = n_heads_val, d_inner = d_model_val, num_labels = 1, dropout=dropout_val) # 4^3 + 1 for padding
model = XLNetForTokenClassification(xlnet_config)

model.classifier = torch.nn.Linear(d_model_val, 1, bias=True)

In [None]:
# train xlnet
training_args = TrainingArguments(
    output_dir=output_loc,
    learning_rate=lr_val,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=batch_size_val,
    per_device_eval_batch_size=1,
    eval_accumulation_steps=4,
    num_train_epochs=tot_epochs,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
    dataloader_pin_memory=True,
    dataloader_num_workers=4
)

trainer = RegressionTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=10)]
)

In [None]:
trainer.train()

# save best model
trainer.save_model(output_loc + "/best_model")

In [None]:
res = trainer.evaluate(eval_dataset=test_dataset)

print(res)