In [1]:
import sys
sys.path.append('/home/xinyiz/pamrats')

import time
import os

import scanpy
import numpy as np
import scipy.sparse as sp

import torch
from torch import optim

# from sklearn.metrics import roc_auc_score
# from sklearn.metrics import average_precision_score

import gae.gae.optimizer as optimizer
import gae.gae.model
import gae.gae.preprocessing as preprocessing

import pickle
import matplotlib.pyplot as plt
import seaborn as sns
import umap
import pandas as pd
from sklearn.preprocessing import scale
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler

import anndata as ad

In [2]:
ifplot=True
ifcluster=True

inverseAct='leakyRelu'
# inverseAct=None
plottype='umap'
pca=PCA()
minCells=15 #min number of cells for analysis
#umap/clustering parameters
clustermethod=['leiden','louvain','hierarchical']
n_neighbors=10
min_dist=0.25
n_pcs=40 #for clustering
# resolution=[0.5,1,1.5,2]
resolution=[0.1,0.2,0.3]
plotepoch=1000
savenameAdd=''

use_cuda=True
fastmode=False #Validate during training pass
seed=3
useSavedMaskedEdges=False
maskedgeName='knn20_connectivity'
hidden1=1024 #Number of units in hidden layer 1
hidden2=1024 #Number of units in hidden layer 2
# hidden3=16
fc_dim1=1024
# fc_dim2=2112
# fc_dim3=2112
# fc_dim4=2112
# gcn_dim1=2048

dropout=0.01
# randFeatureSubset=None
model_str='gcn_vae_xa_e2_d1_dca'
adj_decodeName=None #gala or None
plot_samples={'disease13':'AD_mouse9494','control13':'AD_mouse9498','disease8':'AD_mouse9723','control8':'AD_mouse9735'}
plot_sample_X=['logminmax']
# plot_sample_X=['corrected','scaled']
standardizeX=False
name='allk20XA_01_dca'
logsavepath='/mnt/xinyi/pamrats/log/train_gae_starmap/'+name
modelsavepath='/mnt/xinyi/pamrats/models/train_gae_starmap/'+name
plotsavepath='/mnt/xinyi/pamrats/plots/train_gae_starmap/'+name
    

In [3]:
# Set cuda and seed
np.random.seed(seed)
if use_cuda and (not torch.cuda.is_available()):
    print('cuda not available')
    use_cuda=False
torch.manual_seed(seed)
if use_cuda:
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.enabled = True


In [4]:
#Load data
savedir=os.path.join('/mnt/xinyi/','starmap')
adj_dir=os.path.join(savedir,'a')

featureslist={}
if plot_sample_X[0] in ['corrected','scaled']:
    scaleddata=scanpy.read_h5ad('/mnt/xinyi/2021-01-13-mAD-test-dataset/2020-12-27-starmap-mAD-scaled.h5ad')
    for s in plot_samples.keys():
        featureslist[s+'X_'+'corrected']=torch.tensor(scaleddata.layers['corrected'][scaleddata.obs['sample']==plot_samples[s]])
        featureslist[s+'X_'+'scaled']=torch.tensor(scaleddata.layers['scaled'][scaleddata.obs['sample']==plot_samples[s]])
    
else:
    scaleddata=scanpy.read_h5ad('/mnt/xinyi/2021-01-13-mAD-test-dataset/2020-12-27-starmap-mAD-raw.h5ad')
    
    for s in plot_samples.keys():
        scaleddata_train=scaleddata.X[scaleddata.obs['sample']==plot_samples[s]]

        if plot_sample_X[0]=='logminmax':
            featurelog_train=np.log2(scaleddata_train+1/2)
            scaler = MinMaxScaler()
            featurelog_train_minmax=np.transpose(scaler.fit_transform(np.transpose(featurelog_train)))
            featureslist[s+'X_'+plot_sample_X[0]]=torch.tensor(featurelog_train_minmax)


adj_list={}
adj_list['disease13']=sp.load_npz(os.path.join(adj_dir,maskedgeName+'_AD_mouse9494.npz'))
adj_list['control13']=sp.load_npz(os.path.join(adj_dir,maskedgeName+'_AD_mouse9498.npz'))
adj_list['disease8']=sp.load_npz(os.path.join(adj_dir,maskedgeName+'_AD_mouse9723.npz'))
adj_list['control8']=sp.load_npz(os.path.join(adj_dir,maskedgeName+'_AD_mouse9735.npz'))

In [5]:
# load model
num_nodes,num_features = list(featureslist.values())[0].shape
if model_str=='gcn_vae_xa':
    model  = gae.gae.model.GCNModelVAE_XA(num_features, hidden1, hidden2,fc_dim1,fc_dim2,fc_dim3,fc_dim4, dropout)
elif model_str=='fc1':
    model  = gae.gae.model.FCVAE1(num_features, hidden1,dropout)
elif model_str == 'gcn_vae_xa_e2_d1':
    model  = gae.gae.model.GCNModelVAE_XA_e2_d1(num_features, hidden1,hidden2, dropout)
elif model_str == 'gcn_vae_gcnX_inprA':
    model = gae.gae.model.GCNModelVAE_gcnX_inprA(num_features, hidden1, hidden2,gcn_dim1, dropout)
