In [1]:
import shap
import io
import numpy as np
import pandas as pd
import torch
import sys
import os
from torch import nn
import argparse
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.nn import functional as F
from torch.utils.data import Dataset,DataLoader
from functools import reduce
from subprocess import Popen
from Bio import SeqIO
header='/home/yiyou/test/tmpout'

In [2]:
!pip install deeplift



In [3]:
def fasta2binonehot(data):
    # data is a list of sequence: [n,seqlength]
    # possibly need list version where seqlength differ
    data=np.squeeze(np.array(list(map(list, data))))
    A = np.zeros_like(data,dtype=int)
    C = np.zeros_like(data,dtype=int)
    G = np.zeros_like(data,dtype=int)
    U = np.zeros_like(data,dtype=int)
    A[data == 'A'] = 1
    C[data == 'C'] = 1
    G[data == 'G'] = 1
    U[data == 'U'] = 1
    U[data == 'T'] = 1
    A = A[..., np.newaxis]
    C = C[..., np.newaxis]
    G = G[..., np.newaxis]
    U = U[..., np.newaxis]
    bindata=np.append(A,C,axis=-1)
    bindata = np.append(bindata, G, axis=-1)
    bindata = np.append(bindata, U, axis=-1)
    return bindata

In [4]:
seq_list=[]
for seq_record in SeqIO.parse('%s/sequence.fasta'%(header),format='fasta'):
    sequence=seq_record.seq
    seq_list.append(sequence)
seq_list=np.asarray(seq_list)
sequence=fasta2binonehot(seq_list)
sequence.shape

(252009, 2001, 4)

In [5]:
class MultiAdaptPooling(nn.Module):
    def __init__(self, model, outsizelist=np.array([9, 25, 64])):
        super(MultiAdaptPooling, self).__init__()
        self.model = model
        self.modellist = []
        for i in outsizelist:
            self.modellist.append(nn.AdaptiveAvgPool1d(i))
    def forward(self, x):
        outlist = []
        for model in self.modellist:
            outlist.append(self.model(model(x)))
        out=torch.cat(outlist, -1)
        return out

