This script generates the input TAM matrices of selected genes, which serve as input for the entire model in the second step of training.

In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
from pyfaidx import Fasta
from verstack import stratified_continuous_split
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import mean_squared_error
from sklearn.metrics import f1_score
from sklearn import preprocessing
from sklearn.preprocessing import OneHotEncoder
import scipy
from scipy import stats
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from torch.optim.lr_scheduler import StepLR
import h5py
import seaborn as sns
import os
os.chdir('/nfs/public/xixi/scRegulate/SHAREseq')

In [2]:
genes = Fasta('../../ref_genome/mm10.fa')

In [3]:
df_x = pd.read_csv('atac.aggregate_30cells.csv', index_col=0).transpose()
df_x

Unnamed: 0,chr18-31634804-31635104,chr1-165460961-165461261,chr1-75317415-75317715,chr18-46525819-46526119,chr4-3938558-3938858,chr6-71440583-71440883,chr7-105399959-105400259,chr1-161876704-161877004,chr6-113306525-113306825,chr1-177796516-177796816,...,chr10-111689260-111689560,chr10-111388777-111389077,chr10-110841761-110842061,chr10-10884983-10885283,chr10-108335612-108335912,chr10-107887768-107888068,chr10-107038880-107039180,chr10-10549625-10549925,chr10-105270865-105271165,chr10-103291333-103291633
V1,2.112241,1.988862,1.988862,1.664337,1.758608,2.112241,1.758608,1.434083,1.919221,1.919221,...,0.0,0.000000,0.0,0.0,0.554779,0.0,0.0,0.000000,0.0,0.0
V2,2.184422,2.079218,1.957768,1.533112,1.814123,1.889215,1.957768,1.814123,1.731114,1.814123,...,0.0,0.000000,0.0,0.0,0.546106,0.0,0.0,0.000000,0.0,0.0
V3,2.381858,1.865213,1.717455,1.542694,2.298454,1.542694,2.298454,1.993207,1.865213,1.865213,...,0.0,0.000000,0.0,0.0,0.000000,0.0,0.0,0.000000,0.0,0.0
V4,2.296301,1.952485,1.863144,1.952485,1.364729,2.107085,1.763270,1.952485,1.863144,2.174958,...,0.0,0.000000,0.0,0.0,0.000000,0.0,0.0,0.000000,0.0,0.0
V5,2.021073,1.682703,1.763392,1.903019,1.763392,1.836383,1.682703,1.763392,1.490239,1.490239,...,0.0,0.000000,0.0,0.0,0.841352,0.0,0.0,0.000000,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
V839,3.323174,3.171114,3.171114,2.585932,2.322799,2.585932,3.323174,2.585932,2.000750,2.808408,...,0.0,0.000000,0.0,0.0,0.000000,0.0,0.0,1.000375,0.0,0.0
V840,2.806491,2.806491,2.055718,2.656048,2.656048,2.485490,1.770699,2.055718,2.806491,2.288595,...,0.0,0.000000,0.0,0.0,0.000000,0.0,0.0,0.000000,0.0,0.0
V841,2.915099,1.888067,1.626292,2.101953,2.577613,1.888067,2.282790,1.626292,2.439439,2.282790,...,0.0,0.000000,0.0,0.0,0.000000,0.0,0.0,0.813146,0.0,0.0
V842,1.385682,1.421908,1.021733,1.209012,1.346771,1.385682,1.487627,1.259056,1.346771,1.385682,...,0.0,0.363949,0.0,0.0,0.000000,0.0,0.0,0.363949,0.0,0.0


In [4]:
df_y = pd.read_csv('rna.aggregate_30cells.csv', index_col=0).transpose()
df_y