elif model_str=='fc1_dca':
    model = gae.gae.model.FCVAE1_DCA(num_features, hidden1,fc_dim1, dropout)
elif model_str=='gcn_vae_xa_e2_d1_dca':
    model = gae.gae.model.GCNModelVAE_XA_e2_d1_DCA(num_features, hidden1,hidden2,fc_dim1, dropout)
elif model_str=='gcn_vae_xa_e2_d1_dcaFork':
    model = gae.gae.model.GCNModelVAE_XA_e2_d1_DCAfork(num_features, hidden1,hidden2,fc_dim1, dropout)
elif model_str=='gcn_vae_xa_e2_d1_dcaElemPi':
    model = gae.gae.model.GCNModelVAE_XA_e2_d1_DCAelemPi(num_features, hidden1,hidden2,fc_dim1, dropout,shareGenePi)
elif model_str=='gcn_vae_xa_e2_d1_dcaConstantDisp':
    model = gae.gae.model.GCNModelVAE_XA_e2_d1_DCA_constantDisp(num_features, hidden1,hidden2,fc_dim1, dropout,shareGenePi)
else:
    print('model not found')
model.load_state_dict(torch.load(os.path.join(modelsavepath,str(plotepoch)+'.pt')))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [6]:
def plotembeddingbyCT(ctlist,savename,excludelist,embedding,savepath,plotname,plotdimx=0,plotdimy=1,savenameAdd=''):
    celltypes=np.unique(ctlist)
    celltypes_dict={}
    idx=0
    for ct in celltypes:
        celltypes_dict[ct]=idx
        idx+=1
        
    colortest=sns.color_palette("husl", celltypes.size)
    np.random.shuffle(colortest)
    fig, ax = plt.subplots(dpi=400)
    for ct in celltypes:
        if ct in excludelist:
            continue
        idx=(ctlist==ct)
        ax.scatter(
            embedding[idx, plotdimx],
            embedding[idx, plotdimy],
            color=colortest[celltypes_dict[ct]],label=ct,s=3,alpha=0.5
            )

    plt.gca().set_aspect('equal', 'datalim')
    fig.set_figheight(10)
    fig.set_figwidth(10)
    box = ax.get_position()
    ax.set_position([box.x0, box.y0 + box.height * 0.1,
                     box.width, box.height * 0.9])
    # Put a legend below current axis
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
              fancybox=True, shadow=True, ncol=5)
#     ax.legend(ncol=3)
    plt.title(plotname+' embedding', fontsize=24)
    plt.savefig(os.path.join(savepath,savename+savenameAdd+'.jpg'))
#     plt.show()
    plt.close()

In [7]:
def plotembeddingbyCT_contrast(ctlist,savename,excludelist,embedding,savepath,plotname,plotdimx=0,plotdimy=1,savenameAdd='',maxplot=None): 
    celltypes=np.unique(ctlist)
    celltypes_dict={}
    idx=0
    for ct in celltypes:
        celltypes_dict[ct]=idx
        idx+=1

    colortest=sns.color_palette("tab10")
    if not os.path.exists(os.path.join(savepath)):
        os.makedirs(savepath)

    for ct in celltypes:
        if maxplot and int(ct)>maxplot:
            continue
        fig, ax = plt.subplots()
        if ct == 'Unassigned':
            continue

        idx=(ctlist!=ct)
        ax.scatter(
            embedding[idx, plotdimx],
            embedding[idx, plotdimy],
            color=colortest[1],label='others',s=1,alpha=0.5
            )

        idx=(ctlist==ct)
        ax.scatter(
            embedding[idx, plotdimx],
            embedding[idx, plotdimy],
            color=colortest[0],label=ct,s=3,alpha=0.5
            )

        plt.gca().set_aspect('equal', 'datalim')
        fig.set_figheight(10)
        fig.set_figwidth(10)
        ax.legend()
        plt.title(plotname+' embedding', fontsize=24)
        plt.gcf().savefig(os.path.join(savepath,savename+'_'+ct+savenameAdd+'.jpg'))
#         plt.show()
#         nplot+=1
        
        plt.close()

In [8]:
def inverseLeakyRelu(v,slope=0.01):
    vnegidx=(v<0)
    v[vnegidx]=1/slope*v[vnegidx]
    return v

In [9]:
def clusterLeiden(inArray,n_neighbors,n_pcs,min_dist,resolution,randseed=seed):
    n_pcs=np.min([inArray.shape[0]-1,inArray.shape[1]-1,n_pcs])
    adata=ad.AnnData(inArray)
    scanpy.tl.pca(adata, svd_solver='arpack')
    scanpy.pp.neighbors(adata, n_neighbors=n_neighbors, n_pcs=n_pcs)
    scanpy.tl.umap(adata,min_dist=min_dist,random_state=randseed)
    scanpy.tl.leiden(adata,resolution=resolution,random_state=randseed)
    return adata.obs['leiden'].to_numpy()

