In [15]:
import sys
sys.path.append('/home/xinyiz/pamrats')
sys.path.append('/home/xinyi/anaconda3/envs/pytorch3/lib/python3.10/site-packages/')

import time
import os

import scanpy 
import numpy as np
import scipy.sparse as sp

import torch
from torch import optim

import gae.gae.optimizer as optimizer
import gae.gae.model
import gae.gae.preprocessing as preprocessing

import pickle
import matplotlib.pyplot as plt
import seaborn as sns
import umap
import pandas as pd
from sklearn.preprocessing import scale
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
from sklearn.cluster import DBSCAN,MiniBatchKMeans,AgglomerativeClustering
from sklearn import metrics

import anndata as ad
import gc

import json
import matplotlib.image as mpimg
from skimage import io
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics.pairwise import euclidean_distances

In [23]:
seed=3

np.random.seed(seed)
def plotembeddingbyCT(ctlist,savename,excludelist,embedding,savepath,plotname,plotdimx=0,plotdimy=1,savenameAdd='',img=None,ncolors=None):
    
    celltypes=np.unique(ctlist)
    if ncolors is None:
        colortest=sns.color_palette("husl", celltypes.size)
    else:
        colortest=sns.color_palette("husl", ncolors)
    fig, ax = plt.subplots(dpi=400)
    if not img is None:
        plt.imshow(img)
    for ct in celltypes:
        if ct in excludelist:
            continue
        idx=(ctlist==ct)
        if not img is None:
            ax.scatter(
                embedding[idx, plotdimy],
                embedding[idx, plotdimx],
                color=colortest[int(ct)],label=ct,s=1.5,alpha=0.5
                )
        else:
            ax.scatter(
                embedding[idx, plotdimx],
                embedding[idx, plotdimy],
                color=colortest[int(ct)],label=ct,s=2.5,alpha=1
                )

    plt.gca().set_aspect('equal', 'datalim')
    fig.set_figheight(5)
    fig.set_figwidth(5)
    box = ax.get_position()
    ax.set_position([box.x0, box.y0 + box.height * 0.1,
                     box.width, box.height * 0.9])
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
              fancybox=True,ncol=5, shadow=True,prop={'size': 6})
    plt.title(plotname+' embedding', fontsize=12)
    plt.savefig(os.path.join(savepath,savename+savenameAdd+'.jpg'))
    plt.close('all')
    
    gc.collect()
    
def plotembeddingbyCT_str(ctlist,savename,excludelist,embedding,savepath,plotname,plotdimx=0,plotdimy=1,savenameAdd=''):
    celltypes=np.unique(ctlist)
    celltypes_dict={}
    idx=0
    for ct in celltypes:
        celltypes_dict[ct]=idx
        idx+=1
        
    colortest=sns.color_palette("husl", celltypes.size)
    fig, ax = plt.subplots(dpi=400)
    for ct in celltypes:
        if ct in excludelist:
            continue
        idx=(ctlist==ct)
        ax.scatter(
            embedding[idx, plotdimx],
            embedding[idx, plotdimy],
            color=colortest[celltypes_dict[ct]],label=ct,s=1.5,alpha=0.5
            )

    plt.gca().set_aspect('equal', 'datalim')
    fig.set_figheight(5)
    fig.set_figwidth(5)
    box = ax.get_position()
    ax.set_position([box.x0, box.y0 + box.height * 0.1,
                     box.width, box.height * 0.9])
    # Put a legend below current axis
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
              fancybox=True,ncol=2, shadow=True,prop={'size': 6})
    plt.title(plotname+' embedding', fontsize=24)
    plt.savefig(os.path.join(savepath,savename+savenameAdd+'.jpg'))
    plt.close('all')
    
    gc.collect()

