This file describes how we performed in-silico perturbation of TFs. There are two types of perturbation:

1. in-silico knockout: zero-out TF expressions
2. in-silico upregulation: for a TF, add 0.9-quantile of its expression values across all samples to its original expression values

In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
from pyfaidx import Fasta
from verstack import stratified_continuous_split
import matplotlib.pyplot as plt
import time

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import mean_squared_error
from sklearn.metrics import f1_score
from sklearn import preprocessing
from sklearn.preprocessing import OneHotEncoder
import scipy
from scipy import stats
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.utils.data import TensorDataset, DataLoader
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset
from torchvision import datasets
from torcheval.metrics.functional import multiclass_f1_score

import h5py
import seaborn as sns
import os
os.chdir('/nfs/public/xixi/scRegulate/SHAREseq')

from typing import Tuple

# Data and model preparation

In [3]:
df_x = pd.read_csv('atac.aggregate_30cells.csv', index_col=0).transpose()
df_x

Unnamed: 0,chr18-31634804-31635104,chr1-165460961-165461261,chr1-75317415-75317715,chr18-46525819-46526119,chr4-3938558-3938858,chr6-71440583-71440883,chr7-105399959-105400259,chr1-161876704-161877004,chr6-113306525-113306825,chr1-177796516-177796816,...,chr10-111689260-111689560,chr10-111388777-111389077,chr10-110841761-110842061,chr10-10884983-10885283,chr10-108335612-108335912,chr10-107887768-107888068,chr10-107038880-107039180,chr10-10549625-10549925,chr10-105270865-105271165,chr10-103291333-103291633
V1,2.112241,1.988862,1.988862,1.664337,1.758608,2.112241,1.758608,1.434083,1.919221,1.919221,...,0.0,0.000000,0.0,0.0,0.554779,0.0,0.0,0.000000,0.0,0.0
V2,2.184422,2.079218,1.957768,1.533112,1.814123,1.889215,1.957768,1.814123,1.731114,1.814123,...,0.0,0.000000,0.0,0.0,0.546106,0.0,0.0,0.000000,0.0,0.0
V3,2.381858,1.865213,1.717455,1.542694,2.298454,1.542694,2.298454,1.993207,1.865213,1.865213,...,0.0,0.000000,0.0,0.0,0.000000,0.0,0.0,0.000000,0.0,0.0
V4,2.296301,1.952485,1.863144,1.952485,1.364729,2.107085,1.763270,1.952485,1.863144,2.174958,...,0.0,0.000000,0.0,0.0,0.000000,0.0,0.0,0.000000,0.0,0.0
V5,2.021073,1.682703,1.763392,1.903019,1.763392,1.836383,1.682703,1.763392,1.490239,1.490239,...,0.0,0.000000,0.0,0.0,0.841352,0.0,0.0,0.000000,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
V839,3.323174,3.171114,3.171114,2.585932,2.322799,2.585932,3.323174,2.585932,2.000750,2.808408,...,0.0,0.000000,0.0,0.0,0.000000,0.0,0.0,1.000375,0.0,0.0
V840,2.806491,2.806491,2.055718,2.656048,2.656048,2.485490,1.770699,2.055718,2.806491,2.288595,...,0.0,0.000000,0.0,0.0,0.000000,0.0,0.0,0.000000,0.0,0.0
V841,2.915099,1.888067,1.626292,2.101953,2.577613,1.888067,2.282790,1.626292,2.439439,2.282790,...,0.0,0.000000,0.0,0.0,0.000000,0.0,0.0,0.813146,0.0,0.0
V842,1.385682,1.421908,1.021733,1.209012,1.346771,1.385682,1.487627,1.259056,1.346771,1.385682,...,0.0,0.363949,0.0,0.0,0.000000,0.0,0.0,0.363949,0.0,0.0


In [4]:
df_y = pd.read_csv('rna.aggregate_30cells.csv', index_col=0).transpose()
df_y

