This file describes how we performed in-silico perturbation of TFs and cCREs. There are two types of perturbation for TFs and cCREs each:

1. in-silico knockout (closing): zero-out TF expressions (chromatin accessibilities of cCREs)
2. in-silico upregulation (opening): for a TF, add 0.9-quantile of its expression values across all samples to its original expression values (for a cCRE, set the scaled accessibility scores in all samples to one)

In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
from pyfaidx import Fasta
from verstack import stratified_continuous_split
import matplotlib.pyplot as plt
import time
import math

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import mean_squared_error
from sklearn.metrics import f1_score
from sklearn import preprocessing
from sklearn.preprocessing import OneHotEncoder
import scipy
from scipy import stats
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.utils.data import TensorDataset, DataLoader
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset
from torchvision import datasets
from torcheval.metrics.functional import multiclass_f1_score
from torch_geometric.nn import GCNConv

import h5py
import seaborn as sns
import os
os.chdir('/nfs/public/xixi/scRegulate/T2D')

from typing import Tuple
import scanpy as sc
sc.settings.set_figure_params(dpi=80, facecolor='white')
from collections import Counter

# Data preparation

In [3]:
df_x = pd.read_csv('./data/beta.atac.aggregate_30cells.csv', index_col=0).transpose()
df_x

Unnamed: 0,10-100001665-100002165,10-100002531-100003031,10-100003836-100004336,10-100005433-100005933,10-100006548-100007048,10-100013839-100014339,10-100018334-100018834,10-100019564-100020064,10-100020408-100020908,10-100021589-100022089,...,Y-8756349-8756849,Y-8861982-8862482,Y-8869228-8869728,Y-8894760-8895260,Y-8895712-8896212,Y-8902164-8902664,Y-8903250-8903750,Y-9490312-9490812,Y-9647600-9648100,Y-9649551-9650051
V1,0.000000,0.0,0.00000,0.468211,0.468211,0.936421,0.000000,0.468211,0.468211,0.468211,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.0
V2,0.000000,0.0,0.00000,0.000000,0.468040,0.468040,0.000000,0.000000,0.468040,0.000000,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.0
V3,0.000000,0.0,0.00000,0.000000,0.000000,0.000000,0.000000,0.627410,0.000000,0.000000,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.0
V4,0.565643,0.0,0.00000,0.000000,0.565643,0.000000,1.131286,0.896523,0.000000,0.000000,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.0
V5,0.000000,0.0,0.00000,0.000000,0.000000,0.484623,0.000000,0.484623,0.000000,0.484623,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
V1390,0.000000,0.0,0.00000,0.658323,0.000000,1.043417,0.000000,0.658323,0.658323,0.658323,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.0
V1391,0.000000,0.0,0.00000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.0
V1392,0.000000,0.0,0.00000,0.000000,0.000000,0.910113,0.000000,0.000000,0.000000,0.000000,...,0.0,0.0,0.910113,0.0,0.0,0.0,0.0,0.0,0.0,0.0
V1393,0.630010,0.0,0.63001,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [4]:
df_y = pd.read_csv('./data/beta.rna.aggregate_30cells.csv', index_col=0).transpose()
df_y

