In [None]:
import sys
import numpy as np
import pickle
import os
import pandas as pd
from scipy import sparse

import time

import scanpy
import numpy as np

import torch
from torch import optim

import model_lord

import matplotlib.pyplot as plt
import seaborn as sns
import umap
from sklearn.preprocessing import scale
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler

import anndata as ad
import gc

In [None]:
skin_rnaPath='/data/xinyi/shareseq/skin_rna.h5ad'
skin_rna=scanpy.read(skin_rnaPath)

skin_atacPath='/data/xinyi/shareseq/skin_atac.h5ad'
skin_atac=scanpy.read(skin_atacPath)

In [None]:
skin_atac.var['index']=np.arange(skin_atac.shape[1])

In [None]:
scanpy.pp.filter_genes(skin_rna, min_cells=300)
scanpy.pp.filter_genes(skin_atac, min_cells=300)

scanpy.pp.filter_cells(skin_atac, min_genes=300)
scanpy.pp.filter_cells(skin_rna, min_genes=300)
skin_atac=skin_atac[skin_rna.obs.index]

scanpy.pp.filter_genes(skin_rna, min_cells=300)
scanpy.pp.filter_genes(skin_atac, min_cells=300)
scanpy.pp.filter_cells(skin_atac, min_genes=300)
scanpy.pp.filter_cells(skin_rna, min_genes=300)
skin_atac=skin_atac[skin_rna.obs.index]


atac=skin_atac.X.toarray()
rna=skin_rna.X.toarray()

atac_posweight=(atac.size-np.sum(atac))/np.sum(atac)
rna_posweight=(rna.size-np.sum(rna))/np.sum(rna)

In [None]:
log_data=True
normalize='minmax'
hiddenSize=1024
sharedSize=50
dSpecificSize=20

In [None]:
#train-test split
np.random.seed(3)
pctVal=0.05
pctTest=0.1

allIdx_all=np.arange(atac.shape[0])
np.random.shuffle(allIdx_all)
valIdx_all=allIdx_all[:int(pctVal*atac.shape[0])]
testIdx_all=allIdx_all[int(pctVal*atac.shape[0]):(int(pctVal*atac.shape[0])+int(pctTest*atac.shape[0]))]
trainIdx_all=allIdx_all[(int(pctVal*atac.shape[0])+int(pctTest*atac.shape[0])):]


In [None]:
#preprocess
if log_data:
    rna=np.log(rna+1/2)
    atac=np.log(atac+1/2)
if normalize=='zscore':
    scaler_rna = StandardScaler()
    scaler_rna.fit(rna[trainIdx_all])
    rna=scaler_rna.transform(rna)
    
    scaler_atac = StandardScaler()
    scaler_atac.fit(atac[trainIdx_all])
    atac=scaler_atac.transform(atac)
elif normalize=='minmax':
    rna=(rna-np.min(rna,axis=1,keepdims=True))/(np.max(rna,axis=1,keepdims=True)-np.min(rna,axis=1,keepdims=True))
    atac=(atac-np.min(atac,axis=1,keepdims=True))/(np.max(atac,axis=1,keepdims=True)-np.min(atac,axis=1,keepdims=True))

In [None]:
#preprocess for de
skin_rna_de=skin_rna.copy()
scanpy.pp.normalize_total(skin_rna_de, target_sum=1e4)
scanpy.pp.log1p(skin_rna_de)
scanpy.pp.scale(skin_rna_de,zero_center=False)

skin_atac_de=skin_atac.copy()
scanpy.pp.normalize_total(skin_atac_de, target_sum=1e4)
scanpy.pp.log1p(skin_atac_de)
scanpy.pp.scale(skin_atac_de,zero_center=False)

In [None]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "2" 

batchsize=512
saveFreq=100
epochs=10000
weight_decay=0
seed=3

dropout=0.01


testSaveName='shareseq_lord'
name='randNoise_sharedRecon_bceWweight_bce_morefilter'
modelsavepath_lord=os.path.join('/data/xinyi/shareseq/results/models',testSaveName,name)

nFeatures_rna=rna.shape[1]
nFeatures_atac=atac.shape[1]


train_nodes_idx=trainIdx_all
val_nodes_idx=valIdx_all

loadEpoch_decoders='4900'
loadEpoch_encoders='3900'

In [None]:
#load latent
latent_shared_dec=torch.nn.Embedding(rna.shape[0],sharedSize)
latent_rna_dec=torch.nn.Embedding(rna.shape[0],dSpecificSize)
latent_atac_dec=torch.nn.Embedding(rna.shape[0],dSpecificSize)
with open(os.path.join(modelsavepath_lord,'latentRNA_'+str(sharedSize)+'_'+str(dSpecificSize)+'_'+str(hiddenSize)+'_ep'+str(loadEpoch_decoders)), 'rb') as output:
    latent_rna_dec.weight=pickle.load(output)
with open(os.path.join(modelsavepath_lord,'latentShared_'+str(sharedSize)+'_'+str(dSpecificSize)+'_'+str(hiddenSize)+'_ep'+str(loadEpoch_decoders)), 'rb') as output:
    latent_shared_dec.weight=pickle.load( output)
with open(os.path.join(modelsavepath_lord,'latentATAC_'+str(sharedSize)+'_'+str(dSpecificSize)+'_'+str(hiddenSize)+'_ep'+str(loadEpoch_decoders)), 'rb') as output:
    latent_atac_dec.weight=pickle.load(output)
latent_rna_dec.weight.requires_grad=False
latent_shared_dec.weight.requires_grad=False
latent_atac_dec.weight.requires_grad=False            
            

In [None]:
#load encoders, decoders, and compute latent
model_rna_dec = gae.gae.model_lord.fc_decode_l4(nFeatures_rna,sharedSize+dSpecificSize,hiddenSize, dropout)
model_atac_dec = gae.gae.model_lord.fc_decode_l4(nFeatures_atac, sharedSize+dSpecificSize,hiddenSize, dropout)

model_rna_shared_dec = gae.gae.model_lord.fc_decode_l4(nFeatures_rna,sharedSize,hiddenSize, dropout)
model_atac_shared_dec = gae.gae.model_lord.fc_decode_l4(nFeatures_atac, sharedSize,hiddenSize, dropout)



model_rna_dec.cuda()
model_atac_dec.cuda()
model_rna_shared_dec.cuda()
model_atac_shared_dec.cuda()




