In [1]:
import time
import os

import scanpy
import numpy as np
import scipy.sparse as sp

import torch
from torch import optim
from torch.utils.data import DataLoader

import models.loadImg as loadImg
import models.modelsCNN as modelsCNN
import models.optimizer as optimizer

import pickle
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import gc
from skimage import io
import scipy.stats
from sklearn.metrics import pairwise_distances

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = '3'
use_cuda=True
datadir='/media/xinyi/data'
name='exp0'
plotsavepath='/media/xinyi/plots/cnnvae'+name
sampledir=plotsavepath
clustersavedir_alltrain=os.path.join(sampledir,'cluster_alltrain_reordered')
ep=311
with open(os.path.join(datadir,'processed','train_cnnvae_names'), 'rb') as input:
    allImgNames=pickle.load(input)
#plot by disease progression
br1003aSpecs=pd.read_excel('/media/xinyi/data/BR1003a specs.xlsx',header=10)
br301Specs=pd.read_excel('/media/xinyi/data/BR301 specs.xlsx',header=10)
br8018aSpecs=pd.read_excel('/media/xinyi/data/BR8018a specs.xlsx',header=10)
br1003aSpecs.index=br1003aSpecs.loc[:,'Position']
br301Specs.index=br301Specs.loc[:,'Position']
br8018aSpecs.index=br8018aSpecs.loc[:,'Position']

progList=np.copy(allImgNames)
for s in np.unique(allImgNames):
    ssplit=s.split('_')
    if 'br1003a'==ssplit[0]:
        prog_s=br1003aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br301'==ssplit[0]:
        prog_s=br301Specs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br8018a'==ssplit[0]:
        prog_s=br8018aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    progList[allImgNames==s]=prog_s
    




In [3]:
metadata=pd.read_csv(os.path.join(datadir,'Supplementary Table 1_v1.csv'),header=0)
metadata.index=metadata.sample_id

In [4]:
allPatientIDs=np.repeat('none',allImgNames.size).astype(object)
for s in np.unique(allImgNames):
    sidx=allImgNames==s
    slideID=s.split('_')[0]
    coreID=s.split('_')[-1]
    allPatientIDs[sidx]=metadata.patient_id[slideID+'_'+str.lower(coreID)]

In [5]:
savenamesample='alltrain'

neworder=[1, 5, 3, 7, 2, 0, 4, 6]
#use chosen subcluster number and save plots
scanpy.settings.verbosity = 3
# subcluster=8
subclusterDict={0:[4],1:[6],2:[8],3:[6],4:[6],5:[6],6:[6],7:[4]}
ncluster=8

plotepoch=311
clusterplotdir=os.path.join(clustersavedir_alltrain,'plots')
n_pcs=50
savenamecluster='minibatchkmean_ncluster'+str(ncluster)+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)
with open(os.path.join(clustersavedir_alltrain,'minibatchkmean_ncluster'+str(ncluster)+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+'_all'), 'rb') as output:
    clusterRes=pickle.load(output)

kmeans_sub=(np.zeros(clusterRes.size)-1).astype(str)
savenameAdd='_plottingIdx_progBalanced_'+str(0)
subclusternumbers=[4,6,8,6,6,6,6,4]
savenamecluster='minibatchkmean_ncluster'+str(ncluster)+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+savenameAdd
for cnew in np.unique(clusterRes):
#     print('cluster'+str(c))
    c=neworder[cnew]
    
    subclustersavedir_alltrain=os.path.join(clustersavedir_alltrain,savenamecluster+'_subcluster'+str(c))
    with open(os.path.join(subclustersavedir_alltrain,'minibatchkmean_ncluster'+str(subclusternumbers[c])+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+'_all'), 'rb') as output:
        subclusterRes=pickle.load(output)
    print(np.unique(subclusterRes))
    kmeans_sub[clusterRes==cnew]=np.char.add(np.repeat(str(cnew)+'-',subclusterRes.size),subclusterRes.astype(str))
   

[0 1 2 3 4 5]
[0 1 2 3 4 5]
[0 1 2 3 4 5]
[0 1 2 3]
[0 1 2 3 4 5 6 7]
[0 1 2 3]
[0 1 2 3 4 5]
[0 1 2 3 4 5]


In [6]:
with open(os.path.join(datadir,'processed','train_cnnvae_coord'), 'rb') as output:
    coordlist=pickle.load(output)

