In [None]:
import sys
sys.path.append('/home/xinyiz/pamrats')

import time
import os

import scanpy
import numpy as np
import scipy.sparse as sp

import torch
from torch import optim

# from sklearn.metrics import roc_auc_score
# from sklearn.metrics import average_precision_score

import gae.gae.optimizer as optimizer
import gae.gae.model
import gae.gae.preprocessing as preprocessing

import pickle
import matplotlib.pyplot as plt
import seaborn as sns
import umap
import pandas as pd
from sklearn.preprocessing import scale

In [None]:
use_cuda=True
fastmode=False #Validate during training pass
seed=3
useSavedMaskedEdges=False
maskedgeName='knn20_connectivity'
plotepoch=190
hidden1=32 #Number of units in hidden layer 1
hidden2=16 #Number of units in hidden layer 2
# hidden3=16
fc_dim1=128
fc_dim2=2112
fc_dim3=2112
fc_dim4=2112
dropout=0
# randFeatureSubset=None
model_str='gcn_vae_xa' #currently available choices: gcn_vae, gcn_ae, gcn_vae3, gcn_vae_xa
plot_samples={'disease13':'AD_mouse9494','control13':'AD_mouse9498','disease8':'AD_mouse9723','control8':'AD_mouse9735'}
plot_sample_X=['corrected','scaled']
standardizeX=True
name='c13k20XA_04_lossXreconOnly_wKL'
logsavepath='/home/xinyiz/pamrats/log/train_gae_starmap/'+name
modelsavepath='/home/xinyiz/pamrats/models/train_gae_starmap/'+name
plotsavepath='/home/xinyiz/pamrats/plots/train_gae_starmap/'+name

In [None]:
# Set cuda and seed
np.random.seed(seed)
if use_cuda and (not torch.cuda.is_available()):
    print('cuda not available')
    use_cuda=False
torch.manual_seed(seed)
if use_cuda:
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.enabled = True


In [None]:
#Load data
scaleddata=scanpy.read_h5ad('/mnt/xinyi/2021-01-13-mAD-test-dataset/2020-12-27-starmap-mAD-scaled.h5ad')
savedir=os.path.join('/mnt/xinyi/','starmap')
adj_dir=os.path.join(savedir,'a')

featureslist={}
featureslist['disease13X_corrected']=torch.tensor(scaleddata.layers['corrected'][scaleddata.obs['sample']=='AD_mouse9494'])
featureslist['control13X_corrected']=torch.tensor(scaleddata.layers['corrected'][scaleddata.obs['sample']=='AD_mouse9498'])
featureslist['disease8X_corrected']=torch.tensor(scaleddata.layers['corrected'][scaleddata.obs['sample']=='AD_mouse9723'])
featureslist['control8X_corrected']=torch.tensor(scaleddata.layers['corrected'][scaleddata.obs['sample']=='AD_mouse9735'])
featureslist['disease13X_scaled']=torch.tensor(scaleddata.layers['scaled'][scaleddata.obs['sample']=='AD_mouse9494'])
featureslist['control13X_scaled']=torch.tensor(scaleddata.layers['scaled'][scaleddata.obs['sample']=='AD_mouse9498'])
featureslist['disease8X_scaled']=torch.tensor(scaleddata.layers['scaled'][scaleddata.obs['sample']=='AD_mouse9723'])
featureslist['control8X_scaled']=torch.tensor(scaleddata.layers['scaled'][scaleddata.obs['sample']=='AD_mouse9735'])

adj_list={}
adj_list['disease13']=sp.load_npz(os.path.join(adj_dir,maskedgeName+'_AD_mouse9494.npz'))
adj_list['control13']=sp.load_npz(os.path.join(adj_dir,maskedgeName+'_AD_mouse9498.npz'))
adj_list['disease8']=sp.load_npz(os.path.join(adj_dir,maskedgeName+'_AD_mouse9723.npz'))
adj_list['control8']=sp.load_npz(os.path.join(adj_dir,maskedgeName+'_AD_mouse9735.npz'))

In [1]:
# load model
num_nodes,num_features = featureslist['disease13X_corrected'].shape
if model_str=='gcn_vae_xa':
    model  = gae.gae.model.GCNModelVAE_XA(num_features, hidden1, hidden2,fc_dim1,fc_dim2,fc_dim3,fc_dim4, dropout)
model.load_state_dict(torch.load(os.path.join(modelsavepath,str(plotepoch)+'.pt')))

NameError: name 'featureslist' is not defined

