# PyTorch Deep Explainer DeepSEA example

Running the pytorch Deep Explainer on the DeepSEA model from the kipoi repo

First, we pull the kipoi pytorch model and load it up

In [1]:
import torch

#pull the model from kipoi
! [[ -f deepsea_model_architecture.py ]] || wget https://raw.githubusercontent.com/kipoi/models/master/DeepSEA/model_architecture.py -O deepsea_model_architecture.py
! [[ -f deepsea_model_weights.pth ]] || wget https://zenodo.org/record/1466993/files/deepsea_variant_effects.pth?download=1 -O deepsea_model_weights.pth
from deepsea_model_architecture import veff_model as pytorch_model, Lambda
pytorch_model.load_state_dict(torch.load("deepsea_model_weights.pth"))
pytorch_model = pytorch_model.eval()

In [2]:
import torch.nn as nn

def gather_list_of_layers(module):
    layers_list = []
    for child in module.children():
        if 'nn.modules.container' in str(type(child)):
            layers_list.extend(gather_list_of_layers(child))
        else:
            layers_list.append(child)
    return layers_list


def create_interpretation_model(original_model, idx_to_select, prenonlinearity):
    all_layers = gather_list_of_layers(original_model)
    if (prenonlinearity):
        all_layers = all_layers[:-1]
    all_layers.append(Lambda(lambda x: x[:,idx_to_select:(idx_to_select+1)]))
    interpretation_model = nn.Sequential(*all_layers)
    return interpretation_model


In [3]:
example_sequences = [
    "GGCGATCCTTAGGCCTTGGCCCTGAGACCCCAGGCGAGGTCAGCAACCCAACCG"
    "GGGTGGGACAGGACGAGCAAGAGGTTCTGCTCACGCATGTCCCCACTAACCTGG"
    "CCGAGGGGCTCCCGCCCGGCTTATCCGGACTCCGGGCAGCCTCGCGTGCTTCCC"
    "GTGTCTCCGCTTGTGGAGAATTTTCGGACTCGGATTCGGACTCGGAGTCAAAGC"
    "CCGAAGCTAGGAACTCGTCCACCGTCAGCTCCGCCAGGCGCCTGCGGGTCACGC"
    "AGGAGTCACAGCTGCCCGCACGCCCAGCTCGCCCCAGCCCCGCTGAGAGGAGCA"
    "AGAAAAGCCCCCTTGGATACAGACACCCACCGGGAGGCCAAATCGGCCCTCGGA"
    "CCCGCGGCTTACCTCTTGCGGCTCCCCGCAGCTGCCATGACACCAACCCGAAGC"
    "GTGCACCCCACTTCCGGCCCCAGAATGCCGCGCGGCTGCGCACTTCCGccgccc"
    "aggccccgcccctttccccgccccgccgcgccacgcccagccGAGTGGCTCTAT"
    "GGTTCTCCGACCGCAACGCCGGCGGCCTCAGGGCGGGAGGGCGCGTTCGCGTGC"
    "TCGGTGCGGGCAGCCCCGGTGGGGCCCAGATGCGCCTCCCGCTCGGCGCCCGGC"
    "TCCGTAGGACGCGGTGACGCCGGTGTCCGCCCCGGGGAAGACCGGGAGTCCCGC"
    "CGCGCCCGCAGCCCACCCGGCGCTCCGAAGGCACGCGCCTGCGAGGACGCCAGA"
    "CTGCAACGGCGGGGCTCCTATGCAAAGAGCTCCCACAAATCAACAATAAAAAGC"
    "AGGGAGTCCAGTGGAAAACGCGAGGGGCAGTGGGAACCGCACTGATGTCGCCAG"
    "CTCGACAAAAGACGGGCGACCCGAGGGCCAGGCTGGCTTCGCCTCCGATCCGCG"
    "GAGACCGGGCCAGCGCCACGAACACCACGCAGGGCGCTCCCCGTCCATGGCCCT"
    "CTGGGTGCCGACCGCGGCTCTTCCCGGG",
    "GGGCTGAGGGTGGCCGGGCGGCTGCACACTAGCTGGGTCGCGGCGCAGAAACGC"
    "AGGGGCCGCGAGTGCGCTGGCCGGCGGGTGTCCCGGGTCCACGCTTACGGTCCT"
    "CATGTTCTTTTTCTTCAGGTATCGGGCTTTGGTGCATTTCACAAAGGCTCGAAT"
    "CACGGTTCTGACCGCCAACCTGTAGCAGCGATTTTTCCTTCCCCGGAAGTGCTG"
    "GGACAGAAAACGAGAAACCAGGGTTGTCAgcggggcccgcgccggccgccccTT"
    "GGCCCGCGGGATACCCCGGGCGCCCAGTGCCCAGGCCGGGCAGGCGGCACTCAC"
    "CCTGGCGTGCTTCAGCACCTCCTGGATCCGAAAGTAGCGGTCGGTGACGCGATT"
    "CCGCAGCCAGAGCTGCGCGGTGAGGAAGACCATGGCGCCTGCAGGCCGGCGTCC"
    "CGAACACTCAACAACGCACGCGCAGCGCCGCTGCCATCTTGCCCGGGTCGGAAA"
    "TGGTGGTCACGAGCGCTTCCGGGTCAGCCCCTGCGATACTTCCGGGGCGAAGGT"
    "CGTCTCCCGTCAGCCCGCGGGTGCCCAGTTGTGCTCCTGAACTCGCGGTGGTGG"
    "TGCGTGTTGGGGAGCGGATGTGGGGCCGCGGCGGGGACTGAAAGGAGAACGGGG"
    "CCGCAGCGCCCGTGGCTATTCGCGGACGATGGATAAACAGCAGCGCACGCGGAC"
    "CGTCCCGGAGCACGGCCCCGGCCGCAGCTGTGGCTCCGAGGGCACCGTGAGGGC"
    "AGCGGACCCGGGTcgggggccccgcggccggggagctcgggtgcggcgcgggcg"
    "gggaggggcaggccgcccccTGGGGCCACGAGGATGTTCAGGAACCGAGGTGGA"
    "GATGGTCGCATCGGTGTGAAAGTGCCCGTTGCCTCTGAACCTTGCACtttgttt"
    "acttactcattttgagacggggtctcgcccggtcgccctggctggggtgcagcg"
    "gcccgacctcggctcgccgcggcctctg"
]   