model_rna_dec.load_state_dict(torch.load(os.path.join(modelsavepath_lord,str(sharedSize)+'_'+str(dSpecificSize)+'_'+str(hiddenSize)+'_ep'+str(loadEpoch_decoders)+'_rna.pt')))
model_atac_dec.load_state_dict(torch.load(os.path.join(modelsavepath_lord,str(sharedSize)+'_'+str(dSpecificSize)+'_'+str(hiddenSize)+'_ep'+str(loadEpoch_decoders)+'_atac.pt')))
model_rna_dec.eval()
model_atac_dec.eval()
model_rna_shared_dec.load_state_dict(torch.load(os.path.join(modelsavepath_lord,str(sharedSize)+'_'+str(dSpecificSize)+'_'+str(hiddenSize)+'_ep'+str(loadEpoch_decoders)+'_rnaShared.pt')))
model_atac_shared_dec.load_state_dict(torch.load(os.path.join(modelsavepath_lord,str(sharedSize)+'_'+str(dSpecificSize)+'_'+str(hiddenSize)+'_ep'+str(loadEpoch_decoders)+'_atacShared.pt')))
model_rna_shared_dec.eval()
model_atac_shared_dec.eval()



model_rna= gae.gae.model_lord.fc_encode_l4(nFeatures_rna,hiddenSize,sharedSize,dSpecificSize,sharedSize,dSpecificSize, dropout)
model_atac= gae.gae.model_lord.fc_encode_l4(nFeatures_atac,hiddenSize,sharedSize,dSpecificSize,sharedSize,dSpecificSize, dropout)
model_rna.cuda()
model_atac.cuda()

model_rna.load_state_dict(torch.load(os.path.join(modelsavepath_lord,'encode_'+str(sharedSize)+'_'+str(dSpecificSize)+'_'+str(hiddenSize)+'_'+str(loadEpoch_decoders)+'_ep'+str(loadEpoch_encoders)+'_rna.pt')))
model_atac.load_state_dict(torch.load(os.path.join(modelsavepath_lord,'encode_'+str(sharedSize)+'_'+str(dSpecificSize)+'_'+str(hiddenSize)+'_'+str(loadEpoch_decoders)+'_ep'+str(loadEpoch_encoders)+'_atac.pt')))
model_rna.eval()
model_atac.eval()

all_idx=np.arange(rna.shape[0])
with torch.no_grad():
    latent_encoded_rnaD=None
    latent_encoded_atacD=None
    latent_encoded_rnaShared=None
    latent_encoded_atacShared=None
    nvalBatches=int(np.ceil(rna.shape[0]/batchsize))
    for i in range(nvalBatches):
        valIdx=all_idx[i*batchsize:min((i+1)*batchsize,all_idx.shape[0])]
        valtarget_rna=torch.tensor(rna[valIdx]).cuda().float()
        valtarget_atac=torch.tensor(atac[valIdx]).cuda().float()
        valIdx=torch.tensor(valIdx)
        valInput_shared=latent_shared_dec(valIdx).cuda().float()
        valInput_rna=latent_rna_dec(valIdx).cuda().float()
        valInput_atac=latent_atac_dec(valIdx).cuda().float()

        recon_rna_shared,recon_rna_d= model_rna(valtarget_rna)
        atac_recon_shared,atac_recon_d = model_atac(valtarget_atac)
        
        if latent_encoded_rnaD is None:
            latent_encoded_rnaD=recon_rna_d.cpu().detach()
            latent_encoded_atacD=atac_recon_d.cpu().detach()
            latent_encoded_rnaShared=recon_rna_shared.cpu().detach()
            latent_encoded_atacShared=atac_recon_shared.cpu().detach()
        else:
            latent_encoded_rnaD=torch.cat((latent_encoded_rnaD,recon_rna_d.cpu().detach()),dim=0)
            latent_encoded_atacD=torch.cat((latent_encoded_atacD,atac_recon_d.cpu().detach()),dim=0)
            latent_encoded_rnaShared=torch.cat((latent_encoded_rnaShared,recon_rna_shared.cpu().detach()),dim=0)
            latent_encoded_atacShared=torch.cat((latent_encoded_atacShared,atac_recon_shared.cpu().detach()),dim=0)
            
            

In [None]:
celltype_unique,celltype_labels,celltype_counts=np.unique(skin_atac.obs['celltype'][train_nodes_idx],return_counts=True,return_inverse=True)

### rna

In [None]:
plotsavepath_de=os.path.join('/data/xinyi/shareseq/results/plots/',testSaveName,name,'de')
if not os.path.exists(plotsavepath_de):
    os.mkdir(plotsavepath_de)

In [None]:
# nsamples=12
nsamples=36
prevSampled=0
if prevSampled==0:
    deShared_centered_2bins_rna=[None]*nsamples
    deD_centered_2bins_rna=[None]*nsamples
    deShared_centered_2bins_heldout_rna=[None]*nsamples
    deD_centered_2bins_heldout_rna=[None]*nsamples
else:
    with open(os.path.join(plotsavepath_de,'de_shared_all_centered_2bins_rna'), 'rb') as output:
        deShared_centered_2bins_rna=pickle.load(output)
    with open(os.path.join(plotsavepath_de,'de_d_all_centered_2bins_rna'), 'rb') as output:
        deD_centered_2bins_rna=pickle.load(output)

    with open(os.path.join(plotsavepath_de,'de_shared_all_centered_2bins_heldout_rna'), 'rb') as output:
        deShared_centered_2bins_heldout_rna=pickle.load(output)
    with open(os.path.join(plotsavepath_de,'de_d_all_centered_2bins_heldout_rna'), 'rb') as output:
        deD_centered_2bins_heldout_rna=pickle.load(output)
    
    deShared_centered_2bins_rna.extend([None]*(nsamples-prevSampled))
    deD_centered_2bins_rna.extend([None]*(nsamples-prevSampled))
    deShared_centered_2bins_heldout_rna.extend([None]*(nsamples-prevSampled))
    deD_centered_2bins_heldout_rna.extend([None]*(nsamples-prevSampled))



###pca & binning parameters
nPC_shared=50
nPC_d=20
nSteps=11 
percentiles=(np.arange(nSteps-1)+1)*100/nSteps
nCellsPerSample=30