In [None]:
def plotembeddingbyCT(ctlist,savename,excludelist,embedding,savepath,plotname):
    celltypes=np.unique(ctlist)
    celltypes_dict={}
    idx=0
    for ct in celltypes:
        celltypes_dict[ct]=idx
        idx+=1
        
    colortest=sns.color_palette("husl", celltypes.size)
    np.random.shuffle(colortest)
    fig, ax = plt.subplots(dpi=400)
    for ct in celltypes:
        if ct in excludelist:
            continue
        idx=(ctlist==ct)
        ax.scatter(
            embedding[idx, 0],
            embedding[idx, 1],
            color=colortest[celltypes_dict[ct]],label=ct,s=3,alpha=0.5
            )

    plt.gca().set_aspect('equal', 'datalim')
    fig.set_figheight(10)
    fig.set_figwidth(10)
    box = ax.get_position()
    ax.set_position([box.x0, box.y0 + box.height * 0.1,
                     box.width, box.height * 0.9])
    # Put a legend below current axis
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
              fancybox=True, shadow=True, ncol=5)
#     ax.legend(ncol=3)
    plt.title('UMAP of '+plotname+' embedding', fontsize=24)
    plt.savefig(os.path.join(savepath,savename+'.jpg'))
    plt.show()

In [None]:
def plotembeddingbyCT_contrast(ctlist,savename,excludelist,embedding,savepath,plotname): 
    celltypes=np.unique(ctlist)
    celltypes_dict={}
    idx=0
    for ct in celltypes:
        celltypes_dict[ct]=idx
        idx+=1

    colortest=sns.color_palette("tab10")
    if not os.path.exists(os.path.join(savepath)):
        os.makedirs(savepath)

    for ct in celltypes:
        fig, ax = plt.subplots()
        if ct == 'Unassigned':
            continue

        idx=(ctlist!=ct)
        ax.scatter(
            embedding[idx, 0],
            embedding[idx, 1],
            color=colortest[1],label='others',s=1,alpha=0.5
            )

        idx=(ctlist==ct)
        ax.scatter(
            embedding[idx, 0],
            embedding[idx, 1],
            color=colortest[0],label=ct,s=3,alpha=0.5
            )

        plt.gca().set_aspect('equal', 'datalim')
        fig.set_figheight(10)
        fig.set_figwidth(10)
        ax.legend()
        plt.title('UMAP of '+plotname+' embedding', fontsize=24)
        plt.gcf().savefig(os.path.join(savepath,savename+'_'+ct+'.jpg'))
        plt.show()
    #     plt.close()

In [None]:
for s in plot_samples.keys():
    sampleidx=plot_samples[s]
    adj=adj_list[s]
    adj_norm = preprocessing.preprocess_graph(adj)
    celltype_broad=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'top_level']
    celltype_sub=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'cell_type_label']
    region=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'region']
    for xcorr in plot_sample_X:
        samplename=s+'X_'+xcorr
        features=featureslist[samplename]
        if standardizeX:
            features=torch.tensor(scale(features,axis=0, with_mean=True, with_std=True, copy=True))
        if use_cuda:
            model.cuda()
            features = features.cuda().float()
            adj_norm=adj_norm.cuda()
        sampledir=os.path.join(plotsavepath,samplename)
        savedir=os.path.join(plotsavepath,samplename,'embedding_umap')
        if not os.path.exists(sampledir):
            os.mkdir(sampledir)
        if not os.path.exists(savedir):
            os.mkdir(savedir)
        
        model.eval()
        adj_recon,mu,logvar, features_recon = model(features, adj_norm)
        reducer = umap.UMAP()
        umap_embedding = reducer.fit_transform(mu.cpu().detach().numpy())
        
        plotembeddingbyCT(celltype_broad,'celltype_broad',[],umap_embedding,savedir,s)
        plotembeddingbyCT(celltype_sub,'celltype_sub',[],umap_embedding,savedir,s)
        plotembeddingbyCT(region,'region',[],umap_embedding,savedir,s)
        
        plotembeddingbyCT_contrast(celltype_sub,'celltype_sub',[],umap_embedding,os.path.join(savedir,'contrast'),s)