In [10]:
#compute embeddings
mulist={}
for s in plot_samples.keys():
    adj=adj_list[s]
    adj_norm = preprocessing.preprocess_graph(adj)
    adj_decode=None
    if adj_decodeName == 'gala':
        adj_decode=preprocessing.preprocess_graph_sharp(adj)
    for xcorr in plot_sample_X:
        samplename=s+'X_'+xcorr
        features=featureslist[samplename]
        if standardizeX:
            features=torch.tensor(scale(features,axis=0, with_mean=True, with_std=True, copy=True))
        if use_cuda:
            model.cuda()
            features = features.cuda().float()
            adj_norm=adj_norm.cuda()
            if adj_decodeName:
                adj_decode=adj_decode.cuda()
        
        model.eval()
        if adj_decodeName==None:
            adj_recon,mu,logvar,z, features_recon = model(features, adj_norm)
        else:
            adj_recon,mu,logvar,z, features_recon = model(features, adj_norm,adj_decode)
        if inverseAct=='leakyRelu':
            muplot=inverseLeakyRelu(mu.cpu().detach().numpy())
        else:
            muplot=mu.cpu().detach().numpy()
        mulist[samplename]=muplot

In [None]:
#all cells
for s in plot_samples.keys():
    sampleidx=plot_samples[s]
    
    celltype_broad=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'top_level']
    celltype_sub=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'cell_type_label']
    region=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'region']
    sobj_coord_np=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,['x','y']].to_numpy()
    for xcorr in plot_sample_X:
        samplename=s+'X_'+xcorr
        muplot=mulist[samplename]
        
        if inverseAct:
            samplename+='_beforeAct'
        sampledir=os.path.join(plotsavepath,samplename)
        savedir=os.path.join(plotsavepath,samplename,'embedding_'+plottype)
        clustersavedir=os.path.join(plotsavepath,samplename,'cluster')
        if not os.path.exists(sampledir):
            os.mkdir(sampledir)
        if not os.path.exists(savedir):
            os.mkdir(savedir)
        if not os.path.exists(clustersavedir):
            os.mkdir(clustersavedir)
            
        if plottype=='umap':
            reducer = umap.UMAP(n_neighbors=n_neighbors,min_dist=min_dist,random_state=seed)
            embedding = reducer.fit_transform(muplot)
            savenameAdd='_nn'+str(n_neighbors)+'mdist0'+str(int(min_dist*100))+'epoch'+str(plotepoch)
        elif plottype=='pca':
            embedding=pca.fit_transform(muplot)
            savenameAdd='_epoch'+str(plotepoch)
        
        plotembeddingbyCT(celltype_broad,'celltype_broad',[],embedding,savedir,plottype+' of '+s,savenameAdd=savenameAdd)
        plotembeddingbyCT(celltype_sub,'celltype_sub',[],embedding,savedir,plottype+' of '+s,savenameAdd=savenameAdd)
        plotembeddingbyCT(region,'region',[],embedding,savedir,plottype+' of '+s,savenameAdd=savenameAdd)
        
        plotembeddingbyCT_contrast(celltype_sub,'celltype_sub',[],embedding,os.path.join(savedir,'contrast'),plottype+' of '+s,savenameAdd=savenameAdd)
        
        if embedding.shape[0]<minCells:
            continue
        for r in resolution:
            clusterRes=clusterLeiden(muplot,n_neighbors,n_pcs,min_dist,r,randseed=seed)
            savenamecluster='leiden_nn'+str(n_neighbors)+'mdist0'+str(int(min_dist*100))+'n_pcs'+str(n_pcs)+'res'+str(r)+'epoch'+str(plotepoch)
            with open(os.path.join(clustersavedir,savenamecluster), 'wb') as output:
                pickle.dump(clusterRes, output, pickle.HIGHEST_PROTOCOL)
            plotembeddingbyCT(clusterRes,'leiden',[],embedding,savedir,plottype+' of '+s,savenameAdd=savenamecluster)
            plotembeddingbyCT_contrast(clusterRes,'leiden',[],embedding,os.path.join(savedir,'contrast'),plottype+' of '+s,savenameAdd=savenamecluster)
            
            plotembeddingbyCT(clusterRes,'leiden_location',[],sobj_coord_np,savedir,'location'+' of '+s,savenameAdd=savenamecluster)
            plotembeddingbyCT_contrast(clusterRes,'leiden_location',[],sobj_coord_np,os.path.join(savedir,'contrast'),'location'+' of '+s,savenameAdd=savenamecluster)


In [None]:
# separate plots by cell type
for s in plot_samples.keys():
#     if s =='disease13':
#         continue
    sampleidx=plot_samples[s]
    celltype_broad=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'top_level']
    celltype_sub=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'cell_type_label']
    region=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'region']
    sobj_coord_np=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,['x','y']].to_numpy()
    for xcorr in plot_sample_X:
        samplename=s+'X_'+xcorr
        muplot=mulist[samplename]
        
        if inverseAct:
            samplename+='_beforeAct'
        sampledir=os.path.join(plotsavepath,samplename)
        if not os.path.exists(sampledir):
            os.mkdir(sampledir)
        
        for ct in np.unique(celltype_broad):
            savedir=os.path.join(plotsavepath,samplename,'embedding_'+plottype+'_'+ct)
            clustersavedir=os.path.join(plotsavepath,samplename,'cluster'+'_'+ct)
            if not os.path.exists(savedir):
                os.mkdir(savedir)
            if not os.path.exists(clustersavedir):
                os.mkdir(clustersavedir)
            
            ct_idx=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'top_level']==ct
            
            if plottype=='umap':
                reducer = umap.UMAP(n_neighbors=n_neighbors,min_dist=min_dist,random_state=seed)
                embedding = reducer.fit_transform(muplot[ct_idx])
                savenameAdd='_nn'+str(n_neighbors)+'mdist0'+str(int(min_dist*100))+'epoch'+str(plotepoch)
            elif plottype=='pca':
                embedding=pca.fit_transform(muplot[ct_idx])
                savenameAdd='_epoch'+str(plotepoch)
                
            