In [3]:
np.random.seed(seed)
def plotembeddingbyCT_contrast(ctlist,savename,excludelist,embedding,savepath,plotname,plotdimx=0,plotdimy=1,savenameAdd='',maxplot=None): 
    celltypes=np.unique(ctlist)
    celltypes_dict={}
    idx=0
    for ct in celltypes:
        celltypes_dict[ct]=idx
        idx+=1

    colortest=sns.color_palette("tab10")
    if not os.path.exists(os.path.join(savepath)):
        os.makedirs(savepath)

    for ct in celltypes:
        if maxplot and int(ct)>maxplot:
            continue
        fig, ax = plt.subplots()
        if ct == 'Unassigned':
            continue

        idx=(ctlist!=ct)
        ax.scatter(
            embedding[idx, plotdimx],
            embedding[idx, plotdimy],
            color=colortest[1],label='others',s=5,alpha=1
            )

        idx=(ctlist==ct)
        ax.scatter(
            embedding[idx, plotdimx],
            embedding[idx, plotdimy],
            color=colortest[0],label=ct,s=6,alpha=1
            )

        plt.gca().set_aspect('equal', 'datalim')
        fig.set_figheight(10)
        fig.set_figwidth(10)
        ax.legend()
        plt.title(plotname+' embedding', fontsize=24)
        plt.gcf().savefig(os.path.join(savepath,savename+'_'+str(ct)+savenameAdd+'.jpg'))
        plt.close('all')
        gc.collect()

In [4]:
def plotCTcomp(labels,ctlist,savepath,savenamecluster,addname=''):
    res=np.zeros((np.unique(labels).size,np.unique(ctlist).size))
    for li in range(res.shape[0]):
        l=np.unique(labels)[li]
        nl=np.sum(labels==l)
        ctlist_l=ctlist[labels==l]
        for ci in range(res.shape[1]):
            c=np.unique(ctlist)[ci]
            res[li,ci]=np.sum(ctlist_l==c)
#             res[li,ci]=np.sum(ctlist_l==c)/nl
    if not byCT:
        addname+=''
        for li in range(res.shape[0]):
            l=np.unique(labels)[li]
            nl=np.sum(labels==l)
            res[li]=res[li]/nl
    else:
        addname+='_normbyCT'
        for ci in range(res.shape[1]):
            c=np.unique(ctlist)[ci]
            nc=np.sum(ctlist==c)
            res[:,ci]=res[:,ci]/nc
    
    fig, ax = plt.subplots(figsize=(10, 10))
    im = ax.imshow(res,cmap='binary')
    ax.set_yticks(np.arange(np.unique(labels).size))
    ax.set_yticklabels(np.unique(labels))
    ax.set_xticks(np.arange(np.unique(ctlist).size))
    ax.set_xticklabels(np.unique(ctlist))
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",rotation_mode="anchor")
    fig.tight_layout()
    plt.savefig(os.path.join(savepath,savenamecluster+'_ctComposition'+addname+'.jpg'))
    plt.close()
        

In [5]:
datadir='/home/xinyi/staci_validation/10xVisiumADFFPE/'
tissuepospath='VisiumFFPE_Mouse_Brain_Alzheimers_AppNote_aggr_tissue_positions_list.csv'
tissuepos=pd.read_csv(os.path.join(datadir,tissuepospath),header=None)

tissuepos.index=tissuepos.iloc[:,0]
scalefactor=0.150015
libraryID=pd.read_csv(os.path.join(datadir,'VisiumFFPE_Mouse_Brain_Alzheimers_AppNote_aggregation.csv'))
features=scanpy.read_10x_h5(os.path.join(datadir,'VisiumFFPE_Mouse_Brain_Alzheimers_AppNote_filtered_feature_bc_matrix.h5'))

  utils.warn_names_duplicates("var")


In [6]:
features.obs['x']=tissuepos.loc[features.obs.index,4]*scalefactor
features.obs['y']=tissuepos.loc[features.obs.index,5]*scalefactor
features.obs['barcodes']=features.obs.index
sampleidx=features.obs['barcodes'].apply(lambda x: x.split('-')[1])
samplenameList=np.zeros(sampleidx.size).astype(str)
for s in np.unique(sampleidx):
    samplenameList[sampleidx==s]=libraryID['library_id'][int(s)-1]
