In [None]:
# libraries
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd 
import torch
from transformers import XLNetConfig, XLNetForTokenClassification, TrainingArguments
import random
from torch.nn.utils.rnn import pad_sequence
from torchmetrics.functional import pearson_corrcoef
from torchmetrics import Metric
import os
import seaborn as sns
import shutil
import pytorch_lightning as pl
from ipynb.fs.full.utils_dh import RegressionTrainerFive, RiboDatasetGWS, GWSDatasetFromPandas, collate_fn, compute_metrics, compute_metrics_saved  # custom dataset and trainer, CorrCoef, collate_fn, compute_metrics, compute_metrics_saved  # custom dataset and trainer
# from ipynb.fs.full.prediction_utils import analyse_dh_outputs, quantile_metric, attention_maps, captum_LayerGradAct, captum_LayerGrad, interpretability_panels
from pred_utils import analyse_dh_outputs, quantile_metric, attention_maps, captum_LayerGradAct, captum_LayerGrad, interpretability_panels
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# reproducibility
seed_val = 42
pl.seed_everything(seed_val)

# model parameters
annot_thresh = 0.3
longZerosThresh_val = 20
percNansThresh_val = 0.05
d_model_val = 256
n_layers_val = 3
n_heads_val = 8
dropout_val = 0.1
lr_val = 1e-4
batch_size_val = 1
loss_fun_name = '5L' # 5L

# dataset paths 
data_folder = '/net/lts2gdk0/mnt/scratch/lts2/nallapar/rb-prof/data/Jan_2024/Lina/processed/'

# model name and output folder path
model_name = 'XLNetDHConds DS: DeprNA [' + str(n_layers_val) + ', ' + str(d_model_val) + ', ' + str(n_heads_val) + '] FT: [PEL] BS: ' + str(batch_size_val) + ' Loss: ' + str(loss_fun_name) + ' Data Conds: [NZ: ' + str(longZerosThresh_val) + ', PNTh: ' + str(percNansThresh_val) + ', AnnotThresh: ' + str(annot_thresh) + '] Seed: ' + str(seed_val)
output_loc = "saved_models/" + model_name

class XLNetDH(XLNetForTokenClassification):
    def __init__(self, config):
        super().__init__(config)
        self.classifier = torch.nn.Linear(d_model_val, 2, bias=True)

config = XLNetConfig(vocab_size=71, pad_token_id=70, d_model = d_model_val, n_layer = n_layers_val, n_head = n_heads_val, d_inner = d_model_val, num_labels = 1, dropout=dropout_val) # 64*6 tokens + 1 for padding
model = XLNetDH(config)

In [None]:
# generate dataset
ds = 'ALL' # uses all the three conditions
train_dataset, test_dataset = RiboDatasetGWS(data_folder, ds, threshold = annot_thresh, longZerosThresh = longZerosThresh_val, percNansThresh = percNansThresh_val)

# convert pandas dataframes into torch datasets
train_dataset = GWSDatasetFromPandas(train_dataset)
test_dataset = GWSDatasetFromPandas(test_dataset)
print("samples in train dataset: ", len(train_dataset))
print("samples in test dataset: ", len(test_dataset))

In [None]:
# load model best weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# model.load_state_dict(torch.load(output_loc + "/best_model/pytorch_model.bin"))
# load model from the saved model
model = model.from_pretrained(output_loc + "/best_model")

In [None]:
# xlnet training arguments
training_args = TrainingArguments(
    output_dir = output_loc,
    learning_rate = lr_val,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = batch_size_val, # training batch size = per_device_train_batch_size * gradient_accumulation_steps
    per_device_eval_batch_size = 1,
    eval_accumulation_steps = 4, 
    num_train_epochs = 100,
    weight_decay = 0.01,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    load_best_model_at_end = True,
    push_to_hub = False,
    dataloader_pin_memory = True,
    save_total_limit = 5,
    dataloader_num_workers = 4,
    include_inputs_for_metrics = True
)

# initialize trainer
if loss_fun_name == '5L':
    trainer = RegressionTrainerFive(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        data_collator=collate_fn,
        compute_metrics=compute_metrics_saved
    )

In [None]:
# # evaluate on test set
trainer.evaluate()