for h in range(prevSampled,nsamples):
    print(h)
    plotsavepath_de_h=os.path.join(plotsavepath_de,str(h))
    if not os.path.exists(plotsavepath_de_h):
        os.mkdir(plotsavepath_de_h)
    
    #train-test split
    np.random.seed(h)
    pctTest=0.2

    allIdx_all=np.arange(atac.shape[0])
    np.random.shuffle(allIdx_all)
    testIdx_all=allIdx_all[:int(pctTest*atac.shape[0])]
    trainIdx_all=allIdx_all[int(pctTest*atac.shape[0]):]

    #PCA of training
    pca_train_shared_rna=PCA(n_components=nPC_shared)
    pca_train_d_rna=PCA(n_components=nPC_d)
    pca_train_shared_rna.fit(latent_encoded_rnaShared[trainIdx_all])
    pca_train_d_rna.fit(latent_encoded_rnaD[trainIdx_all])

    #bin training cells
    latent_shared_rna_pca=pca_train_shared_rna.transform(latent_encoded_rnaShared[trainIdx_all])
    latent_shared_rna_pca_percentile=np.percentile(latent_shared_rna_pca,percentiles,axis=0)
    latent_shared_rna_pca_max=np.max(latent_shared_rna_pca,axis=0)
    latent_shared_rna_pca_min=np.min(latent_shared_rna_pca,axis=0)
    latent_d_rna_pca=pca_train_d_rna.transform(latent_encoded_rnaD[trainIdx_all])
    latent_d_rna_pca_percentile=np.percentile(latent_d_rna_pca,percentiles,axis=0)
    latent_d_rna_pca_max=np.max(latent_d_rna_pca,axis=0)
    latent_d_rna_pca_min=np.min(latent_d_rna_pca,axis=0)
    dist2origin_rna_shared=np.square(latent_shared_rna_pca)
    dist2origin_rna_d=np.square(latent_d_rna_pca)

    latent_shared_rna_pca_binID=np.zeros((nPC_shared,latent_shared_rna_pca.shape[0]))-1
    centerIdx_shared_all=[None]*nPC_shared
    latent_d_rna_pca_binID=np.zeros((nPC_d,latent_d_rna_pca.shape[0]))-1
    centerIdx_d_all=[None]*nPC_d
    for pc_i in range(nPC_shared):
        #idx of cells at the center of other pcs
        dist2origin_rna_shared_i=np.sum(dist2origin_rna_shared[:,:pc_i],axis=1)+np.sum(dist2origin_rna_shared[:,pc_i+1:],axis=1)
        thresh_shared_i=np.percentile(dist2origin_rna_shared_i,15)
        centerIdx_shared_all[pc_i]=dist2origin_rna_shared_i<thresh_shared_i

        for sidx in range(nSteps):
            if sidx==0:
                min_shared_s=latent_shared_rna_pca_min[pc_i]
            else:
                min_shared_s=latent_shared_rna_pca_percentile[sidx-1,pc_i]
            if sidx==nSteps-1:
                max_shared_s=latent_shared_rna_pca_max[pc_i]
            else:
                max_shared_s=latent_shared_rna_pca_percentile[sidx,pc_i]

            latent_shared_rna_pca_binID[pc_i,np.logical_and(latent_shared_rna_pca[:,pc_i]>min_shared_s,latent_shared_rna_pca[:,pc_i]<=max_shared_s)]=sidx
    
    for pc_i in range(nPC_d):
        #idx of cells at the center of other pcs
        dist2origin_rna_d_i=np.sum(dist2origin_rna_d[:,:pc_i],axis=1)+np.sum(dist2origin_rna_d[:,pc_i+1:],axis=1)
        thresh_d_i=np.percentile(dist2origin_rna_d_i,15)
        centerIdx_d_all[pc_i]=dist2origin_rna_d_i<thresh_d_i

        for sidx in range(nSteps):
            if sidx==0:
                min_d_s=latent_d_rna_pca_min[pc_i]
            else:
                min_d_s=latent_d_rna_pca_percentile[sidx-1,pc_i]
            if sidx==nSteps-1:
                max_d_s=latent_d_rna_pca_max[pc_i]
            else:
                max_d_s=latent_d_rna_pca_percentile[sidx,pc_i]

            latent_d_rna_pca_binID[pc_i,np.logical_and(latent_d_rna_pca[:,pc_i]>min_d_s,latent_d_rna_pca[:,pc_i]<=max_d_s)]=sidx
    
    ###DE of training###
    groupID_shared=np.zeros_like(latent_shared_rna_pca_binID).astype(str)
    groupID_shared[latent_shared_rna_pca_binID<2]=-1
    groupID_shared[latent_shared_rna_pca_binID>8]=1
    groupID_shared[latent_shared_rna_pca_binID==5]='c'
    groupID_d=np.zeros_like(latent_d_rna_pca_binID).astype(str)
    groupID_d[latent_d_rna_pca_binID<2]=-1
    groupID_d[latent_d_rna_pca_binID>8]=1
    groupID_d[latent_d_rna_pca_binID==5]='c'

    deRes_shared_rna_centered=[None]*nPC_shared
    deRes_d_rna_centered=[None]*nPC_d
    nmco_ad=skin_rna_de.copy()[trainIdx_all]

    for pc_i in range(nPC_shared):
        print(pc_i)
        nmco_ad.obs[str(pc_i)+'percentileGroups_shared']=groupID_shared[pc_i].astype(str)
        nmco_ad_subShared=nmco_ad[centerIdx_shared_all[pc_i]].copy()

        if 'c' in np.unique(nmco_ad_subShared.obs[str(pc_i)+'percentileGroups_shared']):
            scanpy.tl.rank_genes_groups(nmco_ad_subShared, str(pc_i)+'percentileGroups_shared', method='t-test',groups=np.intersect1d(np.unique(nmco_ad_subShared.obs[str(pc_i)+'percentileGroups_shared']),['-1','1','c']).tolist(),use_raw=False)
            deRes_shared_rna_centered[pc_i]={}
            deRes_shared_rna_centered[pc_i]['0']=scanpy.get.rank_genes_groups_df(nmco_ad_subShared,'c')

        if '1' in np.unique(nmco_ad_subShared.obs[str(pc_i)+'percentileGroups_shared']) and '-1' in np.unique(nmco_ad_subShared.obs[str(pc_i)+'percentileGroups_shared']):
            scanpy.tl.rank_genes_groups(nmco_ad_subShared, str(pc_i)+'percentileGroups_shared', method='t-test',groups=['-1','1'],use_raw=False)
            deRes_shared_rna_centered[pc_i]['1']=scanpy.get.rank_genes_groups_df(nmco_ad_subShared,'1')
            deRes_shared_rna_centered[pc_i]['-1']=scanpy.get.rank_genes_groups_df(nmco_ad_subShared,'-1')

    for pc_i in range(nPC_d):
        print(pc_i)
        nmco_ad.obs[str(pc_i)+'percentileGroups_d']=groupID_d[pc_i].astype(str)
        nmco_ad_subD=nmco_ad[centerIdx_d_all[pc_i]].copy()

        if 'c' in np.unique(nmco_ad_subD.obs[str(pc_i)+'percentileGroups_d']):
            scanpy.tl.rank_genes_groups(nmco_ad_subD, str(pc_i)+'percentileGroups_d', method='t-test',groups=np.intersect1d(np.unique(nmco_ad_subD.obs[str(pc_i)+'percentileGroups_d']),['-1','1','c']).tolist(),use_raw=False)
            deRes_d_rna_centered[pc_i]={}
            deRes_d_rna_centered[pc_i]['0']=scanpy.get.rank_genes_groups_df(nmco_ad_subD,'c')

        if '1' in np.unique(nmco_ad_subD.obs[str(pc_i)+'percentileGroups_d']) and '-1' in np.unique(nmco_ad_subD.obs[str(pc_i)+'percentileGroups_d']):
            scanpy.tl.rank_genes_groups(nmco_ad_subD, str(pc_i)+'percentileGroups_d', method='t-test',groups=['-1','1'],use_raw=False)
            deRes_d_rna_centered[pc_i]['1']=scanpy.get.rank_genes_groups_df(nmco_ad_subD,'1')
            deRes_d_rna_centered[pc_i]['-1']=scanpy.get.rank_genes_groups_df(nmco_ad_subD,'-1')

    deShared_centered_2bins_rna[h]=deRes_shared_rna_centered
    deD_centered_2bins_rna[h]=deRes_d_rna_centered
    
    ### test PCA###
    latent_shared_rna_pca_heldout=pca_train_shared_rna.transform(latent_encoded_rnaShared[testIdx_all])
    plt.scatter(latent_shared_rna_pca[:,0],latent_shared_rna_pca[:,1],s=0.1,c='blue')
    plt.scatter(latent_shared_rna_pca_heldout[:,0],latent_shared_rna_pca_heldout[:,1],s=0.1,c='red')
    plt.savefig(os.path.join(plotsavepath_de_h,'pca_shared_rna.pdf'))
    plt.close()

    latent_d_rna_pca_heldout=pca_train_d_rna.transform(latent_encoded_rnaD[testIdx_all])
    plt.scatter(latent_d_rna_pca[:,0],latent_d_rna_pca[:,1],s=0.1,c='blue')
    plt.scatter(latent_d_rna_pca_heldout[:,0],latent_d_rna_pca_heldout[:,1],s=0.1,c='red')
    plt.savefig(os.path.join(plotsavepath_de_h,'pca_d_rna.pdf'))
    plt.close()
    
    ## bin & de
    dist2origin_rna_shared=np.square(latent_shared_rna_pca) #use training distance thresh
    dist2origin_rna_d=np.square(latent_d_rna_pca)#use training distance thresh
    dist2origin_rna_shared_heldout=np.square(latent_shared_rna_pca_heldout) 
    dist2origin_rna_d_heldout=np.square(latent_d_rna_pca_heldout)

    latent_shared_rna_pca_heldout_binID=np.zeros((nPC_shared,latent_shared_rna_pca_heldout.shape[0]))-1
    centerIdx_shared_all=[None]*nPC_shared
    latent_d_rna_pca_heldout_binID=np.zeros((nPC_d,latent_d_rna_pca_heldout.shape[0]))-1
    centerIdx_d_all=[None]*nPC_d
    for pc_i in range(nPC_shared):
        #idx of cells at the center of other pcs
        dist2origin_rna_shared_i=np.sum(dist2origin_rna_shared[:,:pc_i],axis=1)+np.sum(dist2origin_rna_shared[:,pc_i+1:],axis=1)
        thresh_shared_i=np.percentile(dist2origin_rna_shared_i,15)

        dist2origin_rna_shared_i_heldout=np.sum(dist2origin_rna_shared_heldout[:,:pc_i],axis=1)+np.sum(dist2origin_rna_shared_heldout[:,pc_i+1:],axis=1)
        centerIdx_shared_all[pc_i]=dist2origin_rna_shared_i_heldout<thresh_shared_i

        for sidx in range(nSteps):
            if sidx==0:
                min_shared_s=latent_shared_rna_pca_min[pc_i]
            else:
                min_shared_s=latent_shared_rna_pca_percentile[sidx-1,pc_i]
            if sidx==nSteps-1:
                max_shared_s=latent_shared_rna_pca_max[pc_i]
            else:
                max_shared_s=latent_shared_rna_pca_percentile[sidx,pc_i]

            latent_shared_rna_pca_heldout_binID[pc_i,np.logical_and(latent_shared_rna_pca_heldout[:,pc_i]>min_shared_s,latent_shared_rna_pca_heldout[:,pc_i]<=max_shared_s)]=sidx

    for pc_i in range(nPC_d):
        #idx of cells at the center of other pcs
        dist2origin_rna_d_i=np.sum(dist2origin_rna_d[:,:pc_i],axis=1)+np.sum(dist2origin_rna_d[:,pc_i+1:],axis=1)
        thresh_d_i=np.percentile(dist2origin_rna_d_i,15)

        dist2origin_rna_d_i_heldout=np.sum(dist2origin_rna_d_heldout[:,:pc_i],axis=1)+np.sum(dist2origin_rna_d_heldout[:,pc_i+1:],axis=1)
        centerIdx_d_all[pc_i]=dist2origin_rna_d_i_heldout<thresh_d_i

        for sidx in range(nSteps):
            if sidx==0:
                min_d_s=latent_d_rna_pca_min[pc_i]
            else:
                min_d_s=latent_d_rna_pca_percentile[sidx-1,pc_i]
            if sidx==nSteps-1:
                max_d_s=latent_d_rna_pca_max[pc_i]
            else:
                max_d_s=latent_d_rna_pca_percentile[sidx,pc_i]

            latent_d_rna_pca_heldout_binID[pc_i,np.logical_and(latent_d_rna_pca_heldout[:,pc_i]>min_d_s,latent_d_rna_pca_heldout[:,pc_i]<=max_d_s)]=sidx

    groupID_shared_heldout=np.zeros_like(latent_shared_rna_pca_heldout_binID).astype(str)
    groupID_shared_heldout[latent_shared_rna_pca_heldout_binID<2]=-1
    groupID_shared_heldout[latent_shared_rna_pca_heldout_binID>8]=1
    groupID_shared_heldout[latent_shared_rna_pca_heldout_binID==5]='c'
    # groupID_shared_heldout[np.logical_and(latent_shared_rna_pca_heldout_binID>3,latent_shared_rna_pca_heldout_binID<7)]='c'
    groupID_d_heldout=np.zeros_like(latent_d_rna_pca_heldout_binID).astype(str)
    groupID_d_heldout[latent_d_rna_pca_heldout_binID<2]=-1
    groupID_d_heldout[latent_d_rna_pca_heldout_binID>8]=1
    groupID_d_heldout[latent_d_rna_pca_heldout_binID==5]='c'
    # groupID_d_heldout[np.logical_and(latent_d_rna_pca_heldout_binID>3,latent_d_rna_pca_heldout_binID<7)]='c'

    ####DE heldout###
    deRes_shared_rna_centered_heldout=[None]*nPC_shared
    deRes_d_rna_centered_heldout=[None]*nPC_d
    nmco_ad_heldout=skin_rna_de.copy()[testIdx_all]

    for pc_i in range(nPC_shared):
        print(pc_i)
        nmco_ad_heldout.obs[str(pc_i)+'percentileGroups_shared']=groupID_shared_heldout[pc_i].astype(str)
        nmco_ad_heldout_subShared=nmco_ad_heldout[centerIdx_shared_all[pc_i]].copy()

        if 'c' in np.unique(nmco_ad_heldout_subShared.obs[str(pc_i)+'percentileGroups_shared']):
            scanpy.tl.rank_genes_groups(nmco_ad_heldout_subShared, str(pc_i)+'percentileGroups_shared', method='t-test',groups=np.intersect1d(np.unique(nmco_ad_heldout_subShared.obs[str(pc_i)+'percentileGroups_shared']),['-1','1','c']).tolist(),use_raw=False)
            deRes_shared_rna_centered_heldout[pc_i]={}
            deRes_shared_rna_centered_heldout[pc_i]['0']=scanpy.get.rank_genes_groups_df(nmco_ad_heldout_subShared,'c')

        if '1' in np.unique(nmco_ad_heldout_subShared.obs[str(pc_i)+'percentileGroups_shared']) and '-1' in np.unique(nmco_ad_heldout_subShared.obs[str(pc_i)+'percentileGroups_shared']):
            scanpy.tl.rank_genes_groups(nmco_ad_heldout_subShared, str(pc_i)+'percentileGroups_shared', method='t-test',groups=['-1','1'],use_raw=False)
            deRes_shared_rna_centered_heldout[pc_i]['1']=scanpy.get.rank_genes_groups_df(nmco_ad_heldout_subShared,'1')
            deRes_shared_rna_centered_heldout[pc_i]['-1']=scanpy.get.rank_genes_groups_df(nmco_ad_heldout_subShared,'-1')

    for pc_i in range(nPC_d):
        print(pc_i)
        nmco_ad_heldout.obs[str(pc_i)+'percentileGroups_d']=groupID_d_heldout[pc_i].astype(str)
        nmco_ad_heldout_subD=nmco_ad_heldout[centerIdx_d_all[pc_i]].copy()

        if 'c' in np.unique(nmco_ad_heldout_subD.obs[str(pc_i)+'percentileGroups_d']):
            scanpy.tl.rank_genes_groups(nmco_ad_heldout_subD, str(pc_i)+'percentileGroups_d', method='t-test',groups=np.intersect1d(np.unique(nmco_ad_heldout_subD.obs[str(pc_i)+'percentileGroups_d']),['-1','1','c']).tolist(),use_raw=False)
            deRes_d_rna_centered_heldout[pc_i]={}
            deRes_d_rna_centered_heldout[pc_i]['0']=scanpy.get.rank_genes_groups_df(nmco_ad_heldout_subD,'c')

        if '1' in np.unique(nmco_ad_heldout_subD.obs[str(pc_i)+'percentileGroups_d']) and '-1' in np.unique(nmco_ad_heldout_subD.obs[str(pc_i)+'percentileGroups_d']):
            scanpy.tl.rank_genes_groups(nmco_ad_heldout_subD, str(pc_i)+'percentileGroups_d', method='t-test',groups=['-1','1'],use_raw=False)
            deRes_d_rna_centered_heldout[pc_i]['1']=scanpy.get.rank_genes_groups_df(nmco_ad_heldout_subD,'1')
            deRes_d_rna_centered_heldout[pc_i]['-1']=scanpy.get.rank_genes_groups_df(nmco_ad_heldout_subD,'-1')

    deShared_centered_2bins_heldout_rna[h]=deRes_shared_rna_centered_heldout
    deD_centered_2bins_heldout_rna[h]=deRes_d_rna_centered_heldout

