In [1]:
#Explainability in InfoMax VAE
#First part using captum
#For G protein family

%cd ProtWaveVAE/Pfam_analysis/
!pwd

/storage/ice1/6/9/khari8/condaProtein/vqvae/ProtWaveVAE/Pfam_analysis
/storage/ice1/6/9/khari8/condaProtein/vqvae/ProtWaveVAE/Pfam_analysis


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
!pip install captum
!pip install seaborn



In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
import shap
from shap import DeepExplainer
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

from train_on_pfam import call_model
from generate_samples import load_weights
from source.pfam_preprocess import prepare_Gprotein_dataset
import source.pfam_preprocess as pfam_prep

In [None]:
# Define arguments (matching the training configuration)
class Args:
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    class_labels = 21
    z_dim = 3
    encoder_rates = 0
    C_in = 21
    C_out = 512
    alpha = 0.1
    enc_kernel = 3
    num_fc = 1
    wave_hidden_state = 128
    head_hidden_state = 512
    num_dil_rates = 7
    dec_kernel_size = 3
    lr = 1e-4
    xi_weight = 1
    alpha_weight = 0.99
    lambda_weight = 10
    alignment = False
    learning_option = 'unsupervised'

args = Args()

# Set protein length based on the dataset
protein_len = 199  # Replace with actual sequence length if known

In [None]:
# Construct the model
model = call_model(args, protein_len=protein_len).model

# Load weights
weights_path = './outputs/train_sess/pfam/lactamase/lacta_model.pth'
args.weights_path = weights_path
model = load_weights(args, model)

# Set the model to evaluation mode
model.eval()

# Test with dummy input
dummy_input = torch.zeros((1, protein_len, args.C_in)).to(args.DEVICE)
output = model(dummy_input)

# Verify output
print("Model output shape:", output[0].shape)

In [None]:
# Preprocess the G-protein data
train_data_path = "./data/protein_families/lactamase/pfam_lactamase.csv"
train_seq_num, train_seq_OH = prepare_Gprotein_dataset(train_data_path, alignment=False)
test_data_path = "./outputs/prediction/pfam/lactamase/lactamase_sample_sequences.csv"
test_seq_num, test_seq_OH = prepare_Gprotein_dataset(test_data_path, alignment=False)

In [None]:
train_seq_OH.shape

In [None]:
encoder_model = model.inference
#decoder_model = model.generator

In [None]:
training_data = torch.tensor(train_seq_OH[:100], dtype = torch.float32)
test_data = torch.tensor(train_seq_OH[100:110], dtype = torch.float32)
training_data = training_data.to(args.DEVICE)
test_data = test_data.to(args.DEVICE)
encoder_model.to(args.DEVICE)
#decoder_model.to(args.DEVICE)

In [None]:
class SHAPWrapper(torch.nn.Module):
    def __init__(self, model):
        super(SHAPWrapper, self).__init__()
        self.model = model
        self.model.eval()

    def forward(self, x):
        return self.model(x)[0]  # Extract the first element of the tuple

wrapped_model = SHAPWrapper(encoder_model)

In [None]:
#mapping for proteins
protein_string = ['A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','T','V','W','Y','-']

In [None]:
from captum.attr import IntegratedGradients
ig = IntegratedGradients(wrapped_model)
attributions0, delta0 = ig.attribute(training_data.permute(0,2,1), target=0, return_convergence_delta=True)
attribution_map0 = attributions0.sum(axis=0).cpu().numpy()  # Sum over test data
sns.heatmap(attribution_map0, yticklabels = protein_string, cmap='seismic', center=0)
plt.title("Integrated Gradients Attributions")
plt.show()


In [None]:
from captum.attr import IntegratedGradients
ig = IntegratedGradients(wrapped_model)
attributions1, delta1 = ig.attribute(training_data.permute(0,2,1), target=1, return_convergence_delta=True)
attribution_map1 = attributions1.sum(axis=0).cpu().numpy()  # Sum over test data
sns.heatmap(attribution_map1, yticklabels = protein_string, cmap='seismic', center=0)
plt.title("Integrated Gradients Attributions")
plt.show()


In [None]:
from captum.attr import IntegratedGradients
ig = IntegratedGradients(wrapped_model)
attributions2, delta2 = ig.attribute(training_data.permute(0,2,1), target=2, return_convergence_delta=True)
attribution_map2 = attributions2.sum(axis=0).cpu().numpy()  # Sum over test data
sns.heatmap(attribution_map2, yticklabels = protein_string, cmap='seismic', center=0)
plt.title("Integrated Gradients Attributions")
plt.show()


In [None]:
attribution_map_final = attribution_map0+attribution_map1+attribution_map2
sns.heatmap(attribution_map_final, yticklabels = protein_string, cmap='seismic', center=0)
plt.title("Integrated Gradients Attributions Final")
plt.show()

