In [1]:
import sys
sys.path.append('/home/xinyiz/pamrats')

In [2]:
##This is adapted from https://github.com/tkipf/gae/blob/master/gae/train.py and https://github.com/tkipf/pygcn/blob/master/pygcn/train.py##

import time
import os

# Train on CPU (hide GPU) due to memory constraints
# os.environ['CUDA_VISIBLE_DEVICES'] = ""

import scanpy
import numpy as np
import scipy.sparse as sp

import torch
from torch import optim
from torch.utils.data import DataLoader
# from sklearn.metrics import roc_auc_score
# from sklearn.metrics import average_precision_score

import image.loadImage as loadImage
import gae.gae.optimizer as optimizer
import image.modelsCNN as modelsCNN

import pickle
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import matplotlib.colors
from scipy import stats


In [3]:
imageSizes={}
imageSizes['disease13']=(22210, 22344)
imageSizes['control13']=(22355, 18953)
imageSizes['disease8']=(22294, 19552)
imageSizes['control8']=(22452, 19616)

In [4]:
# Settings
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 
use_cuda=True
fastmode=False #Validate during training pass
seed=3
epochs=10000
saveFreq=10
lr=0.0005 #initial learning rate
lr_adv=0.001
weight_decay=0 #Weight for L2 loss on embedding matrix.

batchsize=16

dropout=0.01
testNodes=0.1 #fraction of total nodes for testing
valNodes=0.05 #fraction of total nodes for validation
# randFeatureSubset=None
model_str='alexnet_regrs'

kernel_size=4
stride=2
padding=1

hidden1=64 #Number of channels in hidden layer 1
hidden2=128 
hidden3=256
hidden4=512
hidden5=512
fc_dim1=512*25*25
fc_dim2=1024

pretrainedAE=None #{'name':'controlphy5XAbin_01_dca','epoch':9990}
# training_samples=['control13','disease13','disease8','control8']
training_samples=['disease13','control13']
targetBatch=None
switchFreq=1
diamThresh_mul=800
minThresh_mul=12
overlap=int(diamThresh_mul*0.7)
areaThresh=diamThresh_mul*diamThresh_mul*0.7
plaqueMaskName='PlaqueMask'
plaqueMaskImg='Maskofplaque.tif'
name='cd13regrs_thresh25min12_overlap70area70_01'
logsavepath='/mnt/external_ssd/xinyi/log/train_cnnRegrs_starmap/'+name
modelsavepath='/mnt/external_ssd/xinyi/models/train_cnnRegrs_starmap/'+name
plotsavepath=os.path.join('/mnt/external_ssd/xinyi/plots/train_cnnRegrs_starmap/'+name,'allk20XA_02_dca_over_leiden0.1_epoch9990')
if not os.path.exists(plotsavepath):
    os.mkdir(plotsavepath)
    
#Load data
sampleidx={'disease13':'AD_mouse9494','control13':'AD_mouse9498','disease8':'AD_mouse9723','control8':'AD_mouse9735'}
datadir='/home/xinyiz/2021-01-13-mAD-test-dataset'

# Set cuda and seed
np.random.seed(seed)
if use_cuda and (not torch.cuda.is_available()):
    print('cuda not available')
    use_cuda=False
torch.manual_seed(seed)
if use_cuda:
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.enabled = True

# Load data
# if randFeatureSubset != None:
#     idx=np.random.choice(features.shape[1],randFeatureSubset,replace=False)
#     features=features[:,idx]

In [5]:
plaqueCentroids={}
plaqueCentroids['disease13']=pd.read_csv('/home/xinyiz/2021-01-13-mAD-test-dataset/AD_mouse9494/trimmed_images/'+plaqueMaskName+'.csv', header=0)
plaqueCentroids['disease8']=pd.read_csv('/home/xinyiz/2021-01-13-mAD-test-dataset/AD_mouse9723/trimmed_images/'+plaqueMaskName+'.csv', header=0)
maxArea=max(np.max(plaqueCentroids['disease13']['Area']),np.max(plaqueCentroids['disease8']['Area']))
plaqueCutoffRadius=max(int(np.sqrt(maxArea)/2),int(diamThresh_mul/2))
# plaqueSizeFactor=maxArea/100
plaqueSizeFactor=1