#             plotembeddingbyCT(celltype_broad[reg_idx],'celltype_broad',[],embedding,savedir,plottype+' of '+s+' '+reg)
            plotembeddingbyCT(celltype_sub[ct_idx],'celltype_sub',[],embedding,savedir,plottype+' of '+s+' '+ct,savenameAdd=savenameAdd)
            plotembeddingbyCT(region[ct_idx],'region',[],embedding,savedir,plottype+' of '+s+' '+ct,savenameAdd=savenameAdd)

            plotembeddingbyCT_contrast(celltype_sub[ct_idx],'celltype_sub',[],embedding,os.path.join(savedir,'contrast'),plottype+' of '+s+' '+ct,savenameAdd=savenameAdd)
            
            if embedding.shape[0]<minCells:
                continue
            for r in resolution:
                clusterRes=clusterLeiden(muplot[ct_idx],n_neighbors,n_pcs,min_dist,r,randseed=seed)
                savenamecluster='leiden_nn'+str(n_neighbors)+'mdist0'+str(int(min_dist*100))+'n_pcs'+str(n_pcs)+'res'+str(r)+'epoch'+str(plotepoch)
                with open(os.path.join(clustersavedir,savenamecluster), 'wb') as output:
                    pickle.dump(clusterRes, output, pickle.HIGHEST_PROTOCOL)
                plotembeddingbyCT(clusterRes,'leiden',[],embedding,savedir,plottype+' of '+s+' '+ct,savenameAdd=savenamecluster)
                plotembeddingbyCT_contrast(clusterRes,'leiden',[],embedding,os.path.join(savedir,'contrast'),plottype+' of '+s+' '+ct,savenameAdd=savenamecluster,maxplot=50)

                plotembeddingbyCT(clusterRes,'leiden_location',[],sobj_coord_np[ct_idx],savedir,'location'+' of '+s+' '+ct,savenameAdd=savenamecluster)
                plotembeddingbyCT_contrast(clusterRes,'leiden_location',[],sobj_coord_np[ct_idx],os.path.join(savedir,'contrast'),'location'+' of '+s+' '+ct,savenameAdd=savenamecluster,maxplot=50)


In [None]:
# separate plots by region
for s in plot_samples.keys():
    sampleidx=plot_samples[s]
    celltype_broad=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'top_level']
    celltype_sub=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'cell_type_label']
    region=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'region']
    sobj_coord_np=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,['x','y']].to_numpy()
    for xcorr in plot_sample_X:
        samplename=s+'X_'+xcorr
        muplot=mulist[samplename]
        
        if inverseAct:
            samplename+='_beforeAct'
        sampledir=os.path.join(plotsavepath,samplename)
        if not os.path.exists(sampledir):
            os.mkdir(sampledir)
        
        for reg in np.unique(region):
            savedir=os.path.join(plotsavepath,samplename,'embedding_'+plottype+'_'+reg)
            clustersavedir=os.path.join(plotsavepath,samplename,'cluster'+'_'+reg)
            if not os.path.exists(savedir):
                os.mkdir(savedir)
            if not os.path.exists(clustersavedir):
                os.mkdir(clustersavedir)
            
            reg_idx=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'region']==reg
            
            if plottype=='umap':
                reducer = umap.UMAP(n_neighbors=n_neighbors,min_dist=min_dist,random_state=seed)
                embedding = reducer.fit_transform(muplot[reg_idx])
                savenameAdd='_nn'+str(n_neighbors)+'mdist0'+str(int(min_dist*100))+'epoch'+str(plotepoch)
            elif plottype=='pca':
                embedding=pca.fit_transform(muplot[reg_idx])
                savenameAdd='_epoch'+str(plotepoch)
                
            
            plotembeddingbyCT(celltype_broad[reg_idx],'celltype_broad',[],embedding,savedir,plottype+' of '+s+' '+reg,savenameAdd=savenameAdd)
            plotembeddingbyCT(celltype_sub[reg_idx],'celltype_sub',[],embedding,savedir,plottype+' of '+s+' '+reg,savenameAdd=savenameAdd)
