In [None]:
# libraries
import numpy as np
import pandas as pd
import torch
from captum.attr import LayerIntegratedGradients, LayerGradientXActivation
from transformers import XLNetConfig, XLNetForTokenClassification
from xlnet_plabel_utils import GWSDatasetFromPandas  # custom dataset and trainer, CorrCoef, collate_fn, compute_metrics, compute_metrics_saved  # custom dataset and trainer
import pytorch_lightning as pl
from tqdm import tqdm
import os
# suppress warnings
import warnings
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
part = 0
tot_parts = 16
gbs = 4

In [3]:
class model_CTRL(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model 
    
    def forward(self, x, index_val):
        # input dict
        out_batch = {}

        out_batch["input_ids"] = x.unsqueeze(0)
        for k, v in out_batch.items():
            out_batch[k] = v.to(device)

        out_batch["input_ids"] = torch.tensor(out_batch["input_ids"]).to(device).to(torch.int32)
        out_batch["input_ids"] = out_batch["input_ids"].squeeze(0)
        pred = self.model(out_batch["input_ids"])

        # get dim 0
        pred_fin = torch.relu(pred["logits"][:, :, 0])

        # set output to be values in each examples at index_val
        out_tensor = torch.zeros(len(index_val))
        for el, val in enumerate(index_val):
            out_tensor[el] = pred_fin[el][val]

        return out_tensor
    
class model_DD(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model 
    
    def forward(self, x, index_val):
        # input dict
        out_batch = {}

        out_batch["input_ids"] = x.unsqueeze(0)
        for k, v in out_batch.items():
            out_batch[k] = v.to(device)

        out_batch["input_ids"] = torch.tensor(out_batch["input_ids"]).to(device).to(torch.int32)
        out_batch["input_ids"] = out_batch["input_ids"].squeeze(0)
        pred = self.model(out_batch["input_ids"])

        # get dim 1
        pred_fin = pred["logits"][:, :, 1]

        # set output to be values in each examples at index_val
        out_tensor = torch.zeros(len(index_val))
        for el, val in enumerate(index_val):
            out_tensor[el] = pred_fin[el][val]

        return out_tensor

def lig_output(model, x, y, mode='ctrl'):
    if mode == 'ctrl':
        model_fin = model_CTRL(model)
    elif mode == 'dd':
        model_fin = model_DD(model)
        
    lig = LayerIntegratedGradients(model_fin, model_fin.model.transformer.word_embedding)

    # set torch graph to allow unused tensors
    with torch.autograd.set_detect_anomaly(True):    
        # get all indices
        len_sample = len(x)
        attributions_sample = np.zeros((len_sample, len_sample))

        for j in tqdm(range(0, len_sample, gbs)):
            index_val = list(range(j, min(j+gbs, len_sample)))

            index_val = torch.tensor(index_val).to(device)

            out_batch = {}

            out_batch["input_ids"] = x
            
            out_batch["input_ids"] = torch.tensor(out_batch["input_ids"]).to(device).to(torch.int32)

            baseline_inp = torch.ones(out_batch["input_ids"].shape) * 70 # 70 is the padding token
            baseline_inp = baseline_inp.to(device).to(torch.int32)

            # repeat the input and baseline tensors
            out_batch["input_ids"] = out_batch["input_ids"].repeat(len(index_val), 1)
            baseline_inp = baseline_inp.repeat(len(index_val), 1)

            attributions = lig.attribute((out_batch["input_ids"]), baselines=baseline_inp, 
                                                    method = 'gausslegendre', return_convergence_delta = False, additional_forward_args=index_val, n_steps=10, internal_batch_size=gbs)

            attributions = torch.permute(attributions, (1, 0, 2))
            attributions = torch.sum(attributions, dim=2)

            # norm the attributions per example
            for ex in range(attributions.shape[0]):
                attributions[ex] = attributions[ex] / torch.norm(attributions[ex])
            attributions = attributions.detach().cpu().numpy()
            attributions_sample[j:j+len(index_val)] = attributions
        
        attributions_sample = np.array(attributions_sample)

        # remove first column which is padding token
        attributions_sample = attributions_sample[1:, 1:]

        # flatten the attributions
        attributions_sample = attributions_sample.flatten()

    return attributions_sample

def lxg_output(model, x, y, mode='ctrl'):
    if mode == 'ctrl':
        model_fin = model_CTRL(model)
    elif mode == 'dd':
        model_fin = model_DD(model)
        
    lxg = LayerGradientXActivation(model_fin, model_fin.model.transformer.word_embedding)

    # set torch graph to allow unused tensors
    with torch.autograd.set_detect_anomaly(True):
        len_sample = len(x)
        attributions_sample = np.zeros((len_sample, len_sample))

        for j in tqdm(range(0, len_sample, gbs)):
            index_val = list(range(j, min(j+gbs, len_sample)))

            index_val = torch.tensor(index_val).to(device)

            out_batch = {}

            out_batch["input_ids"] = x
            
            out_batch["input_ids"] = torch.tensor(out_batch["input_ids"]).to(device).to(torch.int32)

            baseline_inp = torch.ones(out_batch["input_ids"].shape) * 70 # 70 is the padding token
            baseline_inp = baseline_inp.to(device).to(torch.int32)

            # repeat the input and baseline tensors
            out_batch["input_ids"] = out_batch["input_ids"].repeat(len(index_val), 1)
            baseline_inp = baseline_inp.repeat(len(index_val), 1)

            attributions = lxg.attribute((out_batch["input_ids"]), additional_forward_args=index_val)
            
            attributions = torch.permute(attributions, (1, 0, 2))
            attributions = torch.sum(attributions, dim=2)

            # norm the attributions per example
            for ex in range(attributions.shape[0]):
                attributions[ex] = attributions[ex] / torch.norm(attributions[ex])
            attributions = attributions.detach().cpu().numpy()
            attributions_sample[j:j+len(index_val)] = attributions
        
        attributions_sample = np.array(attributions_sample)

        # remove first column which is padding token
        attributions_sample = attributions_sample[1:, 1:]

        # flatten the attributions
        attributions_sample = attributions_sample.flatten()

    return attributions_sample

In [None]:
# reproducibility
seed_val = 2
pl.seed_everything(seed_val) 

# add argument for part
print("part: ", part)

# model parameters
annot_thresh = 0.3
longZerosThresh_val = 20
percNansThresh_val = 0.05
d_model_val = 512
n_layers_val = 6
n_heads_val = 4
dropout_val = 0.1
lr_val = 1e-4
batch_size_val = 2
loss_fun_name = '4L' # 5L

# dataset paths 
data_folder = '/net/lts2gdk0/mnt/scratch/lts2/nallapar/rb-prof/data/Jan_2024/Lina/processed/'

# model name and output folder path
model_name = 'XLNet-PLabelDH  Exp: exp1 [NL: 6, NH: 4, D: 512, LR: 0.0001, BS: 2, LF: 4L, Dr: 0.1, S: 2]' ## CHANGE to the model with the highest performance
output_loc = 'saved_models/' + model_name 

condition_dict_values = {64: 'CTRL', 65: 'ILE', 66: 'LEU', 67: 'LEU_ILE', 68: 'LEU_ILE_VAL', 69: 'VAL'}

In [None]:
class XLNetDH(XLNetForTokenClassification):
    def __init__(self, config):
        super().__init__(config)
        self.classifier = torch.nn.Linear(d_model_val, 2, bias=True)

config = XLNetConfig(vocab_size=71, pad_token_id=70, d_model = d_model_val, n_layer = n_layers_val, n_head = n_heads_val, d_inner = d_model_val, num_labels = 1, dropout=dropout_val) # 64*6 tokens + 1 for padding
model = XLNetDH(config)

# generate dataset
test_dataset = pd.read_csv('data/orig/val_0.3_NZ_20_PercNan_0.05.csv')
num_samples = len(test_dataset)

output_folder = 'all_interpret_orig_val/'

# # check files in the output folder
files = os.listdir(output_folder)

# find the indices of the files in the part of the dataset that are not present

# get start and stop of the part
# full_indices = [2986, 2987, 2988, 2989, 2990, 2991, 2992, 2993, 2994, 2995, 2996, 2997, 2998, 2999, 3000, 3001, 3002, 3003, 3004, 3005, 3006, 3007, 3008, 3009, 3010, 3011, 3012, 3013, 3014, 3015, 3016, 3017, 3018, 3019, 3020, 3021, 3022, 3023, 3024, 3025, 3026, 3027, 3028, 3029, 3030, 3031, 3032, 3033, 3034, 3035, 3036, 3037, 3038, 3039, 3040, 3041, 3042, 3043, 3044, 3045, 3046, 3047, 3048, 3049, 3050, 3051, 3052, 3053, 3054, 3055, 3056, 3057, 3058, 3059, 3060, 3061, 3062, 3063, 3064, 3065, 3066, 3067, 3068, 3069, 3070, 3071, 3072, 3073, 3074, 3075, 3076, 3077, 3078, 3079, 3080, 3081, 3082, 3083, 3084, 3085, 3086, 3087, 3088, 3089, 3090, 3091, 3092, 3093, 3094, 3095, 3096, 3097, 3098, 3099, 3100, 3101, 3102, 3103, 3104, 3105, 3106, 3107, 3108, 3109, 3110, 3111, 3112, 3113, 3114, 3115, 3116, 3117, 3118, 3119, 3120, 3121, 3122, 3123, 3124, 3125, 3126, 3127, 3128, 3129, 3130, 3131, 3132, 3133, 3134, 3135, 3136, 3137, 3138, 3139, 3140, 3141, 3142, 3143, 3144, 3145, 3146, 3147, 3148, 3149, 3150, 3151, 3152, 3153, 3154, 3155, 3156, 3157, 3158, 3159, 3160, 3161, 3162, 3163, 8371, 8372, 8373, 8374, 8375, 8376, 8377, 8378, 8379, 8380, 8381, 8382, 8383, 8384, 8385, 8386, 8387, 8388, 8389, 8390, 8391, 8392, 8393, 8394, 8395, 8396, 8397, 8398, 8399, 8400, 8401, 8402, 8403, 8404, 8405, 8406, 8407, 8408, 8409, 8410, 8411, 8412, 8413, 8414, 8415, 8416, 8417, 8418, 8419, 8420, 8421, 8422, 8423, 8424, 8425, 8426, 8427, 8428, 8429, 8430, 8431, 8432, 8433, 8434, 8435, 8436, 8437, 8438, 10529, 10530, 10531, 10532, 10533, 10534, 10535, 10536, 10537, 10538, 10539, 10540, 10541, 10542, 10543, 10544, 10545, 10546, 10547, 10548, 14712, 14713, 14714, 14715, 14716, 14717, 14718, 14719, 14720, 14721, 14722, 14723, 14724, 14725, 14726, 14727, 14728, 14729, 14730, 14731, 14732, 14733, 14734, 14735, 14736, 14737, 14738, 14739, 14740, 14741, 14742, 14743, 14744, 14745, 14746, 14747, 14748, 14749, 14750, 14751, 14752, 14753, 14754, 14755, 14756, 14757, 14758, 14759, 14760, 14761, 14762, 14763, 14764, 14765, 14766, 14767, 14768, 16870, 16871, 16872, 16873, 16874, 16875, 16876, 16877, 16878]

# start = int(part * len(full_indices) / tot_parts)
# stop = int((part + 1) * len(full_indices) / tot_parts)

# print("Number of samples to interpret: ", stop - start + 1)

# indices = full_indices[start:stop]

# test_dataset = test_dataset.iloc[indices]

# # dataset split into 64 parts, get start and stop of the part
start = int(part * len(test_dataset) / tot_parts)
stop = int((part + 1) * len(test_dataset) / tot_parts)

# change start point depending on what files are already present
for i in range(start, stop):
    if 'sample_' + str(i) + '.npz' in files:
        start = i
    else:
        break

print("start: ", start)
print("stop: ", stop)

test_dataset = test_dataset[start:stop]

# convert pandas dataframes into torch datasets
test_dataset = GWSDatasetFromPandas(test_dataset)
print("samples in test dataset: ", len(test_dataset))

In [None]:
# load model from the saved model
model = model.from_pretrained(output_loc + "/best_model")
model.to(device)

# set model to evaluation mode
model.eval()

In [None]:
# count = 0
with torch.autograd.set_detect_anomaly(True):
    for i, (x_input, y_true_full, y_true_ctrl, gene, transcript) in tqdm(enumerate(test_dataset)):
        x = torch.tensor(x_input)
        # remove first token which is condition token
        y = torch.tensor(y_true_full)

        condition_token = condition_dict_values[int(x[0].item())]

        # get LIG attributions
        lig_sample_ctrl = lig_output(model, x, y, mode='ctrl')
        lig_sample_dd = lig_output(model, x, y, mode='dd')

        # # get LXG attributions
        lxg_sample_ctrl = lxg_output(model, x, y, mode='ctrl')
        lxg_sample_dd = lxg_output(model, x, y, mode='dd')

        x_input_dev = torch.unsqueeze(x_input, 0).to('cuda')
        y_pred_full = model(x_input_dev).logits[0]
        y_pred_ctrl = torch.relu(y_pred_full[1:, 0]).cpu().detach().numpy()
        y_pred_depr_diff = y_pred_full[1:, 1].cpu().detach().numpy()

        y_pred_full_sample = y_pred_ctrl + y_pred_depr_diff
        y_true_dd_sample = y_true_full - y_true_ctrl

        # make dict out of everything
        out_dict = {
            'x_input': x_input,
            'y_true_full': y_true_full,
            'y_pred_full': y_pred_full_sample,
            'y_true_ctrl': y_true_ctrl,
            'gene': gene,
            'transcript': transcript,
            'lig_ctrl': lig_sample_ctrl,
            'lig_dd': lig_sample_dd,
            'lxg_ctrl': lxg_sample_ctrl,
            'lxg_dd': lxg_sample_dd,
            'y_pred_ctrl': y_pred_ctrl,
            'y_pred_depr_diff': y_pred_depr_diff,
            'y_true_dd': y_true_dd_sample,
            'condition': condition_token
        }

        # save dict
        # out_num = indices[count]
        np.savez_compressed(output_folder + 'sample_' + str(start+i) + '.npz', out_dict)
        # count += 1