In [None]:

with open(os.path.join(plotsavepath_de,'de_shared_all_centered_2bins_rna'), 'wb') as output:
    pickle.dump(deShared_centered_2bins_rna,output,pickle.HIGHEST_PROTOCOL)
with open(os.path.join(plotsavepath_de,'de_d_all_centered_2bins_rna'), 'wb') as output:
    pickle.dump(deD_centered_2bins_rna,output,pickle.HIGHEST_PROTOCOL)
    
with open(os.path.join(plotsavepath_de,'de_shared_all_centered_2bins_heldout_rna'), 'wb') as output:
    pickle.dump(deShared_centered_2bins_heldout_rna,output,pickle.HIGHEST_PROTOCOL)
with open(os.path.join(plotsavepath_de,'de_d_all_centered_2bins_heldout_rna'), 'wb') as output:
    pickle.dump(deD_centered_2bins_heldout_rna,output,pickle.HIGHEST_PROTOCOL)


In [None]:
print('shared')
plt.bar(np.arange(50),pca_train_shared_rna.explained_variance_[:50])
plt.show()
print('d-specific')
plt.bar(np.arange(20),pca_train_d_rna.explained_variance_[:20])
plt.show()


In [None]:
np.sum(pca_train_shared_rna.explained_variance_ratio_[:30])

