In [None]:
# libraries
from torch_geometric.explain import Explainer, CaptumExplainer
import torch_geometric.transforms as T
import torch 
import captum
import pandas as pd
import numpy as np
from gnn_explain_utils import GCN_LSTMNonEmbed, GCN_LSTM
from tqdm import tqdm

torch.backends.cudnn.enabled = False

In [None]:
# load a sample from the test folder
feature_folder = 'LiverGraphs/test/'

# load test file
exp_dat = exp_dat = pd.read_csv('data/test.csv')

# load the model
tot_epochs = 50
batch_size = 2
dropout_val = 0.4
annot_thresh = 0.3
longZerosThresh_val = 20
percNansThresh_val = 0.05
random_walk_length = 32
alpha = -1
lr = 1e-3
algo = 'SAGE'
edge_attr = 'None'
features = ['embedding']
features_str = '_'.join(features)
loss_fn = 'MAE + PCC'
gcn_layers = [256, 128, 128, 64]
input_nums_dict = {'embedding': 256}
num_inp_ft = sum([input_nums_dict[ft] for ft in features])

save_loc = 'saved_models/LSTM/best.ckpt'
l_model = GCN_LSTM.load_from_checkpoint(save_loc, gcn_layers=gcn_layers, dropout_val=dropout_val, num_epochs=tot_epochs, bs=batch_size, lr=lr, num_inp_ft=num_inp_ft, alpha=alpha, algo=algo, edge_attr=edge_attr)

# remove embeddings layer from the model
non_embed_model = GCN_LSTMNonEmbed.load_from_checkpoint(save_loc, gcn_layers=gcn_layers, dropout_val=dropout_val, num_epochs=tot_epochs, bs=batch_size, lr=lr, num_inp_ft=num_inp_ft, algo=algo, edge_attr=edge_attr)

explainer = Explainer(
    model=non_embed_model, # get torch module from lightning module
    algorithm=CaptumExplainer(attribution_method=captum.attr.InputXGradient),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='regression',
        task_level='node',
        return_type='raw',  # Model returns log probabilities.
    ),
)

In [None]:
part = 0

In [None]:
# Generate explanation for the node at index `10`:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

total_num_sample = len(list(exp_dat['transcript']))

print("total samples: ", len(list(exp_dat['transcript'])))

transcripts_list = list(exp_dat['transcript'])

part_length = len(transcripts_list)

out_folder_path = 'final_useqplus_int/'

In [None]:
# # explainability using captum
# convert to tqdm for progress bar
for sample_number in tqdm(range(part_length)):
    # remove model eval

    file_name = feature_folder + 'sample_' + str(sample_number) + '.pt'
    
    out_dict = {}
    # load the sample
    data = torch.load(file_name)
    data.x = torch.tensor([int(k) for k in data.x['codon_seq']], dtype=torch.long)
    data.y = data.y / torch.nansum(data.y)
    data = data.to(device)
    data.edge_attr = None

    # get the embeddings
    data.x = l_model.embedding(data.x)

    data.x = torch.concat((data.x, data.random_walk_pe), dim=1)

    # get the explanation
    edge_explain_sample = []
    node_explain_sample = []
    
    for index in tqdm(range(data.y.shape[0])):

        explanation = explainer(x = data.x, edge_index = data.edge_index, index=index)
        # add the edge_mask info to edge_index 
        edge_explain = torch.concat([data.edge_index, explanation.edge_mask.unsqueeze(dim=0)], dim=0)

        # flatten edge_explain
        edge_explain = edge_explain.view(-1)

        edge_explain_sample.append(edge_explain)

        node_explain = explanation.node_mask.sum(dim=1)

        node_explain_sample.append(node_explain)

    edge_explain_sample = torch.cat(edge_explain_sample, dim=0)
    # convert to 1d
    # edge_explain_sample = edge_explain_sample.view(-1)
    node_explain_sample = torch.cat(node_explain_sample, dim=0)

    # save the edge_explain_sample and node_explain_sample to the datasets
    out_dict['node_attr_ds'] = node_explain_sample.detach().cpu().numpy()
    out_dict['edge_attr_ds'] = edge_explain_sample.detach().cpu().numpy()

    # # get the prediction
    file_name = feature_folder + 'sample_' + str(sample_number) + '.pt'
    # load the sample
    data = torch.load(file_name)
    data.edge_attr = None

    # get the prediction
    data.x = torch.tensor([int(k) for k in data.x['codon_seq']], dtype=torch.long)
    data = data.to(device)
    pred = l_model(data)

    out_dict['y_pred'] = pred.detach().cpu().numpy()
    out_dict['x_input'] = data.x.detach().cpu().numpy()
    out_dict['edge_index'] = data.edge_index.detach().cpu().numpy()

    # get the truth
    data.y = data.y / torch.nansum(data.y)
    truth = data.y

    out_dict['y_true'] = truth.detach().cpu().numpy()

    # add the transcript
    out_dict['transcript'] = transcripts_list[sample_number]

    # save out_dict
    out_file_name = out_folder_path + 'sample_' + str(sample_number) + '.npz'
    np.savez_compressed(out_file_name, out_dict)
