In [None]:
import sys
sys.path.append('/home/xinyiz/pamrats')

import time
import os

import scanpy
import numpy as np
import scipy.sparse as sp

import torch
from torch import optim

# from sklearn.metrics import roc_auc_score
# from sklearn.metrics import average_precision_score

import gae.gae.optimizer as optimizer
import gae.gae.model
import gae.gae.preprocessing as preprocessing

import pickle
import matplotlib.pyplot as plt
import seaborn as sns
import umap
import pandas as pd
from sklearn.preprocessing import scale
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler

In [None]:
inverseAct='leakyRelu'
# inverseAct=None
plottype='pca'
pca=PCA()
use_cuda=True
fastmode=False #Validate during training pass
seed=3
useSavedMaskedEdges=False
maskedgeName='knn20_connectivity'
plotepoch=1980
hidden1=2048 #Number of units in hidden layer 1
# hidden2=2048 #Number of units in hidden layer 2
# hidden3=16
fc_dim1=2048
# fc_dim2=2112
# fc_dim3=2112
# fc_dim4=2112
# gcn_dim1=2048

dropout=0.01
# randFeatureSubset=None
model_str='fc1_dca'
adj_decodeName=None #gala or None
plot_samples={'disease13':'AD_mouse9494','control13':'AD_mouse9498','disease8':'AD_mouse9723','control8':'AD_mouse9735'}
# plot_sample_X=['logminmax']
plot_sample_X=['corrected','scaled']
standardizeX=False
name='c13k20XA_FCXonly_02_dca'
logsavepath='/mnt/xinyi/pamrats/log/train_gae_starmap/'+name
modelsavepath='/mnt/xinyi/pamrats/models/train_gae_starmap/'+name
plotsavepath='/mnt/xinyi/pamrats/plots/train_gae_starmap/'+name

In [None]:
# Set cuda and seed
np.random.seed(seed)
if use_cuda and (not torch.cuda.is_available()):
    print('cuda not available')
    use_cuda=False
torch.manual_seed(seed)
if use_cuda:
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.enabled = True


In [None]:
#Load data
savedir=os.path.join('/mnt/xinyi/','starmap')
adj_dir=os.path.join(savedir,'a')

featureslist={}
if plot_sample_X[0] in ['corrected','scaled']:
    scaleddata=scanpy.read_h5ad('/mnt/xinyi/2021-01-13-mAD-test-dataset/2020-12-27-starmap-mAD-scaled.h5ad')
    for s in plot_samples.keys():
        featureslist[s+'X_'+'corrected']=torch.tensor(scaleddata.layers['corrected'][scaleddata.obs['sample']==plot_samples[s]])
        featureslist[s+'X_'+'scaled']=torch.tensor(scaleddata.layers['scaled'][scaleddata.obs['sample']==plot_samples[s]])
    
else:
    scaleddata=scanpy.read_h5ad('/mnt/xinyi/2021-01-13-mAD-test-dataset/2020-12-27-starmap-mAD-raw.h5ad')
    
    for s in plot_samples.keys():
        scaleddata_train=scaleddata.X[scaleddata.obs['sample']==plot_samples[s]]

        if plot_sample_X[0]=='logminmax':
            featurelog_train=np.log2(scaleddata_train+1/2)
            scaler = MinMaxScaler()
            featurelog_train_minmax=np.transpose(scaler.fit_transform(np.transpose(featurelog_train)))
            featureslist[s+'X_'+plot_sample_X[0]]=torch.tensor(featurelog_train_minmax)


adj_list={}
adj_list['disease13']=sp.load_npz(os.path.join(adj_dir,maskedgeName+'_AD_mouse9494.npz'))
adj_list['control13']=sp.load_npz(os.path.join(adj_dir,maskedgeName+'_AD_mouse9498.npz'))
adj_list['disease8']=sp.load_npz(os.path.join(adj_dir,maskedgeName+'_AD_mouse9723.npz'))
adj_list['control8']=sp.load_npz(os.path.join(adj_dir,maskedgeName+'_AD_mouse9735.npz'))