In [None]:
from captum.attr import GradientShap

reference_input = torch.zeros_like(test_data.permute(0,2,1)).to(args.DEVICE)

gradient_shap = GradientShap(wrapped_model)

attributions0, delta0 = gradient_shap.attribute(test_data.permute(0,2,1), baselines = reference_input, target=0, return_convergence_delta = True)

attributions_values0 = attributions0.cpu().detach().numpy()

plt.imshow(np.abs(attributions_values0.sum(axis=0)), cmap = 'seismic', aspect = 'auto')
plt.yticks(range(len(protein_string)), protein_string)
plt.colorbar()

In [None]:
from captum.attr import GradientShap

reference_input = torch.zeros_like(test_data.permute(0,2,1)).to(args.DEVICE)

gradient_shap = GradientShap(wrapped_model)

attributions1, delta1 = gradient_shap.attribute(test_data.permute(0,2,1), baselines = reference_input, target=1, return_convergence_delta = True)

attributions_values1 = attributions1.cpu().detach().numpy()

plt.imshow(np.abs(attributions_values1.sum(axis=0)), cmap = 'seismic', aspect = 'auto')
plt.yticks(range(len(protein_string)), protein_string)
plt.colorbar()

In [None]:
from captum.attr import GradientShap

reference_input = torch.zeros_like(test_data.permute(0,2,1)).to(args.DEVICE)

gradient_shap = GradientShap(wrapped_model)

attributions2, delta2 = gradient_shap.attribute(test_data.permute(0,2,1), baselines = reference_input, target=2, return_convergence_delta = True)

attributions_values2 = attributions2.cpu().detach().numpy()

plt.imshow(np.abs(attributions_values2.sum(axis=0)), cmap = 'seismic', aspect = 'auto')
plt.yticks(range(len(protein_string)), protein_string)
plt.colorbar()

In [None]:
attributions_value_final = np.abs(attributions_values0.sum(axis=0))+np.abs(attributions_values1.sum(axis=0))+np.abs(attributions_values2.sum(axis=0))
plt.imshow(attributions_value_final, cmap = 'seismic', aspect = 'auto')
plt.yticks(range(len(protein_string)), protein_string)
plt.colorbar()

In [None]:
from captum.attr import GuidedBackprop

guided_backprop = GuidedBackprop(wrapped_model)

attributions0 = guided_backprop.attribute(test_data.permute(0,2,1), target=0)

attributions_values0 = attributions0.cpu().detach().numpy()

plt.imshow(np.abs(attributions_values0.sum(axis=0)), cmap = 'hot', aspect = 'auto')
plt.yticks(range(len(protein_string)), protein_string)
plt.colorbar() 

In [None]:
from captum.attr import GuidedBackprop

guided_backprop = GuidedBackprop(wrapped_model)

attributions1 = guided_backprop.attribute(test_data.permute(0,2,1), target=1)

attributions_values1 = attributions1.cpu().detach().numpy()

plt.imshow(np.abs(attributions_values1.sum(axis=0)), cmap = 'hot', aspect = 'auto')
plt.yticks(range(len(protein_string)), protein_string)
plt.colorbar() 

In [None]:
from captum.attr import GuidedBackprop

guided_backprop = GuidedBackprop(wrapped_model)

attributions2 = guided_backprop.attribute(test_data.permute(0,2,1), target=2)

attributions_values2 = attributions2.cpu().detach().numpy()

plt.imshow(np.abs(attributions_values2.sum(axis=0)), cmap = 'hot', aspect = 'auto')
plt.yticks(range(len(protein_string)), protein_string)
plt.colorbar() 

In [None]:
attributions_value_final = np.abs(attributions_values0.sum(axis=0))+np.abs(attributions_values1.sum(axis=0))+np.abs(attributions_values2.sum(axis=0))
plt.imshow(attributions_value_final, cmap = 'seismic', aspect = 'auto')
plt.yticks(range(len(protein_string)), protein_string)
plt.colorbar()

In [None]:
from captum.attr import NoiseTunnel

noise_tunnel = NoiseTunnel(gradient_shap)

attributions0 = noise_tunnel.attribute(test_data.permute(0,2,1), baselines = reference_input,target=0)

attributions_values0 = attributions0.cpu().detach().numpy()

plt.imshow(np.abs(attributions_values0.sum(axis=0)), cmap = 'hot', aspect = 'auto')
plt.yticks(range(len(protein_string)), protein_string)
plt.colorbar() 

In [None]:
from captum.attr import NoiseTunnel

noise_tunnel = NoiseTunnel(gradient_shap)

attributions1 = noise_tunnel.attribute(test_data.permute(0,2,1), baselines = reference_input,target=1)

attributions_values1 = attributions1.cpu().detach().numpy()

plt.imshow(np.abs(attributions_values1.sum(axis=0)), cmap = 'hot', aspect = 'auto')
plt.yticks(range(len(protein_string)), protein_string)
plt.colorbar() 

