In [1]:
import time
import os

import scanpy
import numpy as np
import scipy.sparse as sp

import torch
from torch import optim
from torch.utils.data import DataLoader

import models.loadImg as loadImg
import models.modelsCNN as modelsCNN
import models.optimizer as optimizer

import pickle
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import gc
from skimage import io
import scipy.stats
from sklearn.metrics import pairwise_distances

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
use_cuda=True
datadir='/media/xinyi/data'
name='exp0'
plotsavepath='/media/xinyi/plots/cnnvae'+name
sampledir=plotsavepath
clustersavedir_alltrain=os.path.join(sampledir,'cluster_alltrain_reordered')
ep=311
with open(os.path.join(datadir,'processed','train_cnnvae_names'), 'rb') as input:
    allImgNames=pickle.load(input)
#plot by disease progression
br1003aSpecs=pd.read_excel('/media/xinyi/data/BR1003a specs.xlsx',header=10)
br301Specs=pd.read_excel('/media/xinyi/data/BR301 specs.xlsx',header=10)
br8018aSpecs=pd.read_excel('/media/xinyi/data/BR8018a specs.xlsx',header=10)
br1003aSpecs.index=br1003aSpecs.loc[:,'Position']
br301Specs.index=br301Specs.loc[:,'Position']
br8018aSpecs.index=br8018aSpecs.loc[:,'Position']

progList=np.copy(allImgNames)
for s in np.unique(allImgNames):
    ssplit=s.split('_')
    if 'br1003a'==ssplit[0]:
        prog_s=br1003aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br301'==ssplit[0]:
        prog_s=br301Specs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br8018a'==ssplit[0]:
        prog_s=br8018aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    progList[allImgNames==s]=prog_s
    




In [3]:
metadata=pd.read_csv(os.path.join(datadir,'Supplementary Table 1_v1.csv'),header=0)
metadata.index=metadata.sample_id

In [4]:
allPatientIDs=np.repeat('none',allImgNames.size).astype(object)
for s in np.unique(allImgNames):
    sidx=allImgNames==s
    slideID=s.split('_')[0]
    coreID=s.split('_')[-1]
    allPatientIDs[sidx]=metadata.patient_id[slideID+'_'+str.lower(coreID)]

In [5]:
savenamesample='alltrain'

neworder=[1, 5, 3, 7, 2, 0, 4, 6]
#use chosen subcluster number and save plots
scanpy.settings.verbosity = 3
# subcluster=8
subclusterDict={0:[4],1:[6],2:[8],3:[6],4:[6],5:[6],6:[6],7:[4]}
ncluster=8

plotepoch=311
clusterplotdir=os.path.join(clustersavedir_alltrain,'plots')
n_pcs=50
savenamecluster='minibatchkmean_ncluster'+str(ncluster)+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)
with open(os.path.join(clustersavedir_alltrain,'minibatchkmean_ncluster'+str(ncluster)+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+'_all'), 'rb') as output:
    clusterRes=pickle.load(output)

kmeans_sub=(np.zeros(clusterRes.size)-1).astype(str)
savenameAdd='_plottingIdx_progBalanced_'+str(0)
subclusternumbers=[4,6,8,6,6,6,6,4]
savenamecluster='minibatchkmean_ncluster'+str(ncluster)+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+savenameAdd
for cnew in np.unique(clusterRes):
#     print('cluster'+str(c))
    c=neworder[cnew]
    
    subclustersavedir_alltrain=os.path.join(clustersavedir_alltrain,savenamecluster+'_subcluster'+str(c))
    with open(os.path.join(subclustersavedir_alltrain,'minibatchkmean_ncluster'+str(subclusternumbers[c])+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+'_all'), 'rb') as output:
        subclusterRes=pickle.load(output)
    print(np.unique(subclusterRes))
    kmeans_sub[clusterRes==cnew]=np.char.add(np.repeat(str(cnew)+'-',subclusterRes.size),subclusterRes.astype(str))
   

[0 1 2 3 4 5]
[0 1 2 3 4 5]
[0 1 2 3 4 5]
[0 1 2 3]
[0 1 2 3 4 5 6 7]
[0 1 2 3]
[0 1 2 3 4 5]
[0 1 2 3 4 5]


In [6]:
with open(os.path.join(datadir,'processed','train_cnnvae_coord'), 'rb') as output:
    coordlist=pickle.load(output)

