In [1]:
import sys
sys.path.append('/home/xinyiz/pamrats')

import time
import os

import numpy as np

import scanpy
import scipy.sparse as sp

import torch
from torch import optim

# from sklearn.metrics import roc_auc_score
# from sklearn.metrics import average_precision_score

import gae.gae.optimizer as optimizer
import gae.gae.model
import gae.gae.preprocessing as preprocessing

import pickle
import matplotlib.pyplot as plt
import seaborn as sns
import umap
import pandas as pd
from sklearn.preprocessing import scale
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
from sklearn.cluster import DBSCAN,MiniBatchKMeans,AgglomerativeClustering
from sklearn import metrics

import anndata as ad
import gc

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1" 
ifplot=True
ifcluster=True

inverseAct='leakyRelu'
# inverseAct=None
plottype='umap'
pca=PCA()
minCells=15 #min number of cells for analysis
# clustermethod=['kmeanbatch']
clustermethod=['leiden','dbscan','agglomerative','kmeanbatch']
#umap/leiden clustering parameters
n_neighbors=10
min_dist=0.25
n_pcs=40 #for clustering
# resolution=[0.5,1,1.5,2]
resolution=[0.05,0.1,0.2,0.3]
plotepoch=9990
savenameAdd=''
#DBscan
epslist= [6,8,10]
min_sampleslist=[15,30,45] 
#agglomerative
nclusterlist=[2,3,4,5,8,10]
aggMetric=['euclidean']


combineCelltype={'glia':['Astro','Micro', 'OPC', 'Oligo'],'CA':['CA1', 'CA2', 'CA3']}

use_cuda=True
fastmode=False #Validate during training pass
seed=3
useSavedMaskedEdges=False
maskedgeName='knn20_connectivity'
hidden1=1024 #Number of units in hidden layer 1
hidden2=1024 #Number of units in hidden layer 2
# hidden3=16
fc_dim1=1024
# fc_dim2=2112
# fc_dim3=2112
# fc_dim4=2112
# gcn_dim1=2048

protein=None #'scaled_binary'
# proteinWeights=0.05
dropout=0.01
# randFeatureSubset=None
model_str='gcn_vae_xa_e2_d1_dca'
adj_decodeName=None #gala or None
plot_samples={'disease13':'AD_mouse9494','control13':'AD_mouse9498','disease8':'AD_mouse9723','control8':'AD_mouse9735'}
plot_sample_X=['logminmax']
# plot_sample_X=['corrected','scaled']
standardizeX=False
name='allk20XA_01_dca'
logsavepath='/mnt/xinyi/pamrats/log/train_gae_starmap/'+name
modelsavepath='/mnt/xinyi/pamrats/models/train_gae_starmap/'+name
plotsavepath='/mnt/xinyi/pamrats/plots/train_gae_starmap/'+name
    

In [3]:
# Set cuda and seed
np.random.seed(seed)
if use_cuda and (not torch.cuda.is_available()):
    print('cuda not available')
    use_cuda=False
torch.manual_seed(seed)
if use_cuda:
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.enabled = True


In [4]:
#Load data
savedir=os.path.join('/mnt/xinyi/','starmap')
adj_dir=os.path.join(savedir,'a')

featureslist={}
if plot_sample_X[0] in ['corrected','scaled']:
    scaleddata=scanpy.read_h5ad('/mnt/xinyi/2021-01-13-mAD-test-dataset/2020-12-27-starmap-mAD-scaled.h5ad')
    
    for s in plot_samples.keys():
        featureslist[s+'X_'+'corrected']=torch.tensor(scaleddata.layers['corrected'][scaleddata.obs['sample']==plot_samples[s]])
        featureslist[s+'X_'+'scaled']=torch.tensor(scaleddata.layers['scaled'][scaleddata.obs['sample']==plot_samples[s]])
    
else:
    scaleddata=scanpy.read_h5ad('/mnt/xinyi/2021-01-13-mAD-test-dataset/2020-12-27-starmap-mAD-raw.h5ad')
    
    for s in plot_samples.keys():
        scaleddata_train=scaleddata.X[scaleddata.obs['sample']==plot_samples[s]]

        if plot_sample_X[0]=='logminmax':
            featurelog_train=np.log2(scaleddata_train+1/2)
            scaler = MinMaxScaler()
            featurelog_train_minmax=np.transpose(scaler.fit_transform(np.transpose(featurelog_train)))
            featureslist[s+'X_'+plot_sample_X[0]]=torch.tensor(featurelog_train_minmax)

if protein: ##adjust for scaled/corrected
    proteinsavepath=os.path.join('/mnt/xinyi/','starmap','protein')
    for s in plot_samples.keys():
        pmtx=sp.load_npz(os.path.join(proteinsavepath,plot_samples[s]+'_'+protein+'.npz'))
        pmtx=preprocessing.sparse_mx_to_torch_sparse_tensor(pmtx)
        pmtx=pmtx.to_dense()
        scalefactor=torch.sum(featureslist[s+'X_'+plot_sample_X[0]])/torch.sum(pmtx)*proteinWeights
        featureslist[s+'X_'+plot_sample_X[0]]=torch.cat((featureslist[s+'X_'+plot_sample_X[0]],pmtx*scalefactor),dim=1)