In [None]:
from captum.attr import NoiseTunnel

noise_tunnel = NoiseTunnel(gradient_shap)

attributions2 = noise_tunnel.attribute(test_data.permute(0,2,1), baselines = reference_input,target=2)

attributions_values2 = attributions2.cpu().detach().numpy()

plt.imshow(np.abs(attributions_values2.sum(axis=0)), cmap = 'hot', aspect = 'auto')
plt.yticks(range(len(protein_string)), protein_string)
plt.colorbar() 

In [None]:
attributions_value_final = np.abs(attributions_values0.sum(axis=0))+np.abs(attributions_values1.sum(axis=0))+np.abs(attributions_values2.sum(axis=0))
plt.imshow(attributions_value_final, cmap = 'seismic', aspect = 'auto')
plt.yticks(range(len(protein_string)), protein_string)
plt.colorbar()

In [None]:
from captum.attr import FeatureAblation

feature_ablation = FeatureAblation(wrapped_model)

attributes0 = feature_ablation.attribute(test_data.permute(0,2,1), target = 0, show_progress = True)

attributes0 = attributes0.cpu().detach().numpy()
plt.imshow(np.abs(attributes0.sum(axis=0)), cmap = 'hot', aspect = 'auto')
plt.yticks(range(len(protein_string)), protein_string)
plt.colorbar() 

In [None]:
from captum.attr import FeatureAblation

feature_ablation = FeatureAblation(wrapped_model)

attributes1 = feature_ablation.attribute(test_data.permute(0,2,1), target = 1, show_progress = True)

attributes1 = attributes1.cpu().detach().numpy()
plt.imshow(np.abs(attributes1.sum(axis=0)), cmap = 'hot', aspect = 'auto')
plt.yticks(range(len(protein_string)), protein_string)
plt.colorbar() 

In [None]:
from captum.attr import FeatureAblation

feature_ablation = FeatureAblation(wrapped_model)

attributes2 = feature_ablation.attribute(test_data.permute(0,2,1), target = 2, show_progress = True)

attributes2 = attributes2.cpu().detach().numpy()
plt.imshow(np.abs(attributes2.sum(axis=0)), cmap = 'hot', aspect = 'auto')
plt.yticks(range(len(protein_string)), protein_string)
plt.colorbar() 

In [None]:
attributions_value_final = np.abs(attributions_values0.sum(axis=0))+np.abs(attributions_values1.sum(axis=0))+np.abs(attributions_values2.sum(axis=0))
plt.imshow(attributions_value_final, cmap = 'seismic', aspect = 'auto')
plt.yticks(range(len(protein_string)), protein_string)
plt.colorbar()

In [None]:
from captum.attr import FeaturePermutation

feature_permutation = FeaturePermutation(wrapped_model)

#test_data0 = torch.tensor(seq_OH[50:60], dtype = torch.float32).to(args.DEVICE)

attributions0 = feature_permutation.attribute(test_data.permute(0,2,1), show_progress = True, target = 0)

attributions_values0 = attributions0.cpu().detach().numpy()

plt.imshow(np.abs(attributions_values0.sum(axis=0)), cmap = 'hot', aspect = 'auto')
plt.yticks(range(len(protein_string)), protein_string)
plt.colorbar() 

In [None]:
from captum.attr import FeaturePermutation

feature_permutation = FeaturePermutation(wrapped_model)

#test_data0 = torch.tensor(seq_OH[50:60], dtype = torch.float32).to(args.DEVICE)

attributions1 = feature_permutation.attribute(test_data.permute(0,2,1), show_progress = True, target = 1)

attributions_values1 = attributions1.cpu().detach().numpy()

plt.imshow(np.abs(attributions_values1.sum(axis=0)), cmap = 'hot', aspect = 'auto')
plt.yticks(range(len(protein_string)), protein_string)
plt.colorbar() 

In [None]:
from captum.attr import FeaturePermutation

feature_permutation = FeaturePermutation(wrapped_model)

#test_data0 = torch.tensor(seq_OH[50:60], dtype = torch.float32).to(args.DEVICE)

attributions2 = feature_permutation.attribute(test_data.permute(0,2,1), show_progress = True, target = 2)

attributions_values2 = attributions2.cpu().detach().numpy()

plt.imshow(np.abs(attributions_values2.sum(axis=0)), cmap = 'hot', aspect = 'auto')
plt.yticks(range(len(protein_string)), protein_string)
plt.colorbar() 

In [None]:
attributions_value_final = np.abs(attributions_values0.sum(axis=0))+np.abs(attributions_values1.sum(axis=0))+np.abs(attributions_values2.sum(axis=0))
plt.imshow(attributions_value_final, cmap = 'seismic', aspect = 'auto')
plt.yticks(range(len(protein_string)), protein_string)
plt.colorbar()