#             plotembeddingbyCT(region,'region',[],embedding[reg_idx],savedir,'UMAP of '+s)

            plotembeddingbyCT_contrast(celltype_sub[reg_idx],'celltype_sub',[],embedding,os.path.join(savedir,'contrast'),plottype+' of '+s+' '+reg,savenameAdd=savenameAdd)
            
            if embedding.shape[0]<minCells:
                continue
            for r in resolution:
                clusterRes=clusterLeiden(muplot[reg_idx],n_neighbors,n_pcs,min_dist,r,randseed=seed)
                savenamecluster='leiden_nn'+str(n_neighbors)+'mdist0'+str(int(min_dist*100))+'n_pcs'+str(n_pcs)+'res'+str(r)+'epoch'+str(plotepoch)
                with open(os.path.join(clustersavedir,savenamecluster), 'wb') as output:
                    pickle.dump(clusterRes, output, pickle.HIGHEST_PROTOCOL)
                plotembeddingbyCT(clusterRes,'leiden',[],embedding,savedir,plottype+' of '+s+' '+reg,savenameAdd=savenamecluster)
                plotembeddingbyCT_contrast(clusterRes,'leiden',[],embedding,os.path.join(savedir,'contrast'),plottype+' of '+s+' '+reg,savenameAdd=savenamecluster,maxplot=50)
                
                plotembeddingbyCT(clusterRes,'leiden_location',[],sobj_coord_np[reg_idx],savedir,'location'+' of '+s+' '+reg,savenameAdd=savenamecluster)
                plotembeddingbyCT_contrast(clusterRes,'leiden_location',[],sobj_coord_np[reg_idx],os.path.join(savedir,'contrast'),'location'+' of '+s+' '+reg,savenameAdd=savenamecluster,maxplot=50)

In [None]:
# combine all latents to one plot 
for xcorr in plot_sample_X:
    latents=None
    celltype_broad=None
    celltype_sub=None
    region=None
    samplenameList=None
    sobj_coord_np=None
    
    for s in plot_samples.keys():
        sampleidx=plot_samples[s]        
        samplename=s+'X_'+xcorr
        muplot=mulist[samplename]
            
        if latents is None:
            latents=muplot
            celltype_broad=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'top_level']
            celltype_sub=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'cell_type_label']
            region=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'region']
            sobj_coord_np=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,['x','y']].to_numpy()
            samplenameList=np.repeat(s,muplot.shape[0])
        else:
            latents=np.vstack((latents,muplot))
            celltype_broad=np.concatenate((celltype_broad,scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'top_level']),axis=None)
            celltype_sub=np.concatenate((celltype_sub,scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'cell_type_label']),axis=None)
            region=np.concatenate((region,scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'region']),axis=None)
            sobj_coord_np=np.concatenate((sobj_coord_np,scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,['x','y']].to_numpy()),axis=0)
            samplenameList=np.concatenate((samplenameList,np.repeat(s,muplot.shape[0])),axis=None)
        
    sampledir=os.path.join(plotsavepath,'combined'+xcorr)
    if inverseAct:
        sampledir+='_beforeAct'
    savedir=os.path.join(sampledir,'embedding_'+plottype)
    clustersavedir=os.path.join(sampledir,'cluster')
    if not os.path.exists(sampledir):
        os.mkdir(sampledir)
    if not os.path.exists(savedir):
        os.mkdir(savedir)
    if not os.path.exists(clustersavedir):
        os.mkdir(clustersavedir)
    
    if plottype=='umap':
        reducer = umap.UMAP(n_neighbors=n_neighbors,min_dist=min_dist,random_state=seed)
        embedding = reducer.fit_transform(latents)
        savenameAdd='_nn'+str(n_neighbors)+'mdist0'+str(int(min_dist*100))+'epoch'+str(plotepoch)
    elif plottype=='pca':
        embedding=pca.fit_transform(latents)
        savenameAdd='_epoch'+str(plotepoch)
        
    plotembeddingbyCT(samplenameList,'sample',[],embedding,savedir,plottype+'of all samples',savenameAdd=savenameAdd)
    plotembeddingbyCT(celltype_broad,'celltype_broad',[],embedding,savedir,plottype+'all samples',savenameAdd=savenameAdd)
    plotembeddingbyCT(celltype_sub,'celltype_sub',[],embedding,savedir,plottype+'all samples',savenameAdd=savenameAdd)
    plotembeddingbyCT(region,'region',[],embedding,savedir,plottype+'all samples',savenameAdd=savenameAdd)

    plotembeddingbyCT_contrast(celltype_sub,'celltype_sub',[],embedding,os.path.join(savedir,'contrast'),plottype+'all samples',savenameAdd=savenameAdd)    
    
    if embedding.shape[0]<minCells:
        continue
    for r in resolution:
        clusterRes=clusterLeiden(latents,n_neighbors,n_pcs,min_dist,r,randseed=seed)
        savenamecluster='leiden_nn'+str(n_neighbors)+'mdist0'+str(int(min_dist*100))+'n_pcs'+str(n_pcs)+'res'+str(r)+'epoch'+str(plotepoch)
        with open(os.path.join(clustersavedir,savenamecluster), 'wb') as output:
            pickle.dump(clusterRes, output, pickle.HIGHEST_PROTOCOL)
        plotembeddingbyCT(clusterRes,'leiden',[],embedding,savedir,plottype+' of all samples',savenameAdd=savenamecluster)
        plotembeddingbyCT_contrast(clusterRes,'leiden',[],embedding,os.path.join(savedir,'contrast'),plottype+' of all samples',savenameAdd=savenamecluster,maxplot=50)

        for s in plot_samples.keys():
            sidx=(samplenameList==s)
            plotembeddingbyCT(clusterRes[sidx],'leiden_location'+s,[],sobj_coord_np[sidx],savedir,'location'+' of '+s,savenameAdd=savenamecluster)
            plotembeddingbyCT_contrast(clusterRes[sidx],'leiden_location'+s,[],sobj_coord_np[sidx],os.path.join(savedir,'contrast'),'location'+' of '+s,savenameAdd=savenamecluster,maxplot=50)

    #by region
    for reg in np.unique(region):
        savedir=os.path.join(sampledir,'embedding_'+plottype+'_'+reg)
        clustersavedir=os.path.join(sampledir,'cluster'+'_'+reg)
        if not os.path.exists(savedir):
            os.mkdir(savedir)
        if not os.path.exists(clustersavedir):
            os.mkdir(clustersavedir)

        reg_idx=region==reg

        if plottype=='umap':
            reducer = umap.UMAP(n_neighbors=n_neighbors,min_dist=min_dist,random_state=seed)
            embedding = reducer.fit_transform(latents[reg_idx])
            savenameAdd='_nn'+str(n_neighbors)+'mdist0'+str(int(min_dist*100))+'epoch'+str(plotepoch)
        elif plottype=='pca':
            embedding=pca.fit_transform(latents[reg_idx])
            savenameAdd='_epoch'+str(plotepoch)
        
        plotembeddingbyCT(samplenameList[reg_idx],'sample',[],embedding,savedir,plottype+' of '+'all samples'+' '+reg,savenameAdd=savenameAdd)
        plotembeddingbyCT(celltype_broad[reg_idx],'celltype_broad',[],embedding,savedir,plottype+' of '+'all samples'+' '+reg,savenameAdd=savenameAdd)
        plotembeddingbyCT(celltype_sub[reg_idx],'celltype_sub',[],embedding,savedir,plottype+' of '+'all samples'+' '+reg,savenameAdd=savenameAdd)