Unnamed: 0,0610009B22Rik,0610009L18Rik,0610010F05Rik,0610010K14Rik,0610012D04Rik,0610012G03Rik,0610025J13Rik,0610030E20Rik,0610038B21Rik,0610039K10Rik,...,mt-Co2,mt-Co3,mt-Cytb,mt-Nd1,mt-Nd2,mt-Nd3,mt-Nd4,mt-Nd4l,mt-Nd5,mt-Nd6
V1,0.0,0.000000,1.179400,0.000000,0.0,0.000000,0.0,0.589700,0.000000,0.0,...,0.0,0.0,2.629728,2.459005,1.869305,0.0,1.369241,0.0,2.667545,1.655498
V2,0.0,0.000000,1.547446,0.000000,0.0,0.000000,0.0,0.000000,0.000000,0.0,...,0.0,0.0,2.191146,1.849178,1.713501,0.0,1.547446,0.0,2.063261,1.333363
V3,0.0,0.581185,1.162371,0.581185,0.0,0.000000,0.0,0.000000,0.581185,0.0,...,0.0,0.0,2.468833,2.591756,2.270628,0.0,1.930656,0.0,2.511841,1.842314
V4,0.0,0.000000,1.265819,0.000000,0.0,0.000000,0.0,1.003138,0.000000,0.0,...,0.0,0.0,2.822417,2.409712,1.776802,0.0,2.102480,0.0,2.688554,2.006276
V5,0.0,0.000000,0.897976,0.000000,0.0,0.566560,0.0,0.566560,0.566560,0.0,...,0.0,0.0,2.526534,2.266239,2.157094,0.0,1.882071,0.0,2.448630,1.699679
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
V839,0.0,0.000000,1.287068,0.000000,0.0,0.000000,0.0,0.000000,0.000000,0.0,...,0.0,0.0,2.039954,3.861204,2.574136,0.0,2.574136,0.0,4.275547,2.039954
V840,0.0,0.000000,0.502798,0.000000,0.0,0.000000,0.0,0.502798,0.000000,0.0,...,0.0,0.0,1.670258,1.670258,1.167460,0.0,1.593831,0.0,2.208447,1.299713
V841,0.0,0.000000,0.000000,0.000000,0.0,0.000000,0.0,0.000000,0.000000,0.0,...,0.0,0.0,2.453196,2.113068,2.113068,0.0,1.674567,0.0,3.787635,0.000000
V842,0.0,0.000000,0.932346,0.000000,0.0,0.000000,0.0,0.571665,0.360681,0.0,...,0.0,0.0,1.558835,1.608430,1.247750,0.0,1.143330,0.0,1.733919,0.837474


In [5]:
df_peaks = pd.DataFrame(df_x.columns)[0].str.split('-',expand=True) 
df_peaks = df_peaks.rename(columns={0: "chrom", 1: "start", 2: "end"})
df_peaks["start"] = pd.to_numeric(df_peaks["start"])
df_peaks["end"] = pd.to_numeric(df_peaks["end"])
df_peaks

Unnamed: 0,chrom,start,end
0,chr18,31634804,31635104
1,chr1,165460961,165461261
2,chr1,75317415,75317715
3,chr18,46525819,46526119
4,chr4,3938558,3938858
...,...,...,...
338699,chr10,107887768,107888068
338700,chr10,107038880,107039180
338701,chr10,10549625,10549925
338702,chr10,105270865,105271165


In [6]:
#df_peaks.to_csv('/data1/xixi/scRegulate/multiomic_data/10x_PBMC/peaks_all.csv')

In [7]:
geneanno = pd.read_csv('../../ref_genome/mm10_geneanno.txt', sep='\t')
geneanno = geneanno.drop_duplicates(subset=['Gene name'])
geneanno

Unnamed: 0,Gene stable ID,Gene stable ID version,Chromosome/scaffold name,Gene start (bp),Gene end (bp),Strand,Gene name,Source of gene name,Transcription start site (TSS)
0,ENSMUSG00000070103,ENSMUSG00000070103.2,1,158505625,158505733,1,Mir488,MGI Symbol,158505625
1,ENSMUSG00000065567,ENSMUSG00000065567.1,1,23291701,23291784,1,Mir30c-2,MGI Symbol,23291701
2,ENSMUSG00000094946,ENSMUSG00000094946.1,1,83795912,83796032,1,Gm25754,MGI Symbol,83795912
3,ENSMUSG00000093155,ENSMUSG00000093155.1,1,74586896,74586972,-1,Gm25035,MGI Symbol,74586972
4,ENSMUSG00000065458,ENSMUSG00000065458.1,1,137966639,137966718,1,Mir181b-1,MGI Symbol,137966639
...,...,...,...,...,...,...,...,...,...
103991,ENSMUSG00000099633,ENSMUSG00000099633.1,Y,84907473,84910700,1,Gm29071,MGI Symbol,84907473
103993,ENSMUSG00000100388,ENSMUSG00000100388.1,Y,50770044,50773283,1,Gm29116,MGI Symbol,50770044
103995,ENSMUSG00000091987,ENSMUSG00000091987.8,Y,2900989,2912206,1,Gm10352,MGI Symbol,2900989
103997,ENSMUSG00000101667,ENSMUSG00000101667.1,Y,2932582,2939416,1,Gm29289,MGI Symbol,2932582