In [6]:
gaeClusterPath='/mnt/external_ssd/xinyi/plots/train_gae_starmap/allk20XA_02_dca_over/combinedlogminmax_beforeAct/cluster/leiden_nn10mdist025n_pcs40res0.1epoch9990'
with open(gaeClusterPath, 'rb') as input:
    gaeclusterlabels = pickle.load(input)

plot_samples={'disease13':'AD_mouse9494','control13':'AD_mouse9498','disease8':'AD_mouse9723','control8':'AD_mouse9735'}

gaeCoord=None
sampleNames=None
scaleddata=scanpy.read_h5ad(datadir+'/2020-12-27-starmap-mAD-raw.h5ad')
for s in plot_samples.keys():
    sampleidx_s=plot_samples[s] 
    if gaeCoord is None:
        gaeCoord=scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx_s,['y','x']].to_numpy()
        sampleNames=np.repeat(s,np.sum(scaleddata.obs['sample']==sampleidx_s))
    else:
        gaeCoord=np.concatenate((gaeCoord,scaleddata.obs.loc[scaleddata.obs['sample']==sampleidx_s,['y','x']].to_numpy()),axis=0)
        sampleNames=np.concatenate((sampleNames,np.repeat(s,np.sum(scaleddata.obs['sample']==sampleidx_s))),axis=None)
gaeCoord=gaeCoord/0.3
scaleddata=None

In [7]:
def plotExprs(exprs,exprs2,savename,embedding,embedding2,savepath,savenameAdd='',norm=None):
#     fig, ax = plt.subplots(dpi=400)
    if not exprs is None:
        plt.scatter(embedding[:,0],embedding[:,1],s=5,c=exprs,cmap='Greys',edgecolors='blue',linewidth=0.2,alpha=1,marker='o',norm=norm)
    if not exprs2 is None:
        plt.scatter(embedding2[:,0],embedding2[:,1],s=5,c=exprs2,cmap='Greys',edgecolors='green',linewidth=0.2,alpha=1,marker='^',norm=norm)
    plt.colorbar(orientation='vertical', shrink = 0.5)
    plt.gca().set_aspect('equal', 'datalim')
#     fig.set_figheight(5)
#     fig.set_figwidth(5)
    plt.savefig(os.path.join(savepath,savename+savenameAdd+'.jpg'),dpi=400)
#     plt.show()
    
#     fig.clf()
    plt.close('all')

In [8]:
def plotExprsDiff(exprsPos,exprsNeg,exprs2,savename,embeddingPos,embeddingNeg,embedding2,savepath,savenameAdd='',norm=None):
#     fig, ax = plt.subplots(dpi=400)
    if not exprsPos is None:
        plt.scatter(embeddingPos[:,0],embeddingPos[:,1],s=5,c=exprsPos,cmap='Reds',edgecolors='blue',linewidth=0.2,alpha=1,marker='o',norm=norm)
        plt.colorbar(orientation='vertical', shrink = 0.5)
    if not exprsNeg is None:
        plt.scatter(embeddingNeg[:,0],embeddingNeg[:,1],s=5,c=exprsNeg,cmap=plt.cm.get_cmap('Blues'),edgecolors='blue',linewidth=0.2,alpha=1,marker='o',norm=norm)
        plt.colorbar(orientation='vertical', shrink = 0.5)
    if not exprs2 is None:
        plt.scatter(embedding2[:,0],embedding2[:,1],s=5,c=exprs2,cmap='Greys',edgecolors='green',linewidth=0.2,alpha=1,marker='^',norm=norm)
        plt.colorbar(orientation='vertical', shrink = 0.5)
    plt.gca().set_aspect('equal', 'datalim')
#     fig.set_figheight(5)
#     fig.set_figwidth(5)
    plt.savefig(os.path.join(savepath,savename+savenameAdd+'.jpg'),dpi=1200)
#     plt.show()
    
#     fig.clf()
    plt.close('all')