features.obs['samplename']=samplenameList
features.var_names_make_unique()
features=features[:,np.array(np.sum(features.X,axis=0)>3).flatten()]

In [7]:
plot_samples=np.unique(samplenameList)

In [11]:
ifplot=True
ifcluster=True

inverseAct='leakyRelu'
# inverseAct=None
plottype='umap'
pca=PCA()
npc=50 #for pca var ration
npc_plot=10 #for pairwise pc plots

minCells=15 #min number of cells for analysis
clustermethod=['leiden']
#umap/leiden clustering parameters
n_neighbors=10
min_dist=0.25
n_pcs=40 #for clustering
resolution=[0.2,0.3,0.4,0.5,0.6,0.8]
plotepoch=9310
savenameAdd=''


use_cuda=True
fastmode=False #Validate during training pass
seed=3
useSavedMaskedEdges=False
maskedgeName='knn6_connectivity'
nneighbors=6
hidden1=32 #Number of units in hidden layer 1
hidden2=32 #Number of units in hidden layer 2
fc_dim1=32

dropout=0.01
model_str='gcn_vae_xa_e2_d1_dca_sharded'
adj_decodeName=None #gala or None
plot_sample_X=['mnn']
plotRecon='' #'meanRecon'
standardizeX=False
name='10xAD_01_dca_over_mnn' 
logsavepath='/data/xinyi/log/train_gae_visium_validation/'+name
modelsavepath='/data/xinyi/models/train_gae_visium_validation/'+name
plotsavepath='/data/xinyi/plots/train_gae_visium_validation/'+name


In [9]:
def getA_knn(samplename,k,a_mode,savepath=None):
    sobj_coord_np=features.obs.loc[features.obs.index[features.obs['samplename']==samplename],['x','y']].to_numpy()
    nbrs = NearestNeighbors(n_neighbors=k, algorithm='ball_tree').fit(sobj_coord_np)
    a=nbrs.kneighbors_graph(sobj_coord_np,mode=a_mode)
    if a_mode=='connectivity':
        a=a-sp.identity(sobj_coord_np.shape[0],format='csr')
    if a_mode=='distance':
        a[a!=0]=1/a[a!=0]
    if savepath !=None:
        sp.save_npz(savepath,a)
    return a

In [16]:
savedirlist={}
featureslist={}
features_raw_list={}
adj_list={}
coordlist={}
commonGenes=[]
    
    
if plot_sample_X[0]=='logminmax':
    for samplename in np.unique(features.obs['samplename']):
        print(samplename)
    
        featurelog_train=np.log2(features.X[features.obs['samplename']==samplename].toarray()+1/2)
        scaler = MinMaxScaler()
        featurelog_train_minmax=np.transpose(scaler.fit_transform(np.transpose(featurelog_train)))
        featureslist[samplename+'X_'+plot_sample_X[0]]=torch.tensor(featurelog_train_minmax)
elif plot_sample_X[0]=='mnn':
    feature_corrected=features.copy()
    scanpy.pp.normalize_total(feature_corrected, target_sum=1e4)
    scanpy.pp.log1p(feature_corrected)
    scanpy.pp.highly_variable_genes(feature_corrected,batch_key='samplename')
    featureslist_correct={}
    for b in np.unique(feature_corrected.obs['samplename']):
        featureslist_correct[b]=feature_corrected[feature_corrected.obs['samplename']==b]
    featureslist_correct=scanpy.external.pp.mnn_correct(featureslist_correct['Transgenic_17p9_rep1'],
                                            featureslist_correct['Transgenic_17p9_rep2'],
                                            featureslist_correct['Transgenic_2p5_rep1'],
                                            featureslist_correct['Transgenic_2p5_rep2'],
                                            featureslist_correct['Transgenic_5p7_rep1'],
                                            featureslist_correct['Transgenic_5p7_rep2'],
                                            featureslist_correct['Wildtype_13p4_rep1'],
                                            featureslist_correct['Wildtype_13p4_rep2'],
                                            featureslist_correct['Wildtype_2p5_rep1'],
                                            featureslist_correct['Wildtype_2p5_rep2'],
                                            featureslist_correct['Wildtype_5p7_rep1'],
                                            featureslist_correct['Wildtype_5p7_rep2'],batch_key='samplename',var_subset=feature_corrected.var.index[feature_corrected.var.highly_variable],do_concatenate=False)
    for b in range(len(featureslist_correct[0])):
        samplename=featureslist_correct[0][b].obs['samplename'][0]
        featureslist[samplename+'X_'+plot_sample_X[0]]=torch.tensor(featureslist_correct[0][b].X.toarray())