In [8]:
motif_files = os.listdir('../../ref_genome/JASPAR_motifs_pfm_mouse/pfm.np')
motif_files[:5]

['MA0002.2.Runx1.npy',
 'MA0006.1.Ahr::Arnt.npy',
 'MA0004.1.Arnt.npy',
 'MA0007.3.Ar.npy',
 'MA0009.1.Tbxt.npy']

In [9]:
tfs_kept = []
tf_by_region_mat = []
for i in list(motif_files):
    tf = i.split('.')[-2].capitalize()
    if tf in df_y.columns:
        if tf in tfs_kept:
            continue
            
        tfs_kept.append(tf)
tfs_kept = sorted(tfs_kept)

In [10]:
go = pd.read_csv('./go/go_follicle.txt', sep='\t')
go

Unnamed: 0,id1,name,namespace,def,relation,id2
0,GO:0031069,hair follicle morphogenesis,biological_process,The process in which the anatomical structures...,part_of,GO:0001942
1,GO:0048820,hair follicle maturation,biological_process,"A developmental process, independent of morpho...",part_of,GO:0001942
2,GO:0051797,regulation of hair follicle development,biological_process,"Any process that modulates the frequency, rate...",regulates,GO:0001942
3,GO:0051798,positive regulation of hair follicle development,biological_process,Any process that activates or increases the fr...,positively_regulates,GO:0001942
4,GO:0051799,negative regulation of hair follicle development,biological_process,"Any process that stops, prevents, or reduces t...",negatively_regulates,GO:0001942
5,GO:0060789,hair follicle placode formation,biological_process,The developmental process in which a hair plac...,part_of,GO:0001942
6,GO:0042637,catagen,biological_process,The regression phase of the hair cycle during ...,part_of,GO:0048820
7,GO:0042640,anagen,biological_process,"The growth phase of the hair cycle. Lasts, for...",part_of,GO:0048820
8,GO:0051798,positive regulation of hair follicle development,biological_process,Any process that activates or increases the fr...,is_a,GO:0051797
9,GO:0051799,negative regulation of hair follicle development,biological_process,"Any process that stops, prevents, or reduces t...",is_a,GO:0051797


In [11]:
goa = pd.read_csv('./go/goa_follicle.txt', sep='\t')
goa

Unnamed: 0,X0,X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13,X14,X15,X16
0,MGI,MGI:108359,Alx4,involved_in,GO:0001942,MGI:MGI:4834177|GO_REF:0000096,ISO,UniProtKB:Q9H161,P,aristaless-like homeobox 4,Aristaless-like 4,protein_coding_gene,taxon:10090,20100901,MGI,,
1,MGI,MGI:3513977,Apcdd1,involved_in,GO:0001942,MGI:MGI:4834177|GO_REF:0000096,ISO,UniProtKB:Q8J025,P,adenomatosis polyposis coli down-regulated 1,Drapc1|EIG180,protein_coding_gene,taxon:10090,20100423,MGI,,
2,MGI,MGI:88138,Bcl2,involved_in,GO:0031069,MGI:MGI:63353|PMID:8402909,IMP,,P,B cell leukemia/lymphoma 2,Bcl-2|C430015F12Rik|D830018M01Rik,protein_coding_gene,taxon:10090,20070305,UniProt,,
3,MGI,MGI:2443583,Fermt1,involved_in,GO:0051886,MGI:MGI:5571208|PMID:24681597,IMP,,P,fermitin family member 1,5830467P10Rik|Kindlin-1,protein_coding_gene,taxon:10090,20161121,CAFA,,
4,MGI,MGI:102949,Foxn1,involved_in,GO:0051798,MGI:MGI:5437206|PMID:21109991,IMP,,P,forkhead box N1,D11Bhm185e|Hfh11|whn,protein_coding_gene,taxon:10090,20150111,UniProt,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
176,MGI,MGI:99400,Atp7a,acts_upstream_of_or_within,GO:0031069,PMID:2473662,IMP,MGI:MGI:1856466,P,"ATPase, Cu++ transporting, alpha polypeptide",Menkes protein|MNK|br,protein_coding_gene,taxon:10090,20061107,MGI,,
177,MGI,MGI:99560,Hoxc13,acts_upstream_of_or_within,GO:0001942,PMID:21191399,IMP,MGI:MGI:1926422,P,homeobox C13,N,protein_coding_gene,taxon:10090,20121227,MGI,,
178,MGI,MGI:99560,Hoxc13,acts_upstream_of_or_within,GO:0001942,PMID:9420327,IMP,MGI:MGI:1926422,P,homeobox C13,N,protein_coding_gene,taxon:10090,20080322,MGI,results_in_development_of(EMAPA:26747),
179,MGI,MGI:99560,Hoxc13,acts_upstream_of_or_within,GO:0001942,PMID:9420327,IMP,MGI:MGI:1926422,P,homeobox C13,N,protein_coding_gene,taxon:10090,20080322,MGI,results_in_development_of(EMAPA:36498),