In [9]:
def getClusterLabels(clusterlabels,cellCoords,minPt,imgCoords,diamThresh,imgSize):
    imgClusters=np.zeros(imgCoords.shape[0])-1
    for i in range(imgCoords.shape[0]):
        centroid=imgCoords[i]
        rowstart=centroid[0]-diamThresh/2
        rowend=min(rowstart+diamThresh,imgSize[0])
        colstart=centroid[1]-diamThresh/2
        colend=min(colstart+diamThresh,imgSize[1])
        #find corresponding cluster
        clusterIdxRow=np.logical_and(cellCoords[:,0]>=rowstart,cellCoords[:,0]<rowend)
        clusterIdxCol=np.logical_and(cellCoords[:,1]>=colstart,cellCoords[:,1]<colend)
        clusterRes=clusterlabels[np.logical_and(clusterIdxRow,clusterIdxCol)]
        if clusterRes.size==0:
            print('no cells')
            continue
        clusterResMode,modeCounts=stats.mode(clusterRes,axis=None)
        if modeCounts[0]/clusterRes.size<minPt:
            print('mode less than thresh')
            continue
        imgClusters[i]=clusterResMode[0]
    return imgClusters

In [10]:
# lossCE_binary=torch.nn.CrossEntropyLoss(reduction='none')
minPt=0
savenameAdd='_thresh30'
# savenameAdd='_thresh4445'
lossThreshSize=30
def plotLoss(inputNp,labelsNp,coordNp,name,plotsavepath,savenameAdd=''):
    if not os.path.exists(plotsavepath):
        os.mkdir(plotsavepath)
    
    loss_test_all=np.zeros(inputNp.shape[0])
    loss_test_all_binary=np.zeros(inputNp.shape[0])
    loss_test_all_diff=np.zeros(inputNp.shape[0])
    for i in range(inputNp.shape[0]):
        testInput=inputNp[[i]]
        labels=labelsNp[[i]]
        if use_cuda:
            testInput=torch.tensor(testInput).cuda().float()
            labels=torch.tensor(labels).cuda().float()
        pred = model(testInput)
        loss_test_all[i]=lossCE(pred.flatten(),labels).item()
#         if pred[0]*labels[0]>0: 
        if labels[0]>0 and pred[0]>lossThreshSize:
            loss_test_all_binary[i]=0