elif plot_sample_X[0]=='combat':
    feature_corrected=features.copy()
    scanpy.pp.normalize_total(feature_corrected, target_sum=1e4)
    scanpy.pp.log1p(feature_corrected)
    scanpy.pp.scale(feature_corrected)
    scanpy.pp.combat(feature_corrected,key='samplename')
    for b in np.unique(features.obs['samplename']):
        featureslist[b+'X_'+plot_sample_X[0]]=torch.tensor(feature_corrected[feature_corrected.obs['samplename']==b])
for samplename in np.unique(features.obs['samplename']):
    print(samplename)    
    adj_list[samplename]=getA_knn(samplename,nneighbors+1,'connectivity')
    features_raw_list[samplename+'X_'+'raw']=torch.tensor(features.X[features.obs['samplename']==samplename].toarray())
    
    
num_features=features.shape[1]
print(num_features)
adjnormlist={}
pos_weightlist={}
normlist={}
for ai in adj_list.keys():
    adjnormlist[ai]=preprocessing.preprocess_graph(adj_list[ai])
    
    pos_weightlist[ai] = torch.tensor(float(adj_list[ai].shape[0] * adj_list[ai].shape[0] - adj_list[ai].sum()) / adj_list[ai].sum()) #using full unmasked adj
    normlist[ai] = adj_list[ai].shape[0] * adj_list[ai].shape[0] / float((adj_list[ai].shape[0] * adj_list[ai].shape[0] - adj_list[ai].sum()) * 2)
    
    adj_label=adj_list[ai] + sp.eye(adj_list[ai].shape[0])
    adj_list[ai]=torch.tensor(adj_label.toarray())
    
        

  hvg = hvg.append(missing_hvg, ignore_index=True)
  hvg = hvg.append(missing_hvg, ignore_index=True)
  hvg = hvg.append(missing_hvg, ignore_index=True)
  hvg = hvg.append(missing_hvg, ignore_index=True)
  hvg = hvg.append(missing_hvg, ignore_index=True)
  hvg = hvg.append(missing_hvg, ignore_index=True)
  hvg = hvg.append(missing_hvg, ignore_index=True)
  hvg = hvg.append(missing_hvg, ignore_index=True)
  hvg = hvg.append(missing_hvg, ignore_index=True)
  hvg = hvg.append(missing_hvg, ignore_index=True)
  hvg = hvg.append(missing_hvg, ignore_index=True)
  hvg = hvg.append(missing_hvg, ignore_index=True)
  dist[i, j] = np.dot(m[i], n[j])
  scale = np.dot(working, grad)
  curproj = np.dot(grad, curcell)


Cython module _utils not initialized. Fallback to python.
Performing cosine normalization...
Starting MNN correct iteration. Reference batch: 0
Step 1 of 11: processing batch 1
  Looking for MNNs...
  Computing correction vectors...
  Adjusting variance...
  Applying correction...
Step 2 of 11: processing batch 2
  Looking for MNNs...
  Computing correction vectors...
  Adjusting variance...
  Applying correction...
Step 3 of 11: processing batch 3
  Looking for MNNs...
  Computing correction vectors...
  Adjusting variance...
  Applying correction...