Unnamed: 0,WASH7P,RP11-34P13.7,RP11-34P13.10,AL627309.1,RP11-34P13.13,RP11-34P13.9,AP006222.2,RP11-206L10.17,RP5-857K21.2,RP5-857K21.9,...,CH507-513H4.6,CH17-408M7.1,CH507-39O4.1,CH507-39O4.2,CH507-24F1.1,CH507-338C24.1,CH507-254M2.3,CH507-154B10.2,CH507-145C22.3,CH17-351M24.1
V1,0.0,0.515306,0.0,0.515306,0.515306,0.0,0.515306,1.446647,0.0,0.0,...,0.0,0.000000,0.0,0.000000,0.000000,0.000000,0.515306,0.816741,0.0,0.000000
V2,0.0,0.494710,0.0,0.494710,0.784097,0.0,0.494710,0.784097,0.0,0.0,...,0.0,0.000000,0.0,0.494710,0.000000,1.278808,0.784097,1.148682,0.0,0.494710
V3,0.0,0.000000,0.0,0.000000,0.629183,0.0,0.000000,0.997232,0.0,0.0,...,0.0,0.000000,0.0,0.629183,0.000000,1.460918,0.629183,0.629183,0.0,0.000000
V4,0.0,0.610912,0.0,0.968272,0.610912,0.0,0.000000,0.968272,0.0,0.0,...,0.0,0.000000,0.0,0.000000,0.000000,0.610912,1.418493,0.000000,0.0,0.000000
V5,0.0,0.499200,0.0,0.000000,0.000000,0.0,0.499200,1.159107,0.0,0.0,...,0.0,0.000000,0.0,0.000000,0.000000,1.159107,0.499200,0.000000,0.0,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
V1390,0.0,0.000000,0.0,0.000000,0.666096,0.0,1.055737,1.869967,0.0,0.0,...,0.0,0.000000,0.0,0.000000,0.000000,1.546626,0.666096,1.055737,0.0,0.000000
V1391,0.0,1.166952,0.0,0.736265,0.000000,0.0,1.472530,2.333904,0.0,0.0,...,0.0,0.000000,0.0,0.000000,0.000000,1.166952,0.736265,0.000000,0.0,0.000000
V1392,0.0,0.000000,0.0,0.788265,0.000000,0.0,1.830293,1.249370,0.0,0.0,...,0.0,0.788265,0.0,0.000000,0.000000,0.788265,1.576529,0.788265,0.0,0.000000
V1393,0.0,0.000000,0.0,0.505620,0.505620,0.0,1.011240,0.505620,0.0,0.0,...,0.0,0.000000,0.0,0.000000,0.000000,1.011240,0.505620,0.000000,0.0,0.000000


In [5]:
df_peaks = pd.DataFrame(df_x.columns)[0].str.split('-',expand=True) 
df_peaks = df_peaks.rename(columns={0: "chrom", 1: "start", 2: "end"})
df_peaks["start"] = pd.to_numeric(df_peaks["start"])
df_peaks["end"] = pd.to_numeric(df_peaks["end"])
df_peaks

Unnamed: 0,chrom,start,end
0,10,100001665,100002165
1,10,100002531,100003031
2,10,100003836,100004336
3,10,100005433,100005933
4,10,100006548,100007048
...,...,...,...
412092,Y,8902164,8902664
412093,Y,8903250,8903750
412094,Y,9490312,9490812
412095,Y,9647600,9648100


In [7]:
geneanno = pd.read_csv('../../ref_genome/hg19_geneanno.txt', sep='\t')
geneanno = geneanno.drop_duplicates(subset=['Gene name'])
geneanno

Unnamed: 0,Gene stable ID,Transcript stable ID,Chromosome/scaffold name,Gene start (bp),Gene end (bp),Strand,Gene name,Transcription start site (TSS)
0,ENSG00000271782,ENST00000607815,1,50902700,50902978,-1,RP5-850O15.4,50902978
1,ENSG00000232753,ENST00000424955,1,103817769,103828355,1,RP11-347K2.1,103817769
2,ENSG00000225767,ENST00000424664,1,50927141,50936822,1,RP5-850O15.3,50927141
3,ENSG00000202140,ENST00000365270,1,50965430,50965529,-1,Y_RNA,50965529
4,ENSG00000207194,ENST00000384465,1,51048076,51048183,1,RNU6-1026P,51048076
...,...,...,...,...,...,...,...,...
179049,ENSG00000229926,ENST00000442569,9,141031547,141038316,1,RP11-424E7.3,141031547
179050,ENSG00000159247,ENST00000503395,9,141044565,141071821,1,TUBBP5,141044565
179053,ENSG00000237419,ENST00000428088,9,141090383,141093775,-1,RP11-885N19.6,141093775
179054,ENSG00000233013,ENST00000446912,9,141106637,141143444,1,FAM157B,141107518


In [8]:
motif_files = os.listdir('../../ref_genome/JASPAR_motifs_pfm_homosapiens/pfm.np')
motif_files[:5]

