In [1]:
import time
import os

import scanpy
import numpy as np
import scipy.sparse as sp

import torch
from torch import optim
from torch.utils.data import DataLoader

import models.loadImg as loadImg
import models.modelsCNN as modelsCNN
import models.optimizer as optimizer

import pickle
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import gc
from skimage import io
import scipy.stats
from sklearn.metrics import pairwise_distances

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = '1,2,3'
use_cuda=True
datadir='/media/xinyi/dcis2idc/data'
name='exp0'
plotsavepath='/media/xinyi/dcis2idc/plots/cnnvae'+name
sampledir=plotsavepath
clustersavedir_alltrain=os.path.join(sampledir,'cluster_alltrain_reordered')
ep=311
with open(os.path.join(datadir,'processed','train_cnnvae_names'), 'rb') as input:
    allImgNames=pickle.load(input)
#plot by disease progression
br1003aSpecs=pd.read_excel('/media/xinyi/dcis2idc/data/BR1003a specs.xlsx',header=10)
br301Specs=pd.read_excel('/media/xinyi/dcis2idc/data/BR301 specs.xlsx',header=10)
br8018aSpecs=pd.read_excel('/media/xinyi/dcis2idc/data/BR8018a specs.xlsx',header=10)
br1003aSpecs.index=br1003aSpecs.loc[:,'Position']
br301Specs.index=br301Specs.loc[:,'Position']
br8018aSpecs.index=br8018aSpecs.loc[:,'Position']

progList=np.copy(allImgNames)
for s in np.unique(allImgNames):
    ssplit=s.split('_')
    if 'br1003a'==ssplit[0]:
        prog_s=br1003aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br301'==ssplit[0]:
        prog_s=br301Specs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br8018a'==ssplit[0]:
        prog_s=br8018aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    progList[allImgNames==s]=prog_s
    




In [3]:
savenamesample='alltrain'

neworder=[1, 5, 3, 7, 2, 0, 4, 6]
#use chosen subcluster number and save plots
scanpy.settings.verbosity = 3
# subcluster=8
subclusterDict={0:[4],1:[6],2:[8],3:[6],4:[6],5:[6],6:[6],7:[4]}
ncluster=8

plotepoch=311
clusterplotdir=os.path.join(clustersavedir_alltrain,'plots')
n_pcs=50
savenamecluster='minibatchkmean_ncluster'+str(ncluster)+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)
with open(os.path.join(clustersavedir_alltrain,'minibatchkmean_ncluster'+str(ncluster)+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+'_all'), 'rb') as output:
    clusterRes=pickle.load(output)

kmeans_sub=(np.zeros(clusterRes.size)-1).astype(str)
savenameAdd='_plottingIdx_progBalanced_'+str(0)
subclusternumbers=[4,6,8,6,6,6,6,4]
savenamecluster='minibatchkmean_ncluster'+str(ncluster)+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+savenameAdd
for cnew in np.unique(clusterRes):
#     print('cluster'+str(c))
    c=neworder[cnew]
    
    subclustersavedir_alltrain=os.path.join(clustersavedir_alltrain,savenamecluster+'_subcluster'+str(c))
    with open(os.path.join(subclustersavedir_alltrain,'minibatchkmean_ncluster'+str(subclusternumbers[c])+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+'_all'), 'rb') as output:
        subclusterRes=pickle.load(output)
    print(np.unique(subclusterRes))
    kmeans_sub[clusterRes==cnew]=np.char.add(np.repeat(str(cnew)+'-',subclusterRes.size),subclusterRes.astype(str))
   

[0 1 2 3 4 5]
[0 1 2 3 4 5]
[0 1 2 3 4 5]
[0 1 2 3]
[0 1 2 3 4 5 6 7]
[0 1 2 3]
[0 1 2 3 4 5]
[0 1 2 3 4 5]


In [4]:
with open(os.path.join(datadir,'processed','train_cnnvae_coord'), 'rb') as output:
    coordlist=pickle.load(output)