In [6]:
class ExpressRM(pl.LightningModule):
    # unet assume seqlength to be ~500
    def __init__(self,useseq=True,usegeo=True,usetgeo=True,usegene=True,usegenelocexp=True, patchsize=7, patchstride=5, inchan=4, dim=64, kernelsize=7,
                 adaptoutsize=9, geneoutsize=500, geooutsize=32, droprate=0.25, lr=2e-5):
        super(ExpressRM, self).__init__()
        self.useseq = useseq
        self.usegeo = usegeo
        self.usegene = usegene
        self.usegenelocexp = usegenelocexp
        self.usetgeo = usetgeo
        self.droprate = droprate
        self.seqoutsize = 4 * adaptoutsize * dim
        self.geneoutsize = geneoutsize
        self.geooutsize = geooutsize
        self.learning_rate = lr
        self.posweight=torch.as_tensor(3.0)
        self.save_hyperparameters()
        self.conv_model = nn.Sequential(
            nn.Conv1d(in_channels=inchan, out_channels=dim, kernel_size=patchsize, stride=patchstride),
            nn.BatchNorm1d(dim),
            nn.LeakyReLU(),
            nn.Dropout(droprate),
            nn.Conv1d(in_channels=dim, out_channels=dim, kernel_size=kernelsize),
            nn.BatchNorm1d(dim),
            nn.LeakyReLU(),
            nn.Dropout(droprate),
            nn.MaxPool1d(2),
            nn.Conv1d(in_channels=dim, out_channels=dim, kernel_size=kernelsize),
            nn.BatchNorm1d(dim),
            nn.LeakyReLU(),
            nn.Dropout(droprate))
        self.adaptconv_model = MultiAdaptPooling(
            nn.Sequential(
                nn.Conv1d(in_channels=dim, out_channels=dim, kernel_size=kernelsize),
                nn.BatchNorm1d(dim),
                nn.LeakyReLU(),
                nn.Dropout(droprate),
                nn.Conv1d(in_channels=dim, out_channels=dim, kernel_size=kernelsize),
                nn.BatchNorm1d(dim),
                nn.LeakyReLU(),
                nn.Dropout(droprate),
                nn.AdaptiveAvgPool1d(adaptoutsize + 2*(kernelsize - 1)),
                nn.Conv1d(in_channels=dim, out_channels=dim, kernel_size=kernelsize),
                nn.BatchNorm1d(dim),
                nn.LeakyReLU(),
                nn.Dropout(droprate),
                nn.Conv1d(in_channels=dim, out_channels=dim, kernel_size=kernelsize),
                nn.BatchNorm1d(dim),
                nn.LeakyReLU(),
                nn.Dropout(droprate),
                nn.Flatten()
            )
            , np.array([16, 32, 64, 128]))
        self.geneenc = nn.Sequential(nn.Linear(28278, 1000), nn.LeakyReLU(), nn.Dropout(self.droprate),
                                     nn.Linear(1000, self.geneoutsize), nn.LeakyReLU())
        self.predicationhead = nn.Sequential(
            # nn.Flatten(1,-1),
            nn.Linear(self.seqoutsize + self.geneoutsize + 12 + 1, 2048),
            nn.LeakyReLU(),
            nn.Dropout(droprate),
            nn.Linear(2048, 1024),
            nn.LeakyReLU(),
            nn.Dropout(droprate),
            nn.Linear(1024, 1024),
            nn.LeakyReLU(),
            nn.Dropout(droprate),
            nn.Linear(1024, 4),
        )
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer
    def forward(self, x,geo,gene,genelocexp):
        # x=mergedinput[:,:8004].reshape([-1,2001,4])
        # geo=mergedinput[:,8004:8016].reshape([-1,1,12])
        # gene=mergedinput[:,8016:8016+28278].reshape([-1,28278])
        # genelocexp=mergedinput[:,-1].reshape([-1,1,1])
        # print(x.shape)
        # print(geo.shape)
        # print(gene.shape)
        # print(genelocexp.shape)
        
        batchsize = x.size()[0]
        tissuesize = 1
        if self.useseq:
            x = x.transpose(-1, -2)
            adaptout = self.adaptconv_model(self.conv_model(x))
        # seq [N,2304]
        if self.usegene:
            # gene= self.geneenc(torch.mean(self.geneatt(geneloc,gene),dim=-2))
            gene= self.geneenc(gene)
        else:
            gene= torch.zeros([batchsize,tissuesize,self.geneoutsize]).float().cuda()
            #[N,37,24]
        if not self.usetgeo:
                    geo[:,:,6:]*=0
        if not self.usegeo:
                geo[:, :, :6] *= 0
        if not self.usegenelocexp:
            genelocexp*=0
        # for entry in [adaptout, gene, geo.squeeze(1), genelocexp]:
        #     print(entry.shape)
        adaptout = torch.cat([adaptout, gene, geo.squeeze(1), genelocexp], dim=-1)
        out = self.predicationhead(adaptout)
        return out
device='cpu'
model=ExpressRM().load_from_checkpoint('/home/yiyou/test/model.ckpt',map_location=device)
model.eval()