Step 4 of 11: processing batch 4
  Looking for MNNs...
  Computing correction vectors...
  Adjusting variance...
  Applying correction...
Step 5 of 11: processing batch 5
  Looking for MNNs...
  Computing correction vectors...
  Adjusting variance...
  Applying correction...
Step 6 of 11: processing batch 6
  Looking for MNNs...
  Computing correction vectors...
  Adjusting variance...
  Applying correction...
Step 7 of 11: processing bat

  self._set_arrayXarray_sparse(i, j, x)


Done.
Transgenic_17p9_rep1
Transgenic_17p9_rep2
Transgenic_2p5_rep1
Transgenic_2p5_rep2
Transgenic_5p7_rep1
Transgenic_5p7_rep2
Wildtype_13p4_rep1
Wildtype_13p4_rep2
Wildtype_2p5_rep1
Wildtype_2p5_rep2
Wildtype_5p7_rep1
Wildtype_5p7_rep2
17186
0
0
0
0
0
0
0
0
0
0
0
0


In [17]:
# load model
num_nodes,num_features = list(featureslist.values())[0].shape
if model_str=='gcn_vae_xa':
    model  = gae.gae.model.GCNModelVAE_XA(num_features, hidden1, hidden2,fc_dim1,fc_dim2,fc_dim3,fc_dim4, dropout)
elif model_str=='fc1':
    model  = gae.gae.model.FCVAE1(num_features, hidden1,dropout)
elif model_str == 'gcn_vae_xa_e2_d1':
    model  = gae.gae.model.GCNModelVAE_XA_e2_d1(num_features, hidden1,hidden2, dropout)
elif model_str == 'gcn_vae_gcnX_inprA':
    model = gae.gae.model.GCNModelVAE_gcnX_inprA(num_features, hidden1, hidden2,gcn_dim1, dropout)
elif model_str=='fc1_dca':
    model = gae.gae.model.FCVAE1_DCA(num_features, hidden1,fc_dim1, dropout)
elif model_str=='fc1_dca_sharded':
    model = gae.gae.model.FCVAE1_DCA_sharded(num_features, hidden1,fc_dim1, dropout)
elif model_str=='gcn_vae_xa_e2_d1_dca':
    model = gae.gae.model.GCNModelVAE_XA_e2_d1_DCA(num_features, hidden1,hidden2,fc_dim1, dropout)
elif model_str=='gcn_vae_xa_e2_d1_dca_sharded':
    model = gae.gae.model.GCNModelVAE_XA_e2_d1_DCA_sharded(num_features, hidden1,hidden2,fc_dim1, dropout)
elif model_str=='gcn_vae_xa_e2_d1_dcaFork':
    model = gae.gae.model.GCNModelVAE_XA_e2_d1_DCAfork(num_features, hidden1,hidden2,fc_dim1, dropout)
elif model_str=='gcn_vae_xa_e2_d1_dcaElemPi':
    model = gae.gae.model.GCNModelVAE_XA_e2_d1_DCAelemPi(num_features, hidden1,hidden2,fc_dim1, dropout,shareGenePi)
elif model_str=='gcn_vae_xa_e2_d1_dcaConstantDisp':
    model = gae.gae.model.GCNModelVAE_XA_e2_d1_DCA_constantDisp(num_features, hidden1,hidden2,fc_dim1, dropout,shareGenePi)
else:
    print('model not found')
model.load_state_dict(torch.load(os.path.join(modelsavepath,str(plotepoch)+'.pt')))

<All keys matched successfully>

In [18]:
np.random.seed(seed)
def inverseLeakyRelu(v,slope=0.01):
    vnegidx=(v<0)
    v[vnegidx]=1/slope*v[vnegidx]
    return v