In [12]:
files = os.listdir('/nfs/public/xixi/scRegulate/SHAREseq/nn.best.feature6.learnW_go')
markers_filtered = []
for file in files:
    marker = file.split('.')[0]
    if marker not in markers_filtered:
        markers_filtered.append(marker)
len(markers_filtered)

77

In [13]:
genes = goa['X2'].drop_duplicates().to_list()
print(len(genes))
genes_filtered = [i for i in genes if i not in tfs_kept]
print(len(genes_filtered))
genes_filtered = [i for i in genes_filtered if i in markers_filtered]
print(len(genes_filtered))

113
107
77


In [14]:
goa_filtered = goa.loc[goa['X2'].isin(genes_filtered),]
gos = goa_filtered['X4'].drop_duplicates().to_list()
len(gos)

12

In [15]:
go_filtered = go.loc[go['id1'].isin(gos) & go['id2'].isin(gos),]
len(pd.concat([go_filtered['id1'].drop_duplicates(), go_filtered['id2'].drop_duplicates()]).drop_duplicates())

12

In [16]:
gos = goa_filtered['X4'].drop_duplicates().to_list()
mask = torch.zeros(len(genes_filtered), len(gos))
for i in range(len(goa_filtered['X2'])):
    if goa_filtered['X2'].iloc[i] in genes_filtered:
        mask[genes_filtered.index(goa_filtered['X2'].iloc[i]), gos.index(goa_filtered['X4'].iloc[i])] = 1
mask = mask.t()
mask.shape

torch.Size([12, 77])

In [17]:
W_true = np.zeros((len(gos), len(gos)))
for i in range(len(go_filtered)):
    W_true[gos.index(go_filtered['id1'].iloc[i]), gos.index(go_filtered['id2'].iloc[i])]=1
W_true

array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [18]:
h5f = h5py.File('./predict_lineage_pseudotime/data_follicle2.h5', 'r')
X = h5f['X'][:]
expr = h5f['expr'][:]
num_peaks = h5f['num_peaks'][:]
peaks_all = pd.DataFrame(np.array([item.decode('utf-8') for j in h5f['peaks'][:] for item in j]).reshape(-1, 3))
cell_by_tf = h5f['cell_by_tf'][:]
cell_by_peak = h5f['cell_by_peak'][:]
W = h5f['W'][:]
h5f.close()
print(X.shape)
print(expr.shape)
print(num_peaks)

(843, 165, 11058)
(843, 77)
[142 177  60 140 104 156  78 320 256  98 108 205  93 105  73 198 233 209
 158 176 272 187 249 357 167 120 152 119 104 243  39  15 257  32 124   7
 189 308 115  89 138  54  50 157  74 147 145  74 103 136 110 147  50 131
 134 174 173 218 120 146 100  27 119 129 195 153 116  84 140 154 186 132
 205  88 182  71 262]


In [19]:
cut = [0]
s = 0
for i in num_peaks:
    s = s+i
    cut.append(s)
len(cut)

78

