In [None]:
# libraries
import torch
from transformers import XLNetConfig, XLNetForTokenClassification, EarlyStoppingCallback
from xlnet_csh_utils import RegressionTrainer, RiboDatasetGWSDepr, GWSDatasetFromPandas  # custom dataset and trainer
from transformers import TrainingArguments
from torch.nn.utils.rnn import pad_sequence
from torchmetrics.functional import pearson_corrcoef
from torchmetrics import Metric
import wandb 
import pytorch_lightning as pl

In [None]:
class CorrCoef(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("corrcoefs", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
    def update(self, preds, target, mask):
        # # sum preds in dim 2
        # preds = torch.sum(preds, dim=2)
        preds = preds[:, 1:]
        # print(preds.shape, target.shape, mask.shape)
        assert preds.shape == target.shape
        assert preds.shape == mask.shape
        coeffs = []
        for p, t, m in zip(preds, target, mask):
            mp, mt = torch.masked_select(p, m), torch.masked_select(t, m)
            temp_pearson = pearson_corrcoef(mp, mt)
            coeffs.append(temp_pearson)
        coeffs = torch.stack(coeffs)
        self.corrcoefs += torch.sum(coeffs)
        self.total += len(coeffs)
    def compute(self):
        return self.corrcoefs / self.total

# collate function
def collate_fn(batch):
    # batch is a list of tuples (x, y)
    x, y, gene, transcript = zip(*batch)

    # sequence lenghts 
    lengths = torch.tensor([len(x) for x in x])
    x = pad_sequence(x, batch_first=True, padding_value=64) 
    y = pad_sequence(y, batch_first=True, padding_value=-1)

    out_batch = {}

    out_batch["input_ids"] = x
    out_batch["labels"] = y
    out_batch["lengths"] = lengths

    return out_batch

# compute metrics
def compute_metrics(pred):
    labels = pred.label_ids 
    preds = pred.predictions
    mask = labels != -100.0
    labels = torch.tensor(labels)
    preds = torch.tensor(preds)
    preds = torch.squeeze(preds, dim=2)
    mask = torch.tensor(mask)
    mask = torch.logical_and(mask, torch.logical_not(torch.isnan(labels)))
    corr_coef = CorrCoef()
    corr_coef.update(preds, labels, mask)

    return {"r": corr_coef.compute()}


In [None]:
seed = 1

In [None]:
# reproducibility
pl.seed_everything(seed)

In [None]:
data_folder = '/net/lts2gdk0/mnt/scratch/lts2/nallapar/rb-prof/data/Jan_2024/Lina/processed/'
annot_thresh = 0.3
loss_fun_name = 'MAE_PCC'
longZerosThresh_val = 20
percNansThresh_val = 0.05

# GWS dataset
train_dataset, val_dataset, test_dataset = RiboDatasetGWSDepr(threshold = annot_thresh, longZerosThresh = longZerosThresh_val, percNansThresh = percNansThresh_val)

# convert to torch dataset
train_dataset = GWSDatasetFromPandas(train_dataset)
val_dataset = GWSDatasetFromPandas(val_dataset)
test_dataset = GWSDatasetFromPandas(test_dataset)

print("samples in train dataset: ", len(train_dataset))
print("samples in val dataset: ", len(val_dataset))
print("samples in test dataset: ", len(test_dataset))

In [None]:
# hyperparameters
tot_epochs = 100
n_layers_val = 5
batch_size_val = 4
lr_val = 1e-3
dropout_val = 0.1
d_model_val = 128
n_heads_val = 4

In [None]:
# load xlnet to train from scratch
# training arguments
model_name = 'XLNet-CSH ' + '[NL: ' + str(n_layers_val) + ', NH: ' + str(n_heads_val) + ', D: ' + str(d_model_val) + ', LR: ' + str(lr_val) + ', BS: ' + str(batch_size_val) + ', LF: ' + loss_fun_name + ', Dr: ' + str(dropout_val) + ', S: ' + str(seed) + ']'
output_loc = "saved_models/" + model_name

In [None]:
xlnet_config = XLNetConfig(vocab_size=71, pad_token_id=70, d_model = d_model_val, n_layer = n_layers_val, n_head = n_heads_val, d_inner = d_model_val, num_labels = 1, dropout=dropout_val) # 4^3 + 1 for padding
model = XLNetForTokenClassification(xlnet_config)

model.classifier = torch.nn.Linear(d_model_val, 1, bias=True)

In [None]:
wandb.init(project="Riboclette", name=model_name)

output_loc = "saved_models/" + model_name

In [None]:
# train xlnet
training_args = TrainingArguments(
    output_dir=output_loc,
    learning_rate=lr_val,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=batch_size_val,
    per_device_eval_batch_size=1,
    eval_accumulation_steps=4,
    num_train_epochs=tot_epochs,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
    dataloader_pin_memory=True,
    dataloader_num_workers=4,
    report_to="wandb"
)

trainer = RegressionTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=10)]
)

In [None]:
trainer.train()

# save best model
trainer.save_model(output_loc + "/best_model")

In [None]:
res = trainer.evaluate(eval_dataset=test_dataset)

print(res)

In [None]:
wandb.finish()