#         elif pred[0]+labels[0]>0.001: #original
        elif pred[0]>lossThreshSize or labels[0]>0:
            loss_test_all_binary[i]=1
        else:
            loss_test_all_binary[i]=0
        loss_test_all_diff[i]=pred.flatten()[0]-labels[0]
        
    posidx=(labelsNp>0)
    negidx=labelsNp==0
    loss_test_posLoss=np.sum(loss_test_all[posidx])/np.sum(posidx)
    loss_test_negLoss=np.sum(loss_test_all[negidx])/np.sum(negidx)
    print(name+' results',
          'loss positive: {:.4f}'.format(loss_test_posLoss),
         'loss negative: {:.4f}'.format(loss_test_negLoss))
    
    largeridx=np.logical_and(posidx,loss_test_all_diff>=0)
    smalleridx=np.logical_and(posidx,loss_test_all_diff<0)
    if np.sum(largeridx)>0:
        fig, ax = plt.subplots(dpi=400)
        fig.set_figheight(2.5)
        fig.set_figwidth(10)
        plt.xscale('log')
        plt.hist(loss_test_all_diff[largeridx],bins=np.logspace(np.log10(np.min(loss_test_all_diff[largeridx])),np.log10(np.max(loss_test_all_diff[largeridx])),51))
        plt.savefig(os.path.join(plotsavepath,name+'loss'+s+'_diffHist_positiveLarger'+'.jpg'))
        plt.close()
    if np.sum(smalleridx)>0:
        fig, ax = plt.subplots(dpi=400)
        fig.set_figheight(2.5)
        fig.set_figwidth(10)
        plt.xscale('log')
        plt.hist(np.abs(loss_test_all_diff[smalleridx]),bins=np.logspace(np.log10(np.min(np.abs(loss_test_all_diff[smalleridx]))),np.log10(np.max(np.abs(loss_test_all_diff[smalleridx]))),51))
        plt.savefig(os.path.join(plotsavepath,name+'loss'+s+'_diffHist_positiveSmaller'+'.jpg'))
        plt.close()
    if np.sum(negidx)>0:
        fig, ax = plt.subplots(dpi=400)
        fig.set_figheight(2.5)
        fig.set_figwidth(10)
        plt.xscale('log')
        plt.hist(loss_test_all_diff[negidx]+0.1,bins=np.logspace(np.log10(np.min(loss_test_all_diff[negidx]+0.1)),np.log10(np.max(loss_test_all_diff[negidx]+0.1)),51))
        plt.savefig(os.path.join(plotsavepath,name+'loss'+s+'_diffHist_negative'+'.jpg'))
        plt.close()
    
    if np.sum(posidx)>0 and np.sum(negidx)>0:
        plotExprs(loss_test_all[posidx],loss_test_all[negidx]+0.1,name+'loss'+s,coordNp[posidx],coordNp[negidx],plotsavepath,savenameAdd=savenameAdd,norm=matplotlib.colors.LogNorm())
    if np.sum(posidx)>0:
        plotExprs(loss_test_all[posidx],None,name+'loss'+s+'_positive',coordNp[posidx],None,plotsavepath,savenameAdd=savenameAdd,norm=matplotlib.colors.LogNorm())
    if np.sum(negidx)>0:
        plotExprs(None,loss_test_all[negidx]+0.1,name+'loss'+s+'_negative',None,coordNp[negidx],plotsavepath,savenameAdd=savenameAdd,norm=matplotlib.colors.LogNorm())

    if np.sum(posidx)>0:
        if np.sum(largeridx)>0 and np.sum(smalleridx)>0:
            plotExprsDiff(loss_test_all_diff[largeridx]+0.1,np.abs(loss_test_all_diff[smalleridx]),loss_test_all_diff[negidx]+0.1,name+'loss'+s+'_diff',coordNp[largeridx],coordNp[smalleridx],coordNp[negidx],plotsavepath,savenameAdd=savenameAdd,norm=matplotlib.colors.LogNorm())
            plotExprsDiff(loss_test_all_diff[largeridx]+0.1,np.abs(loss_test_all_diff[smalleridx]),None,name+'loss'+s+'_diff_positive',coordNp[largeridx],coordNp[smalleridx],None,plotsavepath,savenameAdd=savenameAdd,norm=matplotlib.colors.LogNorm())
            plotExprsDiff(np.abs(loss_test_all_diff[smalleridx]),loss_test_all_diff[largeridx]+0.1,loss_test_all_diff[negidx]+0.1,name+'loss'+s+'_diff2',coordNp[smalleridx],coordNp[largeridx],coordNp[negidx],plotsavepath,savenameAdd=savenameAdd,norm=matplotlib.colors.LogNorm())
            plotExprsDiff(np.abs(loss_test_all_diff[smalleridx]),loss_test_all_diff[largeridx]+0.1,None,name+'loss'+s+'_diff2_positive',coordNp[smalleridx],coordNp[largeridx],None,plotsavepath,savenameAdd=savenameAdd,norm=matplotlib.colors.LogNorm())
        elif np.sum(largeridx)>0:
            plotExprsDiff(loss_test_all_diff[largeridx]+0.1,None,loss_test_all_diff[negidx]+0.1,name+'loss'+s+'_diff',coordNp[largeridx],coordNp[smalleridx],coordNp[negidx],plotsavepath,savenameAdd=savenameAdd,norm=matplotlib.colors.LogNorm())
            plotExprsDiff(loss_test_all_diff[largeridx]+0.1,None,None,name+'loss'+s+'_diff_positive',coordNp[largeridx],None,None,plotsavepath,savenameAdd=savenameAdd,norm=matplotlib.colors.LogNorm())
        elif np.sum(smalleridx)>0:
            plotExprsDiff(None,np.abs(loss_test_all_diff[smalleridx]),loss_test_all_diff[negidx]+0.1,name+'loss'+s+'_diff',None,coordNp[smalleridx],coordNp[negidx],plotsavepath,savenameAdd=savenameAdd,norm=matplotlib.colors.LogNorm())
            plotExprsDiff(None,np.abs(loss_test_all_diff[smalleridx]),None,name+'loss'+s+'_diff_positive',None,coordNp[smalleridx],None,plotsavepath,savenameAdd=savenameAdd,norm=matplotlib.colors.LogNorm())
            
            
            
    loss_test_posLoss_binary=np.sum(loss_test_all_binary[posidx])/np.sum(posidx)
    loss_test_negLoss_binary=np.sum(loss_test_all_binary[negidx])/np.sum(negidx)
    print(name+' results',
          'loss positive binary: {:.4f}'.format(loss_test_posLoss_binary),
         'loss negative binary: {:.4f}'.format(loss_test_negLoss_binary))
    
    plotExprs(loss_test_all_binary[posidx],loss_test_all_binary[negidx],name+'loss'+s+'_binary',coordNp[posidx],coordNp[negidx],plotsavepath,savenameAdd=savenameAdd)
    if np.sum(posidx)>0:
        plotExprs(loss_test_all_binary[posidx],None,name+'loss'+s+'_positive'+'_binary',coordNp[posidx],None,plotsavepath,savenameAdd=savenameAdd)
    plotExprs(None,loss_test_all_binary[negidx],name+'loss'+s+'_negative'+'_binary',None,coordNp[negidx],plotsavepath,savenameAdd=savenameAdd)

    
