In [6]:
import shap
import io
import numpy as np
import pandas as pd
import torch
import sys
import os
from torch import nn
import argparse
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.nn import functional as F
from torch.utils.data import Dataset,DataLoader
from functools import reduce
from subprocess import Popen
from Bio import SeqIO
header='/home/yiyou/test/tmpout'

In [34]:
!pip install deeplift

Collecting deeplift
  Downloading deeplift-0.6.13.0.tar.gz (30 kB)
Building wheels for collected packages: deeplift
  Building wheel for deeplift (setup.py) ... [?25ldone
[?25h  Created wheel for deeplift: filename=deeplift-0.6.13.0-py3-none-any.whl size=36447 sha256=1f95bc026012ba8839cc2255a87f25b51f83322ce14127b129e086efd9079e88
  Stored in directory: /home/yiyou/.cache/pip/wheels/80/42/80/d6af8dbe1e394d4696459ed54b21787722b9bcb9e240dd81f5
Successfully built deeplift
Installing collected packages: deeplift
Successfully installed deeplift-0.6.13.0


In [4]:
def fasta2binonehot(data):
    # data is a list of sequence: [n,seqlength]
    # possibly need list version where seqlength differ
    data=np.squeeze(np.array(list(map(list, data))))
    A = np.zeros_like(data,dtype=int)
    C = np.zeros_like(data,dtype=int)
    G = np.zeros_like(data,dtype=int)
    U = np.zeros_like(data,dtype=int)
    A[data == 'A'] = 1
    C[data == 'C'] = 1
    G[data == 'G'] = 1
    U[data == 'U'] = 1
    U[data == 'T'] = 1
    A = A[..., np.newaxis]
    C = C[..., np.newaxis]
    G = G[..., np.newaxis]
    U = U[..., np.newaxis]
    bindata=np.append(A,C,axis=-1)
    bindata = np.append(bindata, G, axis=-1)
    bindata = np.append(bindata, U, axis=-1)
    return bindata

In [7]:
seq_list=[]
for seq_record in SeqIO.parse('%s/sequence.fasta'%(header),format='fasta'):
    sequence=seq_record.seq
    seq_list.append(sequence)
seq_list=np.asarray(seq_list)
sequence=fasta2binonehot(seq_list)
sequence.shape

(252009, 2001, 4)

In [15]:
class MultiAdaptPooling(nn.Module):
    def __init__(self, model, outsizelist=np.array([9, 25, 64])):
        super(MultiAdaptPooling, self).__init__()
        self.model = model
        self.modellist = []
        for i in outsizelist:
            self.modellist.append(nn.AdaptiveAvgPool1d(i))
    def forward(self, x):
        outlist = []
        for model in self.modellist:
            outlist.append(self.model(model(x)))
        out=torch.cat(outlist, -1)
        return out

In [140]:
class ExpressRM(pl.LightningModule):
    # unet assume seqlength to be ~500
    def __init__(self,useseq=True,usegeo=True,usetgeo=True,usegene=True,usegenelocexp=True, patchsize=7, patchstride=5, inchan=4, dim=64, kernelsize=7,
                 adaptoutsize=9, geneoutsize=500, geooutsize=32, droprate=0.25, lr=2e-5):
        super(ExpressRM, self).__init__()
        self.useseq = useseq
        self.usegeo = usegeo
        self.usegene = usegene
        self.usegenelocexp = usegenelocexp
        self.usetgeo = usetgeo
        self.droprate = droprate
        self.seqoutsize = 4 * adaptoutsize * dim
        self.geneoutsize = geneoutsize
        self.geooutsize = geooutsize
        self.learning_rate = lr
        self.posweight=torch.as_tensor(3.0)
        self.save_hyperparameters()
        self.conv_model = nn.Sequential(
            nn.Conv1d(in_channels=inchan, out_channels=dim, kernel_size=patchsize, stride=patchstride),
            nn.BatchNorm1d(dim),
            nn.LeakyReLU(),
            nn.Dropout(droprate),
            nn.Conv1d(in_channels=dim, out_channels=dim, kernel_size=kernelsize),
            nn.BatchNorm1d(dim),
            nn.LeakyReLU(),
            nn.Dropout(droprate),
            nn.MaxPool1d(2),
            nn.Conv1d(in_channels=dim, out_channels=dim, kernel_size=kernelsize),
            nn.BatchNorm1d(dim),
            nn.LeakyReLU(),
            nn.Dropout(droprate))
        self.adaptconv_model = MultiAdaptPooling(
            nn.Sequential(
                nn.Conv1d(in_channels=dim, out_channels=dim, kernel_size=kernelsize),
                nn.BatchNorm1d(dim),
                nn.LeakyReLU(),
                nn.Dropout(droprate),
                nn.Conv1d(in_channels=dim, out_channels=dim, kernel_size=kernelsize),
                nn.BatchNorm1d(dim),
                nn.LeakyReLU(),
                nn.Dropout(droprate),
                nn.AdaptiveAvgPool1d(adaptoutsize + 2*(kernelsize - 1)),
                nn.Conv1d(in_channels=dim, out_channels=dim, kernel_size=kernelsize),
                nn.BatchNorm1d(dim),
                nn.LeakyReLU(),
                nn.Dropout(droprate),
                nn.Conv1d(in_channels=dim, out_channels=dim, kernel_size=kernelsize),
                nn.BatchNorm1d(dim),
                nn.LeakyReLU(),
                nn.Dropout(droprate),
                nn.Flatten()
            )
            , np.array([16, 32, 64, 128]))
        self.geneenc = nn.Sequential(nn.Linear(28278, 1000), nn.LeakyReLU(), nn.Dropout(self.droprate),
                                     nn.Linear(1000, self.geneoutsize), nn.LeakyReLU())
        self.predicationhead = nn.Sequential(
            # nn.Flatten(1,-1),
            nn.Linear(self.seqoutsize + self.geneoutsize + 12 + 1, 2048),
            nn.LeakyReLU(),
            nn.Dropout(droprate),
            nn.Linear(2048, 1024),
            nn.LeakyReLU(),
            nn.Dropout(droprate),
            nn.Linear(1024, 1024),
            nn.LeakyReLU(),
            nn.Dropout(droprate),
            nn.Linear(1024, 4),
        )
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer
    def forward(self, x,geo,gene,genelocexp):
        # x=mergedinput[:,:8004].reshape([-1,2001,4])
        # geo=mergedinput[:,8004:8016].reshape([-1,1,12])
        # gene=mergedinput[:,8016:8016+28278].reshape([-1,28278])
        # genelocexp=mergedinput[:,-1].reshape([-1,1,1])
        # print(x.shape)
        # print(geo.shape)
        # print(gene.shape)
        # print(genelocexp.shape)
        
        batchsize = x.size()[0]
        tissuesize = 1
        if self.useseq:
            x = x.transpose(-1, -2)
            adaptout = self.adaptconv_model(self.conv_model(x))
        # seq [N,2304]
        if self.usegene:
            # gene= self.geneenc(torch.mean(self.geneatt(geneloc,gene),dim=-2))
            gene= self.geneenc(gene)
        else:
            gene= torch.zeros([batchsize,tissuesize,self.geneoutsize]).float().cuda()
            #[N,37,24]
        if not self.usetgeo:
                    geo[:,:,6:]*=0
        if not self.usegeo:
                geo[:, :, :6] *= 0
        if not self.usegenelocexp:
            genelocexp*=0
        # for entry in [adaptout, gene, geo.squeeze(1), genelocexp]:
        #     print(entry.shape)
        adaptout = torch.cat([adaptout, gene, geo.squeeze(1), genelocexp], dim=-1)
        out = self.predicationhead(adaptout)
        return out
device='cpu'
model=ExpressRM().load_from_checkpoint('/home/yiyou/test/model.ckpt',map_location=device)
model.eval()

ExpressRM(
  (conv_model): Sequential(
    (0): Conv1d(4, 64, kernel_size=(7,), stride=(5,))
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01)
    (3): Dropout(p=0.25, inplace=False)
    (4): Conv1d(64, 64, kernel_size=(7,), stride=(1,))
    (5): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): LeakyReLU(negative_slope=0.01)
    (7): Dropout(p=0.25, inplace=False)
    (8): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (9): Conv1d(64, 64, kernel_size=(7,), stride=(1,))
    (10): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): LeakyReLU(negative_slope=0.01)
    (12): Dropout(p=0.25, inplace=False)
  )
  (adaptconv_model): MultiAdaptPooling(
    (model): Sequential(
      (0): Conv1d(64, 64, kernel_size=(7,), stride=(1,))
      (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_runnin

In [11]:
refgeo=np.asarray(pd.read_csv('%s/geo.csv'%(header)))[:,6:]
tgeo=np.asarray(pd.read_csv('%s/tgeo0.csv'%(header)))[:,6:]


In [103]:
geo_list=[]
geo=np.append(refgeo,tgeo,axis=-1)
geo_list.append(geo)
for tissue in np.array(generalexppd.columns):
    tgeo=np.asarray(pd.read_csv('/home/yiyou/tissue/geo/%s.csv'%(tissue)))[:,6:]
    geo=np.append(refgeo,tgeo,axis=-1)
    geo_list.append(geo)
geo=np.array(geo_list).transpose([1,0,2])
geo.shape

(252009, 38, 12)

In [128]:
hostgeneexp=np.asarray(pd.read_csv('%s/lg2hosting_expression.csv'%(header)))[:,:1]
geneexp=np.asarray(pd.read_csv('%s/lg2geneexp.csv'%(header),index_col=0))[:,:1]
# generalhostexp=np.asarray(pd.read_csv('/home/yiyou/lg2hosting_expression.csv',index_col=0))
generalexp=np.asarray(pd.read_csv('/home/yiyou/lg2geneexp.csv',index_col=0))
generalexppd=pd.read_csv('/home/yiyou/lg2geneexp.csv',index_col=0)

In [130]:
geneexp=np.append(geneexp,generalexp,axis=1)

In [131]:

sequence=torch.as_tensor(sequence).float().to(device)
geo=torch.as_tensor(geo).float().to(device)
hostgeneexp=torch.as_tensor(hostgeneexp).float().unsqueeze(2).to(device)
geneexp=torch.as_tensor(geneexp).float().transpose(1,0).to(device)

In [135]:
def extract_entry(i,j):
    return [sequence[i:i+1], geo[i:i+1,j], geneexp[j:j+1], hostgeneexp[i:(i+1),0]]
def extract_group_entry(npi,npj):
    entry_list=[[],[],[],[]]
    for j in npj:
        for i in npi:
            entry=extract_entry(i,j)
            for k in range(4):
                entry_list[k].append(entry[k])
    
    for k in range(4):
        print(k)
        entry_list[k]=torch.cat(entry_list[k])
    return entry_list

In [138]:
shap_explainer = shap.GradientExplainer(model, extract_group_entry(np.arange(50),np.arange(38)))
raw_shap_explanations = shap_explainer.shap_values(extract_entry(0,0))

0
1
2
3
torch.Size([1900, 2001, 4])
torch.Size([1900, 12])
torch.Size([1900, 28278])
torch.Size([1900, 1])
torch.Size([1900, 2304])
torch.Size([1900, 500])
torch.Size([1900, 12])
torch.Size([1900, 1])
torch.Size([50, 2001, 4])
torch.Size([50, 12])
torch.Size([50, 28278])
torch.Size([50, 1])
torch.Size([50, 2304])
torch.Size([50, 500])
torch.Size([50, 12])
torch.Size([50, 1])
torch.Size([50, 2001, 4])
torch.Size([50, 12])
torch.Size([50, 28278])
torch.Size([50, 1])
torch.Size([50, 2304])
torch.Size([50, 500])
torch.Size([50, 12])
torch.Size([50, 1])
torch.Size([50, 2001, 4])
torch.Size([50, 12])
torch.Size([50, 28278])
torch.Size([50, 1])
torch.Size([50, 2304])
torch.Size([50, 500])
torch.Size([50, 12])
torch.Size([50, 1])
torch.Size([50, 2001, 4])
torch.Size([50, 12])
torch.Size([50, 28278])
torch.Size([50, 1])
torch.Size([50, 2304])
torch.Size([50, 500])
torch.Size([50, 12])
torch.Size([50, 1])
torch.Size([50, 2001, 4])
torch.Size([50, 12])
torch.Size([50, 28278])
torch.Size([50, 1])


In [190]:
def quantiles(a):
    print(a.shape)
    print(np.quantile(a,0.001))
    print(np.quantile(a,0.01))
    print(np.quantile(a,0.1))
    print(np.quantile(a,0.25))
    print(np.quantile(a,0.5))
    print(np.quantile(a,0.75))
    print(np.quantile(a,0.9))
    print(np.quantile(a,0.99))
    print(np.quantile(a,0.999))
quantiles(raw_shap_explanations[0][2])

(1, 28278)
-0.45865870984287443
-0.17472765757842137
-0.02175326724908
-0.0034240561220748294
0.0023451372628989792
0.025745516929735333
0.07694365741983478
0.32651892193666743
0.6129571111891621


In [201]:
pd.read_csv('/home/yiyou/lg2geneexp.csv',index_col=0).index[np.argsort(raw_shap_explanations[0][2][0])[-20:]][::-1]

Index(['HYAL2', 'CTSC', 'AKAP13', 'HIVEP2', 'THBS1', 'PPP4R3A', 'FBXL12',
       'MYL12B', 'ARHGAP21', 'ATN1', 'HSPA2', 'RPPH1', 'SOS1', 'MIER3',
       'PCDHGA10', 'ZC3H4', 'DIPK1B', 'IER5L', 'DYRK1A', 'ZMIZ1'],
      dtype='object', name='GeneName')

In [202]:
pd.read_csv('/home/yiyou/lg2geneexp.csv',index_col=0).index[np.argsort(raw_shap_explanations[0][2][0])[:20]]

Index(['JRKL', 'AGPAT2', 'MARCHF2', 'ETS2', 'PARP9', 'DNAJC3-DT', 'KPNA1',
       'MICALL2', 'ASH1L-AS1', 'DPH7', 'RBM12', 'RNA18SN5', 'TASOR2', 'KLF10',
       'THAP5', 'POMGNT2', 'CCDC174', 'PIP5K1C', 'ATOX1', 'HERPUD1'],
      dtype='object', name='GeneName')