['MA0002.1.RUNX1.npy',
 'MA0003.4.TFAP2A.npy',
 'MA0007.2.AR.npy',
 'MA0014.3.PAX5.npy',
 'MA0009.2.TBXT.npy']

In [9]:
tfs_kept = []
tf_by_region_mat = []
for i in list(motif_files):
    tf = i.split('.')[-2]#.capitalize()
    if tf in df_y.columns:
        if tf in tfs_kept:
            continue
            
        tfs_kept.append(tf)
tfs_kept = sorted(tfs_kept)

In [45]:
h5f = h5py.File('./predict_status/data_T2D_float16.h5', 'r')
X = h5f['X'][:]
expr = h5f['expr'][:]
num_peaks = h5f['num_peaks'][:]
peaks = pd.DataFrame(np.array([item.decode('utf-8') for j in h5f['peaks'][:] for item in j]).reshape(-1, 3))
cell_by_tf = h5f['cell_by_tf'][:]
cell_by_peak = h5f['cell_by_peak'][:]
W = h5f['W'][:]
h5f.close()
#print(X.shape)
print(expr.shape)
print(num_peaks)

(1394, 239)
[183 114  83 139  90 151 104  67 188 119 163 154  58 144 132 110 128 120
 133 156  82 128 156 106  99 138  99 115 182  65  99 116 134 142  58 134
 166 111  93 110 103 146 152 151 135  99 136 131 104 134 152 150 141  45
 178  75 103  77 115 190  98 107 132 156 100  57 167  63  85  86 145 175
 114 168 183  65 112 166  92 164 113  99  69  94 107 172 165 135  51  43
 164  41 118 106  41 131  99 118 178 120 100  54  27  18 131 101 166  89
  92  30  96  66 107 116 144  98  89 161 144  94 116 167 116 193 164  21
  24 151 127  81 118  80  24 160 162  80 133  78 150 136  80 180 130  83
 171  75  89  88 112  52  46  75  74 182  58 131 109  72  78 165 107  39
 173  77  22 103  99  78 145  48  88  46  42 162  58 133 106  90 133  75
  83 137 145 158  68  67 106  35  44 115 105  95  72 114  60  40  86  48
  74  49  47 104 106 145 117  85 100  59  30  45 146 122 102 168  69  68
  85  18  23  96  73  81 160  34  97  58 135 107  76  41  39  85 134 115
  75 117  67 138  67]


In [46]:
cut = [0]
s = 0
for i in num_peaks:
    s = s+i
    cut.append(s)
len(cut)

240

In [10]:
y = pd.read_csv('./data/beta.label.aggregate_30cells.csv', index_col=0)

y = y['status'].to_list()
y_new = y
values_to_replace = ['Non-diabetic', 'Pre-T2D', 'T2D']
for value in values_to_replace:
    y_new = np.where(y == value, 0, y_new)
y_new = y_new.reshape((len(y), 1))
enc = OneHotEncoder(handle_unknown='ignore')
y_oht = enc.fit_transform(y_new).toarray()
for i, category in enumerate(enc.categories_[0]):
    print(f"{category}: {i}")
print(y_oht)
print(y_oht.shape)

Non-diabetic: 0
Pre-T2D: 1
T2D: 2
[[0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 ...
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]]
(1394, 3)


In [12]:
X_train, X_test, expr_train, expr_test, y_oht_train, y_oht_test, y_label_train, y_label_test = train_test_split(X, expr, y_oht, y, test_size=0.3,
                                                                                                    random_state=2024, stratify=y)
X_test, X_val, expr_test, expr_val, y_oht_test, y_oht_val, y_label_test, y_label_val = train_test_split(X_test, expr_test, y_oht_test, y_label_test, 
                                                                                                test_size=0.5, random_state=2024, 
                                                                                            stratify=y_label_test)