#             plotembeddingbyCT(region,'region',[],embedding[reg_idx],savedir,'UMAP of '+s)

        plotembeddingbyCT_contrast(celltype_sub[reg_idx],'celltype_sub',[],embedding,os.path.join(savedir,'contrast'),plottype+' of '+'all samples'+' '+reg,savenameAdd=savenameAdd)
        
        if embedding.shape[0]<minCells:
            continue
        for r in resolution:
            clusterRes=clusterLeiden(latents[reg_idx],n_neighbors,n_pcs,min_dist,r,randseed=seed)
            savenamecluster='leiden_nn'+str(n_neighbors)+'mdist0'+str(int(min_dist*100))+'n_pcs'+str(n_pcs)+'res'+str(r)+'epoch'+str(plotepoch)
            with open(os.path.join(clustersavedir,savenamecluster), 'wb') as output:
                pickle.dump(clusterRes, output, pickle.HIGHEST_PROTOCOL)
            plotembeddingbyCT(clusterRes,'leiden',[],embedding,savedir,plottype+' of '+'all samples'+' '+reg,savenameAdd=savenamecluster)
            plotembeddingbyCT_contrast(clusterRes,'leiden',[],embedding,os.path.join(savedir,'contrast'),plottype+' of '+'all samples'+' '+reg,savenameAdd=savenamecluster,maxplot=50)

            for s in plot_samples.keys():
                sidx=(samplenameList==s)
                plotembeddingbyCT(clusterRes[sidx[reg_idx]],'leiden_location'+s,[],sobj_coord_np[np.logical_and(sidx,reg_idx)],savedir,'location'+' of '+s+' '+reg,savenameAdd=savenamecluster)
                plotembeddingbyCT_contrast(clusterRes[sidx[reg_idx]],'leiden_location'+s,[],sobj_coord_np[np.logical_and(sidx,reg_idx)],os.path.join(savedir,'contrast'),'location'+' of '+s+' '+reg,savenameAdd=savenamecluster,maxplot=50)

        #by region and celltype
        for ct in np.unique(celltype_broad):
            savedir=os.path.join(sampledir,'embedding_'+plottype+'_'+reg+ct)
            clustersavedir=os.path.join(sampledir,'cluster'+'_'+reg+ct)
            if not os.path.exists(savedir):
                os.mkdir(savedir)
            if not os.path.exists(clustersavedir):
                os.mkdir(clustersavedir)

            ct_idx=celltype_broad==ct
            ct_idx=np.logical_and(reg_idx,ct_idx)
            if np.sum(ct_idx)<3:
                continue
            if plottype=='umap':
                reducer = umap.UMAP(n_neighbors=n_neighbors,min_dist=min_dist,random_state=seed)
                embedding = reducer.fit_transform(latents[ct_idx])
                savenameAdd='_nn'+str(n_neighbors)+'mdist0'+str(int(min_dist*100))+'epoch'+str(plotepoch)
            elif plottype=='pca':
                embedding=pca.fit_transform(latents[ct_idx])
                savenameAdd='_epoch'+str(plotepoch)

            plotembeddingbyCT(samplenameList[ct_idx],'sample',[],embedding,savedir,plottype+' of '+reg+' all samples'+' '+ct,savenameAdd=savenameAdd)
    #         plotembeddingbyCT(celltype_broad[reg_idx],'celltype_broad',[],embedding,savedir,plottype+' of '+'all samples'+' '+reg)
            plotembeddingbyCT(celltype_sub[ct_idx],'celltype_sub',[],embedding,savedir,plottype+' of '+reg+' all samples'+' '+ct,savenameAdd=savenameAdd)