In [None]:
# load model
num_nodes,num_features = list(featureslist.values())[0].shape
if model_str=='gcn_vae_xa':
    model  = gae.gae.model.GCNModelVAE_XA(num_features, hidden1, hidden2,fc_dim1,fc_dim2,fc_dim3,fc_dim4, dropout)
elif model_str=='fc1':
    model  = gae.gae.model.FCVAE1(num_features, hidden1,dropout)
elif model_str == 'gcn_vae_xa_e2_d1':
    model  = gae.gae.model.GCNModelVAE_XA_e2_d1(num_features, hidden1,hidden2, dropout)
elif model_str == 'gcn_vae_gcnX_inprA':
    model = gae.gae.model.GCNModelVAE_gcnX_inprA(num_features, hidden1, hidden2,gcn_dim1, dropout)
elif model_str=='fc1_dca':
    model = gae.gae.model.FCVAE1_DCA(num_features, hidden1,fc_dim1, dropout)
model.load_state_dict(torch.load(os.path.join(modelsavepath,str(plotepoch)+'.pt')))

In [None]:
def plotembeddingbyCT(ctlist,savename,excludelist,embedding,savepath,plotname,plotdimx=0,plotdimy=1):
    celltypes=np.unique(ctlist)
    celltypes_dict={}
    idx=0
    for ct in celltypes:
        celltypes_dict[ct]=idx
        idx+=1
        
    colortest=sns.color_palette("husl", celltypes.size)
    np.random.shuffle(colortest)
    fig, ax = plt.subplots(dpi=400)
    for ct in celltypes:
        if ct in excludelist:
            continue
        idx=(ctlist==ct)
        ax.scatter(
            embedding[idx, plotdimx],
            embedding[idx, plotdimy],
            color=colortest[celltypes_dict[ct]],label=ct,s=3,alpha=0.5
            )

    plt.gca().set_aspect('equal', 'datalim')
    fig.set_figheight(10)
    fig.set_figwidth(10)
    box = ax.get_position()
    ax.set_position([box.x0, box.y0 + box.height * 0.1,
                     box.width, box.height * 0.9])
    # Put a legend below current axis
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
              fancybox=True, shadow=True, ncol=5)
#     ax.legend(ncol=3)
    plt.title(plotname+' embedding', fontsize=24)
    plt.savefig(os.path.join(savepath,savename+'.jpg'))
    plt.show()

In [None]:
def plotembeddingbyCT_contrast(ctlist,savename,excludelist,embedding,savepath,plotname,plotdimx=0,plotdimy=1): 
    celltypes=np.unique(ctlist)
    celltypes_dict={}
    idx=0
    for ct in celltypes:
        celltypes_dict[ct]=idx
        idx+=1

    colortest=sns.color_palette("tab10")
    if not os.path.exists(os.path.join(savepath)):
        os.makedirs(savepath)

    for ct in celltypes:
        fig, ax = plt.subplots()
        if ct == 'Unassigned':
            continue

        idx=(ctlist!=ct)
        ax.scatter(
            embedding[idx, plotdimx],
            embedding[idx, plotdimy],
            color=colortest[1],label='others',s=1,alpha=0.5
            )

        idx=(ctlist==ct)
        ax.scatter(
            embedding[idx, plotdimx],
            embedding[idx, plotdimy],
            color=colortest[0],label=ct,s=3,alpha=0.5
            )

        plt.gca().set_aspect('equal', 'datalim')
        fig.set_figheight(10)
        fig.set_figwidth(10)
        ax.legend()
        plt.title(plotname+' embedding', fontsize=24)
        plt.gcf().savefig(os.path.join(savepath,savename+'_'+ct+'.jpg'))
        plt.show()
    #     plt.close()

In [None]:
def inverseLeakyRelu(v,slope=0.01):
    vnegidx=(v<0)
    v[vnegidx]=1/slope*v[vnegidx]
    return v