testepoch=920
# Create model
if model_str=='alexnet':
    model = modelsCNN.AlexNet(2)
    lossCE=torch.nn.CrossEntropyLoss(torch.tensor([negweight,posweight]).cuda().float())
if model_str=='alexnet_regrs':
    model = modelsCNN.AlexNet(1,regrs=True)
    lossCE=torch.nn.MSELoss(reduction='none')
if use_cuda:
    model.cuda()

model.load_state_dict(torch.load(os.path.join(modelsavepath,str(testepoch)+'.pt')))
model.eval()
plotepoch='epoch'+str(testepoch)
plotepoch+=savenameAdd
for s in sampleidx.keys():
#     if s in ['disease13']:
#         continue
    print(s)
    if s in ['disease13','disease8']:
        trainInputAll, valInputAll, testInputAll, trainLabelsAll,valLabelsAll,testLabelsAll,trainCoordAll,valCoordAll,testCoordAll=loadImage.loadandsplitPlaque_overlap_regrs(plaqueMaskImg,plaqueSizeFactor,areaThresh,plaqueCentroids[s][['Y','X']].to_numpy().astype(int),plaqueCutoffRadius,sampleidx[s],datadir,diamThresh_mul,overlap,valNodes,testNodes,ifFlip=False,minCutoff=minThresh_mul,seed=seed,returnPos=True)
    if s in ['control13','control8']:
        trainInputAll, valInputAll, testInputAll, trainLabelsAll,valLabelsAll,testLabelsAll,trainCoordAll,valCoordAll,testCoordAll=loadImage.loadandsplit(sampleidx[s],datadir,diamThresh_mul,overlap,valNodes,testNodes,ifFlip=False,minCutoff=minThresh_mul,seed=seed,clf=True,returnPos=True)
        
    trainClusterAll=getClusterLabels(gaeclusterlabels[sampleNames==s],gaeCoord[sampleNames==s],minPt,trainCoordAll,diamThresh_mul,imageSizes[s])
    valClusterAll=getClusterLabels(gaeclusterlabels[sampleNames==s],gaeCoord[sampleNames==s],minPt,valCoordAll,diamThresh_mul,imageSizes[s])
    testClusterAll=getClusterLabels(gaeclusterlabels[sampleNames==s],gaeCoord[sampleNames==s],minPt,testCoordAll,diamThresh_mul,imageSizes[s])

    for c in np.unique(trainClusterAll):
        print(c)
    
        cidx=trainClusterAll==float(c)
        trainInputnp=trainInputAll[cidx]
        trainLabelsnp=trainLabelsAll[cidx]
        trainCoordnp=trainCoordAll[cidx]
        vidx=valClusterAll==float(c)
        valInputnp=valInputAll[vidx]
        valLabelsnp=valLabelsAll[vidx]
        valCoordnp=valCoordAll[vidx]
        tidx=testClusterAll==float(c)
        testInputnp=testInputAll[tidx]
        testLabelsnp=testLabelsAll[tidx]
        testCoordnp=testCoordAll[tidx]
        
        plotLoss(testInputnp,testLabelsnp,testCoordnp,'test',os.path.join(plotsavepath,str(c)),plotepoch)
        plotLoss(trainInputnp,trainLabelsnp,trainCoordnp,'train',os.path.join(plotsavepath,str(c)),plotepoch)
        plotLoss(valInputnp,valLabelsnp,valCoordnp,'val',os.path.join(plotsavepath,str(c)),plotepoch)
        plotLoss(np.concatenate((trainInputnp,valInputnp,testInputnp),axis=0),np.concatenate((trainLabelsnp,valLabelsnp,testLabelsnp)),np.concatenate((trainCoordnp,valCoordnp,testCoordnp),axis=0),'all',os.path.join(plotsavepath,str(c)),plotepoch)
    
    trainInputnp, valInputnp, testInputnp, trainLabelsnp,valLabelsnp,testLabelsnp,trainCoordnp,valCoordnp,testCoordnp,trainClusterAll,valClusterAll,testClusterAll=None,None,None,None,None,None,None,None,None,None,None,None

