In [None]:
# libraries
import numpy as np
import pandas as pd
import torch
from captum.attr import LayerIntegratedGradients, LayerGradientXActivation
from transformers import XLNetConfig, XLNetForTokenClassification
from xlnet_plabel_utils import GWSDatasetFromPandas  # custom dataset and trainer, CorrCoef, collate_fn, compute_metrics, compute_metrics_saved  # custom dataset and trainer
import pytorch_lightning as pl
from tqdm import tqdm
import os
# suppress warnings
import warnings
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
gbs = 4

In [3]:
class model_CTRL(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model 
    
    def forward(self, x, index_val):
        # input dict
        out_batch = {}

        out_batch["input_ids"] = x.unsqueeze(0)
        for k, v in out_batch.items():
            out_batch[k] = v.to(device)

        out_batch["input_ids"] = torch.tensor(out_batch["input_ids"]).to(device).to(torch.int32)
        out_batch["input_ids"] = out_batch["input_ids"].squeeze(0)
        pred = self.model(out_batch["input_ids"])

        # get dim 0
        pred_fin = torch.relu(pred["logits"][:, :, 0])

        # set output to be values in each examples at index_val
        out_tensor = torch.zeros(len(index_val))
        for el, val in enumerate(index_val):
            out_tensor[el] = pred_fin[el][val]

        return out_tensor
    
class model_DD(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model 
    
    def forward(self, x, index_val):
        # input dict
        out_batch = {}

        out_batch["input_ids"] = x.unsqueeze(0)
        for k, v in out_batch.items():
            out_batch[k] = v.to(device)

        out_batch["input_ids"] = torch.tensor(out_batch["input_ids"]).to(device).to(torch.int32)
        out_batch["input_ids"] = out_batch["input_ids"].squeeze(0)
        pred = self.model(out_batch["input_ids"])

        # get dim 1
        pred_fin = pred["logits"][:, :, 1]

        # set output to be values in each examples at index_val
        out_tensor = torch.zeros(len(index_val))
        for el, val in enumerate(index_val):
            out_tensor[el] = pred_fin[el][val]

        return out_tensor

def lig_output(model, x, y, mode='ctrl'):
    if mode == 'ctrl':
        model_fin = model_CTRL(model)
    elif mode == 'dd':
        model_fin = model_DD(model)
        
    lig = LayerIntegratedGradients(model_fin, model_fin.model.transformer.word_embedding)

    # set torch graph to allow unused tensors
    with torch.autograd.set_detect_anomaly(True):    
        # get all indices
        len_sample = len(x)
        attributions_sample = np.zeros((len_sample, len_sample))

        for j in tqdm(range(0, len_sample, gbs)):
            index_val = list(range(j, min(j+gbs, len_sample)))

            index_val = torch.tensor(index_val).to(device)

            out_batch = {}

            out_batch["input_ids"] = x
            
            out_batch["input_ids"] = torch.tensor(out_batch["input_ids"]).to(device).to(torch.int32)

            baseline_inp = torch.ones(out_batch["input_ids"].shape) * 70 # 70 is the padding token
            baseline_inp = baseline_inp.to(device).to(torch.int32)

            # repeat the input and baseline tensors
            out_batch["input_ids"] = out_batch["input_ids"].repeat(len(index_val), 1)
            baseline_inp = baseline_inp.repeat(len(index_val), 1)

            attributions = lig.attribute((out_batch["input_ids"]), baselines=baseline_inp, 
                                                    method = 'gausslegendre', return_convergence_delta = False, additional_forward_args=index_val, n_steps=10, internal_batch_size=gbs)

            attributions = torch.permute(attributions, (1, 0, 2))
            attributions = torch.sum(attributions, dim=2)

            # norm the attributions per example
            for ex in range(attributions.shape[0]):
                attributions[ex] = attributions[ex] / torch.norm(attributions[ex])
            attributions = attributions.detach().cpu().numpy()
            attributions_sample[j:j+len(index_val)] = attributions
        
        attributions_sample = np.array(attributions_sample)

        # remove first column which is padding token
        attributions_sample = attributions_sample[1:, 1:]

        # flatten the attributions
        attributions_sample = attributions_sample.flatten()

    return attributions_sample

def lxg_output(model, x, y, mode='ctrl'):
    if mode == 'ctrl':
        model_fin = model_CTRL(model)
    elif mode == 'dd':
        model_fin = model_DD(model)
        
    lxg = LayerGradientXActivation(model_fin, model_fin.model.transformer.word_embedding)

    # set torch graph to allow unused tensors
    with torch.autograd.set_detect_anomaly(True):
        len_sample = len(x)
        attributions_sample = np.zeros((len_sample, len_sample))

        for j in tqdm(range(0, len_sample, gbs)):
            index_val = list(range(j, min(j+gbs, len_sample)))

            index_val = torch.tensor(index_val).to(device)

            out_batch = {}

            out_batch["input_ids"] = x
            
            out_batch["input_ids"] = torch.tensor(out_batch["input_ids"]).to(device).to(torch.int32)

            baseline_inp = torch.ones(out_batch["input_ids"].shape) * 70 # 70 is the padding token
            baseline_inp = baseline_inp.to(device).to(torch.int32)

            # repeat the input and baseline tensors
            out_batch["input_ids"] = out_batch["input_ids"].repeat(len(index_val), 1)
            baseline_inp = baseline_inp.repeat(len(index_val), 1)

            attributions = lxg.attribute((out_batch["input_ids"]), additional_forward_args=index_val)
            
            attributions = torch.permute(attributions, (1, 0, 2))
            attributions = torch.sum(attributions, dim=2)

            # norm the attributions per example
            for ex in range(attributions.shape[0]):
                attributions[ex] = attributions[ex] / torch.norm(attributions[ex])
            attributions = attributions.detach().cpu().numpy()
            attributions_sample[j:j+len(index_val)] = attributions
        
        attributions_sample = np.array(attributions_sample)

        # remove first column which is padding token
        attributions_sample = attributions_sample[1:, 1:]

        # flatten the attributions
        attributions_sample = attributions_sample.flatten()

    return attributions_sample

In [None]:
# reproducibility
seed_val = 2
pl.seed_everything(seed_val) 

# model parameters
annot_thresh = 0.3
longZerosThresh_val = 20
percNansThresh_val = 0.05
d_model_val = 512
n_layers_val = 6
n_heads_val = 4
dropout_val = 0.1
lr_val = 1e-4
batch_size_val = 2
loss_fun_name = '4L' # 5L

# dataset paths 
data_folder = '../../../data/orig/'

# model name and output folder path
model_loc = '../../../checkpoints/XLNet-PLabelDH_S2/' ## CHANGE to the model with the highest performance

condition_dict_values = {64: 'CTRL', 65: 'ILE', 66: 'LEU', 67: 'LEU_ILE', 68: 'LEU_ILE_VAL', 69: 'VAL'}

In [None]:
class XLNetDH(XLNetForTokenClassification):
    def __init__(self, config):
        super().__init__(config)
        self.classifier = torch.nn.Linear(d_model_val, 2, bias=True)

config = XLNetConfig(vocab_size=71, pad_token_id=70, d_model = d_model_val, n_layer = n_layers_val, n_head = n_heads_val, d_inner = d_model_val, num_labels = 1, dropout=dropout_val) # 64*6 tokens + 1 for padding
model = XLNetDH(config)

# set the path to generate attributions
test_dataset = pd.read_csv('../../../data/orig/test.csv')
num_samples = len(test_dataset)

output_folder = 'attr/'

# # check files in the output folder
files = os.listdir(output_folder)

# convert pandas dataframes into torch datasets
test_dataset = GWSDatasetFromPandas(test_dataset)
print("samples in test dataset: ", len(test_dataset))

In [None]:
# load model from the saved model
model = model.from_pretrained(model_loc + "/best_model")
model.to(device)

# set model to evaluation mode
model.eval()

In [None]:
# count = 0
with torch.autograd.set_detect_anomaly(True):
    for i, (x_input, y_true_full, y_true_ctrl, gene, transcript) in tqdm(enumerate(test_dataset)):
        x = torch.tensor(x_input)
        # remove first token which is condition token
        y = torch.tensor(y_true_full)

        condition_token = condition_dict_values[int(x[0].item())]

        # get LIG attributions
        lig_sample_ctrl = lig_output(model, x, y, mode='ctrl')
        lig_sample_dd = lig_output(model, x, y, mode='dd')

        # # get LXG attributions
        lxg_sample_ctrl = lxg_output(model, x, y, mode='ctrl')
        lxg_sample_dd = lxg_output(model, x, y, mode='dd')

        x_input_dev = torch.unsqueeze(x_input, 0).to('cuda')
        y_pred_full = model(x_input_dev).logits[0]
        y_pred_ctrl = torch.relu(y_pred_full[1:, 0]).cpu().detach().numpy()
        y_pred_depr_diff = y_pred_full[1:, 1].cpu().detach().numpy()

        y_pred_full_sample = y_pred_ctrl + y_pred_depr_diff
        y_true_dd_sample = y_true_full - y_true_ctrl

        # make dict out of everything
        out_dict = {
            'x_input': x_input,
            'y_true_full': y_true_full,
            'y_pred_full': y_pred_full_sample,
            'y_true_ctrl': y_true_ctrl,
            'gene': gene,
            'transcript': transcript,
            'lig_ctrl': lig_sample_ctrl,
            'lig_dd': lig_sample_dd,
            'lxg_ctrl': lxg_sample_ctrl,
            'lxg_dd': lxg_sample_dd,
            'y_pred_ctrl': y_pred_ctrl,
            'y_pred_depr_diff': y_pred_depr_diff,
            'y_true_dd': y_true_dd_sample,
            'condition': condition_token
        }

        # save dict
        np.savez_compressed(output_folder + 'sample_' + str(i) + '.npz', out_dict)