In [None]:
for s in plot_samples.keys():
    sampleidx=plot_samples[s]
    adj=adj_list[s]
    adj_norm = preprocessing.preprocess_graph(adj)
    adj_decode=None
    if adj_decodeName == 'gala':
        adj_decode=preprocessing.preprocess_graph_sharp(adj)
    celltype_broad=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'top_level']
    celltype_sub=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'cell_type_label']
    region=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'region']
    for xcorr in plot_sample_X:
        samplename=s+'X_'+xcorr
        features=featureslist[samplename]
        if standardizeX:
            features=torch.tensor(scale(features,axis=0, with_mean=True, with_std=True, copy=True))
        if use_cuda:
            model.cuda()
            features = features.cuda().float()
            adj_norm=adj_norm.cuda()
            if adj_decodeName:
                adj_decode=adj_decode.cuda()
        if inverseAct:
            samplename+='_beforeAct'
        sampledir=os.path.join(plotsavepath,samplename)
        savedir=os.path.join(plotsavepath,samplename,'embedding_'+plottype)
        if not os.path.exists(sampledir):
            os.mkdir(sampledir)
        if not os.path.exists(savedir):
            os.mkdir(savedir)
        
        model.eval()
        if adj_decodeName==None:
            adj_recon,mu,logvar,z, features_recon = model(features, adj_norm)
    #         features_recon, z, mu, logvar=model(features.float())
        else:
            adj_recon,mu,logvar,z, features_recon = model(features, adj_norm,adj_decode)
        if inverseAct=='leakyRelu':
            muplot=inverseLeakyRelu(mu.cpu().detach().numpy())
        else:
            muplot=mu.cpu().detach().numpy()
            
        if plottype=='umap':
            reducer = umap.UMAP()
            embedding = reducer.fit_transform(muplot)
            
        elif plottype=='pca':
            embedding=pca.fit_transform(muplot)
        
        plotembeddingbyCT(celltype_broad,'celltype_broad',[],embedding,savedir,plottype+' of '+s)
        plotembeddingbyCT(celltype_sub,'celltype_sub',[],embedding,savedir,plottype+' of '+s)
        plotembeddingbyCT(region,'region',[],embedding,savedir,plottype+' of '+s)
        
        plotembeddingbyCT_contrast(celltype_sub,'celltype_sub',[],embedding,os.path.join(savedir,'contrast'),plottype+' of '+s)

In [None]:
# separate plots by region
for s in plot_samples.keys():
    sampleidx=plot_samples[s]
    adj=adj_list[s]
    adj_norm = preprocessing.preprocess_graph(adj)
    if adj_decodeName == 'gala':
        adj_decode=preprocessing.preprocess_graph_sharp(adj)
    celltype_broad=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'top_level']
    celltype_sub=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'cell_type_label']
    region=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'region']
    for xcorr in plot_sample_X:
        samplename=s+'X_'+xcorr
        features=featureslist[samplename]
        if standardizeX:
            features=torch.tensor(scale(features,axis=0, with_mean=True, with_std=True, copy=True))
        if use_cuda:
            model.cuda()
            features = features.cuda().float()
            adj_norm=adj_norm.cuda()
            if adj_decodeName:
                adj_decode=adj_decode.cuda()
        if inverseAct:
            samplename+='_beforeAct'
        sampledir=os.path.join(plotsavepath,samplename)
        if not os.path.exists(sampledir):
            os.mkdir(sampledir)
        model.eval()
        if adj_decodeName==None:
            adj_recon,mu,logvar,z, features_recon = model(features, adj_norm)
    #         features_recon, z, mu, logvar=model(features.float())
        else:
            adj_recon,mu,logvar,z, features_recon = model(features, adj_norm,adj_decode)
        
        for reg in np.unique(region):
            savedir=os.path.join(plotsavepath,samplename,'embedding_'+plottype+'_'+reg)
            if not os.path.exists(savedir):
                os.mkdir(savedir)
            
            reg_idx=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'region']==reg
            
            muplot=mu.cpu().detach().numpy()
            if inverseAct=='leakyRelu':
                muplot=inverseLeakyRelu(muplot)
            if plottype=='umap':
                reducer = umap.UMAP()
                embedding = reducer.fit_transform(muplot[reg_idx])
            elif plottype=='pca':
                embedding=pca.fit_transform(muplot[reg_idx])
                
            
            plotembeddingbyCT(celltype_broad[reg_idx],'celltype_broad',[],embedding,savedir,plottype+' of '+s+' '+reg)
            plotembeddingbyCT(celltype_sub[reg_idx],'celltype_sub',[],embedding,savedir,plottype+' of '+s+' '+reg)