# expr_train, expr_test, y_oht_train, y_oht_test, y_label_train, y_label_test = train_test_split(expr, y_oht, y, test_size=0.3,
#                                                                                                     random_state=2024, stratify=y)
# expr_test, expr_val, y_oht_test, y_oht_val, y_label_test, y_label_val = train_test_split(expr_test, y_oht_test, y_label_test, 
#                                                                                                 test_size=0.5, random_state=2024, 
#                                                                                             stratify=y_label_test)
#print(X_train.shape)
print(expr_train.shape)
print(y_oht_train.shape)
#print(X_val.shape)
#print(X_test.shape)

(975, 239)
(975, 3)


In [13]:
indices_train, indices_test, y_label_train, y_label_test= train_test_split(np.arange(X.shape[0]), y, 
                                                                           test_size=0.3, random_state=2024, stratify=y)
indices_test, indices_val, y_label_test, y_label_val= train_test_split(indices_test, y_label_test, 
                                                                           test_size=0.5, random_state=2024, stratify=y_label_test)

# Def

In [14]:
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
    r"""Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Args:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """
    tensors: Tuple[Tensor, ...]

    def __init__(self, *tensors: Tensor) -> None:
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors"
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].size(0)

In [15]:
class subNet(nn.Module):
    def __init__(self, num_peak, num_tf):
        super(subNet, self).__init__()
        self.num_peak = num_peak
        self.num_tf = num_tf

        self.fc1 = nn.Linear(self.num_peak * self.num_tf, 1)
        self.fc1_activate = nn.ReLU()
        self.abs = nn.ReLU()
        
    def forward(self, x):
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        return x

In [16]:
class Net(nn.Module):
    def __init__(self, num_genes, num_peaks, num_tf, cut):
        super(Net, self).__init__()
        self.num_peaks = num_peaks
        self.num_tf = num_tf
        self.num_genes = num_genes
        self.cut = cut
        self.gene_dim = 2
               
        self.subnet_modules = nn.ModuleList()
        for i in range(num_genes):
            num_peak = self.num_peaks[i]
            self.subnet = subNet(num_peak, self.num_tf)
            self.subnet_modules.append(self.subnet)
                    
        self.cat_activate = nn.ReLU()
        self.conv = GCNConv(1, self.gene_dim, add_self_loops=False)
        #self.fc1 = nn.Linear(self.num_genes, 100)
        self.conv_activate = nn.ReLU()
        self.out = nn.Linear(self.num_genes*self.gene_dim, 3)
        
        #self.initialize_parameters()
        
    def initialize_parameters(self):
        weight = self.conv.lin.weight
        bias = self.conv.bias
        nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight)
        #torch.nn.init.xavier_uniform_(weight)
        if bias is not None:
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(bias, -bound, bound)

    def forward(self, x):
        x_cat = torch.zeros(x.shape[0], 0).to(device)
        for i in range(len(self.subnet_modules)):
            x_sub = x[:, :, self.cut[i]:self.cut[i+1]]
            x_sub = self.subnet_modules[i](x_sub)
            x_cat = torch.cat((x_cat, x_sub), dim=1)

        x_cat = self.cat_activate(x_cat)
        x = torch.unsqueeze(x_cat, 2)
        #x = F.dropout(x, p=0.3)
        x = self.conv_activate(self.conv(x, edge))
        #x = F.dropout(x, p=0.3)
        x = x.reshape(x.shape[0], -1)
        out = self.out(x)
        return x_cat, x, out

In [17]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    #train_loss = 0
    for batch_idx, (data, expr, target) in enumerate(train_loader):
        data, expr, target = data.float(), expr.float(), target.float()
        optimizer.zero_grad()
        expr_hat, cluster_repr, output = model(data)
        loss_out = out_criterion(output, target)
        #loss_pseudo = expr_criterion(output[:, -1], target[:, -1])
        loss_expr = expr_criterion(expr_hat, expr)
        loss = loss_out #+ 0.7*loss_pseudo
        #loss = loss_pseudo
        loss.backward()
        optimizer.step()
        #model.fc1.weight.data = model.fc1.weight.mul(torch.repeat_interleave(mask.to(device), 4, dim=0))