#             plotembeddingbyCT(region[ct_idx],'region',[],embedding,savedir,plottype+' of '+reg+' all samples'+' '+ct)

            plotembeddingbyCT_contrast(celltype_sub[ct_idx],'celltype_sub',[],embedding,os.path.join(savedir,'contrast'),plottype+' of '+reg+' all samples'+' '+ct,savenameAdd=savenameAdd)
        
            if embedding.shape[0]<minCells:
                continue
            for r in resolution:
                clusterRes=clusterLeiden(latents[ct_idx],n_neighbors,n_pcs,min_dist,r,randseed=seed)
                savenamecluster='leiden_nn'+str(n_neighbors)+'mdist0'+str(int(min_dist*100))+'n_pcs'+str(n_pcs)+'res'+str(r)+'epoch'+str(plotepoch)
                with open(os.path.join(clustersavedir,savenamecluster), 'wb') as output:
                    pickle.dump(clusterRes, output, pickle.HIGHEST_PROTOCOL)
                plotembeddingbyCT(clusterRes,'leiden',[],embedding,savedir,plottype+' of '+reg+' all samples'+' '+ct,savenameAdd=savenamecluster)
                plotembeddingbyCT_contrast(clusterRes,'leiden',[],embedding,os.path.join(savedir,'contrast'),plottype+' of '+reg+' all samples'+' '+ct,savenameAdd=savenamecluster,maxplot=50)
                
                for s in plot_samples.keys():
                    sidx=(samplenameList==s)
                    plotembeddingbyCT(clusterRes[sidx[ct_idx]],'leiden_location'+s,[],sobj_coord_np[np.logical_and(sidx,ct_idx)],savedir,'location'+' of '+reg+' '+s+' '+ct,savenameAdd=savenamecluster)
                    plotembeddingbyCT_contrast(clusterRes[sidx[ct_idx]],'leiden_location'+s,[],sobj_coord_np[np.logical_and(sidx,ct_idx)],os.path.join(savedir,'contrast'),'location'+' of '+reg+' '+s+' '+ct,savenameAdd=savenamecluster,maxplot=50)
                
    #by celltype
    for ct in np.unique(celltype_broad):
        savedir=os.path.join(sampledir,'embedding_'+plottype+'_'+ct)
        clustersavedir=os.path.join(sampledir,'cluster'+'_'+ct)
        if not os.path.exists(savedir):
            os.mkdir(savedir)
        if not os.path.exists(clustersavedir):
            os.mkdir(clustersavedir)

        ct_idx=celltype_broad==ct

        if plottype=='umap':
            reducer = umap.UMAP(n_neighbors=n_neighbors,min_dist=min_dist,random_state=seed)
            embedding = reducer.fit_transform(latents[ct_idx])
            savenameAdd='_nn'+str(n_neighbors)+'mdist0'+str(int(min_dist*100))+'epoch'+str(plotepoch)
        elif plottype=='pca':
            embedding=pca.fit_transform(latents[ct_idx])
            savenameAdd='_epoch'+str(plotepoch)
        
        plotembeddingbyCT(samplenameList[ct_idx],'sample',[],embedding,savedir,plottype+' of '+'all samples'+' '+ct,savenameAdd=savenameAdd)
#         plotembeddingbyCT(celltype_broad[reg_idx],'celltype_broad',[],embedding,savedir,plottype+' of '+'all samples'+' '+reg)
        plotembeddingbyCT(celltype_sub[ct_idx],'celltype_sub',[],embedding,savedir,plottype+' of '+'all samples'+' '+ct,savenameAdd=savenameAdd)
        plotembeddingbyCT(region[ct_idx],'region',[],embedding,savedir,plottype+' of '+'all samples'+' '+ct,savenameAdd=savenameAdd)

        plotembeddingbyCT_contrast(celltype_sub[ct_idx],'celltype_sub',[],embedding,os.path.join(savedir,'contrast'),plottype+' of '+'all samples'+' '+ct,savenameAdd=savenameAdd)
        
        if embedding.shape[0]<minCells:
            continue
        for r in resolution:
            clusterRes=clusterLeiden(latents[ct_idx],n_neighbors,n_pcs,min_dist,r,randseed=seed)
            savenamecluster='leiden_nn'+str(n_neighbors)+'mdist0'+str(int(min_dist*100))+'n_pcs'+str(n_pcs)+'res'+str(r)+'epoch'+str(plotepoch)
            with open(os.path.join(clustersavedir,savenamecluster), 'wb') as output:
                pickle.dump(clusterRes, output, pickle.HIGHEST_PROTOCOL)
            plotembeddingbyCT(clusterRes,'leiden',[],embedding,savedir,plottype+' of '+'all samples'+' '+ct,savenameAdd=savenamecluster)
            plotembeddingbyCT_contrast(clusterRes,'leiden',[],embedding,os.path.join(savedir,'contrast'),plottype+' of '+'all samples'+' '+ct,savenameAdd=savenamecluster,maxplot=50)
            
            for s in plot_samples.keys():
                sidx=(samplenameList==s)
                plotembeddingbyCT(clusterRes[sidx[ct_idx]],'leiden_location'+s,[],sobj_coord_np[np.logical_and(sidx,ct_idx)],savedir,'location'+' of '+s+' '+ct,savenameAdd=savenamecluster)
                plotembeddingbyCT_contrast(clusterRes[sidx[ct_idx]],'leiden_location'+s,[],sobj_coord_np[np.logical_and(sidx,ct_idx)],os.path.join(savedir,'contrast'),'location'+' of '+s+' '+ct,savenameAdd=savenamecluster,maxplot=50)