disease13
plaque1959
no plaque2434
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
no cells
-1.0
test results loss positive: 260084480.0000 loss negative: 0.0000
test results loss positive binary: 0.0000 loss negative binary: 0.0000
train results loss positive: 9972315.3333 loss negative: 0.0000
train results loss positive binary: 1.0000 loss negative binary: 0.0000


  loss_test_posLoss=np.sum(loss_test_all[posidx])/np.sum(posidx)


val results loss positive: nan loss negative: 0.0000


  loss_test_posLoss_binary=np.sum(loss_test_all_binary[posidx])/np.sum(posidx)


val results loss positive binary: nan loss negative binary: 0.0000
all results loss positive: 72500356.5000 loss negative: 0.0000
all results loss positive binary: 0.7500 loss negative binary: 0.0000
0.0
test results loss positive: 91220876.6654 loss negative: 754544.7354
test results loss positive binary: 0.1765 loss negative binary: 0.0667
train results loss positive: 8585659.3649 loss negative: 0.0000
train results loss positive binary: 0.0588 loss negative binary: 0.0000
val results loss positive: 78802178.9917 loss negative: 0.0000
val results loss positive binary: 0.0000 loss negative binary: 0.0000
all results loss positive: 19001110.5874 loss negative: 87399.0041
all results loss positive binary: 0.0636 loss negative binary: 0.0077
1.0
test results loss positive: 297148038.0560 loss negative: 19205.3887
test results loss positive binary: 0.1277 loss negative binary: 0.2500
train results loss positive: 20289986.7703 loss negative: 0.0000
train results loss positive binary: 0.091

all results loss positive: nan loss negative: 44033.6357
all results loss positive binary: nan loss negative binary: 0.0044
2.0
test results loss positive: nan loss negative: 18085.8147
test results loss positive binary: nan loss negative binary: 0.0172
train results loss positive: nan loss negative: 0.0000
train results loss positive binary: nan loss negative binary: 0.0000
val results loss positive: nan loss negative: 6099.8548
val results loss positive binary: nan loss negative binary: 0.0417
all results loss positive: nan loss negative: 2506.0247
all results loss positive binary: nan loss negative binary: 0.0042
3.0
test results loss positive: nan loss negative: 4514290.0952
test results loss positive binary: nan loss negative binary: 0.0238
train results loss positive: nan loss negative: 11588.7331
train results loss positive binary: nan loss negative binary: 0.0029
val results loss positive: nan loss negative: 11618.5009
val results loss positive binary: nan loss negative binary:

  loss_test_negLoss=np.sum(loss_test_all[negidx])/np.sum(negidx)
  loss_test_negLoss_binary=np.sum(loss_test_all_binary[negidx])/np.sum(negidx)


val results loss positive: nan loss negative: nan
val results loss positive binary: nan loss negative binary: nan
all results loss positive: nan loss negative: 0.0000
all results loss positive binary: nan loss negative binary: 0.0000
disease8
plaque758


MemoryError: Unable to allocate 19.3 GiB for an array with shape (4041, 1, 800, 800) and data type float64