Unnamed: 0,0610009B22Rik,0610009L18Rik,0610010F05Rik,0610010K14Rik,0610012D04Rik,0610012G03Rik,0610025J13Rik,0610030E20Rik,0610038B21Rik,0610039K10Rik,...,mt-Co2,mt-Co3,mt-Cytb,mt-Nd1,mt-Nd2,mt-Nd3,mt-Nd4,mt-Nd4l,mt-Nd5,mt-Nd6
V1,0.0,0.000000,1.179400,0.000000,0.0,0.000000,0.0,0.589700,0.000000,0.0,...,0.0,0.0,2.629728,2.459005,1.869305,0.0,1.369241,0.0,2.667545,1.655498
V2,0.0,0.000000,1.547446,0.000000,0.0,0.000000,0.0,0.000000,0.000000,0.0,...,0.0,0.0,2.191146,1.849178,1.713501,0.0,1.547446,0.0,2.063261,1.333363
V3,0.0,0.581185,1.162371,0.581185,0.0,0.000000,0.0,0.000000,0.581185,0.0,...,0.0,0.0,2.468833,2.591756,2.270628,0.0,1.930656,0.0,2.511841,1.842314
V4,0.0,0.000000,1.265819,0.000000,0.0,0.000000,0.0,1.003138,0.000000,0.0,...,0.0,0.0,2.822417,2.409712,1.776802,0.0,2.102480,0.0,2.688554,2.006276
V5,0.0,0.000000,0.897976,0.000000,0.0,0.566560,0.0,0.566560,0.566560,0.0,...,0.0,0.0,2.526534,2.266239,2.157094,0.0,1.882071,0.0,2.448630,1.699679
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
V839,0.0,0.000000,1.287068,0.000000,0.0,0.000000,0.0,0.000000,0.000000,0.0,...,0.0,0.0,2.039954,3.861204,2.574136,0.0,2.574136,0.0,4.275547,2.039954
V840,0.0,0.000000,0.502798,0.000000,0.0,0.000000,0.0,0.502798,0.000000,0.0,...,0.0,0.0,1.670258,1.670258,1.167460,0.0,1.593831,0.0,2.208447,1.299713
V841,0.0,0.000000,0.000000,0.000000,0.0,0.000000,0.0,0.000000,0.000000,0.0,...,0.0,0.0,2.453196,2.113068,2.113068,0.0,1.674567,0.0,3.787635,0.000000
V842,0.0,0.000000,0.932346,0.000000,0.0,0.000000,0.0,0.571665,0.360681,0.0,...,0.0,0.0,1.558835,1.608430,1.247750,0.0,1.143330,0.0,1.733919,0.837474


In [5]:
df_peaks = pd.DataFrame(df_x.columns)[0].str.split('-',expand=True) 
df_peaks = df_peaks.rename(columns={0: "chrom", 1: "start", 2: "end"})
df_peaks["start"] = pd.to_numeric(df_peaks["start"])
df_peaks["end"] = pd.to_numeric(df_peaks["end"])
df_peaks

Unnamed: 0,chrom,start,end
0,chr18,31634804,31635104
1,chr1,165460961,165461261
2,chr1,75317415,75317715
3,chr18,46525819,46526119
4,chr4,3938558,3938858
...,...,...,...
338699,chr10,107887768,107888068
338700,chr10,107038880,107039180
338701,chr10,10549625,10549925
338702,chr10,105270865,105271165


In [6]:
geneanno = pd.read_csv('../../ref_genome/mm10_geneanno.txt', sep='\t')
geneanno = geneanno.drop_duplicates(subset=['Gene name'])
geneanno

Unnamed: 0,Gene stable ID,Gene stable ID version,Chromosome/scaffold name,Gene start (bp),Gene end (bp),Strand,Gene name,Source of gene name,Transcription start site (TSS)
0,ENSMUSG00000070103,ENSMUSG00000070103.2,1,158505625,158505733,1,Mir488,MGI Symbol,158505625
1,ENSMUSG00000065567,ENSMUSG00000065567.1,1,23291701,23291784,1,Mir30c-2,MGI Symbol,23291701
2,ENSMUSG00000094946,ENSMUSG00000094946.1,1,83795912,83796032,1,Gm25754,MGI Symbol,83795912
3,ENSMUSG00000093155,ENSMUSG00000093155.1,1,74586896,74586972,-1,Gm25035,MGI Symbol,74586972
4,ENSMUSG00000065458,ENSMUSG00000065458.1,1,137966639,137966718,1,Mir181b-1,MGI Symbol,137966639
...,...,...,...,...,...,...,...,...,...
103991,ENSMUSG00000099633,ENSMUSG00000099633.1,Y,84907473,84910700,1,Gm29071,MGI Symbol,84907473
103993,ENSMUSG00000100388,ENSMUSG00000100388.1,Y,50770044,50773283,1,Gm29116,MGI Symbol,50770044
103995,ENSMUSG00000091987,ENSMUSG00000091987.8,Y,2900989,2912206,1,Gm10352,MGI Symbol,2900989
103997,ENSMUSG00000101667,ENSMUSG00000101667.1,Y,2932582,2939416,1,Gm29289,MGI Symbol,2932582


In [7]:
motif_files = os.listdir('../../ref_genome/JASPAR_motifs_pfm_mouse/pfm.np')
motif_files[:5]

['MA0002.2.Runx1.npy',
 'MA0006.1.Ahr::Arnt.npy',
 'MA0004.1.Arnt.npy',
 'MA0007.3.Ar.npy',
 'MA0009.1.Tbxt.npy']

