In [1]:
import sys
sys.path.append('/home/xinyiz/pamrats')

import time
import os

import scanpy
import numpy as np
import scipy.sparse as sp

import torch
from torch import optim

import image.loadImage as loadImage
import gae.gae.optimizer as optimizer
import image.modelsCNN as modelsCNN

import pickle
import matplotlib.pyplot as plt
import seaborn as sns
import umap
import pandas as pd
from sklearn.preprocessing import scale
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
from sklearn.cluster import DBSCAN,MiniBatchKMeans,AgglomerativeClustering
from sklearn import metrics

import anndata as ad
import gc
import matplotlib.font_manager as fm
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar

In [2]:
#load pretrained GAE

datadir='/home/xinyiz/2021-01-13-mAD-test-dataset'
sampleidx={'disease13':'AD_mouse9494','control13':'AD_mouse9498','disease8':'AD_mouse9723','control8':'AD_mouse9735'}


scaleddata=scanpy.read_h5ad(datadir+'/2020-12-27-starmap-mAD-raw.h5ad')

cellCoord={}
for s in sampleidx.keys():
    sampleidx_s=sampleidx[s] 
    cellCoord[s]=((scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx_s,['y','x']].to_numpy())/0.3).astype(int)
scaleddata=None


In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = "2" 
ifplot=True
ifcluster=True

inverseAct=None
# inverseAct=None
plottype='umap'
pca=PCA()
minCells=15 #min number of cells for analysis
# clustermethod=['kmeanbatch']
# clustermethod=['leiden','agglomerative','kmeanbatch']
clustermethod=['leiden']
#umap/leiden clustering parameters
n_neighbors=10
min_dist=0.25
n_pcs=40 #for clustering
# resolution=[0.5,0.8,1,1.5]
resolution=[0.05,0.1,0.2,0.3,0.5,0.8,1,1.5]
plotepoch=36
savenameAdd=''
#DBscan
epslist= [6,8,10]
min_sampleslist=[15,30,45] 
#agglomerative
nclusterlist=[2,3,4,5,8,10,15]
aggMetric=['euclidean']


combineCelltype={'glia':['Astro','Micro', 'OPC', 'Oligo'],'CA':['CA1', 'CA2', 'CA3']}

use_cuda=True
fastmode=False #Validate during training pass
seed=3
kernel_size=4
stride=2
padding=1

# fc_dim1=6000
hidden1=64 #Number of channels in hidden layer 1
hidden2=128 
hidden3=256
hidden4=256
hidden5=96
fc_dim1=96*25*25
fc_dim2=5000
# hidden1=64 #Number of channels in hidden layer 1
# hidden2=128 
# hidden3=256
# hidden4=512
# hidden5=512
# fc_dim1=512*25*25
# fc_dim2=1024
# # fc_dim3=128
# # fc_dim4=128
# # gcn_dim1=2600
# adv_hidden=128

model_str='cnn_vae'
targetBatch=None
diamThresh_mul=800
minThresh_mul=12
overlap=int(diamThresh_mul*0.5)
name='cd13_thresh25_02'
logsavepath='/mnt/external_ssd/xinyi/log/train_jointGAEcnn_starmap/'+name
modelsavepath='/mnt/external_ssd/xinyi/models/train_jointGAEcnn_starmap/'+name
plotsavepath='/mnt/external_ssd/xinyi/plots/train_jointGAEcnn_starmap/'+name
# modelsavepath='/mnt/external_ssd/xinyi/models/train_cnn_starmap/'+name
# plotsavepath='/mnt/external_ssd/xinyi/plots/train_cnn_starmap/'+name

#Load data
plot_samples={'disease13':'AD_mouse9494','control13':'AD_mouse9498','disease8':'AD_mouse9723','control8':'AD_mouse9735'}
datadir='/home/xinyiz/2021-01-13-mAD-test-dataset'    

In [4]:
# Set cuda and seed
np.random.seed(seed)
if use_cuda and (not torch.cuda.is_available()):
    print('cuda not available')
    use_cuda=False