In [7]:
for p in np.unique(progList):
    if p=='Ductal carcinoma in situ':
        progList[progList==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ and breast tissue':
        progList[progList==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ with early infiltratio':
        progList[progList==p]='DCIS with early infiltration'
    
    elif p=='Micropapillary type ductal carcinoma in situ wi':
        progList[progList==p]='DCIS with early infiltration'    
#     elif p=='Atypical hyperlasia':
#         progList[progList==p]='Hyperplasia'

In [8]:
progInclude=np.array(['Breast tissue','Cancer adjacent normal breast tissue','Hyperplasia','Atypical hyperplasia','DCIS and breast tissue',  'DCIS with early infiltration','Invasive ductal carcinoma and breast tissue','Invasive ductal carcinoma'])

In [9]:
progIncludeIdx=np.repeat(False,progList.size)
for p in progInclude:
    progIncludeIdx[progList==p]=True
    
coordlist=coordlist[progIncludeIdx]
allImgNames=allImgNames[progIncludeIdx]
clusterRes=clusterRes[progIncludeIdx]
kmeans_sub=kmeans_sub[progIncludeIdx]
progList=progList[progIncludeIdx]

In [10]:
sUnique,sidx_start=np.unique(allImgNames,return_index=True)
progUnique,labels_train,progCounts=np.unique(progList[sidx_start],return_counts=True,return_inverse=True)
for p in range(progUnique.size):
    print(progUnique[p])
    print(progCounts[p])

Atypical hyperplasia
14
Breast tissue
20
Cancer adjacent normal breast tissue
13
DCIS and breast tissue
16
DCIS with early infiltration
30
Hyperplasia
41
Invasive ductal carcinoma
70
Invasive ductal carcinoma and breast tissue
8


In [11]:
def getHistMatrix_clusters(labels,ctlist,nrow=ncluster,ncol=ncluster):
    res=np.zeros((nrow,ncol))
    for li in range(res.shape[0]):
        l=li
        nl=np.sum(labels==l)
        ctlist_l=ctlist[labels==l]
        for ci in range(res.shape[1]):
            c=ci
            res[l,c]=np.sum(ctlist_l==c)
#             res[li,ci]=np.sum(ctlist_l==c)/nl
        if nl!=0:
            res[li]=res[li]/nl
    return res

neighborhoodSize=48*1.5

In [12]:
#get neighborhood composition

inputNeighborhood=np.zeros((sUnique.size,ncluster*ncluster))
for i in range(sUnique.size):
    imgN=sUnique[i]
    nsamples=np.sum(allImgNames==imgN)
    cluster_i=clusterRes[allImgNames==imgN]
    neighbor_i=np.tile(cluster_i,(nsamples,1))
    self_i=np.tile(cluster_i.reshape((-1,1)),(1,nsamples))

    dist=pairwise_distances(coordlist[allImgNames==imgN],n_jobs=-1)
    distIn=np.logical_and(dist<neighborhoodSize,dist>0)
    res=getHistMatrix_clusters(self_i[distIn],neighbor_i[distIn])
    
    inputNeighborhood[i]=res.flatten()

In [13]:
_,inputCounts=np.unique(allImgNames,return_counts=True)
inputAll_train=np.concatenate((inputNeighborhood,inputCounts.reshape(-1,1)),axis=1)

In [14]:
#val cores (as validation cores) and val samples (as test cores)
clustersavedir_valcores=os.path.join(sampledir,'cluster_valcores_reordered')
clustersavedir_valsamples=os.path.join(sampledir,'cluster_valsamples_reordered')

with open(os.path.join(datadir,'processed','train_cnnvae_coord_valcores'), 'rb') as output:
    coordlist_valcores=pickle.load(output)
with open(os.path.join(datadir,'processed','train_cnnvae_coord_valsamples'), 'rb') as output:
    coordlist_valsamples=pickle.load(output)

savenamecluster='minibatchkmean_ncluster'+str(ncluster)+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)
with open(os.path.join(clustersavedir_valcores,savenamecluster+'_all'), 'rb') as output:
    clusterRes_valcores=pickle.load(output)
with open(os.path.join(clustersavedir_valsamples,'minibatchkmean_ncluster'+str(ncluster)+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+'_all'), 'rb') as output:
    clusterRes_valsamples=pickle.load(output)
    
kmeans_sub_valcores=(np.zeros(clusterRes_valcores.size)-1).astype(str)
for c in np.unique(clusterRes_valcores):
    subclustersavedir=os.path.join(clustersavedir_valcores,savenamecluster+'_plottingIdx_progBalanced_'+str(0)+'_subcluster'+str(neworder[c]))
    with open(os.path.join(subclustersavedir,'minibatchkmean_ncluster'+str(subclusterDict[neworder[c]][0])+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+'_all'), 'rb') as output:
        subclusterRes=pickle.load(output)
    kmeans_sub_valcores[clusterRes_valcores==c]=np.char.add(np.repeat(str(c)+'-',subclusterRes.size),subclusterRes.astype(str))
    
kmeans_sub_valsamples=(np.zeros(clusterRes_valsamples.size)-1).astype(str)
for c in np.unique(clusterRes_valsamples):
    subclustersavedir=os.path.join(clustersavedir_valsamples,savenamecluster+'_plottingIdx_progBalanced_'+str(0)+'_subcluster'+str(neworder[c]))
    with open(os.path.join(subclustersavedir,'minibatchkmean_ncluster'+str(subclusterDict[neworder[c]][0])+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+'_all'), 'rb') as output:
        subclusterRes=pickle.load(output)
    kmeans_sub_valsamples[clusterRes_valsamples==c]=np.char.add(np.repeat(str(c)+'-',subclusterRes.size),subclusterRes.astype(str))
                

In [15]:
with open(os.path.join(datadir,'processed','train_cnnvae_names_valcores'), 'rb') as input:
    allImgNames_valcores=pickle.load(input)
with open(os.path.join(datadir,'processed','train_cnnvae_names_valsamples'), 'rb') as input:
    allImgNames_valsamples=pickle.load(input)

In [16]:
#plot by disease progression
br1003aSpecs=pd.read_excel('/media/xinyi/data/BR1003a specs.xlsx',header=10)
br301Specs=pd.read_excel('/media/xinyi/data/BR301 specs.xlsx',header=10)
br8018aSpecs=pd.read_excel('/media/xinyi/data/BR8018a specs.xlsx',header=10)
br1003aSpecs.index=br1003aSpecs.loc[:,'Position']
br301Specs.index=br301Specs.loc[:,'Position']
br8018aSpecs.index=br8018aSpecs.loc[:,'Position']

progList_valcores=np.copy(allImgNames_valcores)
for s in np.unique(allImgNames_valcores):
    ssplit=s.split('_')
    if 'br1003a'==ssplit[0]:
        prog_s=br1003aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br301'==ssplit[0]:
        prog_s=br301Specs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br8018a'==ssplit[0]:
        prog_s=br8018aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    progList_valcores[allImgNames_valcores==s]=prog_s
    
progList_valsamples=np.copy(allImgNames_valsamples)
for s in np.unique(allImgNames_valsamples):
    ssplit=s.split('_')
    if 'br1003a'==ssplit[0]:
        prog_s=br1003aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br301'==ssplit[0]:
        prog_s=br301Specs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br8018a'==ssplit[0]:
        prog_s=br8018aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    progList_valsamples[allImgNames_valsamples==s]=prog_s
    



In [17]:
for p in np.unique(progList_valcores):
    if p=='Ductal carcinoma in situ':
        progList_valcores[progList_valcores==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ and breast tissue':
        progList_valcores[progList_valcores==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ with early infiltratio':
        progList_valcores[progList_valcores==p]='DCIS with early infiltration'
    
    elif p=='Micropapillary type ductal carcinoma in situ wi':
        progList_valcores[progList_valcores==p]='DCIS with early infiltration'    
#     elif p=='Atypical hyperlasia':
#         progList_valcores[progList_valcores==p]='Hyperplasia'

In [18]:
for p in np.unique(progList_valsamples):
    if p=='Ductal carcinoma in situ':
        progList_valsamples[progList_valsamples==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ and breast tissue':
        progList_valsamples[progList_valsamples==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ with early infiltrati':
        progList_valsamples[progList_valsamples==p]='DCIS with early infiltration'
    
    elif p=='Micropapillary type ductal carcinoma in situ w':
        progList_valsamples[progList_valsamples==p]='DCIS with early infiltration'    
#     elif p=='Atypical hyperlasia':
#         progList_valsamples[progList_valsamples==p]='Hyperplasia'

In [19]:
progIncludeIdx_valcores=np.repeat(False,progList_valcores.size)
for p in progInclude:
    progIncludeIdx_valcores[progList_valcores==p]=True
    
progIncludeIdx_valsamples=np.repeat(False,progList_valsamples.size)
for p in progInclude:
    progIncludeIdx_valsamples[progList_valsamples==p]=True
    
coordlist_valcores=coordlist_valcores[progIncludeIdx_valcores]
allImgNames_valcores=allImgNames_valcores[progIncludeIdx_valcores]
clusterRes_valcores=clusterRes_valcores[progIncludeIdx_valcores]
kmeans_sub_valcores=kmeans_sub_valcores[progIncludeIdx_valcores]
progList_valcores=progList_valcores[progIncludeIdx_valcores]

coordlist_valsamples=coordlist_valsamples[progIncludeIdx_valsamples]
allImgNames_valsamples=allImgNames_valsamples[progIncludeIdx_valsamples]
clusterRes_valsamples=clusterRes_valsamples[progIncludeIdx_valsamples]
kmeans_sub_valsamples=kmeans_sub_valsamples[progIncludeIdx_valsamples]
progList_valsamples=progList_valsamples[progIncludeIdx_valsamples]


In [20]:
sUnique_valcores,sidx_start_valcores=np.unique(allImgNames_valcores,return_index=True)
progUnique_valcores,progCounts_valcores=np.unique(progList_valcores[sidx_start_valcores],return_counts=True)
for p in range(progUnique_valcores.size):
    print(progUnique_valcores[p])
    print(progCounts_valcores[p])

Atypical hyperplasia
15
Breast tissue
20
Cancer adjacent normal breast tissue
1
Hyperplasia
35
Invasive ductal carcinoma
97


In [21]:
sUnique_valsamples,sidx_start_valsamples=np.unique(allImgNames_valsamples,return_index=True)
progUnique_valsamples,progCounts_valsamples=np.unique(progList_valsamples[sidx_start_valsamples],return_counts=True)
for p in range(progUnique_valsamples.size):
    print(progUnique_valsamples[p])
    print(progCounts_valsamples[p])

Atypical hyperplasia
10
Breast tissue
14
Cancer adjacent normal breast tissue
4
DCIS and breast tissue
16
DCIS with early infiltration
29
Hyperplasia
25
Invasive ductal carcinoma
66
Invasive ductal carcinoma and breast tissue
8


In [22]:
#construct labels
labels_valcores=np.zeros(progList_valcores[sidx_start_valcores].size)
for i in range(progUnique.size):
    labels_valcores[progList_valcores[sidx_start_valcores]==progUnique[i]]=i

In [23]:
#construct labels
labels_valsamples=np.zeros(progList_valsamples[sidx_start_valsamples].size)
for i in range(progUnique.size):
    labels_valsamples[progList_valsamples[sidx_start_valsamples]==progUnique[i]]=i

In [24]:
inputNeighborhood_valcores=np.zeros((sUnique_valcores.size,ncluster*ncluster))
for i in range(sUnique_valcores.size):
    imgN=sUnique_valcores[i]
    nsamples=np.sum(allImgNames_valcores==imgN)
    cluster_i=clusterRes_valcores[allImgNames_valcores==imgN]
    neighbor_i=np.tile(cluster_i,(nsamples,1))
    self_i=np.tile(cluster_i.reshape((-1,1)),(1,nsamples))

    dist=pairwise_distances(coordlist_valcores[allImgNames_valcores==imgN],n_jobs=-1)
    distIn=np.logical_and(dist<neighborhoodSize,dist>0)
    res=getHistMatrix_clusters(self_i[distIn],neighbor_i[distIn])
    
    inputNeighborhood_valcores[i]=res.flatten()

In [25]:
inputNeighborhood_valsamples=np.zeros((sUnique_valsamples.size,ncluster*ncluster))
for i in range(sUnique_valsamples.size):
    imgN=sUnique_valsamples[i]
    nsamples=np.sum(allImgNames_valsamples==imgN)
    cluster_i=clusterRes_valsamples[allImgNames_valsamples==imgN]
    neighbor_i=np.tile(cluster_i,(nsamples,1))
    self_i=np.tile(cluster_i.reshape((-1,1)),(1,nsamples))

    dist=pairwise_distances(coordlist_valsamples[allImgNames_valsamples==imgN],n_jobs=-1)
    distIn=np.logical_and(dist<neighborhoodSize,dist>0)
    res=getHistMatrix_clusters(self_i[distIn],neighbor_i[distIn])
    
    inputNeighborhood_valsamples[i]=res.flatten()

In [26]:
_,inputCounts_valcores=np.unique(allImgNames_valcores,return_counts=True)
inputAll_valcores=np.concatenate((inputNeighborhood_valcores,inputCounts_valcores.reshape(-1,1)),axis=1)

In [27]:
_,inputCounts_valsamples=np.unique(allImgNames_valsamples,return_counts=True)
inputAll_valsamples=np.concatenate((inputNeighborhood_valsamples,inputCounts_valsamples.reshape(-1,1)),axis=1)

In [28]:
#concatenate cores
inputAll=np.concatenate((inputAll_train,np.concatenate((inputAll_valcores,inputAll_valsamples),axis=0)),axis=0)
imgNamesAll=np.concatenate((allImgNames[sidx_start],np.concatenate((allImgNames_valcores[sidx_start_valcores],allImgNames_valsamples[sidx_start_valsamples]))))
labelsAll=np.concatenate((labels_train,np.concatenate((labels_valcores,labels_valsamples))))

In [29]:
_,progCountsAll=np.unique(labelsAll,return_counts=True)
weights_train=np.sum(progCountsAll)/progCountsAll

In [30]:
allPatientIDs_valcores=np.repeat('none',allImgNames_valcores.size).astype(object)
for s in np.unique(allImgNames_valcores):
    sidx=allImgNames_valcores==s
    slideID=s.split('_')[0]
    coreID=s.split('_')[-1]
    allPatientIDs_valcores[sidx]=metadata.patient_id[slideID+'_'+str.lower(coreID)]

In [31]:
allPatientIDs_valsamples=np.repeat('none',allImgNames_valsamples.size).astype(object)
for s in np.unique(allImgNames_valsamples):
    sidx=allImgNames_valsamples==s
    slideID=s.split('_')[0]
    coreID=s.split('_')[-1]
    allPatientIDs_valsamples[sidx]=metadata.patient_id[slideID+'_'+str.lower(coreID)]

In [None]:
#img sizes
allImgNamesAll=np.concatenate((allImgNames,np.concatenate((allImgNames_valcores,allImgNames_valsamples))))
progListAll=np.concatenate((progList,np.concatenate((progList_valcores,progList_valsamples))))
sidxAll=np.concatenate((sidx_start,np.concatenate((sidx_start_valcores,sidx_start_valsamples))))
coordlistAll=np.concatenate((coordlist,np.concatenate((coordlist_valcores,coordlist_valsamples),axis=0)),axis=0)

imgSizeAll={}
for p in progUnique:
    img_cores=allImgNamesAll[sidxAll[progListAll[sidxAll]==p]]
    pSizes=np.zeros((img_cores.size))
    for si in range(img_cores.size):
        scoord=coordlistAll[allImgNamesAll==img_cores[si]]
        hsize=np.pi*np.square(np.max(scoord[:,0])-np.min(scoord[:,0]))
        vsize=np.pi*np.square(np.max(scoord[:,1])-np.min(scoord[:,1]))
        pSizes[si]=min(hsize,vsize)
    imgSizeAll[p]=pSizes



In [None]:
imgSize_median={}
for p in progUnique:
    imgSize_median[p]=np.median(imgSizeAll[p])
plt.violinplot(list(imgSizeAll.values()))
plt.scatter(np.arange(progUnique.size)+1,list(imgSize_median.values()))
plt.xticks(np.arange(progUnique.size)+1,list(imgSizeAll),rotation=90)

In [32]:
with open(os.path.join(datadir,'processed','imgSizeByPath'), 'wb') as output:
    pickle.dump(imgSize_median,output,pickle.HIGHEST_PROTOCOL)

In [32]:
with open(os.path.join(datadir,'processed','imgSizeByPath'), 'rb') as output:
    imgSize_median=pickle.load(output)

In [33]:
#normalize count
with open(os.path.join(datadir,'processed','imgSizeByPath'), 'rb') as output:
    imgSize_median=pickle.load(output)
areaAll=np.zeros(labelsAll.size)
for s in range(labelsAll.size):
    areaAll[s]=imgSize_median[progUnique[labelsAll.astype(int)][s]]
inputAll[:,-1]=inputAll[:,-1]/areaAll

In [34]:
seed=3
epochs=6000
saveFreq=200
lr=0.001 #initial learning rate
weight_decay=0 

# batchsize=4
batchsize=6000
model_str='fc3'

fc_dim1=64
fc_dim2=64
fc_dim3=64


dropout=0.01

name='exp0_pathologyClf_neighbor_clusters_exp0_subset_neighborOnly_crossVal_countAreaNorm_s0_wAH_byPatientCorr'
logsavepath='/media/xinyi/log/cnnvae'+name
modelsavepath='/media/xinyi/models/cnnvae'+name
plotsavepath='/media/xinyi/plots/cnnvae'+name


if not os.path.exists(logsavepath):
    os.mkdir(logsavepath)
if not os.path.exists(modelsavepath):
    os.mkdir(modelsavepath)
if not os.path.exists(plotsavepath):
    os.mkdir(plotsavepath)



In [35]:
inputAll=scipy.stats.zscore(inputAll,axis=0)

In [36]:
def train(epoch,trainInput,labels_train):
    t = time.time()
    model.train()
    optimizer.zero_grad()

    pred = model(trainInput)

    loss=lossCE(pred,labels_train)

    loss.backward()
    optimizer.step()

    if epoch%500==0:
        print('Epoch: {:04d}'.format(epoch),
              'loss_train: {:.4f}'.format(loss))
    return loss.item()

In [37]:
allPatientIDs=np.concatenate((allPatientIDs,np.concatenate((allPatientIDs_valcores,allPatientIDs_valsamples))))
pIDList=allPatientIDs[np.concatenate((sidx_start,np.concatenate((sidx_start_valcores,sidx_start_valsamples))))]


In [38]:
allPatientIDs.size

5324614

In [39]:
inputAll=torch.tensor(inputAll).cuda().float()
labelsAll=torch.tensor(labelsAll).cuda().long()

testepoch=5000
predtest=np.zeros((inputAll.shape[0],np.unique(labels_train).size))
for patientIDX in range(np.unique(pIDList).size):
    patientID=np.unique(pIDList)[patientIDX]
    sampleIdx=np.arange(inputAll.shape[0])[pIDList==patientID]
    trainIdx=np.arange(inputAll.shape[0])[pIDList!=patientID]
    
    seed=3
    torch.manual_seed(seed)
    nclasses=np.unique(labels_train).size
    if use_cuda:
        torch.cuda.manual_seed(seed)

    nfeatures=inputAll.shape[1]
    if model_str=='fc3':
        model = modelsCNN.FC_l3(nfeatures,fc_dim1,fc_dim2,fc_dim3,nclasses,0.5,regrs=False)
        lossCE=torch.nn.CrossEntropyLoss(torch.tensor(weights_train).cuda().float())
    if model_str=='fc5':
        model = modelsCNN.FC_l5(nfeatures,fc_dim1,fc_dim2,fc_dim3,fc_dim4,fc_dim5,nclasses,0.5,regrs=False)
        lossCE=torch.nn.CrossEntropyLoss(torch.tensor(weights_train).cuda().float())
    if model_str=='fc1':
        model = modelsCNN.FC_l1(nfeatures,fc_dim1,nclasses,regrs=False)
        lossCE=torch.nn.CrossEntropyLoss(torch.tensor(weights_train).cuda().float())
    if model_str=='fc0':
        model = modelsCNN.FC_l0(nfeatures,nclasses,regrs=False)
        lossCE=torch.nn.CrossEntropyLoss(torch.tensor(weights_train).cuda().float())

    if use_cuda:
        model.cuda()

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    train_loss_ep=[None]*epochs
    val_loss_ep=[None]*epochs
    t_ep=time.time()

    for ep in range(epochs):
        train_loss_ep[ep]=train(ep,inputAll[trainIdx],labelsAll[trainIdx])


        if ep%saveFreq == 0 and ep!=0:
            torch.save(model.cpu().state_dict(), os.path.join(modelsavepath,patientID+'_'+str(ep)+'.pt'))
        if use_cuda:
            model.cuda()
            torch.cuda.empty_cache()
    print(' total time: {:.4f}s'.format(time.time() - t_ep))

    with open(os.path.join(logsavepath,patientID+'_train_loss'), 'wb') as output:
        pickle.dump(train_loss_ep, output, pickle.HIGHEST_PROTOCOL)

    model.load_state_dict(torch.load(os.path.join(modelsavepath,patientID+'_'+str(testepoch)+'.pt')))
    with torch.no_grad():
        model.cuda()
        model.eval()
        pred = model(inputAll[[sampleIdx]])
        predtest[sampleIdx]=pred.cpu().detach().numpy()

        loss_test=lossCE(pred,labelsAll[[sampleIdx]]).item()

    print(loss_test)

Epoch: 0000 loss_train: 2.0783
Epoch: 0500 loss_train: 0.5197
Epoch: 1000 loss_train: 0.2380
Epoch: 1500 loss_train: 0.1908
Epoch: 2000 loss_train: 0.1671
Epoch: 2500 loss_train: 0.1322
Epoch: 3000 loss_train: 0.1264
Epoch: 3500 loss_train: 0.1261
Epoch: 4000 loss_train: 0.1219
Epoch: 4500 loss_train: 0.0856
Epoch: 5000 loss_train: 0.0562
Epoch: 5500 loss_train: 0.0824
 total time: 11.5175s
6.268845558166504
Epoch: 0000 loss_train: 2.0812
Epoch: 0500 loss_train: 0.4649
Epoch: 1000 loss_train: 0.2517
Epoch: 1500 loss_train: 0.1720
Epoch: 2000 loss_train: 0.1187
Epoch: 2500 loss_train: 0.1398
Epoch: 3000 loss_train: 0.1174
Epoch: 3500 loss_train: 0.1071
Epoch: 4000 loss_train: 0.0856
Epoch: 4500 loss_train: 0.0713
Epoch: 5000 loss_train: 0.0756
Epoch: 5500 loss_train: 0.1233
 total time: 10.8151s
0.00021417321113403887
Epoch: 0000 loss_train: 2.0827
Epoch: 0500 loss_train: 0.4851
Epoch: 1000 loss_train: 0.2640
Epoch: 1500 loss_train: 0.1611
Epoch: 2000 loss_train: 0.1379
Epoch: 2500 loss

Epoch: 0500 loss_train: 0.4893
Epoch: 1000 loss_train: 0.2931
Epoch: 1500 loss_train: 0.1860
Epoch: 2000 loss_train: 0.1226
Epoch: 2500 loss_train: 0.0953
Epoch: 3000 loss_train: 0.0901
Epoch: 3500 loss_train: 0.0781
Epoch: 4000 loss_train: 0.0689
Epoch: 4500 loss_train: 0.0771
Epoch: 5000 loss_train: 0.0651
Epoch: 5500 loss_train: 0.0956
 total time: 10.0404s
0.08859477937221527
Epoch: 0000 loss_train: 2.0848
Epoch: 0500 loss_train: 0.4342
Epoch: 1000 loss_train: 0.2760
Epoch: 1500 loss_train: 0.1895
Epoch: 2000 loss_train: 0.1127
Epoch: 2500 loss_train: 0.0953
Epoch: 3000 loss_train: 0.0854
Epoch: 3500 loss_train: 0.1724
Epoch: 4000 loss_train: 0.0780
Epoch: 4500 loss_train: 0.0477
Epoch: 5000 loss_train: 0.0892
Epoch: 5500 loss_train: 0.0448
 total time: 8.9466s
10.616114616394043
Epoch: 0000 loss_train: 2.0818
Epoch: 0500 loss_train: 0.4771
Epoch: 1000 loss_train: 0.2611
Epoch: 1500 loss_train: 0.1751
Epoch: 2000 loss_train: 0.1262
Epoch: 2500 loss_train: 0.1629
Epoch: 3000 loss_tr

Epoch: 0500 loss_train: 0.4017
Epoch: 1000 loss_train: 0.2601
Epoch: 1500 loss_train: 0.2073
Epoch: 2000 loss_train: 0.1357
Epoch: 2500 loss_train: 0.1653
Epoch: 3000 loss_train: 0.0913
Epoch: 3500 loss_train: 0.1066
Epoch: 4000 loss_train: 0.0721
Epoch: 4500 loss_train: 0.0872
Epoch: 5000 loss_train: 0.0754
Epoch: 5500 loss_train: 0.1010
 total time: 7.5114s
9.787761688232422
Epoch: 0000 loss_train: 2.0788
Epoch: 0500 loss_train: 0.4695
Epoch: 1000 loss_train: 0.2757
Epoch: 1500 loss_train: 0.1646
Epoch: 2000 loss_train: 0.1219
Epoch: 2500 loss_train: 0.1485
Epoch: 3000 loss_train: 0.1391
Epoch: 3500 loss_train: 0.1290
Epoch: 4000 loss_train: 0.0882
Epoch: 4500 loss_train: 0.0645
Epoch: 5000 loss_train: 0.0556
Epoch: 5500 loss_train: 0.0904
 total time: 7.5450s
6.1551713943481445
Epoch: 0000 loss_train: 2.0806
Epoch: 0500 loss_train: 0.4205
Epoch: 1000 loss_train: 0.2673
Epoch: 1500 loss_train: 0.1965
Epoch: 2000 loss_train: 0.1138
Epoch: 2500 loss_train: 0.1436
Epoch: 3000 loss_train

Epoch: 0500 loss_train: 0.4255
Epoch: 1000 loss_train: 0.2762
Epoch: 1500 loss_train: 0.1680
Epoch: 2000 loss_train: 0.1458
Epoch: 2500 loss_train: 0.1000
Epoch: 3000 loss_train: 0.1355
Epoch: 3500 loss_train: 0.0956
Epoch: 4000 loss_train: 0.0873
Epoch: 4500 loss_train: 0.1010
Epoch: 5000 loss_train: 0.0895
Epoch: 5500 loss_train: 0.0529
 total time: 7.4918s
17.266796112060547
Epoch: 0000 loss_train: 2.0807
Epoch: 0500 loss_train: 0.4248
Epoch: 1000 loss_train: 0.2907
Epoch: 1500 loss_train: 0.1589
Epoch: 2000 loss_train: 0.1463
Epoch: 2500 loss_train: 0.1398
Epoch: 3000 loss_train: 0.1324
Epoch: 3500 loss_train: 0.0843
Epoch: 4000 loss_train: 0.0870
Epoch: 4500 loss_train: 0.0736
Epoch: 5000 loss_train: 0.0579
Epoch: 5500 loss_train: 0.0709
 total time: 7.5583s
18.822568893432617
Epoch: 0000 loss_train: 2.0796
Epoch: 0500 loss_train: 0.4479
Epoch: 1000 loss_train: 0.2811
Epoch: 1500 loss_train: 0.1864
Epoch: 2000 loss_train: 0.1376
Epoch: 2500 loss_train: 0.1184
Epoch: 3000 loss_trai

In [40]:
with open(os.path.join(logsavepath,'crossVal_loss'), 'wb') as output:
    pickle.dump(predtest, output, pickle.HIGHEST_PROTOCOL)

In [41]:
predtest_label=np.argmax(predtest,axis=1)

In [42]:
imgNamesAll.size

552

In [43]:
res=pd.DataFrame({'patientID':pIDList,'sampleName':imgNamesAll,'true':progUnique[labelsAll.cpu().numpy()],'predicted':progUnique[predtest_label]})
res.to_csv(os.path.join(plotsavepath,'predictions.csv'))

In [44]:
#plot confusion
def plotCTcomp(labels,ctlist,savepath,savenamecluster,byCT,addname='',order=progInclude,showNumber=False):
    res=np.zeros((order.size,order.size))
    
    for li in range(res.shape[0]):
        l=order[li]
        nl=np.sum(labels==l)
        ctlist_l=ctlist[labels==l]
        for ci in range(res.shape[1]):
            c=order[ci]
            res[li,ci]=np.sum(ctlist_l==c)
#             res[li,ci]=np.sum(ctlist_l==c)/nl
    resCounts=np.copy(res)
    if not byCT:
        addname+=''
        for li in range(res.shape[0]):
            l=order[li]
            nl=np.sum(labels==l)
            res[li]=res[li]/nl
    else:
        addname+='_normbyCT'
        for ci in range(res.shape[1]):
            c=order[ci]
            nc=np.sum(ctlist==c)
            res[:,ci]=res[:,ci]/nc
    
    fig, ax = plt.subplots(figsize=(10, 10))
    im = ax.imshow(res,cmap='binary',vmin=0,vmax=1)
    if showNumber:
        addname=addname+'_showNumber'
        for i in range(res.shape[0]):
            for j in range(res.shape[1]):
                text = ax.text(j, i, f'{resCounts[i, j]:.0f}',
                               ha="center", va="center", color="c",fontsize='xx-large')
    fig.colorbar(im)
    ax.set_yticks(np.arange(order.size))
    ax.set_yticklabels(order)
    ax.set_xticks(np.arange(order.size))
    ax.set_xticklabels(order)
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",rotation_mode="anchor")
    fig.tight_layout()
    plt.savefig(os.path.join(savepath,savenamecluster+addname+'.pdf'))
    plt.close()
    
plotCTcomp(progUnique[labelsAll.cpu().numpy()],progUnique[predtest_label],plotsavepath,'confusion'+str(testepoch),False,showNumber=True)
plotCTcomp(progUnique[labelsAll.cpu().numpy()],progUnique[predtest_label],plotsavepath,'confusion'+str(testepoch),False,showNumber=False)

In [45]:
plotCTcomp(progUnique[labelsAll.cpu().numpy()][:sidx_start.size],progUnique[predtest_label][:sidx_start.size],plotsavepath,'confusion_excludeValSamples'+str(testepoch),False)

In [46]:
inputAll_train.shape

(212, 65)