In [8]:
tfs_kept = []
tf_by_region_mat = []
for i in list(motif_files):
    tf = i.split('.')[-2].capitalize()
    if tf in df_y.columns:
        if tf in tfs_kept:
            continue
            
        tfs_kept.append(tf)
tfs_kept = sorted(tfs_kept)

In [9]:
go = pd.read_csv('./go/go_follicle.txt', sep='\t')
goa = pd.read_csv('./go/goa_follicle.txt', sep='\t')
genes = goa['X2'].drop_duplicates().to_list()
len(genes)

113

In [10]:
files = os.listdir('/nfs/public/xixi/scRegulate/SHAREseq/nn.best.feature6.learnW_go')
markers_filtered = []
for file in files:
    marker = file.split('.')[0]
    if marker not in markers_filtered:
        markers_filtered.append(marker)
len(markers_filtered)

77

In [11]:
genes_filtered = [i for i in genes if i not in tfs_kept]
print(len(genes_filtered))
genes_filtered = [i for i in genes_filtered if i in markers_filtered]
print(len(genes_filtered))

107
77


In [12]:
df = pd.read_csv("pearson.testset.feature6_learnW_go.txt", sep='\t', header=None, index_col=0)
df = df.iloc[:, :3]
df = df.loc[genes_filtered, :]
df

Unnamed: 0_level_0,1,2,3
0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
Apcdd1,0.6620664186369295,0.6581277545337687,0.6646115313544828
Bcl2,0.6415016865303044,0.6443123261623774,0.6469849540275614
Fermt1,0.6971588066916068,0.6939047551731387,0.6880100671587428
Fzd3,0.5498111114099485,0.5403839438421032,0.5434602596277018
Gnas,0.5307134140393347,0.5393855660013731,0.5340669389632029
...,...,...,...
Notch1,0.9243966577484184,0.925053774067188,0.9209806042191196
Sos1,0.7292985095543261,0.723612591547852,0.7216092182318834
Sox9,0.4597888520074928,0.4429088483211883,0.43847324184673203
Atp7a,0.7934460779217066,0.7941483962350241,0.7970297382727736


In [13]:
def init_with_precomputed_matrix(tensor, precomputed_matrix):
    tensor.data = precomputed_matrix.clone()

class Net(nn.Module):
    def __init__(self, num_peak, num_tf, W):
        super(Net, self).__init__()
        self.num_peak = num_peak
        self.num_tf = num_tf

        self.W = nn.Parameter(torch.Tensor(num_tf, num_peak))
        init_with_precomputed_matrix(self.W, W)

        self.fc1 = nn.Linear(self.num_peak * self.num_tf, 1)
        self.fc1_activate = nn.ReLU()
        
    def forward(self, x):
        x = x * abs(self.W.repeat(x.shape[0], 1, 1))
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        return x

In [14]:
X_all = []
expr_all = []
num_peaks = []
peaks_all = pd.DataFrame()
count = 0
for marker in genes_filtered:
    print(count)
    count = count+1
    length = 250000
    anno = geneanno.loc[geneanno['Gene name']==marker,:]
    if anno.shape[0]>0:
        print(marker)
        chrom = 'chr'+ str(anno['Chromosome/scaffold name'].values[0])
        start = anno['Gene start (bp)'].values[0]
        end = anno['Gene end (bp)'].values[0]
        strand = anno['Strand'].values[0]
        if strand==1:
            peaks = df_peaks.loc[(df_peaks['chrom']==chrom) & ((df_peaks['start']>=start-length) & (df_peaks['start']<start+length) | 
                                                               (df_peaks['end']>start-length) & (df_peaks['end']<=start+length)),:]
        else:
            peaks = df_peaks.loc[(df_peaks['chrom']==chrom) & ((df_peaks['start']>=end-length) & (df_peaks['start']<end+length) | 
                                                               (df_peaks['end']>end-length) & (df_peaks['end']<=end+length)),:]
        cellcount = (df_x.iloc[:, peaks.index]!=0).astype(int).sum(axis=1)
        cells = cellcount[cellcount!=0].index
    else:
        print(marker)
        print('No such gene in the gene annotation file.')
        continue


    y = np.array(df_y.loc[:, marker], dtype='float32').flatten()

    tfs_use = tfs_kept.copy()
    cell_by_tf_mat = np.array(df_y.loc[:, tfs_use], dtype='float32')
    cell_by_peak_mat = np.array(df_x.iloc[:, peaks.index], dtype='float32')
    order = np.argsort(np.array(peaks.loc[:,'start']))
    cell_by_peak_mat = cell_by_peak_mat[:, order]
    # Scale data
    cell_by_peak_mat = cell_by_peak_mat/np.max(cell_by_peak_mat)

    cell_by_tf_mat = np.repeat(cell_by_tf_mat[:, :, np.newaxis], peaks.shape[0], axis=2)
    cell_by_peak_mat = np.repeat(cell_by_peak_mat[:, np.newaxis, :], len(tfs_use), axis=1)
    X = cell_by_tf_mat * cell_by_peak_mat

    y_std = (y - y.min()) / (y.max() - y.min())
    y = y_std * (1 - (0)) + (0)
   
    print(X.shape)
    
    W = torch.ones(X.shape[1], X.shape[2])
    model = Net(num_peak=X.shape[2], num_tf=X.shape[1], W=W)
    num = np.argsort(np.array(df.loc[marker]))[::-1][0]
    model.load_state_dict(torch.load('./nn.best.feature6.learnW_go/'+marker+'.'+str(num)+'.pt', map_location='cpu'))
    W = abs(model.W.data).clone().detach().numpy()[:, order]
    X = X*W    
    
    X_all.append(X)
    expr_all.append(y)
    peaks_all = peaks_all.append(peaks.iloc[order, :].reset_index(drop=True))
    num_peaks.append(X.shape[2])
    