torch.manual_seed(seed)
if use_cuda:
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.enabled = True


In [5]:
# Create model
if model_str=='cnn_vae':
    model = modelsCNN.CNN_VAE(kernel_size, stride, padding, 1, hidden1, hidden2, hidden3, hidden4, hidden5, fc_dim1,fc_dim2)
if model_str=='cnn_vae_alexnet':
    model = modelsCNN.CNN_VAE_alexnet(fc_dim1)      
if use_cuda:
    model.cuda()    
model.load_state_dict(torch.load(os.path.join(modelsavepath,str(plotepoch)+'.pt')))


<All keys matched successfully>

In [8]:
np.random.seed(seed)
def inverseLeakyRelu(v,slope=0.01):
    vnegidx=(v<0)
    v[vnegidx]=1/slope*v[vnegidx]
    return v

In [6]:
#recon examples by clusters
gaeClusterPath='/mnt/external_ssd/xinyi/plots/train_gae_starmap/allk20XA_02_dca_over/combinedlogminmax_beforeAct/cluster/leiden_nn10mdist025n_pcs40res0.1epoch9990'
with open(gaeClusterPath, 'rb') as input:
    gaeclusterlabels = pickle.load(input)
gaeclusterlabels=gaeclusterlabels.astype(int)

sampleNames=None
for s in sampleidx.keys():
    sampleidx_s=sampleidx[s] 
    if sampleNames is None:
        sampleNames=np.repeat(s,cellCoord[s].shape[0])
    else:
        sampleNames=np.concatenate((sampleNames,np.repeat(s,cellCoord[s].shape[0])),axis=None)


np.random.seed(3)

examplesavepath=os.path.join(plotsavepath,'reconExamples')
if not os.path.exists(examplesavepath):
    os.mkdir(examplesavepath)
model.eval()
for s in plot_samples.keys():
    training_samples_t=s
    for c in np.unique(gaeclusterlabels):
#         if c in [0,1]:
#             continue
#     for c in [0,1]:
        cidx=np.arange(cellCoord[training_samples_t].shape[0])[gaeclusterlabels[sampleNames==s]==c]
        if cidx.size==0:
            continue
        cidx=np.random.choice(cidx,8,replace=False)
        imgInputnp=loadImage.load_cellCentroid(cellCoord[training_samples_t][cidx],sampleidx[training_samples_t],datadir,diamThresh_mul,ifFlip=False,seed=3,imagename='pi_sum.tif',minmaxscale=True,nchannels=1)
        for i in range(imgInputnp.shape[0]):
            imgInput=imgInputnp[[i]]
            if use_cuda:
                imgInput=torch.tensor(imgInput).cuda().float()
            recon,z, mu, logvar = model(imgInput)
            recon=recon.cpu().detach().numpy()
            reconmin=np.min(recon)
            reconmax=np.max(recon)
            recon=(recon-reconmin)/(reconmax-reconmin)
            imgInput=imgInput.cpu().detach().numpy()