In [None]:
# separate plots by region
for s in plot_samples.keys():
    sampleidx=plot_samples[s]
    adj=adj_list[s]
    adj_norm = preprocessing.preprocess_graph(adj)
    celltype_broad=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'top_level']
    celltype_sub=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'cell_type_label']
    region=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'region']
    for xcorr in plot_sample_X:
        samplename=s+'X_'+xcorr
        features=featureslist[samplename]
        if standardizeX:
            features=torch.tensor(scale(features,axis=0, with_mean=True, with_std=True, copy=True))
        if use_cuda:
            model.cuda()
            features = features.cuda().float()
            adj_norm=adj_norm.cuda()
        sampledir=os.path.join(plotsavepath,samplename)
        if not os.path.exists(sampledir):
            os.mkdir(sampledir)
        model.eval()
        adj_recon,mu,logvar, features_recon = model(features, adj_norm)

        
        for reg in np.unique(region):
            savedir=os.path.join(plotsavepath,samplename,'embedding_umap_'+reg)
            if not os.path.exists(savedir):
                os.mkdir(savedir)
            
            reg_idx=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'region']==reg
            
            reducer = umap.UMAP()
            muplot=mu.cpu().detach().numpy()
            umap_embedding = reducer.fit_transform(muplot[reg_idx])
            
            plotembeddingbyCT(celltype_broad[reg_idx],'celltype_broad',[],umap_embedding,savedir,s+' '+reg)
            plotembeddingbyCT(celltype_sub[reg_idx],'celltype_sub',[],umap_embedding,savedir,s+' '+reg)
#             plotembeddingbyCT(region,'region',[],umap_embedding[reg_idx],savedir,s)

            plotembeddingbyCT_contrast(celltype_sub[reg_idx],'celltype_sub',[],umap_embedding,os.path.join(savedir,'contrast'),s+' '+reg)

In [None]:
# combine all latents to one plot 
for xcorr in plot_sample_X:
    latents=None
    celltype_broad=None
    celltype_sub=None
    region=None
    samplenameList=None
    
    for s in plot_samples.keys():
        sampleidx=plot_samples[s]
        adj=adj_list[s]
        adj_norm = preprocessing.preprocess_graph(adj)
        
        samplename=s+'X_'+xcorr
        features=featureslist[samplename]
        if standardizeX:
            features=torch.tensor(scale(features,axis=0, with_mean=True, with_std=True, copy=True))
        if use_cuda:
            model.cuda()
            features = features.cuda().float()
            adj_norm=adj_norm.cuda()
        model.eval()
        adj_recon,mu,logvar, features_recon = model(features, adj_norm)
        
        if latents is None:
            latents=mu.cpu().detach().numpy()
            celltype_broad=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'top_level']
            celltype_sub=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'cell_type_label']
            region=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'region']
            samplenameList=np.repeat(s,mu.shape[0])
        else:
            latents=np.vstack((latents,mu.cpu().detach().numpy()))
            celltype_broad=np.concatenate((celltype_broad,scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'top_level']),axis=None)
            celltype_sub=np.concatenate((celltype_sub,scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'cell_type_label']),axis=None)
            region=np.concatenate((region,scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx,'region']),axis=None)
            samplenameList=np.concatenate((samplenameList,np.repeat(s,mu.shape[0])),axis=None)
        
    sampledir=os.path.join(plotsavepath,'combined'+xcorr)
    savedir=os.path.join(plotsavepath,'combined'+xcorr,'embedding_umap')
    if not os.path.exists(sampledir):
        os.mkdir(sampledir)
    if not os.path.exists(savedir):
        os.mkdir(savedir)
    reducer = umap.UMAP()
    umap_embedding = reducer.fit_transform(latents)

    plotembeddingbyCT(samplenameList,'sample',[],umap_embedding,savedir,'all samples')
    plotembeddingbyCT(celltype_broad,'celltype_broad',[],umap_embedding,savedir,'all samples')
    plotembeddingbyCT(celltype_sub,'celltype_sub',[],umap_embedding,savedir,'all samples')
    plotembeddingbyCT(region,'region',[],umap_embedding,savedir,'all samples')

    plotembeddingbyCT_contrast(celltype_sub,'celltype_sub',[],umap_embedding,os.path.join(savedir,'contrast'),'all samples')    
    for reg in np.unique(region):
        savedir=os.path.join(sampledir,'embedding_umap_'+reg)
        if not os.path.exists(savedir):
            os.mkdir(savedir)

        reg_idx=region==reg

        reducer = umap.UMAP()
        umap_embedding = reducer.fit_transform(latents[reg_idx])
        
        plotembeddingbyCT(samplenameList[reg_idx],'sample',[],umap_embedding,savedir,'all samples'+' '+reg)
        plotembeddingbyCT(celltype_broad[reg_idx],'celltype_broad',[],umap_embedding,savedir,'all samples'+' '+reg)
        plotembeddingbyCT(celltype_sub[reg_idx],'celltype_sub',[],umap_embedding,savedir,'all samples'+' '+reg)
#             plotembeddingbyCT(region,'region',[],umap_embedding[reg_idx],savedir,s)

        plotembeddingbyCT_contrast(celltype_sub[reg_idx],'celltype_sub',[],umap_embedding,os.path.join(savedir,'contrast'),'all samples'+' '+reg)