In [None]:
np.sum(pca_train_d_rna.explained_variance_ratio_[:13])

### ATAC

In [None]:
# nsamples=12
nsamples=36
prevSampled=0
if prevSampled==0:
    deShared_centered_2bins_atac=[None]*nsamples
    deD_centered_2bins_atac=[None]*nsamples
    deShared_centered_2bins_heldout_atac=[None]*nsamples
    deD_centered_2bins_heldout_atac=[None]*nsamples
else:
    with open(os.path.join(plotsavepath_de,'de_shared_all_centered_2bins_atac'), 'rb') as output:
        deShared_centered_2bins_atac=pickle.load(output)
    with open(os.path.join(plotsavepath_de,'de_d_all_centered_2bins_atac'), 'rb') as output:
        deD_centered_2bins_atac=pickle.load(output)

    with open(os.path.join(plotsavepath_de,'de_shared_all_centered_2bins_heldout_atac'), 'rb') as output:
        deShared_centered_2bins_heldout_atac=pickle.load(output)
    with open(os.path.join(plotsavepath_de,'de_d_all_centered_2bins_heldout_atac'), 'rb') as output:
        deD_centered_2bins_heldout_atac=pickle.load(output)
        
    deShared_centered_2bins_atac.extend([None]*(nsamples-prevSampled))
    deD_centered_2bins_atac.extend([None]*(nsamples-prevSampled))
    deShared_centered_2bins_heldout_atac.extend([None]*(nsamples-prevSampled))
    deD_centered_2bins_heldout_atac.extend([None]*(nsamples-prevSampled))



###pca & binning parameters
nPC_shared=50
nPC_d=20
nSteps=11 
minbin=2
maxbin=8
percentiles=(np.arange(nSteps-1)+1)*100/nSteps
nCellsPerSample=30