ExpressRM(
  (conv_model): Sequential(
    (0): Conv1d(4, 64, kernel_size=(7,), stride=(5,))
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01)
    (3): Dropout(p=0.25, inplace=False)
    (4): Conv1d(64, 64, kernel_size=(7,), stride=(1,))
    (5): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): LeakyReLU(negative_slope=0.01)
    (7): Dropout(p=0.25, inplace=False)
    (8): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (9): Conv1d(64, 64, kernel_size=(7,), stride=(1,))
    (10): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): LeakyReLU(negative_slope=0.01)
    (12): Dropout(p=0.25, inplace=False)
  )
  (adaptconv_model): MultiAdaptPooling(
    (model): Sequential(
      (0): Conv1d(64, 64, kernel_size=(7,), stride=(1,))
      (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_runnin

In [7]:
refgeo=np.asarray(pd.read_csv('%s/geo.csv'%(header)))[:,6:]
tgeo=np.asarray(pd.read_csv('%s/tgeo0.csv'%(header)))[:,6:]


In [8]:
hostgeneexp=np.asarray(pd.read_csv('%s/lg2hosting_expression.csv'%(header)))[:,:1]
geneexp=np.asarray(pd.read_csv('%s/lg2geneexp.csv'%(header),index_col=0))[:,:1]
# generalhostexp=np.asarray(pd.read_csv('/home/yiyou/lg2hosting_expression.csv',index_col=0))
generalexp=np.asarray(pd.read_csv('/home/yiyou/lg2geneexp.csv',index_col=0))
generalexppd=pd.read_csv('/home/yiyou/lg2geneexp.csv',index_col=0)

In [9]:
geo_list=[]
geo=np.append(refgeo,tgeo,axis=-1)
geo_list.append(geo)
for tissue in np.array(generalexppd.columns):
    tgeo=np.asarray(pd.read_csv('/home/yiyou/tissue/geo/%s.csv'%(tissue)))[:,6:]
    geo=np.append(refgeo,tgeo,axis=-1)
    geo_list.append(geo)
geo=np.array(geo_list).transpose([1,0,2])
geo.shape

(252009, 38, 12)

In [10]:
geneexp=np.append(geneexp,generalexp,axis=1)

In [11]:

sequence=torch.as_tensor(sequence).float().to(device)
geo=torch.as_tensor(geo).float().to(device)
hostgeneexp=torch.as_tensor(hostgeneexp).float().unsqueeze(2).to(device)
geneexp=torch.as_tensor(geneexp).float().transpose(1,0).to(device)

In [12]:
def extract_entry(i,j):
    return [sequence[i:i+1], geo[i:i+1,j], geneexp[j:j+1], hostgeneexp[i:(i+1),0]]
def extract_group_entry(npi,npj):
    entry_list=[[],[],[],[]]
    for j in npj:
        for i in npi:
            entry=extract_entry(i,j)
            for k in range(4):
                entry_list[k].append(entry[k])
    
    for k in range(4):
        print(k)
        entry_list[k]=torch.cat(entry_list[k])
    return entry_list

In [13]:
tesidx=np.load('/home/yiyou/testidx.npy')

In [14]:
top_entries=torch.argsort(model.forward(*extract_group_entry(tesidx[:200],np.arange(1)))[:,0],descending=True)[:20]

0
1
2
3


In [27]:
top_entries.numpy()

array([172, 121, 136, 149,  46,  37,  14, 180, 144,  83,   5,  76, 169,
       158,  26,  11,  60,  35, 104, 182])

In [16]:
shap_explainer = shap.GradientExplainer(model, extract_group_entry(tesidx[:2000],np.arange(38)))
raw_shap_explanations = shap_explainer.shap_values(extract_group_entry(top_entries.numpy(),[0]))

0
1
2
3
0
1
2
3


In [34]:
top_entries2000=torch.argsort(model.forward(*extract_group_entry(tesidx[:2000],np.arange(1)))[:,0],descending=True)[:200]
raw_shap_explanations200 = shap_explainer.shap_values(extract_group_entry(top_entries2000.numpy(),[0]))

0
1
2
3
0
1
2
3


In [52]:
def quantiles(a):
    print(a.shape)
    print(np.quantile(a,0.001))
    print(np.quantile(a,0.01))
    print(np.quantile(a,0.1))
    print(np.quantile(a,0.25))
    print(np.quantile(a,0.5))
    print(np.quantile(a,0.75))
    print(np.quantile(a,0.9))
    print(np.quantile(a,0.99))
    print(np.quantile(a,0.999))
quantiles(raw_shap_explanations200[0][2])

(200, 28278)
-0.4911646758532898
-0.19461555771718295
-0.023784792024426002
-0.0036727941354961434
0.002633280063718251
0.027812499545780262
0.08363447141400158
0.358455559018508
0.6696466023355953


In [49]:

gene_list=[]
gene_idx_list=np.zeros_like(raw_shap_explanations200[0][2][0])
gene_idx=pd.read_csv('/home/yiyou/lg2geneexp.csv',index_col=0).index
for i in range(200):
    a=raw_shap_explanations200[0][2][i]
    gene_idx_list+=a>np.quantile(a,0.99)
    gene_list.append(gene_idx[a>np.quantile(a,0.99)])
top_idx=np.where(gene_idx_list>=190)[0]

In [20]:
np.set_printoptions(suppress=True)
torch.set_printoptions(sci_mode=False)

In [53]:
import pickle
with open('shap_value_top20.pkl', 'wb') as output:
    pickle.dump(raw_shap_explanations[0],output)
output.close()
with open('shap_value_top200.pkl', 'wb') as output:
    pickle.dump(raw_shap_explanations200[0],output)
output.close()

with open('shap_value_top_full20.pkl', 'wb') as output:
    pickle.dump(raw_shap_explanations,output)
output.close()
with open('shap_value_top_full200.pkl', 'wb') as output:
    pickle.dump(raw_shap_explanations200,output)
output.close()



In [54]:
with open('shap_value_top_full200.pkl', 'rb') as file:
    tmp=pickle.load(file)

unidentified/unexpressed genes are most important?

In [47]:
tmp[0].shape

(200, 2001, 4)

In [50]:
geneexp[0,top_idx]-torch.mean(geneexp[:,top_idx],axis=0)

tensor([ -7.9097,  -6.6658,  -5.1002,  -6.9264,  -5.2148,   5.0860,  -6.5798,
         -4.5753,  -6.0232,  -6.0740,  -4.9180,  -5.2970,  -8.4340,  -5.8198,
         -7.9616,  -3.0600,  -6.4386,  -5.8607,  -6.5008,  -7.8746,  -8.3614,
         -6.2209,  -6.1740,  -5.3491,  -7.6326,  -4.7816,  -3.9192,  -5.8391,
         -5.4816,  -7.3350,  -6.3848,  -4.8049,  -4.8870,  -8.7246,  -6.2435,
         -7.3524,  -4.4110,  -5.3610,  -5.7321,  -6.4962,  -7.2795,  -4.7718,
         -5.7339,  -6.5236,  -5.5394,  -5.4978,  -7.2411,  -5.8040,  -6.6957,
         -5.3208,  -7.2633,  -4.9193,  -5.3102,  -6.6393,  -6.2073,  -6.1045,
         -8.5248,  -5.7696,  -6.9978,  -6.8733,  -7.0491,  -4.3317,  -7.0385,
         -4.1929,  -6.4812,  -5.9206,  -7.2155,  -5.3074,  -5.1881,  -6.5410,
         -6.2910,  -4.9710,  -5.9401,  -9.0877,  -5.5116,  -5.9718,  -4.2068,
         -6.1840,  -5.7003,  -5.5157,  -5.5467,  -8.5396,  -7.9249,  -5.1496,
         -7.1384,  -7.2848,  -6.1860,  -6.1407,  -7.4484,  -4.81

In [24]:
gene_idx[top_idx[5]]

'DOCK7'

cell below shows top-0.1% positive influenctial genes in all twenty examined samples.

In [25]:
np.unique(np.array(gene_list),return_counts=1)[0][np.unique(np.array(gene_list),return_counts=1)[1]>19]

array(['ADCY9', 'ADO', 'ADRB1', 'AFDN', 'AFG3L2', 'AHNAK2', 'AIG1',
       'AKAP1', 'AKAP13', 'AKAP6', 'AKT1', 'ALG11', 'ANGPTL2', 'ANO8',
       'ANTXR2', 'ARHGAP21', 'ARHGAP5-AS1', 'ARMCX6', 'ATN1', 'ATP6V0D1',
       'AVPI1', 'BACE1', 'BCCIP', 'BEX3', 'BICC1', 'BMI1', 'BOD1L1',
       'BPTF', 'BRF1', 'C11orf95', 'C16orf91', 'C1GALT1C1', 'C1orf116',
       'CALR', 'CCDC32', 'CCDC9B', 'CCL21', 'CDC42EP2', 'CDK12',
       'CDK2AP2', 'CEP85L', 'CEP97', 'CHST2', 'CIC', 'CLCN7', 'CLMP',
       'CRIP1', 'CSRNP2', 'CTSC', 'CUL1', 'DAPK1', 'DCAF10', 'DCBLD1',
       'DEPP1', 'DIPK1B', 'DIPK2A', 'DOCK7', 'DYRK1A', 'EDNRB', 'EFL1',
       'EFNB2', 'EGLN2', 'ELK4', 'EMILIN1', 'EP300', 'EPB41L2', 'ERAL1',
       'FAF2', 'FAM171A1', 'FBXL12', 'FBXL3', 'FBXO34', 'FGF9', 'FILIP1L',
       'FRMD4A', 'FRMD6', 'GABPB1-AS1', 'GBGT1', 'GCC1', 'GDE1', 'GEMIN4',
       'GIMAP8', 'GLYR1', 'GNE', 'GOLGA8B', 'GTF2A1', 'H2AJ', 'HDAC7',
       'HIVEP2', 'HNRNPF', 'HSP90B1', 'HSPA2', 'HYAL2', 'IER5L', 'JADE2',


cell below shows top-20 positive influenctial genes in one sample.

In [26]:
pd.read_csv('/home/yiyou/lg2geneexp.csv',index_col=0).index[np.argsort(raw_shap_explanations[0][2][0])[-20:]][::-1]

Index(['HYAL2', 'AKAP13', 'CTSC', 'THBS1', 'HIVEP2', 'PPP4R3A', 'MYL12B',
       'FBXL12', 'ATN1', 'ARHGAP21', 'RPPH1', 'HSPA2', 'SOS1', 'MIER3',
       'PCDHGA10', 'HNRNPF', 'ZC3H4', 'DYRK1A', 'ANTXR2', 'AFG3L2'],
      dtype='object', name='GeneName')