In [19]:
np.random.seed(seed)
def clusterLeiden_single(inArray,n_neighbors,n_pcs,min_dist,resolution,randseed=seed):
    n_pcs=np.min([inArray.shape[0]-1,inArray.shape[1]-1,n_pcs])
    adata=ad.AnnData(inArray)
    scanpy.tl.pca(adata, svd_solver='arpack')
    scanpy.pp.neighbors(adata, n_neighbors=n_neighbors, n_pcs=n_pcs)
    scanpy.tl.umap(adata,min_dist=min_dist,random_state=randseed)
    scanpy.tl.leiden(adata,resolution=resolution,random_state=randseed)
    return adata.obs['leiden'].to_numpy()

def clusterLeiden(inArray,n_neighbors,n_pcs,min_dist,resolution,sobj_coord_np,randseed=seed):
    for r in resolution:
        clusterRes=clusterLeiden_single(inArray,n_neighbors,n_pcs,min_dist,r,randseed=seed)
#         print(clusterRes.shape)
        savenamecluster='leiden_nn'+str(n_neighbors)+'mdist0'+str(int(min_dist*100))+'n_pcs'+str(n_pcs)+'res'+str(r)+'epoch'+str(plotepoch)
        with open(os.path.join(clustersavedir,savenamecluster), 'wb') as output:
            pickle.dump(clusterRes, output, pickle.HIGHEST_PROTOCOL)
        plotembeddingbyCT_contrast(clusterRes,'leiden_location',[],sobj_coord_np,os.path.join(savedir,'contrast'),'location'+' of '+s,savenameAdd=savenamecluster)

def clusterLeiden_allsample(embedding,savedir,clustersavedir,inArray,n_neighbors,n_pcs,min_dist,resolution,sobj_coord_np,samplenameList,randseed=seed):
    for r in resolution:
        clusterRes=clusterLeiden_single(inArray,n_neighbors,n_pcs,min_dist,r,randseed=seed)
        savenamecluster='leiden_nn'+str(n_neighbors)+'mdist0'+str(int(min_dist*100))+'n_pcs'+str(n_pcs)+'res'+str(r)+'epoch'+str(plotepoch)
        with open(os.path.join(clustersavedir,savenamecluster), 'wb') as output:
            pickle.dump(clusterRes, output, pickle.HIGHEST_PROTOCOL)
        plotembeddingbyCT(clusterRes,'leiden',[],embedding,savedir,plottype+' of all samples',savenameAdd=savenamecluster)
        plotembeddingbyCT_contrast(clusterRes,'leiden',[],embedding,os.path.join(savedir,'contrast'),plottype+' of all samples',savenameAdd=savenamecluster,maxplot=50)

        with open(os.path.join(clustersavedir,savenamecluster), 'rb') as output:
            clusterRes=pickle.load(output)
        for s in plot_samples:
            sidx=(samplenameList==s)
            img=None
#             img=mpimg.imread(os.path.join(datadir,'spatial',s,'tissue_hires_image.png'))
            plotembeddingbyCT(clusterRes[sidx],'leiden_location'+s,[],sobj_coord_np[sidx],savedir,'location'+' of '+s,savenameAdd=savenamecluster+'_noImg',img=img,ncolors=np.unique(clusterRes).size)
            plotembeddingbyCT_contrast(clusterRes[sidx],'leiden_location'+s,[],sobj_coord_np[sidx],os.path.join(savedir,'contrast'),'location'+' of '+s,savenameAdd=savenamecluster,maxplot=50)

         

In [20]:
#compute embeddings
mulist={}
for s in np.unique(samplenameList):
    adj=adj_list[s]
    adj_norm = preprocessing.preprocess_graph(adj)
    adj_decode=None
    if adj_decodeName == 'gala':
        adj_decode=preprocessing.preprocess_graph_sharp(adj)
    for xcorr in plot_sample_X:
        samplename=s+'X_'+xcorr
        features_s=featureslist[samplename]
        
        model.eval()
        if adj_decodeName==None:
            adj_recon,mu,logvar,z, features_recon = model(features_s, adj_norm)
        else:
            adj_recon,mu,logvar,z, features_recon = model(features_s, adj_norm,adj_decode)
        if inverseAct=='leakyRelu':
            muplot=inverseLeakyRelu(mu.cpu().detach().numpy())
        else:
            muplot=mu.cpu().detach().numpy()
        if plotRecon:
            if plotRecon=='meanRecon':
                mulist[samplename]=features_recon[3].cpu().detach().numpy()
        else:
            mulist[samplename]=muplot
        