In [20]:
y = pd.read_csv('skin.aggregate.cellid&cluster&pseudotime_30cells.csv', index_col=0)
time = np.array(y['aggr_pseudotime'])
y_std = (time - time.min()) / (time.max() - time.min())
y_pseudo = y_std * (1 - (0)) + (0)
y_pseudo = y_pseudo.reshape((len(y), 1))

y = np.array(y['celltype'])
y_new = y.reshape((len(y), 1))
enc = OneHotEncoder(handle_unknown='ignore')
y_oht = enc.fit_transform(y_new).toarray()
for i, category in enumerate(enc.categories_[0]):
    print(f"{category}: {i}")
print(y_oht.shape)
y_oht = np.concatenate([y_oht, y_pseudo], axis=1)
print(y_oht.shape)

0: 0
1: 1
2: 2
3: 3
4: 4
5: 5
(843, 6)
(843, 7)


In [21]:
X_train, X_test, expr_train, expr_test, y_oht_train, y_oht_test, y_train, y_test = train_test_split(X, expr, y_oht, y, test_size=0.3,
                                                                                                    random_state=2023, stratify=y)
X_test, X_val, expr_test, expr_val, y_oht_test, y_oht_val, y_test, y_val = train_test_split(X_test, expr_test, y_oht_test, y_test, 
                                                                                                test_size=0.5, random_state=2023, 
                                                                                            stratify=y_test)
print(X_train.shape)
print(expr_train.shape)
print(y_oht_train.shape)
print(X_val.shape)
print(X_test.shape)

(590, 165, 11058)
(590, 77)
(590, 7)
(127, 165, 11058)
(126, 165, 11058)


In [22]:
indices_train, indices_test, y_label_train, y_label_test= train_test_split(np.arange(X.shape[0]), y, 
                                                                           test_size=0.3, random_state=2023, stratify=y)
indices_test, indices_val, y_label_test, y_label_val= train_test_split(indices_test, y_label_test, 
                                                                           test_size=0.5, random_state=2023, stratify=y_label_test)

In [23]:
X_train = torch.from_numpy(X_train).float()
X_val = torch.from_numpy(X_val).float()
X_test = torch.from_numpy(X_test).float()
expr_train = torch.from_numpy(expr_train).float()
expr_val = torch.from_numpy(expr_val).float()
expr_test = torch.from_numpy(expr_test).float()
y_train = torch.from_numpy(y_oht_train).float()
y_val = torch.from_numpy(y_oht_val).float()
y_test = torch.from_numpy(y_oht_test).float()
print(X_train.shape)
print(expr_train.shape)
print(y_train.shape)

torch.Size([590, 165, 11058])
torch.Size([590, 77])
torch.Size([590, 7])


# Def

In [24]:
from torch_geometric.nn import GATConv, GATv2Conv, GCNConv
from torch_geometric.nn.conv import MessagePassing
from torch.nn import Parameter
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.utils import softmax

In [25]:
import torch
from torch.nn import Parameter, functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
from torch_geometric.utils.num_nodes import maybe_num_nodes
#from torch_scatter import scatter_max, scatter_add
import math

In [26]:
edge_x = []
edge_y = []
for i in range(len(go_filtered)):
    edge_x.append(gos.index(go_filtered['id1'].iloc[i]))
    edge_y.append(gos.index(go_filtered['id2'].iloc[i]))
edge_x = np.array(edge_x)
edge_y = np.array(edge_y)
edge = np.vstack([edge_x, edge_y])
edge = torch.from_numpy(edge)

In [27]:
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
    r"""Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Args:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """
    tensors: Tuple[Tensor, ...]

    def __init__(self, *tensors: Tensor) -> None:
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors"
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].size(0)

In [28]:
class subNet(nn.Module):
    def __init__(self, num_peak, num_tf):
        super(subNet, self).__init__()
        self.num_peak = num_peak
        self.num_tf = num_tf

        self.fc1 = nn.Linear(self.num_peak * self.num_tf, 1)
        self.fc1_activate = nn.LeakyReLU()
        self.abs = nn.ReLU()
        
    def forward(self, x):
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        return x