#             plt.imsave(os.path.join(examplesavepath,s+'cluster'+str(c)+'_'+str(i)+'epoch'+str(plotepoch)+'_recon.jpg'),np.stack((recon.reshape((diamThresh_mul,diamThresh_mul)),np.zeros((diamThresh_mul,diamThresh_mul)),np.zeros((diamThresh_mul,diamThresh_mul))),axis=2))
#             plt.imsave(os.path.join(examplesavepath,s+'cluster'+str(c)+'_'+str(i)+'epoch'+str(plotepoch)+'_input.jpg'),np.stack((imgInput.reshape((diamThresh_mul,diamThresh_mul)),np.zeros((diamThresh_mul,diamThresh_mul)),np.zeros((diamThresh_mul,diamThresh_mul))),axis=2))
            
            fontprops = fm.FontProperties(size=2,family='arial')
            fig, ax = plt.subplots(figsize=(1, 1), dpi=800)
            ax.imshow(np.stack((recon.reshape((diamThresh_mul,diamThresh_mul)),np.zeros((diamThresh_mul,diamThresh_mul)),np.zeros((diamThresh_mul,diamThresh_mul))),axis=2))
            scalebar = AnchoredSizeBar(ax.transData,
                           106, u'10 \u03bcm', 'lower left', 
                           pad=0.1,
                           color='white',
                           frameon=False,
                           size_vertical=0.3,
                           fontproperties=fontprops)

            ax.add_artist(scalebar)
            ax.set_yticks([])
            ax.set_xticks([])
            plt.savefig(os.path.join(examplesavepath,s+'cluster'+str(c)+'_'+str(i)+'epoch'+str(plotepoch)+'_recon_scalebar.jpg'),dpi=800)
            plt.close()
            
            fig, ax = plt.subplots(figsize=(1, 1), dpi=800)
            ax.imshow(np.stack((imgInput.reshape((diamThresh_mul,diamThresh_mul)),np.zeros((diamThresh_mul,diamThresh_mul)),np.zeros((diamThresh_mul,diamThresh_mul))),axis=2))
            scalebar = AnchoredSizeBar(ax.transData,
                           106, u'10 \u03bcm', 'lower left', 
                           pad=0.1,
                           color='white',
                           frameon=False,
                           size_vertical=0.3,
                           fontproperties=fontprops)

            ax.add_artist(scalebar)
            ax.set_yticks([])
            ax.set_xticks([])
            plt.savefig(os.path.join(examplesavepath,s+'cluster'+str(c)+'_'+str(i)+'epoch'+str(plotepoch)+'_input_scalebar.jpg'),dpi=800)
            plt.close()

(22210, 22344)


findfont: Font family ['arial'] not found. Falling back to DejaVu Sans.


(22210, 22344)
(22210, 22344)
(22210, 22344)
(22210, 22344)
(22210, 22344)
(22210, 22344)
(22210, 22344)
(22355, 18953)
(22355, 18953)
(22355, 18953)
(22355, 18953)
(22355, 18953)
(22355, 18953)
(22355, 18953)
(22355, 18953)
(22355, 18953)
(22294, 19552)
(22294, 19552)
(22294, 19552)
(22294, 19552)
(22294, 19552)
(22294, 19552)
(22294, 19552)
(22294, 19552)
(22452, 19616)
(22452, 19616)
(22452, 19616)
(22452, 19616)
(22452, 19616)
(22452, 19616)
(22452, 19616)
(22452, 19616)


In [None]:
#recon examples
examplesavepath=os.path.join(plotsavepath,'reconExamples')
if not os.path.exists(examplesavepath):
    os.mkdir(examplesavepath)
model.eval()
for s in plot_samples.keys():
    training_samples_t=s
    imgInputnp=loadImage.load_cellCentroid(cellCoord[training_samples_t][100:105],sampleidx[training_samples_t],datadir,diamThresh_mul,ifFlip=False,seed=3,imagename='pi_sum.tif',minmaxscale=True,nchannels=1)
    for i in range(imgInputnp.shape[0]):
        imgInput=imgInputnp[[i]]
        if use_cuda:
            imgInput=torch.tensor(imgInput).cuda().float()
        recon,z, mu, logvar = model(imgInput)
        recon=recon.cpu().detach().numpy()
        reconmin=np.min(recon)
        reconmax=np.max(recon)
        recon=(recon-reconmin)/(reconmax-reconmin)
        imgInput=imgInput.cpu().detach().numpy()
        plt.imsave(os.path.join(examplesavepath,s+str(i)+'epoch'+str(plotepoch)+'_recon.jpg'),np.stack((recon.reshape((diamThresh_mul,diamThresh_mul)),np.zeros((diamThresh_mul,diamThresh_mul)),np.zeros((diamThresh_mul,diamThresh_mul))),axis=2))
        plt.imsave(os.path.join(examplesavepath,s+str(i)+'epoch'+str(plotepoch)+'_input.jpg'),np.stack((imgInput.reshape((diamThresh_mul,diamThresh_mul)),np.zeros((diamThresh_mul,diamThresh_mul)),np.zeros((diamThresh_mul,diamThresh_mul))),axis=2))