#             plotembeddingbyCT(region,'region',[],embedding[reg_idx],savedir,'UMAP of '+s)

            plotembeddingbyCT_contrast(celltype_sub[reg_idx],'celltype_sub',[],embedding,os.path.join(savedir,'contrast'),plottype+' of '+s+' '+reg)

In [None]:
# combine all latents to one plot 
for xcorr in plot_sample_X:
    latents=None
    celltype_broad=None
    celltype_sub=None
    region=None
    samplenameList=None
    
    for s in plot_samples.keys():
        sampleidx=plot_samples[s]
        adj=adj_list[s]
        adj_norm = preprocessing.preprocess_graph(adj)
        if adj_decodeName == 'gala':
            adj_decode=preprocessing.preprocess_graph_sharp(adj)
        
        samplename=s+'X_'+xcorr
        features=featureslist[samplename]
        if standardizeX:
            features=torch.tensor(scale(features,axis=0, with_mean=True, with_std=True, copy=True))
        if use_cuda:
            model.cuda()
            features = features.cuda().float()
            adj_norm=adj_norm.cuda()
            if adj_decodeName:
                adj_decode=adj_decode.cuda()
        model.eval()
        if adj_decodeName==None:
            adj_recon,mu,logvar,z, features_recon = model(features, adj_norm)
    #         features_recon, z, mu, logvar=model(features.float())
        else:
            adj_recon,mu,logvar,z, features_recon = model(features, adj_norm,adj_decode)
            
        if latents is None:
            latents=mu.cpu().detach().numpy()
            celltype_broad=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'top_level']
            celltype_sub=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'cell_type_label']
            region=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'region']
            samplenameList=np.repeat(s,mu.shape[0])
        else:
            latents=np.vstack((latents,mu.cpu().detach().numpy()))
            celltype_broad=np.concatenate((celltype_broad,scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'top_level']),axis=None)
            celltype_sub=np.concatenate((celltype_sub,scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'cell_type_label']),axis=None)
            region=np.concatenate((region,scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'region']),axis=None)
            samplenameList=np.concatenate((samplenameList,np.repeat(s,mu.shape[0])),axis=None)
        
    sampledir=os.path.join(plotsavepath,'combined'+xcorr)
    if inverseAct:
        sampledir+='_beforeAct'
    savedir=os.path.join(plotsavepath,'combined'+xcorr,'embedding_'+plottype)
    if not os.path.exists(sampledir):
        os.mkdir(sampledir)
    if not os.path.exists(savedir):
        os.mkdir(savedir)
    
    if inverseAct=='leakyRelu':
        latents=inverseLeakyRelu(latents)
    if plottype=='umap':
        reducer = umap.UMAP()
        embedding = reducer.fit_transform(latents)
    elif plottype=='pca':
        embedding=pca.fit_transform(latents)
        
    plotembeddingbyCT(samplenameList,'sample',[],embedding,savedir,'all samples')
    plotembeddingbyCT(celltype_broad,'celltype_broad',[],embedding,savedir,'all samples')
    plotembeddingbyCT(celltype_sub,'celltype_sub',[],embedding,savedir,'all samples')
    plotembeddingbyCT(region,'region',[],embedding,savedir,'all samples')

    plotembeddingbyCT_contrast(celltype_sub,'celltype_sub',[],embedding,os.path.join(savedir,'contrast'),'all samples')    
    for reg in np.unique(region):
        savedir=os.path.join(sampledir,'embedding_'+plottype+'_'+reg)
        if not os.path.exists(savedir):
            os.mkdir(savedir)

        reg_idx=region==reg

        if plottype=='umap':
            reducer = umap.UMAP()
            embedding = reducer.fit_transform(latents[reg_idx])
        elif plottype=='pca':
            embedding=pca.fit_transform(latents[reg_idx])
        
        plotembeddingbyCT(samplenameList[reg_idx],'sample',[],embedding,savedir,plottype+' of '+'all samples'+' '+reg)
        plotembeddingbyCT(celltype_broad[reg_idx],'celltype_broad',[],embedding,savedir,plottype+' of '+'all samples'+' '+reg)
        plotembeddingbyCT(celltype_sub[reg_idx],'celltype_sub',[],embedding,savedir,plottype+' of '+'all samples'+' '+reg)