In [29]:
class selfGATConv(MessagePassing):
    def __init__(self, in_channels, out_channels, heads=2, concat=False,
                 negative_slope=0.2, dropout=0, bias=True, **kwargs):
        super(selfGATConv, self).__init__(aggr='mean', **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.node_dim = 0

        self.weight = Parameter(
            torch.Tensor(in_channels, heads * out_channels))
        self.att = Parameter(torch.Tensor(1, 1, heads, 2 * out_channels))

        if bias and concat:
            self.bias = Parameter(torch.Tensor(heads * out_channels))
        elif bias and not concat:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        glorot(self.att)
        zeros(self.bias)

    def forward(self, x, edge_index, size=None):
        """"""
        # if size is None and torch.is_tensor(x):
        #     edge_index, _ = remove_self_loops(edge_index)
        #     edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        if torch.is_tensor(x):
            x = torch.matmul(x, self.weight)
        else:
            x = (None if x[0] is None else torch.matmul(x[0], self.weight),
                 None if x[1] is None else torch.matmul(x[1], self.weight))

        return self.propagate(edge_index, size=size, x=x)

    def message(self, edge_index_i, ptr, x_i, x_j, size_i):
        # Compute attention coefficients.
        x_j = x_j.view(edge_index_i.shape[0], -1, self.heads, self.out_channels)
        if x_i is None:
            alpha = (x_j * self.att[:, :, :, self.out_channels:]).sum(dim=-1)
        else:
            x_i = x_i.view(edge_index_i.shape[0], -1, self.heads, self.out_channels)
            alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
            #print((torch.cat([x_i, x_j], dim=-1) * self.att).shape)

        alpha = F.leaky_relu(alpha, self.negative_slope)
        alpha = softmax(alpha, edge_index_i, ptr, size_i)
        self.alpha = alpha

        # Sample attention coefficients stochastically.
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        # print(alpha.shape)

        return x_j * alpha.view(edge_index_i.shape[0], -1, self.heads, 1)

    def update(self, aggr_out):
        if self.concat is True:
            aggr_out = aggr_out.view(-1, self.heads * self.out_channels)
        else:
            aggr_out = aggr_out.mean(dim=2)

        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        return aggr_out

    def __repr__(self):
        return '{}({}, {}, heads={})'.format(self.__class__.__name__,
                                             self.in_channels,
                                             self.out_channels, self.heads)


In [30]:
class Net(nn.Module):
    def __init__(self, genes, num_peaks, num_tf, cut, mask, num_gos):
        super(Net, self).__init__()
        self.num_peaks = num_peaks
        self.num_tf = num_tf
        self.genes = genes
        self.cut = cut
        self.mask = mask
        self.num_gos = num_gos
        self.go_dim=6
               
        self.subnet_modules = nn.ModuleList()
        for i in range(len(self.genes)):
            gene = self.genes[i]
            num_peak = self.num_peaks[i]
            self.subnet = subNet(num_peak, self.num_tf)
            self.subnet_modules.append(self.subnet)
                    
        self.cat_activate = nn.LeakyReLU()
        self.fc1_weight = nn.Parameter(torch.Tensor(self.num_gos*self.go_dim, len(self.genes)), requires_grad=True)
        self.mask_rep = nn.Parameter(torch.repeat_interleave(mask, self.go_dim, dim=0), requires_grad=False)
        self.fc1_bias = nn.Parameter(torch.Tensor(self.num_gos*self.go_dim), requires_grad=True)
        self.fc1_activate = nn.LeakyReLU()
        #self.fc1_drop = nn.Dropout()
        self.gat = selfGATConv(self.go_dim, 4, heads=2)
        self.gat_activate = nn.LeakyReLU()
        self.gat_drop = nn.Dropout()
        self.out = nn.Linear(self.num_gos*4, 7)
        self.reset_parameters()
        
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.fc1_weight, a=math.sqrt(5))
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.fc1_weight)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.fc1_bias, -bound, bound)
        
    def forward(self, x):
        x_cat = torch.zeros(x.shape[0], 0).to(device)
        for i in range(len(self.subnet_modules)):
            x_sub = x[:, :, self.cut[i]:self.cut[i+1]]
            x_sub = self.subnet_modules[i](x_sub)
            x_cat = torch.cat((x_cat, x_sub), dim=1)

        x = self.cat_activate(x_cat)
        x = x.matmul((self.fc1_weight*self.mask_rep).t()) + self.fc1_bias
        x = self.fc1_activate(x)
        #x = self.fc1_drop(x)
        x = x.reshape(x.shape[0], self.go_dim, -1)
        x = x.permute(2, 0, 1)
        x = self.gat(x, edge)
        x = x.permute(1, 2, 0)
        x = x.reshape(x.shape[0], -1) 
        x = self.gat_activate(x)
        x = self.gat_drop(x)
        out = self.out(x)
        return x_cat, x, out