#         if batch_idx % batchsize == 0:
#             print('\nTrain Epoch: {} [{}/{} ({:.0f}%)], Expr loss: {:.6f}, Cluster loss: {:.6f}'.
#                   format(
#                 epoch, batch_idx * len(data), len(train_loader.dataset),
#                 100. * batch_idx / len(train_loader), loss_expr.item(), loss_out.item()))
        #return(train_loss)

                  
def test(model, device, test_loader, num_clusters):
    model.eval()
    with torch.no_grad():
        outputs = torch.zeros(0, num_clusters).to(device)
        targets = torch.zeros(0).to(device)
        outputs_pseudo = torch.zeros(0).to(device)
        targets_pseudo = torch.zeros(0).to(device)
        test_loss = 0
        for data, expr, target in test_loader:
            data, expr, target = data.float(), expr.float(), target.float()
            expr_hat, cluster_repr, output = model(data)
            
            loss_out = out_criterion(output, target)
            #loss_pseudo = expr_criterion(output[:, -1], target[:, -1])
            loss_expr = expr_criterion(expr_hat, expr)
            loss = loss_out
            #loss = loss_pseudo
            test_loss = test_loss+loss.item()

            target_cluster = target.argmax(dim=1)
            output_cluster = output.softmax(dim=1)
            outputs = torch.cat((outputs, output_cluster), dim=0)
            targets = torch.cat((targets, target_cluster), dim=0)
            #target_pseudo = target[:, -1]
            #output_pseudo = output[:, -1]
            #outputs_pseudo = torch.cat((outputs_pseudo, output_pseudo), dim=0)
            #targets_pseudo = torch.cat((targets_pseudo, target_pseudo), dim=0)
        f1_score = multiclass_f1_score(outputs, targets, num_classes=num_clusters)
        #pearsonr, _ = stats.pearsonr(targets_pseudo.detach().cpu().numpy(), outputs_pseudo.detach().cpu().numpy())

    return(f1_score, test_loss)

In [18]:
def correlation_score(y_true, y_pred):
    #print(np.corrcoef(y_true, y_pred))
    return np.corrcoef(y_true, y_pred)[1, 0]

def spearman_correlation(y_true, y_pred):
    statistic, pvalue = stats.spearmanr(y_true, y_pred)
    return abs(statistic)

def pearson_correlation(y_true, y_pred):
    statistic, pvalue = stats.pearsonr(y_true, y_pred)
    return abs(statistic[0])

# in-silico perturbation

In [19]:
import math
from captum.attr import visualization as viz
from captum.attr import Lime, LimeBase, DeepLift, IntegratedGradients, GradientShap, NoiseTunnel, FeatureAblation, KernelShap
from captum._utils.models.linear_model import SkLearnLinearRegression, SkLearnLasso

## TF

In [157]:
use_cuda = False
device = torch.device("cuda:3" if use_cuda else "cpu")

# X_train = torch.from_numpy(X_train).to(device)
# X_val = torch.from_numpy(X_val).to(device)
X_test = torch.from_numpy(X_test).float().to(device)
# expr_train = torch.from_numpy(expr_train).to(device)
# expr_val = torch.from_numpy(expr_val).to(device)
expr_test = torch.from_numpy(expr_test).to(device)
# y_train = torch.from_numpy(y_oht_train).to(device)
# y_val = torch.from_numpy(y_oht_val).to(device)
y_test = torch.from_numpy(y_oht_test).to(device)

In [158]:
edge = pd.read_csv('./predict_status/string_filtered.txt', sep='\t')
edge = torch.cat((torch.from_numpy(np.array(edge)).t(), torch.from_numpy(np.array(edge)).t().flip([0])), dim=1).to(device)

In [202]:
X_test.shape

torch.Size([209, 457, 25198])

### in-silico upregulation