In [7]:
for p in np.unique(progList):
    if p=='Ductal carcinoma in situ':
        progList[progList==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ and breast tissue':
        progList[progList==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ with early infiltratio':
        progList[progList==p]='DCIS with early infiltration'
    
    elif p=='Micropapillary type ductal carcinoma in situ wi':
        progList[progList==p]='DCIS with early infiltration'    
#     elif p=='Atypical hyperlasia':
#         progList[progList==p]='Hyperplasia'

In [8]:
progInclude=np.array(['Breast tissue','Cancer adjacent normal breast tissue','Hyperplasia','Atypical hyperplasia','DCIS and breast tissue',  'DCIS with early infiltration','Invasive ductal carcinoma and breast tissue','Invasive ductal carcinoma'])

In [9]:
progIncludeIdx=np.repeat(False,progList.size)
for p in progInclude:
    progIncludeIdx[progList==p]=True
    
coordlist=coordlist[progIncludeIdx]
allImgNames=allImgNames[progIncludeIdx]
clusterRes=clusterRes[progIncludeIdx]
kmeans_sub=kmeans_sub[progIncludeIdx]
progList=progList[progIncludeIdx]

In [10]:
sUnique,sidx_start=np.unique(allImgNames,return_index=True)
progUnique,labels_train,progCounts=np.unique(progList[sidx_start],return_counts=True,return_inverse=True)
for p in range(progUnique.size):
    print(progUnique[p])
    print(progCounts[p])

Atypical hyperplasia
14
Breast tissue
20
Cancer adjacent normal breast tissue
13
DCIS and breast tissue
16
DCIS with early infiltration
30
Hyperplasia
41
Invasive ductal carcinoma
70
Invasive ductal carcinoma and breast tissue
8


In [11]:
#get cluster composition
clusterUnique=np.unique(clusterRes)
subclusterUnique=np.unique(kmeans_sub)
inputCluster=np.zeros((sUnique.size,clusterUnique.size))
inputSubcluster=np.zeros((sUnique.size,subclusterUnique.size))
for i in range(sUnique.size):
    clusterRes_i=clusterRes[allImgNames==sUnique[i]]
    kmeans_sub_i=kmeans_sub[allImgNames==sUnique[i]]
    for j in range(clusterUnique.size):
        inputCluster[i,j]=np.sum(clusterRes_i==clusterUnique[j])
    for j in range(subclusterUnique.size):
        inputSubcluster[i,j]=np.sum(kmeans_sub_i==subclusterUnique[j])
inputCluster=inputCluster/np.sum(inputCluster,axis=1,keepdims=True)
inputSubcluster=inputSubcluster/np.sum(inputSubcluster,axis=1,keepdims=True)

In [12]:
_,inputCounts=np.unique(allImgNames,return_counts=True)
inputAll=np.concatenate((np.concatenate((inputCluster,inputSubcluster),axis=1),inputCounts.reshape(-1,1)),axis=1)

In [13]:
#val cores (as validation cores) and val samples (as test cores)
clustersavedir_valcores=os.path.join(sampledir,'cluster_valcores_reordered')
clustersavedir_valsamples=os.path.join(sampledir,'cluster_valsamples_reordered')

with open(os.path.join(datadir,'processed','train_cnnvae_coord_valcores'), 'rb') as output:
    coordlist_valcores=pickle.load(output)
with open(os.path.join(datadir,'processed','train_cnnvae_coord_valsamples'), 'rb') as output:
    coordlist_valsamples=pickle.load(output)

savenamecluster='minibatchkmean_ncluster'+str(ncluster)+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)
with open(os.path.join(clustersavedir_valcores,savenamecluster+'_all'), 'rb') as output:
    clusterRes_valcores=pickle.load(output)
with open(os.path.join(clustersavedir_valsamples,'minibatchkmean_ncluster'+str(ncluster)+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+'_all'), 'rb') as output:
    clusterRes_valsamples=pickle.load(output)
    
kmeans_sub_valcores=(np.zeros(clusterRes_valcores.size)-1).astype(str)
for c in np.unique(clusterRes_valcores):
    subclustersavedir=os.path.join(clustersavedir_valcores,savenamecluster+'_plottingIdx_progBalanced_'+str(0)+'_subcluster'+str(neworder[c]))
    with open(os.path.join(subclustersavedir,'minibatchkmean_ncluster'+str(subclusterDict[neworder[c]][0])+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+'_all'), 'rb') as output:
        subclusterRes=pickle.load(output)
    kmeans_sub_valcores[clusterRes_valcores==c]=np.char.add(np.repeat(str(c)+'-',subclusterRes.size),subclusterRes.astype(str))
    
kmeans_sub_valsamples=(np.zeros(clusterRes_valsamples.size)-1).astype(str)
for c in np.unique(clusterRes_valsamples):
    subclustersavedir=os.path.join(clustersavedir_valsamples,savenamecluster+'_plottingIdx_progBalanced_'+str(0)+'_subcluster'+str(neworder[c]))
    with open(os.path.join(subclustersavedir,'minibatchkmean_ncluster'+str(subclusterDict[neworder[c]][0])+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+'_all'), 'rb') as output:
        subclusterRes=pickle.load(output)
    kmeans_sub_valsamples[clusterRes_valsamples==c]=np.char.add(np.repeat(str(c)+'-',subclusterRes.size),subclusterRes.astype(str))
                

In [14]:
with open(os.path.join(datadir,'processed','train_cnnvae_names_valcores'), 'rb') as input:
    allImgNames_valcores=pickle.load(input)
with open(os.path.join(datadir,'processed','train_cnnvae_names_valsamples'), 'rb') as input:
    allImgNames_valsamples=pickle.load(input)

In [15]:
#plot by disease progression
br1003aSpecs=pd.read_excel('/media/xinyi/data/BR1003a specs.xlsx',header=10)
br301Specs=pd.read_excel('/media/xinyi/data/BR301 specs.xlsx',header=10)
br8018aSpecs=pd.read_excel('/media/xinyi/data/BR8018a specs.xlsx',header=10)
br1003aSpecs.index=br1003aSpecs.loc[:,'Position']
br301Specs.index=br301Specs.loc[:,'Position']
br8018aSpecs.index=br8018aSpecs.loc[:,'Position']

progList_valcores=np.copy(allImgNames_valcores)
for s in np.unique(allImgNames_valcores):
    ssplit=s.split('_')
    if 'br1003a'==ssplit[0]:
        prog_s=br1003aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br301'==ssplit[0]:
        prog_s=br301Specs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br8018a'==ssplit[0]:
        prog_s=br8018aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    progList_valcores[allImgNames_valcores==s]=prog_s
    
progList_valsamples=np.copy(allImgNames_valsamples)
for s in np.unique(allImgNames_valsamples):
    ssplit=s.split('_')
    if 'br1003a'==ssplit[0]:
        prog_s=br1003aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br301'==ssplit[0]:
        prog_s=br301Specs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br8018a'==ssplit[0]:
        prog_s=br8018aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    progList_valsamples[allImgNames_valsamples==s]=prog_s
    



In [16]:
np.unique(progList_valsamples)


array(['Atypical hyperplasia', 'Breast tissue',
       'Cancer adjacent normal breast tissue', 'Ductal carcinoma in situ',
       'Ductal carcinoma in situ and breast tissue',
       'Ductal carcinoma in situ with early infiltrati', 'Hyperplasia',
       'Hyperplasia with saccular dilatation',
       'Invasive ductal carcinoma',
       'Invasive ductal carcinoma (breast tissue)',
       'Invasive ductal carcinoma and breast tissue',
       'Micropapillary type ductal carcinoma in situ w'], dtype='<U46')

In [17]:
for p in np.unique(progList_valcores):
    if p=='Ductal carcinoma in situ':
        progList_valcores[progList_valcores==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ and breast tissue':
        progList_valcores[progList_valcores==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ with early infiltratio':
        progList_valcores[progList_valcores==p]='DCIS with early infiltration'
    
    elif p=='Micropapillary type ductal carcinoma in situ wi':
        progList_valcores[progList_valcores==p]='DCIS with early infiltration'    
#     elif p=='Atypical hyperlasia':
#         progList_valcores[progList_valcores==p]='Hyperplasia'

In [18]:
for p in np.unique(progList_valsamples):
    if p=='Ductal carcinoma in situ':
        progList_valsamples[progList_valsamples==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ and breast tissue':
        progList_valsamples[progList_valsamples==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ with early infiltrati':
        progList_valsamples[progList_valsamples==p]='DCIS with early infiltration'
    
    elif p=='Micropapillary type ductal carcinoma in situ w':
        progList_valsamples[progList_valsamples==p]='DCIS with early infiltration'    
#     elif p=='Atypical hyperlasia':
#         progList_valsamples[progList_valsamples==p]='Hyperplasia'

In [19]:
progIncludeIdx_valcores=np.repeat(False,progList_valcores.size)
for p in progInclude:
    progIncludeIdx_valcores[progList_valcores==p]=True
    
progIncludeIdx_valsamples=np.repeat(False,progList_valsamples.size)
for p in progInclude:
    progIncludeIdx_valsamples[progList_valsamples==p]=True
    
coordlist_valcores=coordlist_valcores[progIncludeIdx_valcores]
allImgNames_valcores=allImgNames_valcores[progIncludeIdx_valcores]
clusterRes_valcores=clusterRes_valcores[progIncludeIdx_valcores]
kmeans_sub_valcores=kmeans_sub_valcores[progIncludeIdx_valcores]
progList_valcores=progList_valcores[progIncludeIdx_valcores]

coordlist_valsamples=coordlist_valsamples[progIncludeIdx_valsamples]
allImgNames_valsamples=allImgNames_valsamples[progIncludeIdx_valsamples]
clusterRes_valsamples=clusterRes_valsamples[progIncludeIdx_valsamples]
kmeans_sub_valsamples=kmeans_sub_valsamples[progIncludeIdx_valsamples]
progList_valsamples=progList_valsamples[progIncludeIdx_valsamples]


In [20]:
sUnique_valcores,sidx_start_valcores=np.unique(allImgNames_valcores,return_index=True)
progUnique_valcores,progCounts_valcores=np.unique(progList_valcores[sidx_start_valcores],return_counts=True)
for p in range(progUnique_valcores.size):
    print(progUnique_valcores[p])
    print(progCounts_valcores[p])

Atypical hyperplasia
15
Breast tissue
20
Cancer adjacent normal breast tissue
1
Hyperplasia
35
Invasive ductal carcinoma
97


In [21]:
sUnique_valsamples,sidx_start_valsamples=np.unique(allImgNames_valsamples,return_index=True)
progUnique_valsamples,progCounts_valsamples=np.unique(progList_valsamples[sidx_start_valsamples],return_counts=True)
for p in range(progUnique_valsamples.size):
    print(progUnique_valsamples[p])
    print(progCounts_valsamples[p])

Atypical hyperplasia
10
Breast tissue
14
Cancer adjacent normal breast tissue
4
DCIS and breast tissue
16
DCIS with early infiltration
29
Hyperplasia
25
Invasive ductal carcinoma
66
Invasive ductal carcinoma and breast tissue
8


In [22]:
#construct labels
labels_valcores=np.zeros(progList_valcores[sidx_start_valcores].size)
for i in range(progUnique.size):
    labels_valcores[progList_valcores[sidx_start_valcores]==progUnique[i]]=i

In [23]:
#construct labels
labels_valsamples=np.zeros(progList_valsamples[sidx_start_valsamples].size)
for i in range(progUnique.size):
    labels_valsamples[progList_valsamples[sidx_start_valsamples]==progUnique[i]]=i

In [24]:
#get cluster composition
inputCluster_valcores=np.zeros((sUnique_valcores.size,clusterUnique.size))
inputSubcluster_valcores=np.zeros((sUnique_valcores.size,subclusterUnique.size))
for i in range(sUnique_valcores.size):
    clusterRes_i=clusterRes_valcores[allImgNames_valcores==sUnique_valcores[i]]
    kmeans_sub_i=kmeans_sub_valcores[allImgNames_valcores==sUnique_valcores[i]]
    for j in range(clusterUnique.size):
        inputCluster_valcores[i,j]=np.sum(clusterRes_i==clusterUnique[j])
    for j in range(subclusterUnique.size):
        inputSubcluster_valcores[i,j]=np.sum(kmeans_sub_i==subclusterUnique[j])
inputCluster_valcores=inputCluster_valcores/np.sum(inputCluster_valcores,axis=1,keepdims=True)
inputSubcluster_valcores=inputSubcluster_valcores/np.sum(inputSubcluster_valcores,axis=1,keepdims=True)

In [25]:
#get cluster composition
inputCluster_valsamples=np.zeros((sUnique_valsamples.size,clusterUnique.size))
inputSubcluster_valsamples=np.zeros((sUnique_valsamples.size,subclusterUnique.size))
for i in range(sUnique_valsamples.size):
    clusterRes_i=clusterRes_valsamples[allImgNames_valsamples==sUnique_valsamples[i]]
    kmeans_sub_i=kmeans_sub_valsamples[allImgNames_valsamples==sUnique_valsamples[i]]
    for j in range(clusterUnique.size):
        inputCluster_valsamples[i,j]=np.sum(clusterRes_i==clusterUnique[j])
    for j in range(subclusterUnique.size):
        inputSubcluster_valsamples[i,j]=np.sum(kmeans_sub_i==subclusterUnique[j])
inputCluster_valsamples=inputCluster_valsamples/np.sum(inputCluster_valsamples,axis=1,keepdims=True)
inputSubcluster_valsamples=inputSubcluster_valsamples/np.sum(inputSubcluster_valsamples,axis=1,keepdims=True)

In [26]:
_,inputCounts_valcores=np.unique(allImgNames_valcores,return_counts=True)
inputAll_valcores=np.concatenate((np.concatenate((inputCluster_valcores,inputSubcluster_valcores),axis=1),inputCounts_valcores.reshape(-1,1)),axis=1)

In [27]:
_,inputCounts_valsamples=np.unique(allImgNames_valsamples,return_counts=True)
inputAll_valsamples=np.concatenate((np.concatenate((inputCluster_valsamples,inputSubcluster_valsamples),axis=1),inputCounts_valsamples.reshape(-1,1)),axis=1)

In [28]:
allPatientIDs_valcores=np.repeat('none',allImgNames_valcores.size).astype(object)
for s in np.unique(allImgNames_valcores):
    sidx=allImgNames_valcores==s
    slideID=s.split('_')[0]
    coreID=s.split('_')[-1]
    allPatientIDs_valcores[sidx]=metadata.patient_id[slideID+'_'+str.lower(coreID)]

In [29]:
allPatientIDs_valsamples=np.repeat('none',allImgNames_valsamples.size).astype(object)
for s in np.unique(allImgNames_valsamples):
    sidx=allImgNames_valsamples==s
    slideID=s.split('_')[0]
    coreID=s.split('_')[-1]
    allPatientIDs_valsamples[sidx]=metadata.patient_id[slideID+'_'+str.lower(coreID)]

In [30]:
#normalize count
with open(os.path.join(datadir,'processed','imgSizeByPath'), 'rb') as output:
    imgSize_median=pickle.load(output)
area_train=np.zeros(labels_train.size)
for s in range(labels_train.size):
    area_train[s]=imgSize_median[progUnique[labels_train.astype(int)][s]]
inputAll[:,-1]=inputAll[:,-1]/area_train

In [31]:
area_valcores=np.zeros(labels_valcores.size)
for s in range(labels_valcores.size):
    area_valcores[s]=imgSize_median[progUnique[labels_valcores.astype(int)][s]]
inputAll_valcores[:,-1]=inputAll_valcores[:,-1]/area_valcores

In [32]:
area_valsamples=np.zeros(labels_valsamples.size)
for s in range(labels_valsamples.size):
    area_valsamples[s]=imgSize_median[progUnique[labels_valsamples.astype(int)][s]]
inputAll_valsamples[:,-1]=inputAll_valsamples[:,-1]/area_valsamples

In [33]:
#concatenate cores
inputAll=np.concatenate((inputAll,np.concatenate((inputAll_valcores,inputAll_valsamples),axis=0)),axis=0)
imgNamesAll=np.concatenate((allImgNames[sidx_start],np.concatenate((allImgNames_valcores[sidx_start_valcores],allImgNames_valsamples[sidx_start_valsamples]))))
labelsAll=np.concatenate((labels_train,np.concatenate((labels_valcores,labels_valsamples))))

In [34]:
_,progCountsAll=np.unique(labelsAll,return_counts=True)
weights_train=np.sum(progCountsAll)/progCountsAll

In [35]:
seed=3
epochs=6000
saveFreq=200
lr=0.001 #initial learning rate
weight_decay=0 

# batchsize=4
batchsize=6000
model_str='fc3'

fc_dim1=64
fc_dim2=64
fc_dim3=64


dropout=0.01

name='exp0_pathologyClf_neighbor_clusters_exp0_subset_clusterOnly_countAreaNorm_crossVal_wAH_byPatientCorr'
logsavepath='/media/xinyi/log/cnnvae'+name
modelsavepath='/media/xinyi/models/cnnvae'+name
plotsavepath='/media/xinyi/plots/cnnvae'+name


if not os.path.exists(logsavepath):
    os.mkdir(logsavepath)
if not os.path.exists(modelsavepath):
    os.mkdir(modelsavepath)
if not os.path.exists(plotsavepath):
    os.mkdir(plotsavepath)



In [36]:
inputAll=scipy.stats.zscore(inputAll,axis=0)

In [37]:
def train(epoch,trainInput,labels_train):
    t = time.time()
    model.train()
    optimizer.zero_grad()

    pred = model(trainInput)

    loss=lossCE(pred,labels_train)

    loss.backward()
    optimizer.step()

    if epoch%500==0:
        print('Epoch: {:04d}'.format(epoch),
              'loss_train: {:.4f}'.format(loss))
    return loss.item()

In [38]:
allPatientIDs=np.concatenate((allPatientIDs,np.concatenate((allPatientIDs_valcores,allPatientIDs_valsamples))))
pIDList=allPatientIDs[np.concatenate((sidx_start,np.concatenate((sidx_start_valcores,sidx_start_valsamples))))]


In [39]:
inputAll=torch.tensor(inputAll).cuda().float()
labelsAll=torch.tensor(labelsAll).cuda().long()

testepoch=5000
predtest=np.zeros((inputAll.shape[0],np.unique(labels_train).size))

for patientIDX in range(np.unique(pIDList).size):
    patientID=np.unique(pIDList)[patientIDX]
    print(patientID)
    sampleIdx=np.arange(inputAll.shape[0])[pIDList==patientID]
    trainIdx=np.arange(inputAll.shape[0])[pIDList!=patientID]
    
    seed=3
    torch.manual_seed(seed)
    nclasses=np.unique(labels_train).size
    if use_cuda:
        torch.cuda.manual_seed(seed)

    nfeatures=inputAll.shape[1]
    if model_str=='fc3':
        model = modelsCNN.FC_l3(nfeatures,fc_dim1,fc_dim2,fc_dim3,nclasses,0.5,regrs=False)
        lossCE=torch.nn.CrossEntropyLoss(torch.tensor(weights_train).cuda().float())
    if model_str=='fc5':
        model = modelsCNN.FC_l5(nfeatures,fc_dim1,fc_dim2,fc_dim3,fc_dim4,fc_dim5,nclasses,0.5,regrs=False)
        lossCE=torch.nn.CrossEntropyLoss(torch.tensor(weights_train).cuda().float())
    if model_str=='fc1':
        model = modelsCNN.FC_l1(nfeatures,fc_dim1,nclasses,regrs=False)
        lossCE=torch.nn.CrossEntropyLoss(torch.tensor(weights_train).cuda().float())
    if model_str=='fc0':
        model = modelsCNN.FC_l0(nfeatures,nclasses,regrs=False)
        lossCE=torch.nn.CrossEntropyLoss(torch.tensor(weights_train).cuda().float())

    if use_cuda:
        model.cuda()

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    train_loss_ep=[None]*epochs
    val_loss_ep=[None]*epochs
    t_ep=time.time()

    for ep in range(epochs):
        train_loss_ep[ep]=train(ep,inputAll[trainIdx],labelsAll[trainIdx])


        if ep%saveFreq == 0 and ep!=0:
            torch.save(model.cpu().state_dict(), os.path.join(modelsavepath,patientID+'_'+str(ep)+'.pt'))
        if use_cuda:
            model.cuda()
            torch.cuda.empty_cache()
    print(' total time: {:.4f}s'.format(time.time() - t_ep))

    with open(os.path.join(logsavepath,patientID+'_train_loss'), 'wb') as output:
        pickle.dump(train_loss_ep, output, pickle.HIGHEST_PROTOCOL)

    model.load_state_dict(torch.load(os.path.join(modelsavepath,patientID+'_'+str(testepoch)+'.pt')))
    with torch.no_grad():
        model.cuda()
        model.eval()
        pred = model(inputAll[[sampleIdx]])
        predtest[sampleIdx]=pred.cpu().detach().numpy()

        loss_test=lossCE(pred,labelsAll[[sampleIdx]]).item()

    print(loss_test)

fmg010001
Epoch: 0000 loss_train: 2.0791
Epoch: 0500 loss_train: 0.4608
Epoch: 1000 loss_train: 0.2853
Epoch: 1500 loss_train: 0.1553
Epoch: 2000 loss_train: 0.1476
Epoch: 2500 loss_train: 0.0951
Epoch: 3000 loss_train: 0.1173
Epoch: 3500 loss_train: 0.0823
Epoch: 4000 loss_train: 0.0650
Epoch: 4500 loss_train: 0.0787
Epoch: 5000 loss_train: 0.0502
Epoch: 5500 loss_train: 0.0761
 total time: 9.2582s
5.642299175262451
fmg010016
Epoch: 0000 loss_train: 2.0804
Epoch: 0500 loss_train: 0.4780
Epoch: 1000 loss_train: 0.2783
Epoch: 1500 loss_train: 0.2591
Epoch: 2000 loss_train: 0.1337
Epoch: 2500 loss_train: 0.1207
Epoch: 3000 loss_train: 0.0937
Epoch: 3500 loss_train: 0.1101
Epoch: 4000 loss_train: 0.0813
Epoch: 4500 loss_train: 0.0707
Epoch: 5000 loss_train: 0.0714
Epoch: 5500 loss_train: 0.0728
 total time: 9.4755s
0.23730091750621796
fmg010027
Epoch: 0000 loss_train: 2.0838
Epoch: 0500 loss_train: 0.4343
Epoch: 1000 loss_train: 0.2521
Epoch: 1500 loss_train: 0.2128
Epoch: 2000 loss_train

Epoch: 3000 loss_train: 0.1104
Epoch: 3500 loss_train: 0.0980
Epoch: 4000 loss_train: 0.0612
Epoch: 4500 loss_train: 0.0639
Epoch: 5000 loss_train: 0.0715
Epoch: 5500 loss_train: 0.0570
 total time: 9.3878s
0.008433650247752666
fmg020034
Epoch: 0000 loss_train: 2.0757
Epoch: 0500 loss_train: 0.4988
Epoch: 1000 loss_train: 0.2792
Epoch: 1500 loss_train: 0.1919
Epoch: 2000 loss_train: 0.1138
Epoch: 2500 loss_train: 0.1474
Epoch: 3000 loss_train: 0.1069
Epoch: 3500 loss_train: 0.0798
Epoch: 4000 loss_train: 0.0634
Epoch: 4500 loss_train: 0.0602
Epoch: 5000 loss_train: 0.1052
Epoch: 5500 loss_train: 0.0578
 total time: 9.5239s
1.6450813973278855e-06
fmg020088
Epoch: 0000 loss_train: 2.0753
Epoch: 0500 loss_train: 0.4374
Epoch: 1000 loss_train: 0.2847
Epoch: 1500 loss_train: 0.2216
Epoch: 2000 loss_train: 0.1592
Epoch: 2500 loss_train: 0.1240
Epoch: 3000 loss_train: 0.0750
Epoch: 3500 loss_train: 0.1109
Epoch: 4000 loss_train: 0.0745
Epoch: 4500 loss_train: 0.0500
Epoch: 5000 loss_train: 0.

 total time: 9.3019s
19.073293685913086
fmg040104
Epoch: 0000 loss_train: 2.0795
Epoch: 0500 loss_train: 0.5025
Epoch: 1000 loss_train: 0.2816
Epoch: 1500 loss_train: 0.2079
Epoch: 2000 loss_train: 0.1444
Epoch: 2500 loss_train: 0.1232
Epoch: 3000 loss_train: 0.1246
Epoch: 3500 loss_train: 0.1055
Epoch: 4000 loss_train: 0.0848
Epoch: 4500 loss_train: 0.0977
Epoch: 5000 loss_train: 0.0688
Epoch: 5500 loss_train: 0.0549
 total time: 9.3156s
0.002214620588347316
fmg040346
Epoch: 0000 loss_train: 2.0798
Epoch: 0500 loss_train: 0.4619
Epoch: 1000 loss_train: 0.2518
Epoch: 1500 loss_train: 0.2147
Epoch: 2000 loss_train: 0.1305
Epoch: 2500 loss_train: 0.1175
Epoch: 3000 loss_train: 0.1212
Epoch: 3500 loss_train: 0.0953
Epoch: 4000 loss_train: 0.0725
Epoch: 4500 loss_train: 0.0946
Epoch: 5000 loss_train: 0.0716
Epoch: 5500 loss_train: 0.0506
 total time: 9.3637s
16.806142807006836
fmg040703
Epoch: 0000 loss_train: 2.0789
Epoch: 0500 loss_train: 0.4478
Epoch: 1000 loss_train: 0.3073
Epoch: 1500

Epoch: 2500 loss_train: 0.1350
Epoch: 3000 loss_train: 0.1017
Epoch: 3500 loss_train: 0.1205
Epoch: 4000 loss_train: 0.1293
Epoch: 4500 loss_train: 0.0667
Epoch: 5000 loss_train: 0.0814
Epoch: 5500 loss_train: 0.0685
 total time: 12.4901s
34.643043518066406
fmg150231
Epoch: 0000 loss_train: 2.0784
Epoch: 0500 loss_train: 0.4643
Epoch: 1000 loss_train: 0.3079
Epoch: 1500 loss_train: 0.1734
Epoch: 2000 loss_train: 0.1231
Epoch: 2500 loss_train: 0.1192
Epoch: 3000 loss_train: 0.1017
Epoch: 3500 loss_train: 0.1311
Epoch: 4000 loss_train: 0.0831
Epoch: 4500 loss_train: 0.0618
Epoch: 5000 loss_train: 0.0812
Epoch: 5500 loss_train: 0.0761
 total time: 10.0563s
20.548988342285156
fmg150258
Epoch: 0000 loss_train: 2.0803
Epoch: 0500 loss_train: 0.4330
Epoch: 1000 loss_train: 0.3411
Epoch: 1500 loss_train: 0.1626
Epoch: 2000 loss_train: 0.1395
Epoch: 2500 loss_train: 0.0901
Epoch: 3000 loss_train: 0.0880
Epoch: 3500 loss_train: 0.1021
Epoch: 4000 loss_train: 0.0700
Epoch: 4500 loss_train: 0.0763

In [40]:
with open(os.path.join(logsavepath,'crossVal_loss'), 'wb') as output:
    pickle.dump(predtest, output, pickle.HIGHEST_PROTOCOL)

In [41]:
predtest_label=np.argmax(predtest,axis=1)

In [42]:
res=pd.DataFrame({'patientID':pIDList,'sampleName':imgNamesAll,'true':progUnique[labelsAll.cpu().numpy()],'predicted':progUnique[predtest_label]})
res.to_csv(os.path.join(plotsavepath,'predictions.csv'))

In [8]:
# res=pd.read_csv(os.path.join(plotsavepath,'predictions.csv'),index_col=0)

In [10]:
# res=res.to_numpy()

In [43]:
#plot confusion
def plotCTcomp(labels,ctlist,savepath,savenamecluster,byCT,addname='',order=progInclude,showNumber=False):
    res=np.zeros((order.size,order.size))
    
    for li in range(res.shape[0]):
        l=order[li]
        nl=np.sum(labels==l)
        ctlist_l=ctlist[labels==l]
        for ci in range(res.shape[1]):
            c=order[ci]
            res[li,ci]=np.sum(ctlist_l==c)
#             res[li,ci]=np.sum(ctlist_l==c)/nl
    resCounts=np.copy(res)
    if not byCT:
        addname+=''
        for li in range(res.shape[0]):
            l=order[li]
            nl=np.sum(labels==l)
            res[li]=res[li]/nl
    else:
        addname+='_normbyCT'
        for ci in range(res.shape[1]):
            c=order[ci]
            nc=np.sum(ctlist==c)
            res[:,ci]=res[:,ci]/nc
    
    fig, ax = plt.subplots(figsize=(10, 10))
    im = ax.imshow(res,cmap='binary',vmin=0,vmax=1)
    if showNumber:
        addname=addname+'_showNumber'
        for i in range(res.shape[0]):
            for j in range(res.shape[1]):
                text = ax.text(j, i, f'{resCounts[i, j]:.0f}',
                               ha="center", va="center", color="c",fontsize='xx-large')
    fig.colorbar(im)
    ax.set_yticks(np.arange(order.size))
    ax.set_yticklabels(order)
    ax.set_xticks(np.arange(order.size))
    ax.set_xticklabels(order)
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",rotation_mode="anchor")
    fig.tight_layout()
    plt.savefig(os.path.join(savepath,savenamecluster+addname+'.pdf'))
    plt.close()
# plotCTcomp(res[:,2],res[:,3],plotsavepath,'confusion'+str(5000),False,showNumber=True)
plotCTcomp(progUnique[labelsAll.cpu().numpy()],progUnique[predtest_label],plotsavepath,'confusion'+str(testepoch),False,showNumber=True)
plotCTcomp(progUnique[labelsAll.cpu().numpy()],progUnique[predtest_label],plotsavepath,'confusion'+str(testepoch),False,showNumber=False)