In [31]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    #train_loss = 0
    for batch_idx, (data, expr, target) in enumerate(train_loader):
        data, expr, target = data.to(device), expr.to(device), target.to(device)
        optimizer.zero_grad()
        expr_hat, cluster_repr, output = model(data)
        loss_out = out_criterion(output[:, :6], target[:, :6])
        loss_pseudo = expr_criterion(output[:, -1], target[:, -1])
        loss_expr = expr_criterion(expr_hat, expr)
        loss = 0.05*loss_out + 1*loss_pseudo
        #loss = loss_pseudo
        loss.backward()
        optimizer.step()
        #model.fc1.weight.data = model.fc1.weight.mul(torch.repeat_interleave(mask.to(device), 4, dim=0))
        if batch_idx % batchsize == 0:
            print('\nTrain Epoch: {} [{}/{} ({:.0f}%)], Total loss: {:.6f}, Expr loss: {:.6f}, Cluster loss: {:.6f}, Pseudo loss: {:.6f}'.
                  format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item(), loss_expr.item(), loss_out.item(), loss_pseudo.item()))
        #return(train_loss)

                  
def test(model, device, test_loader, num_clusters):
    model.eval()
    with torch.no_grad():
        outputs = torch.zeros(0, num_clusters).to(device)
        targets = torch.zeros(0).to(device)
        outputs_pseudo = torch.zeros(0).to(device)
        targets_pseudo = torch.zeros(0).to(device)
        test_loss = 0
        for data, expr, target in test_loader:
            data, expr, target = data.to(device), expr.to(device), target.to(device)
            expr_hat, cluster_repr, output = model(data)
            
            loss_out = out_criterion(output[:, :6], target[:, :6])
            loss_pseudo = expr_criterion(output[:, -1], target[:, -1])
            loss_expr = expr_criterion(expr_hat, expr)
            loss = 0.05*loss_out + 1*loss_pseudo
            #loss = loss_pseudo
            test_loss = test_loss+loss.item()

            target_cluster = target[:, :6].argmax(dim=1)
            output_cluster = output[:, :6].softmax(dim=1)
            outputs = torch.cat((outputs, output_cluster), dim=0)
            targets = torch.cat((targets, target_cluster), dim=0)
            target_pseudo = target[:, -1]
            output_pseudo = output[:, -1]
            outputs_pseudo = torch.cat((outputs_pseudo, output_pseudo), dim=0)
            targets_pseudo = torch.cat((targets_pseudo, target_pseudo), dim=0)
        f1_score = multiclass_f1_score(outputs, targets, num_classes=num_clusters)
        pearsonr, _ = stats.pearsonr(targets_pseudo.detach().cpu().numpy(), outputs_pseudo.detach().cpu().numpy())

    return(f1_score, pearsonr, test_loss)

In [32]:
def correlation_score(y_true, y_pred):
    #print(np.corrcoef(y_true, y_pred))
    return np.corrcoef(y_true, y_pred)[1, 0]

def spearman_correlation(y_true, y_pred):
    statistic, pvalue = stats.spearmanr(y_true, y_pred)
    return abs(statistic)

def pearson_correlation(y_true, y_pred):
    statistic, pvalue = stats.pearsonr(y_true, y_pred)
    return abs(statistic[0])

# in-silico perturbation

In [33]:
import math
from captum.attr import visualization as viz
from captum.attr import Lime, LimeBase, DeepLift, IntegratedGradients, GradientShap, NoiseTunnel, FeatureAblation, KernelShap
from captum._utils.models.linear_model import SkLearnLinearRegression, SkLearnLasso

In [34]:
use_cuda = True
device = torch.device("cuda:3" if use_cuda else "cpu")

In [35]:
edge_x = []
edge_y = []
for i in range(len(go_filtered)):
    edge_x.append(gos.index(go_filtered['id1'].iloc[i]))
    edge_y.append(gos.index(go_filtered['id2'].iloc[i]))