adj_list={}
adj_list['disease13']=sp.load_npz(os.path.join(adj_dir,maskedgeName+'_AD_mouse9494.npz'))
adj_list['control13']=sp.load_npz(os.path.join(adj_dir,maskedgeName+'_AD_mouse9498.npz'))
adj_list['disease8']=sp.load_npz(os.path.join(adj_dir,maskedgeName+'_AD_mouse9723.npz'))
adj_list['control8']=sp.load_npz(os.path.join(adj_dir,maskedgeName+'_AD_mouse9735.npz'))

feature_names=scaleddata.var.index

In [5]:
# load model
num_nodes,num_features = list(featureslist.values())[0].shape
if model_str=='gcn_vae_xa':
    model  = gae.gae.model.GCNModelVAE_XA(num_features, hidden1, hidden2,fc_dim1,fc_dim2,fc_dim3,fc_dim4, dropout)
elif model_str=='fc1':
    model  = gae.gae.model.FCVAE1(num_features, hidden1,dropout)
elif model_str == 'gcn_vae_xa_e2_d1':
    model  = gae.gae.model.GCNModelVAE_XA_e2_d1(num_features, hidden1,hidden2, dropout)
elif model_str == 'gcn_vae_gcnX_inprA':
    model = gae.gae.model.GCNModelVAE_gcnX_inprA(num_features, hidden1, hidden2,gcn_dim1, dropout)
elif model_str=='fc1_dca':
    model = gae.gae.model.FCVAE1_DCA(num_features, hidden1,fc_dim1, dropout)
elif model_str=='gcn_vae_xa_e2_d1_dca':
    model = gae.gae.model.GCNModelVAE_XA_e2_d1_DCA(num_features, hidden1,hidden2,fc_dim1, dropout)
elif model_str=='gcn_vae_xa_e2_d1_dcaFork':
    model = gae.gae.model.GCNModelVAE_XA_e2_d1_DCAfork(num_features, hidden1,hidden2,fc_dim1, dropout)
elif model_str=='gcn_vae_xa_e2_d1_dcaElemPi':
    model = gae.gae.model.GCNModelVAE_XA_e2_d1_DCAelemPi(num_features, hidden1,hidden2,fc_dim1, dropout,shareGenePi)
elif model_str=='gcn_vae_xa_e2_d1_dcaConstantDisp':
    model = gae.gae.model.GCNModelVAE_XA_e2_d1_DCA_constantDisp(num_features, hidden1,hidden2,fc_dim1, dropout,shareGenePi)
else:
    print('model not found')
model.load_state_dict(torch.load(os.path.join(modelsavepath,str(plotepoch)+'.pt')))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [6]:
np.random.seed(seed)
def inverseLeakyRelu(v,slope=0.01):
    vnegidx=(v<0)
    v[vnegidx]=1/slope*v[vnegidx]
    return v

In [24]:
def readMinibatchKmean(nclusterL,n_pcs):
    for ncluster in nclusterL:
        savenamecluster='minibatchkmean_ncluster'+str(ncluster)+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)
        readpath=os.path.join(clustersavedir,savenamecluster)
        if not os.path.exists(readpath):
            print('DNE: '+readpath)
            continue
        with open(readpath, 'rb') as input:
            labels = pickle.load(input)
        res=np.reshape(labels,(-1,1))
        res=np.hstack((np.reshape(cellidx,(-1,1)),res))
        np.savetxt(os.path.join(cometsavedir,savenamecluster+'.txt'),res,delimiter='\t',fmt="%s")

In [8]:
#compute embeddings
mulist={}
for s in plot_samples.keys():
    adj=adj_list[s]
    adj_norm = preprocessing.preprocess_graph(adj)
    adj_decode=None
    if adj_decodeName == 'gala':
        adj_decode=preprocessing.preprocess_graph_sharp(adj)
    for xcorr in plot_sample_X:
        samplename=s+'X_'+xcorr
        features=featureslist[samplename]
        if standardizeX:
            features=torch.tensor(scale(features,axis=0, with_mean=True, with_std=True, copy=True))
        if use_cuda:
            model.cuda()
            features = features.cuda().float()
            adj_norm=adj_norm.cuda()
            if adj_decodeName:
                adj_decode=adj_decode.cuda()
        
        model.eval()
        if adj_decodeName==None:
            adj_recon,mu,logvar,z, features_recon = model(features, adj_norm)
        else:
            adj_recon,mu,logvar,z, features_recon = model(features, adj_norm,adj_decode)
        if inverseAct=='leakyRelu':
            muplot=inverseLeakyRelu(mu.cpu().detach().numpy())
        else:
            muplot=mu.cpu().detach().numpy()
        mulist[samplename]=muplot

In [27]:
plot_samples={'disease13':'AD_mouse9494','control13':'AD_mouse9498','disease8':'AD_mouse9723','control8':'AD_mouse9735'}