In [4]:
#code for one-hot encoding
import numpy as np
import torch


#this is set up for 1d convolutions where examples
#have dimensions (len, num_channels)
#the channel axis is the axis for one-hot encoding.
def one_hot_encode_along_channel_axis(sequence):
    to_return = np.zeros((len(sequence),4), dtype=np.int8)
    seq_to_one_hot_fill_in_array(zeros_array=to_return,
                                 sequence=sequence, one_hot_axis=1)
    return to_return.transpose((1,0))[:,None,:]


def seq_to_one_hot_fill_in_array(zeros_array, sequence, one_hot_axis):
    assert one_hot_axis==0 or one_hot_axis==1
    if (one_hot_axis==0):
        assert zeros_array.shape[1] == len(sequence)
    elif (one_hot_axis==1): 
        assert zeros_array.shape[0] == len(sequence)
    #will mutate zeros_array
    for (i,char) in enumerate(sequence):
        if (char=="A" or char=="a"):
            char_idx = 0
        elif (char=="C" or char=="c"):
            char_idx = 1
        elif (char=="G" or char=="g"):
            char_idx = 2
        elif (char=="T" or char=="t"):
            char_idx = 3
        elif (char=="N" or char=="n"):
            continue #leave that pos as all 0's
        else:
            raise RuntimeError("Unsupported character: "+str(char))
        if (one_hot_axis==0):
            zeros_array[char_idx,i] = 1
        elif (one_hot_axis==1):
            zeros_array[i,char_idx] = 1

            
onehot_data = np.array([one_hot_encode_along_channel_axis(seq) for seq in example_sequences])


In [5]:
out1 = pytorch_model(torch.tensor(onehot_data.astype("float32"))).detach().numpy()
interpretation_model = create_interpretation_model(pytorch_model,65, prenonlinearity=True)
out2 = interpretation_model(torch.tensor(onehot_data.astype("float32"))).detach().numpy()

In [6]:
print(out1)
print(out2)
print(interpretation_model)

[[8.1442165e-01 7.3527199e-01 7.6854521e-01 ... 8.9137870e-01
  2.0305334e-02 7.7324861e-05]
 [8.5935098e-01 8.8919538e-01 8.9216346e-01 ... 9.9293751e-01
  6.0476321e-01 7.9107616e-04]]
[[2.7659173]
 [2.9149015]]
Sequential(
  (0): ReCodeAlphabet()
  (1): Conv2d(4, 320, kernel_size=(1, 8), stride=(1, 1))
  (2): Threshold(threshold=0, value=1e-06)
  (3): MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0, dilation=1, ceil_mode=False)
  (4): Dropout(p=0.2)
  (5): Conv2d(320, 480, kernel_size=(1, 8), stride=(1, 1))
  (6): Threshold(threshold=0, value=1e-06)
  (7): MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0, dilation=1, ceil_mode=False)
  (8): Dropout(p=0.2)
  (9): Conv2d(480, 960, kernel_size=(1, 8), stride=(1, 1))
  (10): Threshold(threshold=0, value=1e-06)
  (11): Dropout(p=0.5)
  (12): Lambda()
  (13): Lambda()
  (14): Linear(in_features=50880, out_features=925, bias=True)
  (15): Threshold(threshold=0, value=1e-06)
  (16): Lambda()
  (17): Linear(in_features=925, out_

In [7]:
import numpy as np
import shap
import importlib
from importlib import reload
reload(shap.explainers.deep.deep_pytorch)
reload(shap.explainers.deep)
reload(shap.explainers)
reload(shap)
import torch

e = shap.DeepExplainer(interpretation_model, torch.tensor(np.zeros((1,4,1,1000)).astype("float32")))
explanation = e.shap_values(torch.tensor(onehot_data.astype("float32")))



In [8]:
print(np.sum(explanation[0]))

7.154758927599687