#             plotembeddingbyCT(region,'region',[],embedding[reg_idx],savedir,'UMAP of '+s)

        plotembeddingbyCT_contrast(celltype_sub[reg_idx],'celltype_sub',[],embedding,os.path.join(savedir,'contrast'),plottype+' of '+'all samples'+' '+reg)

In [None]:
# separate plots by region and cell types
for s in plot_samples.keys():
    sampleidx=plot_samples[s]
    adj=adj_list[s]
    adj_norm = preprocessing.preprocess_graph(adj)
    if adj_decodeName == 'gala':
        adj_decode=preprocessing.preprocess_graph_sharp(adj)
    celltype_broad=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'top_level']
    celltype_sub=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'cell_type_label']
    region=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'region']
    for xcorr in plot_sample_X:
        samplename=s+'X_'+xcorr
        features=featureslist[samplename]
        if standardizeX:
            features=torch.tensor(scale(features,axis=0, with_mean=True, with_std=True, copy=True))
        if use_cuda:
            model.cuda()
            features = features.cuda().float()
            adj_norm=adj_norm.cuda()
            if adj_decodeName:
                adj_decode=adj_decode.cuda()
        if inverseAct:
            samplename+='_beforeAct'
        sampledir=os.path.join(plotsavepath,samplename)
        if not os.path.exists(sampledir):
            os.mkdir(sampledir)
        model.eval()
        if adj_decodeName==None:
            adj_recon,mu,logvar,z, features_recon = model(features, adj_norm)
    #         features_recon, z, mu, logvar=model(features.float())
        else:
            adj_recon,mu,logvar,z, features_recon = model(features, adj_norm,adj_decode)
            
        for r in np.unique(region):
            ridx=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'region']==r
            for reg in np.unique(celltype_broad):
                savedir=os.path.join(plotsavepath,samplename,'embedding_'+plottype+'_'+reg)
                if not os.path.exists(savedir):
                    os.mkdir(savedir)

                ct_idx=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'top_level']==reg
                reg_idx=np.logical_and(ridx,ct_idx)
                if np.sum(reg_idx)==0:
                    continue
                
                muplot=mu.cpu().detach().numpy()
                if inverseAct=='leakyRelu':
                    muplot=inverseLeakyRelu(muplot)
                if plottype=='umap'
                    reducer = umap.UMAP()
                    embedding = reducer.fit_transform(muplot[reg_idx])
                elif plottype=='pca':
                    embedding=pca.fit_transform(muplot[reg_idx])

    #             plotembeddingbyCT(celltype_broad[reg_idx],'celltype_broad',[],embedding,savedir,'UMAP of '+s+' '+reg)
                plotembeddingbyCT(celltype_sub[reg_idx],'celltype_sub_'+r,[],embedding,savedir,plottype+' of '+r+' '+s+' '+reg)
    #             plotembeddingbyCT(region,'region',[],embedding[reg_idx],savedir,s)

                plotembeddingbyCT_contrast(celltype_sub[reg_idx],'celltype_sub_'+r,[],embedding,os.path.join(savedir,'contrast'),plottype+' of '+r+' '+s+' '+reg)

In [None]:
savedir