ifcluster=True
# separate plots by region and cell types
np.random.seed(seed)
for s in plot_samples.keys():
#     if s in ['disease13']:
#         continue
    print(s)
    sampleidx=plot_samples[s]
    celltype_broad=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'top_level']
    celltype_sub=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'cell_type_label']
    region=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'region']
    sobj_coord_np=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,['x','y']].to_numpy()
    
    origCT=np.unique(celltype_broad)
    celltypeplot=np.concatenate((origCT,list(combineCelltype.keys())),axis=None)
    for xcorr in plot_sample_X:
        samplename=s+'X_'+xcorr
        muplot=np.copy(mulist[samplename])
        featureDE=np.copy(featureslist[samplename])*1000

        if inverseAct:
            samplename+='_beforeAct'
        sampledir=os.path.join(plotsavepath,samplename)
        if not os.path.exists(sampledir):
            os.mkdir(sampledir)
            
        for r in np.unique(region):
            print(r)
            ridx=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'region']==r
            for reg in celltypeplot:
                if not ((r=='Cortex' and reg in ['Ex']) or (r=='Hippocampus' and reg in ['CA1','DG','Micro'])):
                    continue
#                 if not (r=='Hippocampus' and reg in ['CA']):
#                     continue
#                 if s=='disease8' and r=='Cortex' and (reg in ['Astro','CA1','CA2','CA3','DG','Endo','Ex','Inhi','LHb','Micro','OPC','Oligo','SMC']):
#                     continue
                print(reg)
                clustersavedir=os.path.join(plotsavepath,samplename,'cluster'+'_'+reg,r)
                cometsavedir=os.path.join(clustersavedir,'cometMarker')
                if not os.path.exists(cometsavedir):
                    os.mkdir(cometsavedir)

                if reg in origCT:
                    ct_idx=celltype_broad==reg
                else:
                    ct_idx=False
                    for i in combineCelltype[reg]:
                        ct_idx=np.logical_or(ct_idx,celltype_broad==i)
                
                reg_idx=np.logical_and(ridx,ct_idx)
                if np.sum(reg_idx)<=3:
                    print('skipped')
                    continue
                
                if plottype=='umap':
                    reducer = umap.UMAP(n_neighbors=n_neighbors,min_dist=min_dist,random_state=seed)
                    embedding = reducer.fit_transform(muplot[reg_idx])
                    savenameAdd='_nn'+str(n_neighbors)+'mdist0'+str(int(min_dist*100))+'epoch'+str(plotepoch)
                elif plottype=='pca':
                    embedding=pca.fit_transform(muplot[reg_idx])
                    savenameAdd='_epoch'+str(plotepoch)
                cellidx=np.core.defchararray.add('cell',np.arange(np.sum(reg_idx),dtype='int').astype(str))
                embeddingsave=np.reshape(embedding,(-1,2))
                embeddingsave=np.hstack((np.reshape(cellidx,(-1,1)),embeddingsave))
                np.savetxt(os.path.join(cometsavedir,'umapCoord'+savenameAdd+'.txt'),embeddingsave,delimiter='\t',fmt="%s")
                
                featuresave=np.transpose(featureDE[reg_idx])
                featuresave=np.hstack((np.reshape(feature_names,(-1,1)),featuresave))
                featuresave=np.vstack((np.concatenate(([''],cellidx)),featuresave))
                featuresave=featuresave.astype(str)
                np.savetxt(os.path.join(cometsavedir,'features1000'+'.txt'),featuresave,delimiter='\t',fmt="%s")
                if embedding.shape[0]<minCells:
                    continue
                if ifcluster:
#                     if 'leiden' in clustermethod:
#                         clusterLeiden(muplot[reg_idx],n_neighbors,n_pcs,min_dist,resolution,sobj_coord_np[reg_idx],randseed=seed)
#                         assert np.sum(muplot-np.copy(mulist[s+'X_'+xcorr]))==0
#                     if 'dbscan' in clustermethod:
#                         readDBscan(epslist,min_sampleslist,n_pcs)
#                     if 'agglomerative' in clustermethod:
#                         clusterAgg(muplot[reg_idx],nclusterlist,aggMetric,n_pcs,sobj_coord_np[reg_idx])
#                         assert np.sum(muplot-np.copy(mulist[s+'X_'+xcorr]))==0
                    if 'kmeanbatch' in clustermethod:
                        readMinibatchKmean(nclusterlist,n_pcs)

disease13
Cortex
Ex
Hippocampus
CA1
DG
Micro
White Matter
control13
Cortex
Ex
Hippocampus
CA1
DG
Micro
White Matter
disease8
Cortex
Ex
Hippocampus
CA1
DG
Micro
White Matter
control8
Cortex
Ex
Hippocampus
CA1
DG
Micro
White Matter


In [23]:
np.core.defchararray.add('cell',np.arange(3,dtype='int').astype(str))

array(['cell0', 'cell1', 'cell2'], dtype='<U25')