In [22]:
repeats = 10
num_clusters = 3
results = np.zeros([repeats, X_test.shape[0], X_test.shape[1], num_clusters])
results_orig = np.zeros([repeats, X_test.shape[0], X_test.shape[1], num_clusters])
results_perturb = np.zeros([repeats, X_test.shape[0], X_test.shape[1], num_clusters])
for i in range(repeats):
    model = Net(num_genes=expr_test.shape[1], num_peaks=num_peaks, num_tf=X_test.shape[1], cut=cut).to(device)
    model.eval()
    model.load_state_dict(torch.load('./predict_status/model.GCN.string.'+str(i)+'.pt', 
                                    map_location=device))
    np.random.seed(123)  

    _, _, orig = model(X_test)
    orig = F.softmax(orig)
    for j in range(X_test.shape[1]):
        X_new = torch.clone(X_test).detach()
        
        tf_mat = cell_by_tf[indices_test, j] + np.quantile(cell_by_tf[:, j], 0.9)
        tf_mat = np.repeat(tf_mat[:, np.newaxis], X_test.shape[2], axis=1)
        peak_mat = cell_by_peak[indices_test, :]
        w = W[j, :]
        w = np.repeat(w[np.newaxis, :], X_test.shape[0], axis=0)
        new = tf_mat*peak_mat*w
        
        X_new[:, j, :] = torch.from_numpy(new).float().to(device)
        _, _, perturb = model(X_new)
        perturb = F.softmax(perturb)
        results[i, :, j, :] = (perturb-orig).detach().cpu().numpy()
        results_orig[i, :, j, :] = orig.clone().detach().cpu().numpy()
        results_perturb[i, :, j, :] = perturb.detach().cpu().numpy()

In [23]:
np.save('./predict_status/interpret/input_output_TF_0.9quantile.GCN.string.npy', results)
np.save('./predict_status/interpret/input_output_TF_0.9quantile_orig.GCN.string.npy', results_orig)
np.save('./predict_status/interpret/input_output_TF_0.9quantile_perturb.GCN.string.npy', results_perturb)

### in-silico knockout

In [24]:
repeats = 10
num_clusters = 3
results = np.zeros([repeats, X_test.shape[0], X_test.shape[1], num_clusters])
results_orig = np.zeros([repeats, X_test.shape[0], X_test.shape[1], num_clusters])
results_perturb = np.zeros([repeats, X_test.shape[0], X_test.shape[1], num_clusters])
for i in range(repeats):
    model = Net(num_genes=expr_test.shape[1], num_peaks=num_peaks, num_tf=X_test.shape[1], cut=cut).to(device)
    model.eval()
    model.load_state_dict(torch.load('./predict_status/model.GCN.string.'+str(i)+'.pt', 
                                    map_location=device))
    np.random.seed(123)  

    _, _, orig = model(X_test)
    orig = F.softmax(orig)
    for j in range(X_test.shape[1]):
        X_new = torch.clone(X_test).detach()
                
        X_new[:, j, :] = 0
        _, _, perturb = model(X_new)
        perturb = F.softmax(perturb)
        results[i, :, j, :] = (perturb-orig).detach().cpu().numpy()
        results_orig[i, :, j, :] = orig.clone().detach().cpu().numpy()
        results_perturb[i, :, j, :] = perturb.detach().cpu().numpy()

In [25]:
np.save('./predict_status/interpret/input_output_TF*0.GCN.string.npy', results)
np.save('./predict_status/interpret/input_output_TF*0_orig.GCN.string.npy', results_orig)
np.save('./predict_status/interpret/input_output_TF*0_perturb.GCN.string.npy', results_perturb)

## cCRE

**Note:**

We recommend running this part of scripts with "nohup", instead of jupyter, since it takes more than 10 hours

In [None]:
use_cuda = True
device = torch.device("cuda:3" if use_cuda else "cpu")

# X_train = torch.from_numpy(X_train).to(device)
# X_val = torch.from_numpy(X_val).to(device)
X_test = torch.from_numpy(X_test).float().to(device)
# expr_train = torch.from_numpy(expr_train).to(device)
# expr_val = torch.from_numpy(expr_val).to(device)
expr_test = torch.from_numpy(expr_test).to(device)
# y_train = torch.from_numpy(y_oht_train).to(device)
# y_val = torch.from_numpy(y_oht_val).to(device)
y_test = torch.from_numpy(y_oht_test).to(device)