for h in range(prevSampled,nsamples):
    print(h)
    plotsavepath_de_h=os.path.join(plotsavepath_de,str(h))
    if not os.path.exists(plotsavepath_de_h):
        os.mkdir(plotsavepath_de_h)
    
    #train-test split
    np.random.seed(h)
    pctTest=0.2

    allIdx_all=np.arange(atac.shape[0])
    np.random.shuffle(allIdx_all)
    testIdx_all=allIdx_all[:int(pctTest*atac.shape[0])]
    trainIdx_all=allIdx_all[int(pctTest*atac.shape[0]):]

    #PCA of training
    pca_train_shared_atac=PCA(n_components=nPC_shared)
    pca_train_d_atac=PCA(n_components=nPC_d)
    pca_train_shared_atac.fit(latent_encoded_atacShared[trainIdx_all])
    pca_train_d_atac.fit(latent_encoded_atacD[trainIdx_all])

    #bin training cells
    latent_shared_atac_pca=pca_train_shared_atac.transform(latent_encoded_atacShared[trainIdx_all])
    latent_shared_atac_pca_percentile=np.percentile(latent_shared_atac_pca,percentiles,axis=0)
    latent_shared_atac_pca_max=np.max(latent_shared_atac_pca,axis=0)
    latent_shared_atac_pca_min=np.min(latent_shared_atac_pca,axis=0)
    latent_d_atac_pca=pca_train_d_atac.transform(latent_encoded_atacD[trainIdx_all])
    latent_d_atac_pca_percentile=np.percentile(latent_d_atac_pca,percentiles,axis=0)
    latent_d_atac_pca_max=np.max(latent_d_atac_pca,axis=0)
    latent_d_atac_pca_min=np.min(latent_d_atac_pca,axis=0)
    dist2origin_atac_shared=np.square(latent_shared_atac_pca)
    dist2origin_atac_d=np.square(latent_d_atac_pca)

    latent_shared_atac_pca_binID=np.zeros((nPC_shared,latent_shared_atac_pca.shape[0]))-1
    centerIdx_shared_all=[None]*nPC_shared
    latent_d_atac_pca_binID=np.zeros((nPC_d,latent_d_atac_pca.shape[0]))-1
    centerIdx_d_all=[None]*nPC_d
    for pc_i in range(nPC_shared):
        #idx of cells at the center of other pcs
        dist2origin_atac_shared_i=np.sum(dist2origin_atac_shared[:,:pc_i],axis=1)+np.sum(dist2origin_atac_shared[:,pc_i+1:],axis=1)
        thresh_shared_i=np.percentile(dist2origin_atac_shared_i,15)
        centerIdx_shared_all[pc_i]=dist2origin_atac_shared_i<thresh_shared_i

        for sidx in range(nSteps):
            if sidx==0:
                min_shared_s=latent_shared_atac_pca_min[pc_i]
            else:
                min_shared_s=latent_shared_atac_pca_percentile[sidx-1,pc_i]
            if sidx==nSteps-1:
                max_shared_s=latent_shared_atac_pca_max[pc_i]
            else:
                max_shared_s=latent_shared_atac_pca_percentile[sidx,pc_i]

            latent_shared_atac_pca_binID[pc_i,np.logical_and(latent_shared_atac_pca[:,pc_i]>min_shared_s,latent_shared_atac_pca[:,pc_i]<=max_shared_s)]=sidx
    
    for pc_i in range(nPC_d):
        #idx of cells at the center of other pcs
        dist2origin_atac_d_i=np.sum(dist2origin_atac_d[:,:pc_i],axis=1)+np.sum(dist2origin_atac_d[:,pc_i+1:],axis=1)
        thresh_d_i=np.percentile(dist2origin_atac_d_i,15)
        centerIdx_d_all[pc_i]=dist2origin_atac_d_i<thresh_d_i

        for sidx in range(nSteps):
            if sidx==0:
                min_d_s=latent_d_atac_pca_min[pc_i]
            else:
                min_d_s=latent_d_atac_pca_percentile[sidx-1,pc_i]
            if sidx==nSteps-1:
                max_d_s=latent_d_atac_pca_max[pc_i]
            else:
                max_d_s=latent_d_atac_pca_percentile[sidx,pc_i]

            latent_d_atac_pca_binID[pc_i,np.logical_and(latent_d_atac_pca[:,pc_i]>min_d_s,latent_d_atac_pca[:,pc_i]<=max_d_s)]=sidx
    
    ###DE of training###
    groupID_shared=np.zeros_like(latent_shared_atac_pca_binID).astype(str)
    groupID_shared[latent_shared_atac_pca_binID<2]=-1
    groupID_shared[latent_shared_atac_pca_binID>8]=1
    groupID_shared[latent_shared_atac_pca_binID==5]='c'
    groupID_d=np.zeros_like(latent_d_atac_pca_binID).astype(str)
    groupID_d[latent_d_atac_pca_binID<2]=-1
    groupID_d[latent_d_atac_pca_binID>8]=1
    groupID_d[latent_d_atac_pca_binID==5]='c'

    deRes_shared_atac_centered=[None]*nPC_shared
    deRes_d_atac_centered=[None]*nPC_d
    nmco_ad=skin_atac_de.copy()[trainIdx_all]

    for pc_i in range(nPC_shared):
        print(pc_i)
        nmco_ad.obs[str(pc_i)+'percentileGroups_shared']=groupID_shared[pc_i].astype(str)
        nmco_ad_subShared=nmco_ad[centerIdx_shared_all[pc_i]].copy()

        _groups=np.intersect1d(np.unique(nmco_ad_subShared.obs[str(pc_i)+'percentileGroups_shared']),['-1','1','c'])
        if 'c' in np.unique(nmco_ad_subShared.obs[str(pc_i)+'percentileGroups_shared']) and _groups.size>1:
#             print('deDirecC')
            groups=[]
            for g in _groups:
                if np.sum(nmco_ad_subShared.obs[str(pc_i)+'percentileGroups_shared']==g)>1:
                    groups.append(g)
            scanpy.tl.rank_genes_groups(nmco_ad_subShared, str(pc_i)+'percentileGroups_shared', method='t-test',groups=groups,use_raw=False)
            deRes_shared_atac_centered[pc_i]={}
            deRes_shared_atac_centered[pc_i]['0']=scanpy.get.rank_genes_groups_df(nmco_ad_subShared,'c')

        if np.sum(nmco_ad_subShared.obs[str(pc_i)+'percentileGroups_shared']=='1')>1 and np.sum(nmco_ad_subShared.obs[str(pc_i)+'percentileGroups_shared']=='-1')>1:
            print('deDirec1')
            scanpy.tl.rank_genes_groups(nmco_ad_subShared, str(pc_i)+'percentileGroups_shared', method='t-test',groups=['-1','1'],use_raw=False)
            deRes_shared_atac_centered[pc_i]['1']=scanpy.get.rank_genes_groups_df(nmco_ad_subShared,'1')
            deRes_shared_atac_centered[pc_i]['-1']=scanpy.get.rank_genes_groups_df(nmco_ad_subShared,'-1')

    for pc_i in range(nPC_d):
        print(pc_i)
        nmco_ad.obs[str(pc_i)+'percentileGroups_d']=groupID_d[pc_i].astype(str)
        nmco_ad_subD=nmco_ad[centerIdx_d_all[pc_i]].copy()

        _groups=np.intersect1d(np.unique(nmco_ad_subD.obs[str(pc_i)+'percentileGroups_d']),['-1','1','c'])
        if 'c' in np.unique(nmco_ad_subD.obs[str(pc_i)+'percentileGroups_d']) and _groups.size>1:
            groups=[]
            for g in _groups:
                if np.sum(nmco_ad_subD.obs[str(pc_i)+'percentileGroups_d']==g)>1:
                    groups.append(g)
            scanpy.tl.rank_genes_groups(nmco_ad_subD, str(pc_i)+'percentileGroups_d', method='t-test',groups=groups,use_raw=False)
            deRes_d_atac_centered[pc_i]={}
            deRes_d_atac_centered[pc_i]['0']=scanpy.get.rank_genes_groups_df(nmco_ad_subD,'c')

        if np.sum(nmco_ad_subD.obs[str(pc_i)+'percentileGroups_d']=='1')>1 and np.sum(nmco_ad_subD.obs[str(pc_i)+'percentileGroups_d']=='-1')>1:
            scanpy.tl.rank_genes_groups(nmco_ad_subD, str(pc_i)+'percentileGroups_d', method='t-test',groups=['-1','1'],use_raw=False)
            deRes_d_atac_centered[pc_i]['1']=scanpy.get.rank_genes_groups_df(nmco_ad_subD,'1')
            deRes_d_atac_centered[pc_i]['-1']=scanpy.get.rank_genes_groups_df(nmco_ad_subD,'-1')

    deShared_centered_2bins_atac[h]=deRes_shared_atac_centered
    deD_centered_2bins_atac[h]=deRes_d_atac_centered
    
    ### test PCA###
    latent_shared_atac_pca_heldout=pca_train_shared_atac.transform(latent_encoded_atacShared[testIdx_all])
    plt.scatter(latent_shared_atac_pca[:,0],latent_shared_atac_pca[:,1],s=0.1,c='blue')
    plt.scatter(latent_shared_atac_pca_heldout[:,0],latent_shared_atac_pca_heldout[:,1],s=0.1,c='red')
    plt.savefig(os.path.join(plotsavepath_de_h,'pca_shared_atac.pdf'))
    plt.close()

    latent_d_atac_pca_heldout=pca_train_d_atac.transform(latent_encoded_atacD[testIdx_all])
    plt.scatter(latent_d_atac_pca[:,0],latent_d_atac_pca[:,1],s=0.1,c='blue')
    plt.scatter(latent_d_atac_pca_heldout[:,0],latent_d_atac_pca_heldout[:,1],s=0.1,c='red')
    plt.savefig(os.path.join(plotsavepath_de_h,'pca_d_atac.pdf'))
    plt.close()
    
    ## bin & de
    dist2origin_atac_shared=np.square(latent_shared_atac_pca) #use training distance thresh
    dist2origin_atac_d=np.square(latent_d_atac_pca)#use training distance thresh
    dist2origin_atac_shared_heldout=np.square(latent_shared_atac_pca_heldout) 
    dist2origin_atac_d_heldout=np.square(latent_d_atac_pca_heldout)

    latent_shared_atac_pca_heldout_binID=np.zeros((nPC_shared,latent_shared_atac_pca_heldout.shape[0]))-1
    centerIdx_shared_all=[None]*nPC_shared
    latent_d_atac_pca_heldout_binID=np.zeros((nPC_d,latent_d_atac_pca_heldout.shape[0]))-1
    centerIdx_d_all=[None]*nPC_d
    for pc_i in range(nPC_shared):
        #idx of cells at the center of other pcs
        dist2origin_atac_shared_i=np.sum(dist2origin_atac_shared[:,:pc_i],axis=1)+np.sum(dist2origin_atac_shared[:,pc_i+1:],axis=1)
        thresh_shared_i=np.percentile(dist2origin_atac_shared_i,15)

        dist2origin_atac_shared_i_heldout=np.sum(dist2origin_atac_shared_heldout[:,:pc_i],axis=1)+np.sum(dist2origin_atac_shared_heldout[:,pc_i+1:],axis=1)
        centerIdx_shared_all[pc_i]=dist2origin_atac_shared_i_heldout<thresh_shared_i

        for sidx in range(nSteps):
            if sidx==0:
                min_shared_s=latent_shared_atac_pca_min[pc_i]
            else:
                min_shared_s=latent_shared_atac_pca_percentile[sidx-1,pc_i]
            if sidx==nSteps-1:
                max_shared_s=latent_shared_atac_pca_max[pc_i]
            else:
                max_shared_s=latent_shared_atac_pca_percentile[sidx,pc_i]

            latent_shared_atac_pca_heldout_binID[pc_i,np.logical_and(latent_shared_atac_pca_heldout[:,pc_i]>min_shared_s,latent_shared_atac_pca_heldout[:,pc_i]<=max_shared_s)]=sidx

    for pc_i in range(nPC_d):
        #idx of cells at the center of other pcs
        dist2origin_atac_d_i=np.sum(dist2origin_atac_d[:,:pc_i],axis=1)+np.sum(dist2origin_atac_d[:,pc_i+1:],axis=1)
        thresh_d_i=np.percentile(dist2origin_atac_d_i,15)

        dist2origin_atac_d_i_heldout=np.sum(dist2origin_atac_d_heldout[:,:pc_i],axis=1)+np.sum(dist2origin_atac_d_heldout[:,pc_i+1:],axis=1)
        centerIdx_d_all[pc_i]=dist2origin_atac_d_i_heldout<thresh_d_i

        for sidx in range(nSteps):
            if sidx==0:
                min_d_s=latent_d_atac_pca_min[pc_i]
            else:
                min_d_s=latent_d_atac_pca_percentile[sidx-1,pc_i]
            if sidx==nSteps-1:
                max_d_s=latent_d_atac_pca_max[pc_i]
            else:
                max_d_s=latent_d_atac_pca_percentile[sidx,pc_i]

            latent_d_atac_pca_heldout_binID[pc_i,np.logical_and(latent_d_atac_pca_heldout[:,pc_i]>min_d_s,latent_d_atac_pca_heldout[:,pc_i]<=max_d_s)]=sidx

    groupID_shared_heldout=np.zeros_like(latent_shared_atac_pca_heldout_binID).astype(str)
    groupID_shared_heldout[latent_shared_atac_pca_heldout_binID<2]=-1
    groupID_shared_heldout[latent_shared_atac_pca_heldout_binID>8]=1
    groupID_shared_heldout[latent_shared_atac_pca_heldout_binID==5]='c'
    # groupID_shared_heldout[np.logical_and(latent_shared_atac_pca_heldout_binID>3,latent_shared_atac_pca_heldout_binID<7)]='c'
    groupID_d_heldout=np.zeros_like(latent_d_atac_pca_heldout_binID).astype(str)
    groupID_d_heldout[latent_d_atac_pca_heldout_binID<2]=-1
    groupID_d_heldout[latent_d_atac_pca_heldout_binID>8]=1
    groupID_d_heldout[latent_d_atac_pca_heldout_binID==5]='c'
    # groupID_d_heldout[np.logical_and(latent_d_atac_pca_heldout_binID>3,latent_d_atac_pca_heldout_binID<7)]='c'

    ####DE heldout###
    deRes_shared_atac_centered_heldout=[None]*nPC_shared
    deRes_d_atac_centered_heldout=[None]*nPC_d
    nmco_ad_heldout=skin_atac_de.copy()[testIdx_all]

    for pc_i in range(nPC_shared):
        print(pc_i)
        nmco_ad_heldout.obs[str(pc_i)+'percentileGroups_shared']=groupID_shared_heldout[pc_i].astype(str)
        nmco_ad_heldout_subShared=nmco_ad_heldout[centerIdx_shared_all[pc_i]].copy()

        _groups=np.intersect1d(np.unique(nmco_ad_heldout_subShared.obs[str(pc_i)+'percentileGroups_shared']),['-1','1','c'])
        if 'c' in np.unique(nmco_ad_heldout_subShared.obs[str(pc_i)+'percentileGroups_shared']) and _groups.size>1:
            groups=[]
            for g in _groups:
                if np.sum(nmco_ad_heldout_subShared.obs[str(pc_i)+'percentileGroups_shared']==g)>1:
                    groups.append(g)
            scanpy.tl.rank_genes_groups(nmco_ad_heldout_subShared, str(pc_i)+'percentileGroups_shared', method='t-test',groups=groups,use_raw=False)
            deRes_shared_atac_centered_heldout[pc_i]={}
            deRes_shared_atac_centered_heldout[pc_i]['0']=scanpy.get.rank_genes_groups_df(nmco_ad_heldout_subShared,'c')

        if np.sum(nmco_ad_heldout_subShared.obs[str(pc_i)+'percentileGroups_shared']=='1')>1 and np.sum(nmco_ad_heldout_subShared.obs[str(pc_i)+'percentileGroups_shared']=='-1')>1:
            scanpy.tl.rank_genes_groups(nmco_ad_heldout_subShared, str(pc_i)+'percentileGroups_shared', method='t-test',groups=['-1','1'],use_raw=False)
            deRes_shared_atac_centered_heldout[pc_i]['1']=scanpy.get.rank_genes_groups_df(nmco_ad_heldout_subShared,'1')
            deRes_shared_atac_centered_heldout[pc_i]['-1']=scanpy.get.rank_genes_groups_df(nmco_ad_heldout_subShared,'-1')

    for pc_i in range(nPC_d):
        print(pc_i)
        nmco_ad_heldout.obs[str(pc_i)+'percentileGroups_d']=groupID_d_heldout[pc_i].astype(str)
        nmco_ad_heldout_subD=nmco_ad_heldout[centerIdx_d_all[pc_i]].copy()

        _groups=np.intersect1d(np.unique(nmco_ad_heldout_subD.obs[str(pc_i)+'percentileGroups_d']),['-1','1','c'])
        if 'c' in np.unique(nmco_ad_heldout_subD.obs[str(pc_i)+'percentileGroups_d']) and _groups.size>1:
            groups=[]
            for g in _groups:
                if np.sum(nmco_ad_heldout_subD.obs[str(pc_i)+'percentileGroups_d']==g)>1:
                    groups.append(g)
            scanpy.tl.rank_genes_groups(nmco_ad_heldout_subD, str(pc_i)+'percentileGroups_d', method='t-test',groups=groups,use_raw=False)
            deRes_d_atac_centered_heldout[pc_i]={}
            deRes_d_atac_centered_heldout[pc_i]['0']=scanpy.get.rank_genes_groups_df(nmco_ad_heldout_subD,'c')

        if np.sum(nmco_ad_heldout_subD.obs[str(pc_i)+'percentileGroups_d']=='1')>1 and np.sum(nmco_ad_heldout_subD.obs[str(pc_i)+'percentileGroups_d']=='-1')>1:
            scanpy.tl.rank_genes_groups(nmco_ad_heldout_subD, str(pc_i)+'percentileGroups_d', method='t-test',groups=['-1','1'],use_raw=False)
            deRes_d_atac_centered_heldout[pc_i]['1']=scanpy.get.rank_genes_groups_df(nmco_ad_heldout_subD,'1')
            deRes_d_atac_centered_heldout[pc_i]['-1']=scanpy.get.rank_genes_groups_df(nmco_ad_heldout_subD,'-1')

    deShared_centered_2bins_heldout_atac[h]=deRes_shared_atac_centered_heldout
    deD_centered_2bins_heldout_atac[h]=deRes_d_atac_centered_heldout