print('All done!')

0
Apcdd1
(843, 165, 142)
1
Bcl2
(843, 165, 177)
2
Fermt1
(843, 165, 60)
3
Fzd3
(843, 165, 140)
4
Gnas
(843, 165, 104)
5
Hdac1
(843, 165, 156)
6
Hdac2
(843, 165, 78)
7
Krt17
(843, 165, 320)
8
Krt71
(843, 165, 256)
9
Nsun2
(843, 165, 98)
10
Numa1
(843, 165, 108)
11
Pias4
(843, 165, 205)
12
Tgfb2
(843, 165, 93)
13
Trps1
(843, 165, 105)
14
Wnt5a
(843, 165, 73)
15
Krt28
(843, 165, 198)
16
Krt25
(843, 165, 233)
17
Krt27
(843, 165, 209)
18
Ext1
(843, 165, 158)
19
Lamc1
(843, 165, 176)
20
Pum2
(843, 165, 272)
21
Lncpint
(843, 165, 187)
22
Ppard
(843, 165, 249)
23
Rela
(843, 165, 357)
24
Lama5
(843, 165, 167)
25
Myo5a
(843, 165, 120)
26
Tfap2c
(843, 165, 152)
27
Smo
(843, 165, 119)
28
Fzd6
(843, 165, 104)
29
Barx2
(843, 165, 243)
30
Nsdhl
(843, 165, 39)
31
Fgf10
(843, 165, 15)
32
Celsr1
(843, 165, 257)
33
Eda
(843, 165, 32)
34
Psen1
(843, 165, 124)
35
Dkk1
(843, 165, 7)
36
Trp63
(843, 165, 189)
37
Acvr1b
(843, 165, 308)
38
Lgr5
(843, 165, 115)
39
Edar
(843, 165, 89)
40
Tnfrsf19
(843, 165, 138)


In [15]:
X_new = np.concatenate(X_all, axis=2)
X_new.shape

(843, 165, 11058)

In [16]:
expr = np.array(expr_all).T
expr.shape

(843, 77)

In [17]:
s = 0
for i in num_peaks:
    s = s + i
print(s)

11058


In [19]:
peaks_all

Unnamed: 0,chrom,start,end
0,chr18,62684266,62684566
1,chr18,62684923,62685223
2,chr18,62685434,62685734
3,chr18,62699736,62700036
4,chr18,62717629,62717929
...,...,...,...
257,chr15,103160035,103160335
258,chr15,103163184,103163484
259,chr15,103165467,103165767
260,chr15,103168994,103169294


In [20]:
peaks_all['chrom'] = peaks_all['chrom'].astype(str)
peaks_all['start'] = peaks_all['start'].astype(str)
peaks_all['end'] = peaks_all['end'].astype(str)

In [21]:
h5f = h5py.File('/nfs/public/xixi/scRegulate/SHAREseq/predict_lineage_pseudotime/data_follicle2.h5', 'w')
h5f.create_dataset('X', data=X_new)
h5f.create_dataset('expr', data=expr)
h5f.create_dataset('num_peaks', data=np.array(num_peaks))
h5f.create_dataset('peaks', data=peaks_all.to_numpy())
h5f.close()