In [5]:
for p in np.unique(progList):
    if p=='Ductal carcinoma in situ':
        progList[progList==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ and breast tissue':
        progList[progList==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ with early infiltratio':
        progList[progList==p]='DCIS with early infiltration'
    
    elif p=='Micropapillary type ductal carcinoma in situ wi':
        progList[progList==p]='DCIS with early infiltration'    
#     elif p=='Atypical hyperlasia':
#         progList[progList==p]='Hyperplasia'

In [6]:
progInclude=np.array(['Breast tissue','Cancer adjacent normal breast tissue','Hyperplasia','DCIS and breast tissue',
                     'DCIS with early infiltration','Invasive ductal carcinoma and breast tissue','Invasive ductal carcinoma'])

In [7]:
progIncludeIdx=np.repeat(False,progList.size)
for p in progInclude:
    progIncludeIdx[progList==p]=True
    
coordlist=coordlist[progIncludeIdx]
allImgNames=allImgNames[progIncludeIdx]
clusterRes=clusterRes[progIncludeIdx]
kmeans_sub=kmeans_sub[progIncludeIdx]
progList=progList[progIncludeIdx]

In [8]:
sUnique,sidx_start=np.unique(allImgNames,return_index=True)
progUnique,labels_train,progCounts=np.unique(progList[sidx_start],return_counts=True,return_inverse=True)
for p in range(progUnique.size):
    print(progUnique[p])
    print(progCounts[p])

Breast tissue
20
Cancer adjacent normal breast tissue
13
DCIS and breast tissue
16
DCIS with early infiltration
30
Hyperplasia
41
Invasive ductal carcinoma
70
Invasive ductal carcinoma and breast tissue
8


In [9]:
#get cluster composition
clusterUnique=np.unique(clusterRes)
subclusterUnique=np.unique(kmeans_sub)
inputCluster=np.zeros((sUnique.size,clusterUnique.size))
inputSubcluster=np.zeros((sUnique.size,subclusterUnique.size))
for i in range(sUnique.size):
    clusterRes_i=clusterRes[allImgNames==sUnique[i]]
    kmeans_sub_i=kmeans_sub[allImgNames==sUnique[i]]
    for j in range(clusterUnique.size):
        inputCluster[i,j]=np.sum(clusterRes_i==clusterUnique[j])
    for j in range(subclusterUnique.size):
        inputSubcluster[i,j]=np.sum(kmeans_sub_i==subclusterUnique[j])
inputCluster=inputCluster/np.sum(inputCluster,axis=1,keepdims=True)
inputSubcluster=inputSubcluster/np.sum(inputSubcluster,axis=1,keepdims=True)

In [10]:
def getHistMatrix_clusters(labels,ctlist,nrow=ncluster,ncol=ncluster):
    res=np.zeros((nrow,ncol))
    for li in range(res.shape[0]):
        l=li
        nl=np.sum(labels==l)
        ctlist_l=ctlist[labels==l]
        for ci in range(res.shape[1]):
            c=ci
            res[l,c]=np.sum(ctlist_l==c)
#             res[li,ci]=np.sum(ctlist_l==c)/nl
        if nl!=0:
            res[li]=res[li]/nl
    return res

neighborhoodSize=16*9

In [11]:
#get neighborhood composition

inputNeighborhood=np.zeros((sUnique.size,ncluster*ncluster))
for i in range(sUnique.size):
    imgN=sUnique[i]
    nsamples=np.sum(allImgNames==imgN)
    cluster_i=clusterRes[allImgNames==imgN]
    neighbor_i=np.tile(cluster_i,(nsamples,1))
    self_i=np.tile(cluster_i.reshape((-1,1)),(1,nsamples))

    dist=pairwise_distances(coordlist[allImgNames==imgN],n_jobs=-1)
    distIn=np.logical_and(dist<neighborhoodSize,dist>0)
    res=getHistMatrix_clusters(self_i[distIn],neighbor_i[distIn])
    
    inputNeighborhood[i]=res.flatten()

In [12]:
_,inputCounts=np.unique(allImgNames,return_counts=True)
inputAll=np.concatenate((np.concatenate((inputCluster,inputSubcluster),axis=1),np.concatenate((inputNeighborhood,inputCounts.reshape(-1,1)),axis=1)),axis=1)

In [13]:
#val cores (as validation cores) and val samples (as test cores)
clustersavedir_valcores=os.path.join(sampledir,'cluster_valcores_reordered')
clustersavedir_valsamples=os.path.join(sampledir,'cluster_valsamples_reordered')

with open(os.path.join(datadir,'processed','train_cnnvae_coord_valcores'), 'rb') as output:
    coordlist_valcores=pickle.load(output)
with open(os.path.join(datadir,'processed','train_cnnvae_coord_valsamples'), 'rb') as output:
    coordlist_valsamples=pickle.load(output)

savenamecluster='minibatchkmean_ncluster'+str(ncluster)+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)
with open(os.path.join(clustersavedir_valcores,savenamecluster+'_all'), 'rb') as output:
    clusterRes_valcores=pickle.load(output)
with open(os.path.join(clustersavedir_valsamples,'minibatchkmean_ncluster'+str(ncluster)+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+'_all'), 'rb') as output:
    clusterRes_valsamples=pickle.load(output)
    
kmeans_sub_valcores=(np.zeros(clusterRes_valcores.size)-1).astype(str)
for c in np.unique(clusterRes_valcores):
    subclustersavedir=os.path.join(clustersavedir_valcores,savenamecluster+'_plottingIdx_progBalanced_'+str(0)+'_subcluster'+str(neworder[c]))
    with open(os.path.join(subclustersavedir,'minibatchkmean_ncluster'+str(subclusterDict[neworder[c]][0])+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+'_all'), 'rb') as output:
        subclusterRes=pickle.load(output)
    kmeans_sub_valcores[clusterRes_valcores==c]=np.char.add(np.repeat(str(c)+'-',subclusterRes.size),subclusterRes.astype(str))
    
kmeans_sub_valsamples=(np.zeros(clusterRes_valsamples.size)-1).astype(str)
for c in np.unique(clusterRes_valsamples):
    subclustersavedir=os.path.join(clustersavedir_valsamples,savenamecluster+'_plottingIdx_progBalanced_'+str(0)+'_subcluster'+str(neworder[c]))
    with open(os.path.join(subclustersavedir,'minibatchkmean_ncluster'+str(subclusterDict[neworder[c]][0])+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+'_all'), 'rb') as output:
        subclusterRes=pickle.load(output)
    kmeans_sub_valsamples[clusterRes_valsamples==c]=np.char.add(np.repeat(str(c)+'-',subclusterRes.size),subclusterRes.astype(str))
                

In [14]:
with open(os.path.join(datadir,'processed','train_cnnvae_names_valcores'), 'rb') as input:
    allImgNames_valcores=pickle.load(input)
with open(os.path.join(datadir,'processed','train_cnnvae_names_valsamples'), 'rb') as input:
    allImgNames_valsamples=pickle.load(input)

In [15]:
#plot by disease progression
br1003aSpecs=pd.read_excel('/media/xinyi/dcis2idc/data/BR1003a specs.xlsx',header=10)
br301Specs=pd.read_excel('/media/xinyi/dcis2idc/data/BR301 specs.xlsx',header=10)
br8018aSpecs=pd.read_excel('/media/xinyi/dcis2idc/data/BR8018a specs.xlsx',header=10)
br1003aSpecs.index=br1003aSpecs.loc[:,'Position']
br301Specs.index=br301Specs.loc[:,'Position']
br8018aSpecs.index=br8018aSpecs.loc[:,'Position']

progList_valcores=np.copy(allImgNames_valcores)
for s in np.unique(allImgNames_valcores):
    ssplit=s.split('_')
    if 'br1003a'==ssplit[0]:
        prog_s=br1003aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br301'==ssplit[0]:
        prog_s=br301Specs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br8018a'==ssplit[0]:
        prog_s=br8018aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    progList_valcores[allImgNames_valcores==s]=prog_s
    
progList_valsamples=np.copy(allImgNames_valsamples)
for s in np.unique(allImgNames_valsamples):
    ssplit=s.split('_')
    if 'br1003a'==ssplit[0]:
        prog_s=br1003aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br301'==ssplit[0]:
        prog_s=br301Specs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br8018a'==ssplit[0]:
        prog_s=br8018aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    progList_valsamples[allImgNames_valsamples==s]=prog_s
    



In [16]:
np.unique(progList_valsamples)


array(['Atypical hyperplasia', 'Breast tissue',
       'Cancer adjacent normal breast tissue', 'Ductal carcinoma in situ',
       'Ductal carcinoma in situ and breast tissue',
       'Ductal carcinoma in situ with early infiltrati', 'Hyperplasia',
       'Hyperplasia with saccular dilatation',
       'Invasive ductal carcinoma',
       'Invasive ductal carcinoma (breast tissue)',
       'Invasive ductal carcinoma and breast tissue',
       'Micropapillary type ductal carcinoma in situ w'], dtype='<U46')

In [17]:
for p in np.unique(progList_valcores):
    if p=='Ductal carcinoma in situ':
        progList_valcores[progList_valcores==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ and breast tissue':
        progList_valcores[progList_valcores==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ with early infiltratio':
        progList_valcores[progList_valcores==p]='DCIS with early infiltration'
    
    elif p=='Micropapillary type ductal carcinoma in situ wi':
        progList_valcores[progList_valcores==p]='DCIS with early infiltration'    
#     elif p=='Atypical hyperlasia':
#         progList_valcores[progList_valcores==p]='Hyperplasia'

In [18]:
for p in np.unique(progList_valsamples):
    if p=='Ductal carcinoma in situ':
        progList_valsamples[progList_valsamples==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ and breast tissue':
        progList_valsamples[progList_valsamples==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ with early infiltrati':
        progList_valsamples[progList_valsamples==p]='DCIS with early infiltration'
    
    elif p=='Micropapillary type ductal carcinoma in situ w':
        progList_valsamples[progList_valsamples==p]='DCIS with early infiltration'    
#     elif p=='Atypical hyperlasia':
#         progList_valsamples[progList_valsamples==p]='Hyperplasia'

In [19]:
progIncludeIdx_valcores=np.repeat(False,progList_valcores.size)
for p in progInclude:
    progIncludeIdx_valcores[progList_valcores==p]=True
    
progIncludeIdx_valsamples=np.repeat(False,progList_valsamples.size)
for p in progInclude:
    progIncludeIdx_valsamples[progList_valsamples==p]=True
    
coordlist_valcores=coordlist_valcores[progIncludeIdx_valcores]
allImgNames_valcores=allImgNames_valcores[progIncludeIdx_valcores]
clusterRes_valcores=clusterRes_valcores[progIncludeIdx_valcores]
kmeans_sub_valcores=kmeans_sub_valcores[progIncludeIdx_valcores]
progList_valcores=progList_valcores[progIncludeIdx_valcores]

coordlist_valsamples=coordlist_valsamples[progIncludeIdx_valsamples]
allImgNames_valsamples=allImgNames_valsamples[progIncludeIdx_valsamples]
clusterRes_valsamples=clusterRes_valsamples[progIncludeIdx_valsamples]
kmeans_sub_valsamples=kmeans_sub_valsamples[progIncludeIdx_valsamples]
progList_valsamples=progList_valsamples[progIncludeIdx_valsamples]


In [20]:
sUnique_valcores,sidx_start_valcores=np.unique(allImgNames_valcores,return_index=True)
progUnique_valcores,progCounts_valcores=np.unique(progList_valcores[sidx_start_valcores],return_counts=True)
for p in range(progUnique_valcores.size):
    print(progUnique_valcores[p])
    print(progCounts_valcores[p])

Breast tissue
20
Cancer adjacent normal breast tissue
1
Hyperplasia
35
Invasive ductal carcinoma
97


In [21]:
sUnique_valsamples,sidx_start_valsamples=np.unique(allImgNames_valsamples,return_index=True)
progUnique_valsamples,progCounts_valsamples=np.unique(progList_valsamples[sidx_start_valsamples],return_counts=True)
for p in range(progUnique_valsamples.size):
    print(progUnique_valsamples[p])
    print(progCounts_valsamples[p])

Breast tissue
14
Cancer adjacent normal breast tissue
4
DCIS and breast tissue
16
DCIS with early infiltration
29
Hyperplasia
25
Invasive ductal carcinoma
66
Invasive ductal carcinoma and breast tissue
8


In [22]:
#construct labels
labels_valcores=np.zeros(progList_valcores[sidx_start_valcores].size)
for i in range(progUnique.size):
    labels_valcores[progList_valcores[sidx_start_valcores]==progUnique[i]]=i

In [23]:
#construct labels
labels_valsamples=np.zeros(progList_valsamples[sidx_start_valsamples].size)
for i in range(progUnique.size):
    labels_valsamples[progList_valsamples[sidx_start_valsamples]==progUnique[i]]=i

In [24]:
#get cluster composition
inputCluster_valcores=np.zeros((sUnique_valcores.size,clusterUnique.size))
inputSubcluster_valcores=np.zeros((sUnique_valcores.size,subclusterUnique.size))
for i in range(sUnique_valcores.size):
    clusterRes_i=clusterRes_valcores[allImgNames_valcores==sUnique_valcores[i]]
    kmeans_sub_i=kmeans_sub_valcores[allImgNames_valcores==sUnique_valcores[i]]
    for j in range(clusterUnique.size):
        inputCluster_valcores[i,j]=np.sum(clusterRes_i==clusterUnique[j])
    for j in range(subclusterUnique.size):
        inputSubcluster_valcores[i,j]=np.sum(kmeans_sub_i==subclusterUnique[j])
inputCluster_valcores=inputCluster_valcores/np.sum(inputCluster_valcores,axis=1,keepdims=True)
inputSubcluster_valcores=inputSubcluster_valcores/np.sum(inputSubcluster_valcores,axis=1,keepdims=True)

In [25]:
#get cluster composition
inputCluster_valsamples=np.zeros((sUnique_valsamples.size,clusterUnique.size))
inputSubcluster_valsamples=np.zeros((sUnique_valsamples.size,subclusterUnique.size))
for i in range(sUnique_valsamples.size):
    clusterRes_i=clusterRes_valsamples[allImgNames_valsamples==sUnique_valsamples[i]]
    kmeans_sub_i=kmeans_sub_valsamples[allImgNames_valsamples==sUnique_valsamples[i]]
    for j in range(clusterUnique.size):
        inputCluster_valsamples[i,j]=np.sum(clusterRes_i==clusterUnique[j])
    for j in range(subclusterUnique.size):
        inputSubcluster_valsamples[i,j]=np.sum(kmeans_sub_i==subclusterUnique[j])
inputCluster_valsamples=inputCluster_valsamples/np.sum(inputCluster_valsamples,axis=1,keepdims=True)
inputSubcluster_valsamples=inputSubcluster_valsamples/np.sum(inputSubcluster_valsamples,axis=1,keepdims=True)

In [26]:
inputNeighborhood_valcores=np.zeros((sUnique_valcores.size,ncluster*ncluster))
for i in range(sUnique_valcores.size):
    imgN=sUnique_valcores[i]
    nsamples=np.sum(allImgNames_valcores==imgN)
    cluster_i=clusterRes_valcores[allImgNames_valcores==imgN]
    neighbor_i=np.tile(cluster_i,(nsamples,1))
    self_i=np.tile(cluster_i.reshape((-1,1)),(1,nsamples))

    dist=pairwise_distances(coordlist_valcores[allImgNames_valcores==imgN],n_jobs=-1)
    distIn=np.logical_and(dist<neighborhoodSize,dist>0)
    res=getHistMatrix_clusters(self_i[distIn],neighbor_i[distIn])
    
    inputNeighborhood_valcores[i]=res.flatten()

In [27]:
inputNeighborhood_valsamples=np.zeros((sUnique_valsamples.size,ncluster*ncluster))
for i in range(sUnique_valsamples.size):
    imgN=sUnique_valsamples[i]
    nsamples=np.sum(allImgNames_valsamples==imgN)
    cluster_i=clusterRes_valsamples[allImgNames_valsamples==imgN]
    neighbor_i=np.tile(cluster_i,(nsamples,1))
    self_i=np.tile(cluster_i.reshape((-1,1)),(1,nsamples))

    dist=pairwise_distances(coordlist_valsamples[allImgNames_valsamples==imgN],n_jobs=-1)
    distIn=np.logical_and(dist<neighborhoodSize,dist>0)
    res=getHistMatrix_clusters(self_i[distIn],neighbor_i[distIn])
    
    inputNeighborhood_valsamples[i]=res.flatten()

In [28]:
_,inputCounts_valcores=np.unique(allImgNames_valcores,return_counts=True)
inputAll_valcores=np.concatenate((np.concatenate((inputCluster_valcores,inputSubcluster_valcores),axis=1),np.concatenate((inputNeighborhood_valcores,inputCounts_valcores.reshape(-1,1)),axis=1)),axis=1)

In [29]:
_,inputCounts_valsamples=np.unique(allImgNames_valsamples,return_counts=True)
inputAll_valsamples=np.concatenate((np.concatenate((inputCluster_valsamples,inputSubcluster_valsamples),axis=1),np.concatenate((inputNeighborhood_valsamples,inputCounts_valsamples.reshape(-1,1)),axis=1)),axis=1)

In [30]:
#normalize count
with open(os.path.join(datadir,'processed','imgSizeByPath'), 'rb') as output:
    imgSize_median=pickle.load(output)
area_train=np.zeros(labels_train.size)
for s in range(labels_train.size):
    area_train[s]=imgSize_median[progUnique[labels_train.astype(int)][s]]
inputAll[:,-1]=inputAll[:,-1]/area_train

In [31]:
area_valcores=np.zeros(labels_valcores.size)
for s in range(labels_valcores.size):
    area_valcores[s]=imgSize_median[progUnique[labels_valcores.astype(int)][s]]
inputAll_valcores[:,-1]=inputAll_valcores[:,-1]/area_valcores

In [32]:
area_valsamples=np.zeros(labels_valsamples.size)
for s in range(labels_valsamples.size):
    area_valsamples[s]=imgSize_median[progUnique[labels_valsamples.astype(int)][s]]
inputAll_valsamples[:,-1]=inputAll_valsamples[:,-1]/area_valsamples

In [33]:
#concatenate cores
inputAll=np.concatenate((inputAll,np.concatenate((inputAll_valcores,inputAll_valsamples),axis=0)),axis=0)
imgNamesAll=np.concatenate((allImgNames[sidx_start],np.concatenate((allImgNames_valcores[sidx_start_valcores],allImgNames_valsamples[sidx_start_valsamples]))))
labelsAll=np.concatenate((labels_train,np.concatenate((labels_valcores,labels_valsamples))))

In [34]:
_,progCountsAll=np.unique(labelsAll,return_counts=True)
weights_train=np.sum(progCountsAll)/progCountsAll

In [35]:
seed=3
epochs=6000
saveFreq=200
lr=0.001 #initial learning rate
weight_decay=0 

# batchsize=4
batchsize=6000
model_str='fc3'

fc_dim1=64
fc_dim2=64
fc_dim3=64


dropout=0.01

name='exp0_pathologyClf_neighbor_clusters_exp0_subset_countAreaNorm_crossVal'
logsavepath='/media/xinyi/dcis2idc/log/cnnvae'+name
modelsavepath='/media/xinyi/dcis2idc/models/cnnvae'+name
plotsavepath='/media/xinyi/dcis2idc/plots/cnnvae'+name


if not os.path.exists(logsavepath):
    os.mkdir(logsavepath)
if not os.path.exists(modelsavepath):
    os.mkdir(modelsavepath)
if not os.path.exists(plotsavepath):
    os.mkdir(plotsavepath)



In [36]:
inputAll=scipy.stats.zscore(inputAll,axis=0)

In [37]:
def train(epoch,trainInput,labels_train):
    t = time.time()
    model.train()
    optimizer.zero_grad()

    pred = model(trainInput)

    loss=lossCE(pred,labels_train)

    loss.backward()
    optimizer.step()

    if epoch%500==0:
        print('Epoch: {:04d}'.format(epoch),
              'loss_train: {:.4f}'.format(loss))
    return loss.item()

In [38]:
inputAll=torch.tensor(inputAll).cuda().float()
labelsAll=torch.tensor(labelsAll).cuda().long()

testepoch=5800
predtest=np.zeros((inputAll.shape[0],np.unique(labels_train).size))
for sampleIdx in range(inputAll.shape[0]):

    trainIdx=np.arange(inputAll.shape[0])!=sampleIdx
    
    seed=3
    torch.manual_seed(seed)
    nclasses=np.unique(labels_train).size
    if use_cuda:
        torch.cuda.manual_seed(seed)

    nfeatures=inputAll.shape[1]
    if model_str=='fc3':
        model = modelsCNN.FC_l3(nfeatures,fc_dim1,fc_dim2,fc_dim3,nclasses,0.5,regrs=False)
        lossCE=torch.nn.CrossEntropyLoss(torch.tensor(weights_train).cuda().float())
    if model_str=='fc5':
        model = modelsCNN.FC_l5(nfeatures,fc_dim1,fc_dim2,fc_dim3,fc_dim4,fc_dim5,nclasses,0.5,regrs=False)
        lossCE=torch.nn.CrossEntropyLoss(torch.tensor(weights_train).cuda().float())
    if model_str=='fc1':
        model = modelsCNN.FC_l1(nfeatures,fc_dim1,nclasses,regrs=False)
        lossCE=torch.nn.CrossEntropyLoss(torch.tensor(weights_train).cuda().float())
    if model_str=='fc0':
        model = modelsCNN.FC_l0(nfeatures,nclasses,regrs=False)
        lossCE=torch.nn.CrossEntropyLoss(torch.tensor(weights_train).cuda().float())

    if use_cuda:
        model.cuda()

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    train_loss_ep=[None]*epochs
    val_loss_ep=[None]*epochs
    t_ep=time.time()

    for ep in range(epochs):
        train_loss_ep[ep]=train(ep,inputAll[trainIdx],labelsAll[trainIdx])


        if ep%saveFreq == 0 and ep!=0:
            torch.save(model.cpu().state_dict(), os.path.join(modelsavepath,imgNamesAll[sampleIdx]+'_'+str(ep)+'.pt'))
        if use_cuda:
            model.cuda()
            torch.cuda.empty_cache()
    print(' total time: {:.4f}s'.format(time.time() - t_ep))

    with open(os.path.join(logsavepath,imgNamesAll[sampleIdx]+'_train_loss'), 'wb') as output:
        pickle.dump(train_loss_ep, output, pickle.HIGHEST_PROTOCOL)

    model.load_state_dict(torch.load(os.path.join(modelsavepath,imgNamesAll[sampleIdx]+'_'+str(testepoch)+'.pt')))
    with torch.no_grad():
        model.cuda()
        model.eval()
        pred = model(inputAll[[sampleIdx]])
        predtest[sampleIdx]=pred.cpu().detach().numpy()

        loss_test=lossCE(pred,labelsAll[[sampleIdx]]).item()

    print(loss_test)

Epoch: 0000 loss_train: 1.9516
Epoch: 0500 loss_train: 0.2583
Epoch: 1000 loss_train: 0.0794
Epoch: 1500 loss_train: 0.0371
Epoch: 2000 loss_train: 0.0629
Epoch: 2500 loss_train: 0.0445
Epoch: 3000 loss_train: 0.0501
Epoch: 3500 loss_train: 0.0294
Epoch: 4000 loss_train: 0.0273
Epoch: 4500 loss_train: 0.0181
Epoch: 5000 loss_train: 0.0277
Epoch: 5500 loss_train: 0.0141
 total time: 8.6591s
0.0
Epoch: 0000 loss_train: 1.9516
Epoch: 0500 loss_train: 0.2455
Epoch: 1000 loss_train: 0.0776
Epoch: 1500 loss_train: 0.0430
Epoch: 2000 loss_train: 0.0491
Epoch: 2500 loss_train: 0.0337
Epoch: 3000 loss_train: 0.0444
Epoch: 3500 loss_train: 0.0297
Epoch: 4000 loss_train: 0.0459
Epoch: 4500 loss_train: 0.0342
Epoch: 5000 loss_train: 0.0159
Epoch: 5500 loss_train: 0.0232
 total time: 8.2596s
0.0
Epoch: 0000 loss_train: 1.9515
Epoch: 0500 loss_train: 0.2472
Epoch: 1000 loss_train: 0.0851
Epoch: 1500 loss_train: 0.0413
Epoch: 2000 loss_train: 0.0491
Epoch: 2500 loss_train: 0.0472
Epoch: 3000 loss_tra

 total time: 7.8276s
2.062299427052494e-05
Epoch: 0000 loss_train: 1.9520
Epoch: 0500 loss_train: 0.2528
Epoch: 1000 loss_train: 0.0872
Epoch: 1500 loss_train: 0.0436
Epoch: 2000 loss_train: 0.0669
Epoch: 2500 loss_train: 0.0563
Epoch: 3000 loss_train: 0.0353
Epoch: 3500 loss_train: 0.0438
Epoch: 4000 loss_train: 0.0199
Epoch: 4500 loss_train: 0.0456
Epoch: 5000 loss_train: 0.0285
Epoch: 5500 loss_train: 0.0178
 total time: 7.7368s
5.411955135059543e-05
Epoch: 0000 loss_train: 1.9521
Epoch: 0500 loss_train: 0.2506
Epoch: 1000 loss_train: 0.0919
Epoch: 1500 loss_train: 0.0426
Epoch: 2000 loss_train: 0.0515
Epoch: 2500 loss_train: 0.0526
Epoch: 3000 loss_train: 0.0483
Epoch: 3500 loss_train: 0.0410
Epoch: 4000 loss_train: 0.0552
Epoch: 4500 loss_train: 0.0244
Epoch: 5000 loss_train: 0.0224
Epoch: 5500 loss_train: 0.0705
 total time: 7.3657s
0.0
Epoch: 0000 loss_train: 1.9521
Epoch: 0500 loss_train: 0.2538
Epoch: 1000 loss_train: 0.0887
Epoch: 1500 loss_train: 0.0353
Epoch: 2000 loss_trai

Epoch: 5500 loss_train: 0.0122
 total time: 7.7695s
2.3841855067985307e-07
Epoch: 0000 loss_train: 1.9524
Epoch: 0500 loss_train: 0.2498
Epoch: 1000 loss_train: 0.0964
Epoch: 1500 loss_train: 0.0402
Epoch: 2000 loss_train: 0.0410
Epoch: 2500 loss_train: 0.0385
Epoch: 3000 loss_train: 0.0318
Epoch: 3500 loss_train: 0.0326
Epoch: 4000 loss_train: 0.0235
Epoch: 4500 loss_train: 0.0411
Epoch: 5000 loss_train: 0.0221
Epoch: 5500 loss_train: 0.0282
 total time: 8.0629s
0.003133864840492606
Epoch: 0000 loss_train: 1.9525
Epoch: 0500 loss_train: 0.2386
Epoch: 1000 loss_train: 0.0856
Epoch: 1500 loss_train: 0.0376
Epoch: 2000 loss_train: 0.0613
Epoch: 2500 loss_train: 0.0318
Epoch: 3000 loss_train: 0.0436
Epoch: 3500 loss_train: 0.0207
Epoch: 4000 loss_train: 0.0171
Epoch: 4500 loss_train: 0.0453
Epoch: 5000 loss_train: 0.0171
Epoch: 5500 loss_train: 0.0203
 total time: 8.1978s
1.1920927533992653e-07
Epoch: 0000 loss_train: 1.9526
Epoch: 0500 loss_train: 0.2452
Epoch: 1000 loss_train: 0.0761
Ep

Epoch: 5000 loss_train: 0.0598
Epoch: 5500 loss_train: 0.0340
 total time: 8.2112s
0.16008363664150238
Epoch: 0000 loss_train: 1.9532
Epoch: 0500 loss_train: 0.2414
Epoch: 1000 loss_train: 0.1068
Epoch: 1500 loss_train: 0.0374
Epoch: 2000 loss_train: 0.0890
Epoch: 2500 loss_train: 0.0497
Epoch: 3000 loss_train: 0.0314
Epoch: 3500 loss_train: 0.0397
Epoch: 4000 loss_train: 0.0245
Epoch: 4500 loss_train: 0.0179
Epoch: 5000 loss_train: 0.0278
Epoch: 5500 loss_train: 0.0162
 total time: 8.2325s
0.572195291519165
Epoch: 0000 loss_train: 1.9530
Epoch: 0500 loss_train: 0.2415
Epoch: 1000 loss_train: 0.0850
Epoch: 1500 loss_train: 0.0322
Epoch: 2000 loss_train: 0.0535
Epoch: 2500 loss_train: 0.0461
Epoch: 3000 loss_train: 0.0318
Epoch: 3500 loss_train: 0.0229
Epoch: 4000 loss_train: 0.0189
Epoch: 4500 loss_train: 0.0407
Epoch: 5000 loss_train: 0.0136
Epoch: 5500 loss_train: 0.0176
 total time: 8.1927s
12.96768569946289
Epoch: 0000 loss_train: 1.9534
Epoch: 0500 loss_train: 0.2517
Epoch: 1000 l

Epoch: 4500 loss_train: 0.0452
Epoch: 5000 loss_train: 0.0168
Epoch: 5500 loss_train: 0.0255
 total time: 7.5860s
0.0002847504511009902
Epoch: 0000 loss_train: 1.9540
Epoch: 0500 loss_train: 0.2467
Epoch: 1000 loss_train: 0.0922
Epoch: 1500 loss_train: 0.0365
Epoch: 2000 loss_train: 0.0565
Epoch: 2500 loss_train: 0.0489
Epoch: 3000 loss_train: 0.0823
Epoch: 3500 loss_train: 0.0295
Epoch: 4000 loss_train: 0.0356
Epoch: 4500 loss_train: 0.0281
Epoch: 5000 loss_train: 0.0192
Epoch: 5500 loss_train: 0.0738
 total time: 7.7124s
4.303362584323622e-05
Epoch: 0000 loss_train: 1.9534
Epoch: 0500 loss_train: 0.2298
Epoch: 1000 loss_train: 0.0915
Epoch: 1500 loss_train: 0.0570
Epoch: 2000 loss_train: 0.0336
Epoch: 2500 loss_train: 0.0544
Epoch: 3000 loss_train: 0.0239
Epoch: 3500 loss_train: 0.0316
Epoch: 4000 loss_train: 0.0349
Epoch: 4500 loss_train: 0.0311
Epoch: 5000 loss_train: 0.0143
Epoch: 5500 loss_train: 0.0152
 total time: 7.6889s
52.4119873046875
Epoch: 0000 loss_train: 1.9549
Epoch: 0

Epoch: 4000 loss_train: 0.0331
Epoch: 4500 loss_train: 0.0252
Epoch: 5000 loss_train: 0.0269
Epoch: 5500 loss_train: 0.0463
 total time: 7.6957s
63.99615478515625
Epoch: 0000 loss_train: 1.9524
Epoch: 0500 loss_train: 0.2265
Epoch: 1000 loss_train: 0.0928
Epoch: 1500 loss_train: 0.0513
Epoch: 2000 loss_train: 0.0607
Epoch: 2500 loss_train: 0.0417
Epoch: 3000 loss_train: 0.0507
Epoch: 3500 loss_train: 0.0179
Epoch: 4000 loss_train: 0.0215
Epoch: 4500 loss_train: 0.0494
Epoch: 5000 loss_train: 0.0216
Epoch: 5500 loss_train: 0.0263
 total time: 7.6100s
0.000710592488758266
Epoch: 0000 loss_train: 1.9512
Epoch: 0500 loss_train: 0.2466
Epoch: 1000 loss_train: 0.0719
Epoch: 1500 loss_train: 0.0612
Epoch: 2000 loss_train: 0.0615
Epoch: 2500 loss_train: 0.0401
Epoch: 3000 loss_train: 0.0315
Epoch: 3500 loss_train: 0.0297
Epoch: 4000 loss_train: 0.0192
Epoch: 4500 loss_train: 0.0362
Epoch: 5000 loss_train: 0.0481
Epoch: 5500 loss_train: 0.0116
 total time: 7.7179s
7.629103492945433e-05
Epoch: 0

Epoch: 3500 loss_train: 0.0319
Epoch: 4000 loss_train: 0.0266
Epoch: 4500 loss_train: 0.0439
Epoch: 5000 loss_train: 0.0268
Epoch: 5500 loss_train: 0.0141
 total time: 7.7561s
0.0008172033121809363
Epoch: 0000 loss_train: 1.9529
Epoch: 0500 loss_train: 0.2147
Epoch: 1000 loss_train: 0.0716
Epoch: 1500 loss_train: 0.0541
Epoch: 2000 loss_train: 0.0427
Epoch: 2500 loss_train: 0.0374
Epoch: 3000 loss_train: 0.0345
Epoch: 3500 loss_train: 0.0239
Epoch: 4000 loss_train: 0.0308
Epoch: 4500 loss_train: 0.0170
Epoch: 5000 loss_train: 0.0242
Epoch: 5500 loss_train: 0.0140
 total time: 7.7301s
0.07231848686933517
Epoch: 0000 loss_train: 1.9528
Epoch: 0500 loss_train: 0.2295
Epoch: 1000 loss_train: 0.0843
Epoch: 1500 loss_train: 0.0772
Epoch: 2000 loss_train: 0.0450
Epoch: 2500 loss_train: 0.0566
Epoch: 3000 loss_train: 0.0304
Epoch: 3500 loss_train: 0.0239
Epoch: 4000 loss_train: 0.0251
Epoch: 4500 loss_train: 0.0247
Epoch: 5000 loss_train: 0.0192
Epoch: 5500 loss_train: 0.0175
 total time: 7.71

Epoch: 0500 loss_train: 0.2282
Epoch: 1000 loss_train: 0.0758
Epoch: 1500 loss_train: 0.0592
Epoch: 2000 loss_train: 0.0571
Epoch: 2500 loss_train: 0.0387
Epoch: 3000 loss_train: 0.0249
Epoch: 3500 loss_train: 0.0251
Epoch: 4000 loss_train: 0.0259
Epoch: 4500 loss_train: 0.0379
Epoch: 5000 loss_train: 0.0310
Epoch: 5500 loss_train: 0.0172
 total time: 7.7279s
5.501033306121826
Epoch: 0000 loss_train: 1.9528
Epoch: 0500 loss_train: 0.2475
Epoch: 1000 loss_train: 0.0941
Epoch: 1500 loss_train: 0.0678
Epoch: 2000 loss_train: 0.0430
Epoch: 2500 loss_train: 0.0336
Epoch: 3000 loss_train: 0.0252
Epoch: 3500 loss_train: 0.0379
Epoch: 4000 loss_train: 0.0231
Epoch: 4500 loss_train: 0.0160
Epoch: 5000 loss_train: 0.0214
Epoch: 5500 loss_train: 0.0429
 total time: 7.7205s
0.0
Epoch: 0000 loss_train: 1.9528
Epoch: 0500 loss_train: 0.2291
Epoch: 1000 loss_train: 0.0749
Epoch: 1500 loss_train: 0.0596
Epoch: 2000 loss_train: 0.0452
Epoch: 2500 loss_train: 0.0322
Epoch: 3000 loss_train: 0.0315
Epoch:

Epoch: 2500 loss_train: 0.0483
Epoch: 3000 loss_train: 0.0500
Epoch: 3500 loss_train: 0.0452
Epoch: 4000 loss_train: 0.0175
Epoch: 4500 loss_train: 0.0413
Epoch: 5000 loss_train: 0.0333
Epoch: 5500 loss_train: 0.0312
 total time: 7.7256s
0.0
Epoch: 0000 loss_train: 1.9536
Epoch: 0500 loss_train: 0.2088
Epoch: 1000 loss_train: 0.0991
Epoch: 1500 loss_train: 0.0646
Epoch: 2000 loss_train: 0.0585
Epoch: 2500 loss_train: 0.0753
Epoch: 3000 loss_train: 0.0737
Epoch: 3500 loss_train: 0.0204
Epoch: 4000 loss_train: 0.0347
Epoch: 4500 loss_train: 0.0303
Epoch: 5000 loss_train: 0.0153
Epoch: 5500 loss_train: 0.0358
 total time: 7.7418s
1.1920927533992653e-07
Epoch: 0000 loss_train: 1.9536
Epoch: 0500 loss_train: 0.1933
Epoch: 1000 loss_train: 0.1219
Epoch: 1500 loss_train: 0.0522
Epoch: 2000 loss_train: 0.0728
Epoch: 2500 loss_train: 0.0672
Epoch: 3000 loss_train: 0.0774
Epoch: 3500 loss_train: 0.0353
Epoch: 4000 loss_train: 0.0289
Epoch: 4500 loss_train: 0.0347
Epoch: 5000 loss_train: 0.0372
E

Epoch: 5500 loss_train: 0.0280
 total time: 7.6253s
0.0
Epoch: 0000 loss_train: 1.9535
Epoch: 0500 loss_train: 0.1797
Epoch: 1000 loss_train: 0.1042
Epoch: 1500 loss_train: 0.0719
Epoch: 2000 loss_train: 0.0734
Epoch: 2500 loss_train: 0.0431
Epoch: 3000 loss_train: 0.0651
Epoch: 3500 loss_train: 0.0320
Epoch: 4000 loss_train: 0.0350
Epoch: 4500 loss_train: 0.0274
Epoch: 5000 loss_train: 0.0190
Epoch: 5500 loss_train: 0.0369
 total time: 7.7583s
0.0
Epoch: 0000 loss_train: 1.9535
Epoch: 0500 loss_train: 0.1802
Epoch: 1000 loss_train: 0.1229
Epoch: 1500 loss_train: 0.0610
Epoch: 2000 loss_train: 0.0732
Epoch: 2500 loss_train: 0.0544
Epoch: 3000 loss_train: 0.0731
Epoch: 3500 loss_train: 0.0204
Epoch: 4000 loss_train: 0.0333
Epoch: 4500 loss_train: 0.0446
Epoch: 5000 loss_train: 0.0192
Epoch: 5500 loss_train: 0.0139
 total time: 7.6969s
0.0
Epoch: 0000 loss_train: 1.9534
Epoch: 0500 loss_train: 0.1850
Epoch: 1000 loss_train: 0.1144
Epoch: 1500 loss_train: 0.0415
Epoch: 2000 loss_train: 0.

Epoch: 2000 loss_train: 0.0349
Epoch: 2500 loss_train: 0.0607
Epoch: 3000 loss_train: 0.0499
Epoch: 3500 loss_train: 0.0384
Epoch: 4000 loss_train: 0.0490
Epoch: 4500 loss_train: 0.0662
Epoch: 5000 loss_train: 0.0374
Epoch: 5500 loss_train: 0.0146
 total time: 7.7996s
0.0006317288498394191
Epoch: 0000 loss_train: 1.9521
Epoch: 0500 loss_train: 0.2089
Epoch: 1000 loss_train: 0.0830
Epoch: 1500 loss_train: 0.0515
Epoch: 2000 loss_train: 0.0445
Epoch: 2500 loss_train: 0.0544
Epoch: 3000 loss_train: 0.0589
Epoch: 3500 loss_train: 0.0359
Epoch: 4000 loss_train: 0.0513
Epoch: 4500 loss_train: 0.0264
Epoch: 5000 loss_train: 0.0245
Epoch: 5500 loss_train: 0.0232
 total time: 7.8673s
30.80598258972168
Epoch: 0000 loss_train: 1.9521
Epoch: 0500 loss_train: 0.2218
Epoch: 1000 loss_train: 0.0859
Epoch: 1500 loss_train: 0.0480
Epoch: 2000 loss_train: 0.0374
Epoch: 2500 loss_train: 0.0669
Epoch: 3000 loss_train: 0.0355
Epoch: 3500 loss_train: 0.0359
Epoch: 4000 loss_train: 0.0463
Epoch: 4500 loss_tr

Epoch: 2500 loss_train: 0.0499
Epoch: 3000 loss_train: 0.0434
Epoch: 3500 loss_train: 0.0403
Epoch: 4000 loss_train: 0.0317
Epoch: 4500 loss_train: 0.0358
Epoch: 5000 loss_train: 0.0175
Epoch: 5500 loss_train: 0.0174
 total time: 8.1986s
7.152555099310121e-07
Epoch: 0000 loss_train: 1.9510
Epoch: 0500 loss_train: 0.2025
Epoch: 1000 loss_train: 0.1043
Epoch: 1500 loss_train: 0.0475
Epoch: 2000 loss_train: 0.0599
Epoch: 2500 loss_train: 0.0564
Epoch: 3000 loss_train: 0.0404
Epoch: 3500 loss_train: 0.0263
Epoch: 4000 loss_train: 0.0288
Epoch: 4500 loss_train: 0.0210
Epoch: 5000 loss_train: 0.0222
Epoch: 5500 loss_train: 0.0212
 total time: 8.2148s
1.1920927533992653e-07
Epoch: 0000 loss_train: 1.9510
Epoch: 0500 loss_train: 0.2009
Epoch: 1000 loss_train: 0.1009
Epoch: 1500 loss_train: 0.0468
Epoch: 2000 loss_train: 0.0739
Epoch: 2500 loss_train: 0.0562
Epoch: 3000 loss_train: 0.0545
Epoch: 3500 loss_train: 0.0547
Epoch: 4000 loss_train: 0.0381
Epoch: 4500 loss_train: 0.0289
Epoch: 5000 lo

Epoch: 1500 loss_train: 0.0615
Epoch: 2000 loss_train: 0.0525
Epoch: 2500 loss_train: 0.0397
Epoch: 3000 loss_train: 0.0340
Epoch: 3500 loss_train: 0.0334
Epoch: 4000 loss_train: 0.0272
Epoch: 4500 loss_train: 0.0327
Epoch: 5000 loss_train: 0.0294
Epoch: 5500 loss_train: 0.0126
 total time: 7.7699s
8.088946342468262
Epoch: 0000 loss_train: 1.9504
Epoch: 0500 loss_train: 0.1887
Epoch: 1000 loss_train: 0.0794
Epoch: 1500 loss_train: 0.0515
Epoch: 2000 loss_train: 0.0527
Epoch: 2500 loss_train: 0.0613
Epoch: 3000 loss_train: 0.0402
Epoch: 3500 loss_train: 0.0343
Epoch: 4000 loss_train: 0.0362
Epoch: 4500 loss_train: 0.0293
Epoch: 5000 loss_train: 0.0216
Epoch: 5500 loss_train: 0.0364
 total time: 7.7695s
8.344646857949556e-07
Epoch: 0000 loss_train: 1.9504
Epoch: 0500 loss_train: 0.1951
Epoch: 1000 loss_train: 0.0816
Epoch: 1500 loss_train: 0.0644
Epoch: 2000 loss_train: 0.0507
Epoch: 2500 loss_train: 0.0483
Epoch: 3000 loss_train: 0.0416
Epoch: 3500 loss_train: 0.0567
Epoch: 4000 loss_tr

Epoch: 3500 loss_train: 0.0573
Epoch: 4000 loss_train: 0.0283
Epoch: 4500 loss_train: 0.0336
Epoch: 5000 loss_train: 0.0399
Epoch: 5500 loss_train: 0.0288
 total time: 7.6580s
0.0
Epoch: 0000 loss_train: 1.9506
Epoch: 0500 loss_train: 0.1833
Epoch: 1000 loss_train: 0.0956
Epoch: 1500 loss_train: 0.0676
Epoch: 2000 loss_train: 0.0405
Epoch: 2500 loss_train: 0.0503
Epoch: 3000 loss_train: 0.0342
Epoch: 3500 loss_train: 0.0387
Epoch: 4000 loss_train: 0.0521
Epoch: 4500 loss_train: 0.0271
Epoch: 5000 loss_train: 0.0181
Epoch: 5500 loss_train: 0.0302
 total time: 7.7698s
0.0
Epoch: 0000 loss_train: 1.9506
Epoch: 0500 loss_train: 0.1986
Epoch: 1000 loss_train: 0.0938
Epoch: 1500 loss_train: 0.0473
Epoch: 2000 loss_train: 0.0390
Epoch: 2500 loss_train: 0.0442
Epoch: 3000 loss_train: 0.0786
Epoch: 3500 loss_train: 0.0365
Epoch: 4000 loss_train: 0.0396
Epoch: 4500 loss_train: 0.0232
Epoch: 5000 loss_train: 0.0271
Epoch: 5500 loss_train: 0.0363
 total time: 7.8638s
0.0
Epoch: 0000 loss_train: 1.

Epoch: 0500 loss_train: 0.1926
Epoch: 1000 loss_train: 0.0974
Epoch: 1500 loss_train: 0.0558
Epoch: 2000 loss_train: 0.0477
Epoch: 2500 loss_train: 0.0554
Epoch: 3000 loss_train: 0.0516
Epoch: 3500 loss_train: 0.0412
Epoch: 4000 loss_train: 0.0337
Epoch: 4500 loss_train: 0.0216
Epoch: 5000 loss_train: 0.0168
Epoch: 5500 loss_train: 0.0211
 total time: 7.7731s
0.0
Epoch: 0000 loss_train: 1.9509
Epoch: 0500 loss_train: 0.1862
Epoch: 1000 loss_train: 0.1008
Epoch: 1500 loss_train: 0.0431
Epoch: 2000 loss_train: 0.0505
Epoch: 2500 loss_train: 0.0475
Epoch: 3000 loss_train: 0.0292
Epoch: 3500 loss_train: 0.0463
Epoch: 4000 loss_train: 0.0405
Epoch: 4500 loss_train: 0.0190
Epoch: 5000 loss_train: 0.0261
Epoch: 5500 loss_train: 0.0232
 total time: 7.7937s
0.0
Epoch: 0000 loss_train: 1.9509
Epoch: 0500 loss_train: 0.1977
Epoch: 1000 loss_train: 0.0767
Epoch: 1500 loss_train: 0.0520
Epoch: 2000 loss_train: 0.0454
Epoch: 2500 loss_train: 0.0606
Epoch: 3000 loss_train: 0.0358
Epoch: 3500 loss_tra

Epoch: 4000 loss_train: 0.0358
Epoch: 4500 loss_train: 0.0462
Epoch: 5000 loss_train: 0.0191
Epoch: 5500 loss_train: 0.0157
 total time: 7.6857s
0.0
Epoch: 0000 loss_train: 1.9512
Epoch: 0500 loss_train: 0.1903
Epoch: 1000 loss_train: 0.1039
Epoch: 1500 loss_train: 0.0429
Epoch: 2000 loss_train: 0.0472
Epoch: 2500 loss_train: 0.0540
Epoch: 3000 loss_train: 0.0838
Epoch: 3500 loss_train: 0.0520
Epoch: 4000 loss_train: 0.0343
Epoch: 4500 loss_train: 0.0228
Epoch: 5000 loss_train: 0.0160
Epoch: 5500 loss_train: 0.0319
 total time: 7.6773s
0.0
Epoch: 0000 loss_train: 1.9512
Epoch: 0500 loss_train: 0.1858
Epoch: 1000 loss_train: 0.1065
Epoch: 1500 loss_train: 0.0435
Epoch: 2000 loss_train: 0.0503
Epoch: 2500 loss_train: 0.0448
Epoch: 3000 loss_train: 0.0298
Epoch: 3500 loss_train: 0.0440
Epoch: 4000 loss_train: 0.0336
Epoch: 4500 loss_train: 0.0371
Epoch: 5000 loss_train: 0.0176
Epoch: 5500 loss_train: 0.0282
 total time: 7.6931s
0.0
Epoch: 0000 loss_train: 1.9512
Epoch: 0500 loss_train: 0.

Epoch: 1500 loss_train: 0.0516
Epoch: 2000 loss_train: 0.0501
Epoch: 2500 loss_train: 0.0436
Epoch: 3000 loss_train: 0.0452
Epoch: 3500 loss_train: 0.0374
Epoch: 4000 loss_train: 0.0458
Epoch: 4500 loss_train: 0.0192
Epoch: 5000 loss_train: 0.0149
Epoch: 5500 loss_train: 0.0242
 total time: 7.6009s
0.0
Epoch: 0000 loss_train: 1.9513
Epoch: 0500 loss_train: 0.1906
Epoch: 1000 loss_train: 0.1112
Epoch: 1500 loss_train: 0.0608
Epoch: 2000 loss_train: 0.0536
Epoch: 2500 loss_train: 0.0723
Epoch: 3000 loss_train: 0.0543
Epoch: 3500 loss_train: 0.0412
Epoch: 4000 loss_train: 0.0437
Epoch: 4500 loss_train: 0.0255
Epoch: 5000 loss_train: 0.0153
Epoch: 5500 loss_train: 0.0326
 total time: 7.6525s
0.0
Epoch: 0000 loss_train: 1.9513
Epoch: 0500 loss_train: 0.1912
Epoch: 1000 loss_train: 0.1187
Epoch: 1500 loss_train: 0.0531
Epoch: 2000 loss_train: 0.0541
Epoch: 2500 loss_train: 0.0615
Epoch: 3000 loss_train: 0.0496
Epoch: 3500 loss_train: 0.0387
Epoch: 4000 loss_train: 0.0413
Epoch: 4500 loss_tra

Epoch: 5000 loss_train: 0.0100
Epoch: 5500 loss_train: 0.0266
 total time: 7.8047s
0.0
Epoch: 0000 loss_train: 1.9514
Epoch: 0500 loss_train: 0.1948
Epoch: 1000 loss_train: 0.1111
Epoch: 1500 loss_train: 0.0481
Epoch: 2000 loss_train: 0.0506
Epoch: 2500 loss_train: 0.0671
Epoch: 3000 loss_train: 0.0464
Epoch: 3500 loss_train: 0.0443
Epoch: 4000 loss_train: 0.0427
Epoch: 4500 loss_train: 0.0649
Epoch: 5000 loss_train: 0.0202
Epoch: 5500 loss_train: 0.0137
 total time: 7.6698s
0.0
Epoch: 0000 loss_train: 1.9514
Epoch: 0500 loss_train: 0.1889
Epoch: 1000 loss_train: 0.1128
Epoch: 1500 loss_train: 0.0536
Epoch: 2000 loss_train: 0.0353
Epoch: 2500 loss_train: 0.0604
Epoch: 3000 loss_train: 0.0407
Epoch: 3500 loss_train: 0.0318
Epoch: 4000 loss_train: 0.0484
Epoch: 4500 loss_train: 0.0190
Epoch: 5000 loss_train: 0.0243
Epoch: 5500 loss_train: 0.0134
 total time: 7.7709s
0.0
Epoch: 0000 loss_train: 1.9514
Epoch: 0500 loss_train: 0.1936
Epoch: 1000 loss_train: 0.1103
Epoch: 1500 loss_train: 0.

Epoch: 0500 loss_train: 0.1954
Epoch: 1000 loss_train: 0.1146
Epoch: 1500 loss_train: 0.0529
Epoch: 2000 loss_train: 0.0613
Epoch: 2500 loss_train: 0.0805
Epoch: 3000 loss_train: 0.0652
Epoch: 3500 loss_train: 0.0527
Epoch: 4000 loss_train: 0.0439
Epoch: 4500 loss_train: 0.0348
Epoch: 5000 loss_train: 0.0365
Epoch: 5500 loss_train: 0.0123
 total time: 7.7153s
0.0
Epoch: 0000 loss_train: 1.9515
Epoch: 0500 loss_train: 0.1902
Epoch: 1000 loss_train: 0.1054
Epoch: 1500 loss_train: 0.0503
Epoch: 2000 loss_train: 0.0499
Epoch: 2500 loss_train: 0.0584
Epoch: 3000 loss_train: 0.0704
Epoch: 3500 loss_train: 0.0885
Epoch: 4000 loss_train: 0.0468
Epoch: 4500 loss_train: 0.0250
Epoch: 5000 loss_train: 0.0211
Epoch: 5500 loss_train: 0.0342
 total time: 7.7132s
12.86246109008789
Epoch: 0000 loss_train: 1.9516
Epoch: 0500 loss_train: 0.1917
Epoch: 1000 loss_train: 0.1192
Epoch: 1500 loss_train: 0.0402
Epoch: 2000 loss_train: 0.0499
Epoch: 2500 loss_train: 0.0469
Epoch: 3000 loss_train: 0.0221
Epoch:

Epoch: 0500 loss_train: 0.2191
Epoch: 1000 loss_train: 0.0991
Epoch: 1500 loss_train: 0.0563
Epoch: 2000 loss_train: 0.0374
Epoch: 2500 loss_train: 0.0370
Epoch: 3000 loss_train: 0.0701
Epoch: 3500 loss_train: 0.0317
Epoch: 4000 loss_train: 0.0239
Epoch: 4500 loss_train: 0.0134
Epoch: 5000 loss_train: 0.0223
Epoch: 5500 loss_train: 0.0223
 total time: 7.7474s
9.179073458653875e-06
Epoch: 0000 loss_train: 1.9516
Epoch: 0500 loss_train: 0.2162
Epoch: 1000 loss_train: 0.1078
Epoch: 1500 loss_train: 0.0551
Epoch: 2000 loss_train: 0.0507
Epoch: 2500 loss_train: 0.0450
Epoch: 3000 loss_train: 0.0633
Epoch: 3500 loss_train: 0.0453
Epoch: 4000 loss_train: 0.0272
Epoch: 4500 loss_train: 0.0330
Epoch: 5000 loss_train: 0.0215
Epoch: 5500 loss_train: 0.0289
 total time: 7.6638s
8.702239938429557e-06
Epoch: 0000 loss_train: 1.9514
Epoch: 0500 loss_train: 0.2158
Epoch: 1000 loss_train: 0.1021
Epoch: 1500 loss_train: 0.0407
Epoch: 2000 loss_train: 0.0657
Epoch: 2500 loss_train: 0.0326
Epoch: 3000 los

Epoch: 0500 loss_train: 0.2240
Epoch: 1000 loss_train: 0.0926
Epoch: 1500 loss_train: 0.0366
Epoch: 2000 loss_train: 0.0629
Epoch: 2500 loss_train: 0.0587
Epoch: 3000 loss_train: 0.0248
Epoch: 3500 loss_train: 0.0417
Epoch: 4000 loss_train: 0.0206
Epoch: 4500 loss_train: 0.0293
Epoch: 5000 loss_train: 0.0130
Epoch: 5500 loss_train: 0.0138
 total time: 7.7312s
0.0002294515579706058
Epoch: 0000 loss_train: 1.9543
Epoch: 0500 loss_train: 0.2166
Epoch: 1000 loss_train: 0.0967
Epoch: 1500 loss_train: 0.0557
Epoch: 2000 loss_train: 0.0416
Epoch: 2500 loss_train: 0.0474
Epoch: 3000 loss_train: 0.0365
Epoch: 3500 loss_train: 0.0441
Epoch: 4000 loss_train: 0.0118
Epoch: 4500 loss_train: 0.0305
Epoch: 5000 loss_train: 0.0253
Epoch: 5500 loss_train: 0.0123
 total time: 7.7044s
0.0002658013836480677
Epoch: 0000 loss_train: 1.9542
Epoch: 0500 loss_train: 0.2274
Epoch: 1000 loss_train: 0.0889
Epoch: 1500 loss_train: 0.0495
Epoch: 2000 loss_train: 0.0613
Epoch: 2500 loss_train: 0.0426
Epoch: 3000 los

Epoch: 0500 loss_train: 0.2173
Epoch: 1000 loss_train: 0.1053
Epoch: 1500 loss_train: 0.0569
Epoch: 2000 loss_train: 0.0594
Epoch: 2500 loss_train: 0.0382
Epoch: 3000 loss_train: 0.0380
Epoch: 3500 loss_train: 0.0453
Epoch: 4000 loss_train: 0.0342
Epoch: 4500 loss_train: 0.0339
Epoch: 5000 loss_train: 0.0238
Epoch: 5500 loss_train: 0.0148
 total time: 7.4958s
0.002246477175503969
Epoch: 0000 loss_train: 1.9541
Epoch: 0500 loss_train: 0.2202
Epoch: 1000 loss_train: 0.1182
Epoch: 1500 loss_train: 0.0683
Epoch: 2000 loss_train: 0.0465
Epoch: 2500 loss_train: 0.0420
Epoch: 3000 loss_train: 0.0182
Epoch: 3500 loss_train: 0.0520
Epoch: 4000 loss_train: 0.0762
Epoch: 4500 loss_train: 0.0091
Epoch: 5000 loss_train: 0.0170
Epoch: 5500 loss_train: 0.0512
 total time: 7.5189s
0.036769527941942215
Epoch: 0000 loss_train: 1.9536
Epoch: 0500 loss_train: 0.2102
Epoch: 1000 loss_train: 0.1124
Epoch: 1500 loss_train: 0.0690
Epoch: 2000 loss_train: 0.0487
Epoch: 2500 loss_train: 0.0481
Epoch: 3000 loss_

Epoch: 0500 loss_train: 0.2277
Epoch: 1000 loss_train: 0.1158
Epoch: 1500 loss_train: 0.0582
Epoch: 2000 loss_train: 0.0687
Epoch: 2500 loss_train: 0.0590
Epoch: 3000 loss_train: 0.0194
Epoch: 3500 loss_train: 0.0271
Epoch: 4000 loss_train: 0.0352
Epoch: 4500 loss_train: 0.0217
Epoch: 5000 loss_train: 0.0322
Epoch: 5500 loss_train: 0.0131
 total time: 7.5437s
0.0004015354788862169
Epoch: 0000 loss_train: 1.9544
Epoch: 0500 loss_train: 0.2160
Epoch: 1000 loss_train: 0.0936
Epoch: 1500 loss_train: 0.0629
Epoch: 2000 loss_train: 0.0758
Epoch: 2500 loss_train: 0.0342
Epoch: 3000 loss_train: 0.0251
Epoch: 3500 loss_train: 0.0501
Epoch: 4000 loss_train: 0.0316
Epoch: 4500 loss_train: 0.0224
Epoch: 5000 loss_train: 0.0175
Epoch: 5500 loss_train: 0.0159
 total time: 7.5291s
0.0007277462864294648
Epoch: 0000 loss_train: 1.9561
Epoch: 0500 loss_train: 0.2103
Epoch: 1000 loss_train: 0.1081
Epoch: 1500 loss_train: 0.0646
Epoch: 2000 loss_train: 0.0608
Epoch: 2500 loss_train: 0.0462
Epoch: 3000 los

Epoch: 3000 loss_train: 0.0297
Epoch: 3500 loss_train: 0.0237
Epoch: 4000 loss_train: 0.0273
Epoch: 4500 loss_train: 0.0132
Epoch: 5000 loss_train: 0.0199
Epoch: 5500 loss_train: 0.0251
 total time: 7.4742s
3.6954195820726454e-05
Epoch: 0000 loss_train: 1.9564
Epoch: 0500 loss_train: 0.2147
Epoch: 1000 loss_train: 0.1048
Epoch: 1500 loss_train: 0.0479
Epoch: 2000 loss_train: 0.0700
Epoch: 2500 loss_train: 0.0401
Epoch: 3000 loss_train: 0.0356
Epoch: 3500 loss_train: 0.0332
Epoch: 4000 loss_train: 0.0292
Epoch: 4500 loss_train: 0.0278
Epoch: 5000 loss_train: 0.0133
Epoch: 5500 loss_train: 0.0194
 total time: 7.5114s
0.0
Epoch: 0000 loss_train: 1.9564
Epoch: 0500 loss_train: 0.2097
Epoch: 1000 loss_train: 0.1104
Epoch: 1500 loss_train: 0.0522
Epoch: 2000 loss_train: 0.0587
Epoch: 2500 loss_train: 0.0511
Epoch: 3000 loss_train: 0.0171
Epoch: 3500 loss_train: 0.0518
Epoch: 4000 loss_train: 0.0371
Epoch: 4500 loss_train: 0.0140
Epoch: 5000 loss_train: 0.0135
Epoch: 5500 loss_train: 0.0203
 

 total time: 7.5200s
0.0
Epoch: 0000 loss_train: 1.9566
Epoch: 0500 loss_train: 0.2108
Epoch: 1000 loss_train: 0.1122
Epoch: 1500 loss_train: 0.0588
Epoch: 2000 loss_train: 0.0929
Epoch: 2500 loss_train: 0.0719
Epoch: 3000 loss_train: 0.0314
Epoch: 3500 loss_train: 0.0583
Epoch: 4000 loss_train: 0.0274
Epoch: 4500 loss_train: 0.0124
Epoch: 5000 loss_train: 0.0252
Epoch: 5500 loss_train: 0.0248
 total time: 7.5146s
0.0
Epoch: 0000 loss_train: 1.9566
Epoch: 0500 loss_train: 0.2068
Epoch: 1000 loss_train: 0.1064
Epoch: 1500 loss_train: 0.0540
Epoch: 2000 loss_train: 0.0755
Epoch: 2500 loss_train: 0.0447
Epoch: 3000 loss_train: 0.0313
Epoch: 3500 loss_train: 0.0407
Epoch: 4000 loss_train: 0.0262
Epoch: 4500 loss_train: 0.0238
Epoch: 5000 loss_train: 0.0295
Epoch: 5500 loss_train: 0.0129
 total time: 7.4887s
0.0
Epoch: 0000 loss_train: 1.9566
Epoch: 0500 loss_train: 0.2142
Epoch: 1000 loss_train: 0.0986
Epoch: 1500 loss_train: 0.0630
Epoch: 2000 loss_train: 0.0839
Epoch: 2500 loss_train: 0.

Epoch: 3500 loss_train: 0.0462
Epoch: 4000 loss_train: 0.0214
Epoch: 4500 loss_train: 0.0227
Epoch: 5000 loss_train: 0.0171
Epoch: 5500 loss_train: 0.0235
 total time: 7.5178s
0.0
Epoch: 0000 loss_train: 1.9567
Epoch: 0500 loss_train: 0.2102
Epoch: 1000 loss_train: 0.1120
Epoch: 1500 loss_train: 0.0670
Epoch: 2000 loss_train: 0.0757
Epoch: 2500 loss_train: 0.0398
Epoch: 3000 loss_train: 0.0288
Epoch: 3500 loss_train: 0.0607
Epoch: 4000 loss_train: 0.0428
Epoch: 4500 loss_train: 0.0164
Epoch: 5000 loss_train: 0.0196
Epoch: 5500 loss_train: 0.0179
 total time: 7.4419s
0.0
Epoch: 0000 loss_train: 1.9568
Epoch: 0500 loss_train: 0.2194
Epoch: 1000 loss_train: 0.1054
Epoch: 1500 loss_train: 0.0632
Epoch: 2000 loss_train: 0.0828
Epoch: 2500 loss_train: 0.0522
Epoch: 3000 loss_train: 0.0492
Epoch: 3500 loss_train: 0.0395
Epoch: 4000 loss_train: 0.0776
Epoch: 4500 loss_train: 0.0216
Epoch: 5000 loss_train: 0.0125
Epoch: 5500 loss_train: 0.0144
 total time: 7.4691s
0.0
Epoch: 0000 loss_train: 1.

In [39]:
with open(os.path.join(logsavepath,'crossVal_loss'), 'wb') as output:
    pickle.dump(predtest, output, pickle.HIGHEST_PROTOCOL)

In [40]:
predtest_label=np.argmax(predtest,axis=1)

In [41]:
res=pd.DataFrame({'sampleName':imgNamesAll,'true':progUnique[labelsAll.cpu().numpy()],'predicted':progUnique[predtest_label]})
res.to_csv(os.path.join(plotsavepath,'predictions.csv'))

In [42]:
#plot confusion
def plotCTcomp(labels,ctlist,savepath,savenamecluster,byCT,addname='',order=progInclude):
    res=np.zeros((order.size,order.size))
    for li in range(res.shape[0]):
        l=order[li]
        nl=np.sum(labels==l)
        ctlist_l=ctlist[labels==l]
        for ci in range(res.shape[1]):
            c=order[ci]
            res[li,ci]=np.sum(ctlist_l==c)
#             res[li,ci]=np.sum(ctlist_l==c)/nl
    if not byCT:
        addname+=''
        for li in range(res.shape[0]):
            l=order[li]
            nl=np.sum(labels==l)
            res[li]=res[li]/nl
    else:
        addname+='_normbyCT'
        for ci in range(res.shape[1]):
            c=order[ci]
            nc=np.sum(ctlist==c)
            res[:,ci]=res[:,ci]/nc
    
    fig, ax = plt.subplots(figsize=(10, 10))
    im = ax.imshow(res,cmap='binary',vmin=0,vmax=1)
    fig.colorbar(im)
    ax.set_yticks(np.arange(order.size))
    ax.set_yticklabels(order)
    ax.set_xticks(np.arange(order.size))
    ax.set_xticklabels(order)
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",rotation_mode="anchor")
    fig.tight_layout()
    plt.savefig(os.path.join(savepath,savenamecluster+addname+'.pdf'))
    plt.close()
    
plotCTcomp(progUnique[labelsAll.cpu().numpy()],progUnique[predtest_label],plotsavepath,'confusion'+str(testepoch),False)