In [None]:
# load preds
# make directory
model_name = 'XLNetDHConds DS: DeprNA [' + str(n_layers_val) + ', ' + str(d_model_val) + ', ' + str(n_heads_val) + '] FT: [PEL] BS: ' + str(batch_size_val) + ' Loss: ' + str(loss_fun_name) + ' Data Conds: [NZ: ' + str(longZerosThresh_val) + ', PNTh: ' + str(percNansThresh_val) + ', AnnotThresh: ' + str(annot_thresh) + '] Seed: 2'
model_name = model_name.replace(" ", "_")

dir = "preds/" + model_name
if not os.path.exists(dir):
    os.makedirs(dir)

dir = "preds/" + model_name + "/analysis_dh"
if not os.path.exists(dir):
    os.makedirs(dir)

dir = "preds/" + model_name + "/analysis_dh/full_plots"
if not os.path.exists(dir):
    os.makedirs(dir)
else:
    shutil.rmtree(dir)
    os.makedirs(dir)

dir = "preds/" + model_name + "/analysis_dh/interpretability_panels"
if not os.path.exists(dir):
    os.makedirs(dir)
else:
    shutil.rmtree(dir)
    os.makedirs(dir)

dir = "preds/" + model_name + "/analysis_dh/quantile_metric_plots"
if not os.path.exists(dir):
    os.makedirs(dir)
else:
    shutil.rmtree(dir)
    os.makedirs(dir)

dir = "preds/" + model_name + "/analysis_dh/condition_dists"
if not os.path.exists(dir):
    os.makedirs(dir)
else:
    shutil.rmtree(dir)
    os.makedirs(dir)

dir = "preds/" + model_name + "/analysis_dh/attn_plots"
if not os.path.exists(dir):
    os.makedirs(dir)
else:
    shutil.rmtree(dir)
    os.makedirs(dir)

dir = "preds/" + model_name + "/analysis_dh/captum_plots"
if not os.path.exists(dir):
    os.makedirs(dir)
else:
    shutil.rmtree(dir)
    os.makedirs(dir)

# move the preds to the directory
os.system("mv -f " + "preds/preds.npy preds/" + model_name + "/preds.npy")
os.system("mv -f " + "preds/labels.npy preds/" + model_name + "/labels.npy")
os.system("mv -f " + "preds/inputs.npy preds/" + model_name + "/inputs.npy")

preds = np.load("preds/" + model_name + "/preds.npy")
labels = np.load("preds/" + model_name + "/labels.npy")
inputs = np.load("preds/" + model_name + "/inputs.npy")

In [None]:
# generates plots for the best 10 and worst 10 predictions
# saves all these plots in the "preds/model_name/analysis_dh/full_plots" directory
analyse_dh_outputs(preds, labels, inputs, "preds/" + model_name + "/analysis_dh", 'data/dh/test_' + str(annot_thresh) + '_NZ_' + str(longZerosThresh_val) + '_PercNan_' + str(percNansThresh_val) + '.csv')

In [None]:
# quantile_metric(preds, labels, inputs, "preds/" + model_name + "/analysis_dh/quantile_metric_plots/", 'data/dh/test_remBadRep_' + str(annot_thresh) + '_NZ_' + str(longZerosThresh_val) + '_PercNan_' + str(percNansThresh_val) + '.csv')

In [None]:
# same plots for all the conditions (is this normal?)
# attention_maps(model, test_dataset, "preds/" + model_name + "/analysis_dh/attn_plots/", 'CTRL')
# attention_maps(model, test_dataset, "preds/" + model_name + "/analysis_dh/attn_plots/", 'LEU')
# attention_maps(model, test_dataset, "preds/" + model_name + "/analysis_dh/attn_plots/", 'ILE')
# attention_maps(model, test_dataset, "preds/" + model_name + "/analysis_dh/attn_plots/", 'VAL')
# attention_maps(model, test_dataset, "preds/" + model_name + "/analysis_dh/attn_plots/", 'LEU_ILE')
# attention_maps(model, test_dataset, "preds/" + model_name + "/analysis_dh/attn_plots/", 'LEU_ILE_VAL')

In [None]:
# captum_LayerGradAct(model, test_dataset, "preds/" + model_name + "/analysis_dh/captum_plots/")

In [None]:
# captum_LayerGrad(model, test_dataset, "preds/" + model_name + "/analysis_dh/captum_plots/")

In [None]:
# interpretability_panels(model, preds, labels, inputs, "preds/" + model_name + "/analysis_dh/", 'data/dh/test_' + str(annot_thresh) + '_NZ_' + str(longZerosThresh_val) + '_PercNan_' + str(percNansThresh_val) + '.csv')