0
0
0
0
0
0
0
0
0
0
0
0


In [24]:
# combine all latents to one plot 
np.random.seed(seed)
for xcorr in plot_sample_X:
    latents=None
    samplenameList_plot=None
    sobj_coord_np=None
    
    for s in np.unique(samplenameList):
        samplename=s+'X_'+xcorr
        muplot=np.copy(mulist[samplename])
            
        if latents is None:
            latents=muplot
            sobj_coord_np=features.obs.loc[:,('x','y')].to_numpy()[samplenameList==s]
            samplenameList_plot=np.repeat(s,muplot.shape[0])
        else:
            latents=np.vstack((latents,muplot))
            sobj_coord_np=np.concatenate((sobj_coord_np,features.obs.loc[:,('x','y')].to_numpy()[samplenameList==s]),axis=0)
            samplenameList_plot=np.concatenate((samplenameList_plot,np.repeat(s,muplot.shape[0])),axis=None)

    sampledir=os.path.join(plotsavepath,'combined'+xcorr)
    if inverseAct:
        sampledir+='_beforeAct'
    savedir=os.path.join(sampledir,'embedding_'+plottype)
    clustersavedir=os.path.join(sampledir,'cluster')
    if not os.path.exists(sampledir):
        os.mkdir(sampledir)
    if not os.path.exists(savedir):
        os.mkdir(savedir)
    if not os.path.exists(clustersavedir):
        os.mkdir(clustersavedir)
    
    if plottype=='umap':
        npc_plot=2
        reducer = umap.UMAP(n_neighbors=n_neighbors,min_dist=min_dist,random_state=seed)
        embedding = reducer.fit_transform(latents)
        savenameAdd='_nn'+str(n_neighbors)+'mdist0'+str(int(min_dist*100))+'epoch'+str(plotepoch)
    elif plottype=='pca':
        pca.fit(latents)
        fig, ax = plt.subplots(dpi=400)
        fig.set_figheight(2.5)
        fig.set_figwidth(10)
        plt.bar(np.arange(npc),pca.explained_variance_ratio_[:npc])
        plt.savefig(os.path.join(savedir,'varRatio_'+str(npc)+'_epoch'+str(plotepoch)+'.jpg'))
        plt.close()
        fig, ax = plt.subplots(dpi=400)
        fig.set_figheight(2.5)
        fig.set_figwidth(10)
        plt.bar(np.arange(npc),pca.explained_variance_[:npc])
        plt.savefig(os.path.join(savedir,'var_'+str(npc)+'_epoch'+str(plotepoch)+'.jpg'))
        plt.close()
        embedding=pca.transform(latents)
#         embedding=pca.fit_transform(latents)
        savenameAdd='_epoch'+str(plotepoch)
    if ifplot:
        for dim1 in range(npc_plot-1):
            for dim2 in range(dim1+1,npc_plot):
                plotembeddingbyCT_str(samplenameList,'sample',[],embedding,savedir,plottype+'of all samples',plotdimx=dim1,plotdimy=dim2,savenameAdd=savenameAdd+'_pc'+str(dim1)+'pc'+str(dim2))
    
    if embedding.shape[0]<minCells:
        continue
    if ifcluster:
        if 'leiden' in clustermethod:
            clusterLeiden_allsample(embedding,savedir,clustersavedir,latents,n_neighbors,n_pcs,min_dist,resolution,sobj_coord_np,samplenameList,randseed=seed)
            assert np.sum(muplot-np.copy(mulist[s+'X_'+xcorr]))==0