In [None]:
# separate plots by region and cell types
for s in plot_samples.keys():
    sampleidx=plot_samples[s]
    celltype_broad=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'top_level']
    celltype_sub=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'cell_type_label']
    region=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'region']
    sobj_coord_np=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,['x','y']].to_numpy()
    for xcorr in plot_sample_X:
        samplename=s+'X_'+xcorr
        muplot=mulist[samplename]

        if inverseAct:
            samplename+='_beforeAct'
        sampledir=os.path.join(plotsavepath,samplename)
        if not os.path.exists(sampledir):
            os.mkdir(sampledir)
            
        for r in np.unique(region):
            ridx=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'region']==r
            for reg in np.unique(celltype_broad):
                savedir=os.path.join(plotsavepath,samplename,'embedding_'+plottype+'_'+reg)
                clustersavedir=os.path.join(plotsavepath,samplename,'cluster'+'_'+reg)
                if not os.path.exists(savedir):
                    os.mkdir(savedir)
                if not os.path.exists(clustersavedir):
                    os.mkdir(clustersavedir)

                ct_idx=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'top_level']==reg
                reg_idx=np.logical_and(ridx,ct_idx)
                if np.sum(reg_idx)<3:
                    continue
                
                if plottype=='umap':
                    reducer = umap.UMAP(n_neighbors=n_neighbors,min_dist=min_dist,random_state=seed)
                    embedding = reducer.fit_transform(muplot[reg_idx])
                    savenameAdd='_nn'+str(n_neighbors)+'mdist0'+str(int(min_dist*100))+'epoch'+str(plotepoch)
                elif plottype=='pca':
                    embedding=pca.fit_transform(muplot[reg_idx])
                    savenameAdd='_epoch'+str(plotepoch)

#                 plotembeddingbyCT(celltype_broad[reg_idx],'celltype_broad_'+r,[],embedding,savedir,plottype+' of '+r+' '+s+' '+reg)
                plotembeddingbyCT(celltype_sub[reg_idx],'celltype_sub_'+r,[],embedding,savedir,plottype+' of '+r+' '+s+' '+reg,savenameAdd=savenameAdd)
    #             plotembeddingbyCT(region,'region',[],embedding[reg_idx],savedir,s)

                plotembeddingbyCT_contrast(celltype_sub[reg_idx],'celltype_sub_'+r,[],embedding,os.path.join(savedir,'contrast'),plottype+' of '+r+' '+s+' '+reg,savenameAdd=savenameAdd)
        
                if embedding.shape[0]<minCells:
                    continue
                for res in resolution:
                    clusterRes=clusterLeiden(muplot[reg_idx],n_neighbors,n_pcs,min_dist,res,randseed=seed)
                    savenamecluster='leiden_nn'+str(n_neighbors)+'mdist0'+str(int(min_dist*100))+'n_pcs'+str(n_pcs)+'res'+str(res)+'epoch'+str(plotepoch)
                    with open(os.path.join(clustersavedir,savenamecluster), 'wb') as output:
                        pickle.dump(clusterRes, output, pickle.HIGHEST_PROTOCOL)
                    plotembeddingbyCT(clusterRes,'leiden',[],embedding,savedir,plottype+' of '+r+' '+s+' '+reg,savenameAdd=savenamecluster)
                    plotembeddingbyCT_contrast(clusterRes,'leiden',[],embedding,os.path.join(savedir,'contrast'),plottype+' of '+r+' '+s+' '+reg,savenameAdd=savenamecluster,maxplot=50)
                    
                    plotembeddingbyCT(clusterRes,'leiden_location',[],sobj_coord_np[reg_idx],savedir,'location'+' of '+r+' '+s+' '+reg,savenameAdd=savenamecluster)
                    plotembeddingbyCT_contrast(clusterRes,'leiden_location',[],sobj_coord_np[reg_idx],os.path.join(savedir,'contrast'),'location'+' of '+r+' '+s+' '+reg,savenameAdd=savenamecluster,maxplot=50)

In [70]:
latents[ct_idx].shape

(19, 1024)

In [None]:
int(clusterRes[0])