edge = pd.read_csv('./predict_status/string_filtered.txt', sep='\t')
edge = torch.cat((torch.from_numpy(np.array(edge)).t(), torch.from_numpy(np.array(edge)).t().flip([0])), dim=1).to(device)

### in-silico opening

In [None]:
repeats = 10
num_clusters = 3
results = np.zeros([repeats, X_test.shape[0], X_test.shape[2], num_clusters])
results_orig = np.zeros([repeats, X_test.shape[0], X_test.shape[2], num_clusters])
results_perturb = np.zeros([repeats, X_test.shape[0], X_test.shape[2], num_clusters])
for i in range(repeats):
    print(i)
    model = Net(num_genes=expr_test.shape[1], num_peaks=num_peaks, num_tf=X_test.shape[1], cut=cut).to(device)
    model.eval()
    model.load_state_dict(torch.load('./predict_status/model.GCN.string.'+str(i)+'.pt',
                                     map_location=device))
    np.random.seed(123)  

    _, _, orig = model(X_test)
    orig = F.softmax(orig)
    for j in range(X_test.shape[2]):
        #print(j)
        reference = peaks.iloc[j, :]
        duplicates = peaks.apply(lambda row: row.equals(reference), axis=1)
        cols = peaks.index[duplicates].tolist()
        X_new = torch.clone(X_test)#.detach()
        for col in cols:           
            peak_mat = np.ones([X_test.shape[0], len(tfs_kept)])
            tf_mat = cell_by_tf[indices_test, :]
            w = W[:, col]
            w = np.repeat(w[np.newaxis, :], X_test.shape[0], axis=0)
            new = tf_mat*peak_mat*w
            
            X_new[:, :, col] = torch.from_numpy(new).float().to(device)
        _, _, perturb = model(X_new)
        perturb = F.softmax(perturb)
        results[i, :, j, :] = (perturb-orig).detach().cpu().numpy()
        results_orig[i, :, j, :] = orig.clone().detach().cpu().numpy()
        results_perturb[i, :, j, :] = perturb.detach().cpu().numpy()

np.save('./predict_status/interpret/input_output_openness_all1_GCN.string.npy', results)
np.save('./predict_status/interpret/input_output_openness_all1_orig_GCN.string.npy', results_orig)
np.save('./predict_status/interpret/input_output_openness_all1_perturb_GCN.string.npy', results_perturb)

### in-silico closing

In [None]:
repeats = 10
num_clusters = 3
results = np.zeros([repeats, X_test.shape[0], X_test.shape[2], num_clusters])
results_orig = np.zeros([repeats, X_test.shape[0], X_test.shape[2], num_clusters])
results_perturb = np.zeros([repeats, X_test.shape[0], X_test.shape[2], num_clusters])
for i in range(repeats):
    print(i)
    model = Net(num_genes=expr_test.shape[1], num_peaks=num_peaks, num_tf=X_test.shape[1], cut=cut).to(device)
    model.eval()
    model.load_state_dict(torch.load('./predict_status/model.GCN.string.'+str(i)+'.pt',
                                     map_location=device))
    np.random.seed(123)  

    _, _, orig = model(X_test)
    orig = F.softmax(orig)
    for j in range(X_test.shape[2]):
        #print(j)
        reference = peaks.iloc[j, :]
        duplicates = peaks.apply(lambda row: row.equals(reference), axis=1)
        cols = peaks.index[duplicates].tolist()
        X_new = torch.clone(X_test).detach()
        for col in cols:
            X_new[:, :, col] = 0
        _, _, perturb = model(X_new)
        perturb = F.softmax(perturb)
        results[i, :, j, :] = (perturb-orig).detach().cpu().numpy()
        results_orig[i, :, j, :] = orig.clone().detach().cpu().numpy()
        results_perturb[i, :, j, :] = perturb.detach().cpu().numpy()

np.save('./predict_status/interpret/input_output_openness*0_GCN.string.npy', results)
np.save('./predict_status/interpret/input_output_openness*0_orig_GCN.string.npy', results_orig)
np.save('./predict_status/interpret/input_output_openness*0_perturb_GCN.string.npy', results_perturb)