In [None]:

with open(os.path.join(plotsavepath_de,'de_shared_all_centered_2bins_atac'), 'wb') as output:
    pickle.dump(deShared_centered_2bins_atac,output,pickle.HIGHEST_PROTOCOL)
with open(os.path.join(plotsavepath_de,'de_d_all_centered_2bins_atac'), 'wb') as output:
    pickle.dump(deD_centered_2bins_atac,output,pickle.HIGHEST_PROTOCOL)
    
with open(os.path.join(plotsavepath_de,'de_shared_all_centered_2bins_heldout_atac'), 'wb') as output:
    pickle.dump(deShared_centered_2bins_heldout_atac,output,pickle.HIGHEST_PROTOCOL)
with open(os.path.join(plotsavepath_de,'de_d_all_centered_2bins_heldout_atac'), 'wb') as output:
    pickle.dump(deD_centered_2bins_heldout_atac,output,pickle.HIGHEST_PROTOCOL)


In [None]:
print('shared')
plt.bar(np.arange(50),pca_train_shared_atac.explained_variance_[:50])
plt.show()
print('d-specific')
plt.bar(np.arange(20),pca_train_d_atac.explained_variance_[:20])
plt.show()


In [None]:
np.sum(pca_train_shared_atac.explained_variance_ratio_[:29])

In [None]:
np.sum(pca_train_d_atac.explained_variance_ratio_[:13])