edge_x = np.array(edge_x)
edge_y = np.array(edge_y)
edge = np.vstack([edge_x, edge_y])
edge = torch.from_numpy(edge).to(device)

In [36]:
X_tensor = torch.from_numpy(X).float().to(device)#.detach()

## TF

### in-silico upregulation

In [102]:
repeats = 10
num_clusters = 6
results = np.zeros([repeats, X_tensor.shape[0], X_tensor.shape[1], num_clusters])
results_orig = np.zeros([repeats, X_tensor.shape[0], X_tensor.shape[1], num_clusters])
results_perturb = np.zeros([repeats, X_tensor.shape[0], X_tensor.shape[1], num_clusters])
for i in range(repeats):
    model = Net(genes=genes_filtered, num_peaks=num_peaks, num_tf=X.shape[1], cut=cut, mask=mask, num_gos=W_true.shape[0]).to(device)
    model.eval()
    model.load_state_dict(torch.load('./predict_lineage_pseudotime/model.mask_GAT.leaky.lr001.loss_0.05_1.'+str(i)+'.pt', 
                                    map_location=device))
    np.random.seed(123)  

    _, _, orig = model(X_tensor)
    orig = F.softmax(orig[:, :6])
    for j in range(X_tensor.shape[1]):
        X_new = torch.clone(X_tensor).detach()
        
        tf_mat = cell_by_tf[:, j] + np.quantile(cell_by_tf[:, j], 0.9)
        tf_mat = np.repeat(tf_mat[:, np.newaxis], X_tensor.shape[2], axis=1)
        peak_mat = cell_by_peak
        w = W[j, :]
        w = np.repeat(w[np.newaxis, :], X_tensor.shape[0], axis=0)
        new = tf_mat*peak_mat*w
        
        X_new[:, j, :] = torch.from_numpy(new).float().to(device)
        _, _, perturb = model(X_new)
        perturb = F.softmax(perturb[:, :6])
        results[i, :, j, :] = (perturb-orig).detach().cpu().numpy()
        results_orig[i, :, j, :] = orig.clone().detach().cpu().numpy()
        results_perturb[i, :, j, :] = perturb.detach().cpu().numpy()

In [67]:
np.save('./predict_lineage_pseudotime/interpret_GAT/input_out_TF_0.9quantile.npy', results)
np.save('./predict_lineage_pseudotime/interpret_GAT/input_out_TF_0.9quantile_orig.npy', results_orig)
np.save('./predict_lineage_pseudotime/interpret_GAT/input_out_TF_0.9quantile_perturb.npy', results_perturb)

### in-silico knockout

In [37]:
repeats = 10
num_clusters = 6
results = np.zeros([repeats, X_tensor.shape[0], X_tensor.shape[1], num_clusters])
results_orig = np.zeros([repeats, X_tensor.shape[0], X_tensor.shape[1], num_clusters])
results_perturb = np.zeros([repeats, X_tensor.shape[0], X_tensor.shape[1], num_clusters])
for i in range(repeats):
    model = Net(genes=genes_filtered, num_peaks=num_peaks, num_tf=X.shape[1], cut=cut, mask=mask, num_gos=W_true.shape[0]).to(device)
    model.eval()
    model.load_state_dict(torch.load('./predict_lineage_pseudotime/model.mask_GAT.leaky.lr001.loss_0.05_1.'+str(i)+'.pt', 
                                    map_location=device))
    np.random.seed(123)  

    _, _, orig = model(X_tensor)
    orig = F.softmax(orig[:, :6])
    for j in range(X_tensor.shape[1]):
        X_new = torch.clone(X_tensor).detach()
        
        X_new[:, j, :] = 0
        _, _, perturb = model(X_new)
        perturb = F.softmax(perturb[:, :6])
        results[i, :, j, :] = (perturb-orig).detach().cpu().numpy()
        results_orig[i, :, j, :] = orig.clone().detach().cpu().numpy()
        results_perturb[i, :, j, :] = perturb.detach().cpu().numpy()

In [38]:
np.save('./predict_lineage_pseudotime/interpret_GAT/input_out_TF*0.npy', results)
np.save('./predict_lineage_pseudotime/interpret_GAT/input_out_TF*0_orig.npy', results_orig)
np.save('./predict_lineage_pseudotime/interpret_GAT/input_out_TF*0_perturb.npy', results_perturb)