In [1]:
import time
import os

import scanpy
import numpy as np
import scipy.sparse as sp

import torch
from torch import optim
from torch.utils.data import DataLoader

import models.loadImg as loadImg
import models.modelsCNN as modelsCNN
import models.optimizer as optimizer

import pickle
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import gc
from skimage import io
import scipy.stats
from sklearn.metrics import pairwise_distances

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = '3'
use_cuda=True
datadir='/media/xinyi/dcis2idc1/data'
name='exp0'
plotsavepath='/media/xinyi/dcis2idc1/plots/cnnvae'+name
sampledir=plotsavepath
clustersavedir_alltrain=os.path.join(sampledir,'cluster_alltrain_reordered')
ep=311
with open(os.path.join(datadir,'processed','train_cnnvae_names'), 'rb') as input:
    allImgNames=pickle.load(input)
#plot by disease progression
br1003aSpecs=pd.read_excel('/media/xinyi/dcis2idc1/data/BR1003a specs.xlsx',header=10)
br301Specs=pd.read_excel('/media/xinyi/dcis2idc1/data/BR301 specs.xlsx',header=10)
br8018aSpecs=pd.read_excel('/media/xinyi/dcis2idc1/data/BR8018a specs.xlsx',header=10)
br1003aSpecs.index=br1003aSpecs.loc[:,'Position']
br301Specs.index=br301Specs.loc[:,'Position']
br8018aSpecs.index=br8018aSpecs.loc[:,'Position']

progList=np.copy(allImgNames)
for s in np.unique(allImgNames):
    ssplit=s.split('_')
    if 'br1003a'==ssplit[0]:
        prog_s=br1003aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br301'==ssplit[0]:
        prog_s=br301Specs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br8018a'==ssplit[0]:
        prog_s=br8018aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    progList[allImgNames==s]=prog_s
    




In [4]:
savenamesample='alltrain'

neworder=[1, 5, 3, 7, 2, 0, 4, 6]
#use chosen subcluster number and save plots
scanpy.settings.verbosity = 3
# subcluster=8
subclusterDict={0:[4],1:[6],2:[8],3:[6],4:[6],5:[6],6:[6],7:[4]}
ncluster=8

plotepoch=311
clusterplotdir=os.path.join(clustersavedir_alltrain,'plots')
n_pcs=50
savenamecluster='minibatchkmean_ncluster'+str(ncluster)+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)
with open(os.path.join(clustersavedir_alltrain,'minibatchkmean_ncluster'+str(ncluster)+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+'_all'), 'rb') as output:
    clusterRes=pickle.load(output)

kmeans_sub=(np.zeros(clusterRes.size)-1).astype(str)
savenameAdd='_plottingIdx_progBalanced_'+str(0)
subclusternumbers=[4,6,8,6,6,6,6,4]
savenamecluster='minibatchkmean_ncluster'+str(ncluster)+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+savenameAdd
for cnew in np.unique(clusterRes):
#     print('cluster'+str(c))
    c=neworder[cnew]
    
    subclustersavedir_alltrain=os.path.join(clustersavedir_alltrain,savenamecluster+'_subcluster'+str(c))
    with open(os.path.join(subclustersavedir_alltrain,'minibatchkmean_ncluster'+str(subclusternumbers[c])+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+'_all'), 'rb') as output:
        subclusterRes=pickle.load(output)
    print(np.unique(subclusterRes))
    kmeans_sub[clusterRes==cnew]=np.char.add(np.repeat(str(cnew)+'-',subclusterRes.size),subclusterRes.astype(str))
   

[0 1 2 3 4 5]
[0 1 2 3 4 5]
[0 1 2 3 4 5]
[0 1 2 3]
[0 1 2 3 4 5 6 7]
[0 1 2 3]
[0 1 2 3 4 5]
[0 1 2 3 4 5]


In [5]:
with open(os.path.join(datadir,'processed','train_cnnvae_coord'), 'rb') as output:
    coordlist=pickle.load(output)

In [6]:
for p in np.unique(progList):
    if p=='Ductal carcinoma in situ':
        progList[progList==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ and breast tissue':
        progList[progList==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ with early infiltratio':
        progList[progList==p]='DCIS with early infiltration'
    
    elif p=='Micropapillary type ductal carcinoma in situ wi':
        progList[progList==p]='DCIS with early infiltration'    
#     elif p=='Atypical hyperlasia':
#         progList[progList==p]='Hyperplasia'

In [7]:
progInclude=np.array(['Breast tissue','Cancer adjacent normal breast tissue','Hyperplasia','DCIS and breast tissue',  'DCIS with early infiltration','Invasive ductal carcinoma and breast tissue','Invasive ductal carcinoma'])

In [8]:
progIncludeIdx=np.repeat(False,progList.size)
for p in progInclude:
    progIncludeIdx[progList==p]=True
    
coordlist=coordlist[progIncludeIdx]
allImgNames=allImgNames[progIncludeIdx]
clusterRes=clusterRes[progIncludeIdx]
kmeans_sub=kmeans_sub[progIncludeIdx]
progList=progList[progIncludeIdx]

In [9]:
sUnique,sidx_start=np.unique(allImgNames,return_index=True)
progUnique,labels_train,progCounts=np.unique(progList[sidx_start],return_counts=True,return_inverse=True)
for p in range(progUnique.size):
    print(progUnique[p])
    print(progCounts[p])

Breast tissue
20
Cancer adjacent normal breast tissue
13
DCIS and breast tissue
16
DCIS with early infiltration
30
Hyperplasia
41
Invasive ductal carcinoma
70
Invasive ductal carcinoma and breast tissue
8


In [10]:
def getHistMatrix_clusters(labels,ctlist,nrow=ncluster,ncol=ncluster):
    res=np.zeros((nrow,ncol))
    for li in range(res.shape[0]):
        l=li
        nl=np.sum(labels==l)
        ctlist_l=ctlist[labels==l]
        for ci in range(res.shape[1]):
            c=ci
            res[l,c]=np.sum(ctlist_l==c)
#             res[li,ci]=np.sum(ctlist_l==c)/nl
        if nl!=0:
            res[li]=res[li]/nl
    return res

neighborhoodSize=16*9

In [11]:
#get neighborhood composition

inputNeighborhood=np.zeros((sUnique.size,ncluster*ncluster))
for i in range(sUnique.size):
    imgN=sUnique[i]
    nsamples=np.sum(allImgNames==imgN)
    cluster_i=clusterRes[allImgNames==imgN]
    neighbor_i=np.tile(cluster_i,(nsamples,1))
    self_i=np.tile(cluster_i.reshape((-1,1)),(1,nsamples))

    dist=pairwise_distances(coordlist[allImgNames==imgN],n_jobs=-1)
    distIn=np.logical_and(dist<neighborhoodSize,dist>0)
    res=getHistMatrix_clusters(self_i[distIn],neighbor_i[distIn])
    
    inputNeighborhood[i]=res.flatten()

In [12]:
_,inputCounts=np.unique(allImgNames,return_counts=True)
inputAll_train=np.concatenate((inputNeighborhood,inputCounts.reshape(-1,1)),axis=1)

In [13]:
#val cores (as validation cores) and val samples (as test cores)
clustersavedir_valcores=os.path.join(sampledir,'cluster_valcores_reordered')
clustersavedir_valsamples=os.path.join(sampledir,'cluster_valsamples_reordered')

with open(os.path.join(datadir,'processed','train_cnnvae_coord_valcores'), 'rb') as output:
    coordlist_valcores=pickle.load(output)
with open(os.path.join(datadir,'processed','train_cnnvae_coord_valsamples'), 'rb') as output:
    coordlist_valsamples=pickle.load(output)

savenamecluster='minibatchkmean_ncluster'+str(ncluster)+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)
with open(os.path.join(clustersavedir_valcores,savenamecluster+'_all'), 'rb') as output:
    clusterRes_valcores=pickle.load(output)
with open(os.path.join(clustersavedir_valsamples,'minibatchkmean_ncluster'+str(ncluster)+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+'_all'), 'rb') as output:
    clusterRes_valsamples=pickle.load(output)
    
kmeans_sub_valcores=(np.zeros(clusterRes_valcores.size)-1).astype(str)
for c in np.unique(clusterRes_valcores):
    subclustersavedir=os.path.join(clustersavedir_valcores,savenamecluster+'_plottingIdx_progBalanced_'+str(0)+'_subcluster'+str(neworder[c]))
    with open(os.path.join(subclustersavedir,'minibatchkmean_ncluster'+str(subclusterDict[neworder[c]][0])+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+'_all'), 'rb') as output:
        subclusterRes=pickle.load(output)
    kmeans_sub_valcores[clusterRes_valcores==c]=np.char.add(np.repeat(str(c)+'-',subclusterRes.size),subclusterRes.astype(str))
    
kmeans_sub_valsamples=(np.zeros(clusterRes_valsamples.size)-1).astype(str)
for c in np.unique(clusterRes_valsamples):
    subclustersavedir=os.path.join(clustersavedir_valsamples,savenamecluster+'_plottingIdx_progBalanced_'+str(0)+'_subcluster'+str(neworder[c]))
    with open(os.path.join(subclustersavedir,'minibatchkmean_ncluster'+str(subclusterDict[neworder[c]][0])+'n_pcs'+str(n_pcs)+'epoch'+str(plotepoch)+'_all'), 'rb') as output:
        subclusterRes=pickle.load(output)
    kmeans_sub_valsamples[clusterRes_valsamples==c]=np.char.add(np.repeat(str(c)+'-',subclusterRes.size),subclusterRes.astype(str))
                

In [14]:
with open(os.path.join(datadir,'processed','train_cnnvae_names_valcores'), 'rb') as input:
    allImgNames_valcores=pickle.load(input)
with open(os.path.join(datadir,'processed','train_cnnvae_names_valsamples'), 'rb') as input:
    allImgNames_valsamples=pickle.load(input)

In [16]:
#plot by disease progression
br1003aSpecs=pd.read_excel('/media/xinyi/dcis2idc1/data/BR1003a specs.xlsx',header=10)
br301Specs=pd.read_excel('/media/xinyi/dcis2idc1/data/BR301 specs.xlsx',header=10)
br8018aSpecs=pd.read_excel('/media/xinyi/dcis2idc1/data/BR8018a specs.xlsx',header=10)
br1003aSpecs.index=br1003aSpecs.loc[:,'Position']
br301Specs.index=br301Specs.loc[:,'Position']
br8018aSpecs.index=br8018aSpecs.loc[:,'Position']

progList_valcores=np.copy(allImgNames_valcores)
for s in np.unique(allImgNames_valcores):
    ssplit=s.split('_')
    if 'br1003a'==ssplit[0]:
        prog_s=br1003aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br301'==ssplit[0]:
        prog_s=br301Specs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br8018a'==ssplit[0]:
        prog_s=br8018aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    progList_valcores[allImgNames_valcores==s]=prog_s
    
progList_valsamples=np.copy(allImgNames_valsamples)
for s in np.unique(allImgNames_valsamples):
    ssplit=s.split('_')
    if 'br1003a'==ssplit[0]:
        prog_s=br1003aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br301'==ssplit[0]:
        prog_s=br301Specs.loc[(ssplit[-1],'Pathology diagnosis')]
    elif 'br8018a'==ssplit[0]:
        prog_s=br8018aSpecs.loc[(ssplit[-1],'Pathology diagnosis')]
    progList_valsamples[allImgNames_valsamples==s]=prog_s
    



In [17]:
for p in np.unique(progList_valcores):
    if p=='Ductal carcinoma in situ':
        progList_valcores[progList_valcores==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ and breast tissue':
        progList_valcores[progList_valcores==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ with early infiltratio':
        progList_valcores[progList_valcores==p]='DCIS with early infiltration'
    
    elif p=='Micropapillary type ductal carcinoma in situ wi':
        progList_valcores[progList_valcores==p]='DCIS with early infiltration'    
#     elif p=='Atypical hyperlasia':
#         progList_valcores[progList_valcores==p]='Hyperplasia'

In [18]:
for p in np.unique(progList_valsamples):
    if p=='Ductal carcinoma in situ':
        progList_valsamples[progList_valsamples==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ and breast tissue':
        progList_valsamples[progList_valsamples==p]='DCIS and breast tissue'
    elif p=='Ductal carcinoma in situ with early infiltrati':
        progList_valsamples[progList_valsamples==p]='DCIS with early infiltration'
    
    elif p=='Micropapillary type ductal carcinoma in situ w':
        progList_valsamples[progList_valsamples==p]='DCIS with early infiltration'    
#     elif p=='Atypical hyperlasia':
#         progList_valsamples[progList_valsamples==p]='Hyperplasia'

In [19]:
progIncludeIdx_valcores=np.repeat(False,progList_valcores.size)
for p in progInclude:
    progIncludeIdx_valcores[progList_valcores==p]=True
    
progIncludeIdx_valsamples=np.repeat(False,progList_valsamples.size)
for p in progInclude:
    progIncludeIdx_valsamples[progList_valsamples==p]=True
    
coordlist_valcores=coordlist_valcores[progIncludeIdx_valcores]
allImgNames_valcores=allImgNames_valcores[progIncludeIdx_valcores]
clusterRes_valcores=clusterRes_valcores[progIncludeIdx_valcores]
kmeans_sub_valcores=kmeans_sub_valcores[progIncludeIdx_valcores]
progList_valcores=progList_valcores[progIncludeIdx_valcores]

coordlist_valsamples=coordlist_valsamples[progIncludeIdx_valsamples]
allImgNames_valsamples=allImgNames_valsamples[progIncludeIdx_valsamples]
clusterRes_valsamples=clusterRes_valsamples[progIncludeIdx_valsamples]
kmeans_sub_valsamples=kmeans_sub_valsamples[progIncludeIdx_valsamples]
progList_valsamples=progList_valsamples[progIncludeIdx_valsamples]


In [20]:
sUnique_valcores,sidx_start_valcores=np.unique(allImgNames_valcores,return_index=True)
progUnique_valcores,progCounts_valcores=np.unique(progList_valcores[sidx_start_valcores],return_counts=True)
for p in range(progUnique_valcores.size):
    print(progUnique_valcores[p])
    print(progCounts_valcores[p])

Breast tissue
20
Cancer adjacent normal breast tissue
1
Hyperplasia
35
Invasive ductal carcinoma
97


In [21]:
sUnique_valsamples,sidx_start_valsamples=np.unique(allImgNames_valsamples,return_index=True)
progUnique_valsamples,progCounts_valsamples=np.unique(progList_valsamples[sidx_start_valsamples],return_counts=True)
for p in range(progUnique_valsamples.size):
    print(progUnique_valsamples[p])
    print(progCounts_valsamples[p])

Breast tissue
14
Cancer adjacent normal breast tissue
4
DCIS and breast tissue
16
DCIS with early infiltration
29
Hyperplasia
25
Invasive ductal carcinoma
66
Invasive ductal carcinoma and breast tissue
8


In [22]:
#construct labels
labels_valcores=np.zeros(progList_valcores[sidx_start_valcores].size)
for i in range(progUnique.size):
    labels_valcores[progList_valcores[sidx_start_valcores]==progUnique[i]]=i

In [23]:
#construct labels
labels_valsamples=np.zeros(progList_valsamples[sidx_start_valsamples].size)
for i in range(progUnique.size):
    labels_valsamples[progList_valsamples[sidx_start_valsamples]==progUnique[i]]=i

In [25]:
inputNeighborhood_valcores=np.zeros((sUnique_valcores.size,ncluster*ncluster))
for i in range(sUnique_valcores.size):
    imgN=sUnique_valcores[i]
    nsamples=np.sum(allImgNames_valcores==imgN)
    cluster_i=clusterRes_valcores[allImgNames_valcores==imgN]
    neighbor_i=np.tile(cluster_i,(nsamples,1))
    self_i=np.tile(cluster_i.reshape((-1,1)),(1,nsamples))

    dist=pairwise_distances(coordlist_valcores[allImgNames_valcores==imgN],n_jobs=-1)
    distIn=np.logical_and(dist<neighborhoodSize,dist>0)
    res=getHistMatrix_clusters(self_i[distIn],neighbor_i[distIn])
    
    inputNeighborhood_valcores[i]=res.flatten()

In [26]:
inputNeighborhood_valsamples=np.zeros((sUnique_valsamples.size,ncluster*ncluster))
for i in range(sUnique_valsamples.size):
    imgN=sUnique_valsamples[i]
    nsamples=np.sum(allImgNames_valsamples==imgN)
    cluster_i=clusterRes_valsamples[allImgNames_valsamples==imgN]
    neighbor_i=np.tile(cluster_i,(nsamples,1))
    self_i=np.tile(cluster_i.reshape((-1,1)),(1,nsamples))

    dist=pairwise_distances(coordlist_valsamples[allImgNames_valsamples==imgN],n_jobs=-1)
    distIn=np.logical_and(dist<neighborhoodSize,dist>0)
    res=getHistMatrix_clusters(self_i[distIn],neighbor_i[distIn])
    
    inputNeighborhood_valsamples[i]=res.flatten()

In [27]:
_,inputCounts_valcores=np.unique(allImgNames_valcores,return_counts=True)
inputAll_valcores=np.concatenate((inputNeighborhood_valcores,inputCounts_valcores.reshape(-1,1)),axis=1)

In [28]:
_,inputCounts_valsamples=np.unique(allImgNames_valsamples,return_counts=True)
inputAll_valsamples=np.concatenate((inputNeighborhood_valsamples,inputCounts_valsamples.reshape(-1,1)),axis=1)

In [29]:
#concatenate cores
inputAll=np.concatenate((inputAll_train,np.concatenate((inputAll_valcores,inputAll_valsamples),axis=0)),axis=0)
imgNamesAll=np.concatenate((allImgNames[sidx_start],np.concatenate((allImgNames_valcores[sidx_start_valcores],allImgNames_valsamples[sidx_start_valsamples]))))
labelsAll=np.concatenate((labels_train,np.concatenate((labels_valcores,labels_valsamples))))

In [30]:
_,progCountsAll=np.unique(labelsAll,return_counts=True)
weights_train=np.sum(progCountsAll)/progCountsAll

In [None]:
#img sizes
allImgNamesAll=np.concatenate((allImgNames,np.concatenate((allImgNames_valcores,allImgNames_valsamples))))
progListAll=np.concatenate((progList,np.concatenate((progList_valcores,progList_valsamples))))
sidxAll=np.concatenate((sidx_start,np.concatenate((sidx_start_valcores,sidx_start_valsamples))))
coordlistAll=np.concatenate((coordlist,np.concatenate((coordlist_valcores,coordlist_valsamples),axis=0)),axis=0)

imgSizeAll={}
for p in progUnique:
    img_cores=allImgNamesAll[sidxAll[progListAll[sidxAll]==p]]
    pSizes=np.zeros((img_cores.size))
    for si in range(img_cores.size):
        scoord=coordlistAll[allImgNamesAll==img_cores[si]]
        hsize=np.pi*np.square(np.max(scoord[:,0])-np.min(scoord[:,0]))
        vsize=np.pi*np.square(np.max(scoord[:,1])-np.min(scoord[:,1]))
        pSizes[si]=min(hsize,vsize)
    imgSizeAll[p]=pSizes



In [None]:
imgSize_median={}
for p in progUnique:
    imgSize_median[p]=np.median(imgSizeAll[p])
plt.violinplot(list(imgSizeAll.values()))
plt.scatter(np.arange(progUnique.size)+1,list(imgSize_median.values()))
plt.xticks(np.arange(progUnique.size)+1,list(imgSizeAll),rotation=90)

In [31]:
with open(os.path.join(datadir,'processed','imgSizeByPath'), 'wb') as output:
    pickle.dump(imgSize_median,output,pickle.HIGHEST_PROTOCOL)

In [31]:
with open(os.path.join(datadir,'processed','imgSizeByPath'), 'rb') as output:
    imgSize_median=pickle.load(output)

In [32]:
#normalize count
with open(os.path.join(datadir,'processed','imgSizeByPath'), 'rb') as output:
    imgSize_median=pickle.load(output)
areaAll=np.zeros(labelsAll.size)
for s in range(labelsAll.size):
    areaAll[s]=imgSize_median[progUnique[labelsAll.astype(int)][s]]
inputAll[:,-1]=inputAll[:,-1]/areaAll

In [33]:
seed=3
epochs=6000
saveFreq=200
lr=0.001 #initial learning rate
weight_decay=0 

# batchsize=4
batchsize=6000
model_str='logistic'

fc_dim1=64
fc_dim2=64
fc_dim3=64


dropout=0.01

name='exp0_pathologyClf_neighbor_clusters_exp0_subset_neighborOnly_crossVal_countAreaNorm_logistic'
logsavepath='/media/xinyi/dcis2idc1/log/cnnvae'+name
modelsavepath='/media/xinyi/dcis2idc1/models/cnnvae'+name
plotsavepath='/media/xinyi/dcis2idc1/plots/cnnvae'+name


if not os.path.exists(logsavepath):
    os.mkdir(logsavepath)
if not os.path.exists(modelsavepath):
    os.mkdir(modelsavepath)
if not os.path.exists(plotsavepath):
    os.mkdir(plotsavepath)



In [34]:
inputAll=scipy.stats.zscore(inputAll,axis=0)

In [35]:
def train(epoch,trainInput,labels_train):
    t = time.time()
    model.train()
    optimizer.zero_grad()

    pred = model(trainInput)

    loss=lossCE(pred,labels_train)

    loss.backward()
    optimizer.step()

    if epoch%500==0:
        print('Epoch: {:04d}'.format(epoch),
              'loss_train: {:.4f}'.format(loss))
    return loss.item()

In [39]:
inputAll=torch.tensor(inputAll).cuda().float()
labelsAll=torch.tensor(labelsAll).cuda().long()

testepoch=5800
predtest=np.zeros((inputAll.shape[0],np.unique(labels_train).size))
for sampleIdx in range(inputAll.shape[0]):

    trainIdx=np.arange(inputAll.shape[0])!=sampleIdx
    
    seed=3
    torch.manual_seed(seed)
    nclasses=np.unique(labels_train).size
    if use_cuda:
        torch.cuda.manual_seed(seed)

    nfeatures=inputAll.shape[1]
    if model_str=='fc3':
        model = modelsCNN.FC_l3(nfeatures,fc_dim1,fc_dim2,fc_dim3,nclasses,0.5,regrs=False)
        lossCE=torch.nn.CrossEntropyLoss(torch.tensor(weights_train).cuda().float())
    if model_str=='logistic':
        model = modelsCNN.LogisticReg(nfeatures,nclasses)
        lossCE=torch.nn.CrossEntropyLoss(torch.tensor(weights_train).cuda().float())
    if model_str=='fc5':
        model = modelsCNN.FC_l5(nfeatures,fc_dim1,fc_dim2,fc_dim3,fc_dim4,fc_dim5,nclasses,0.5,regrs=False)
        lossCE=torch.nn.CrossEntropyLoss(torch.tensor(weights_train).cuda().float())
    if model_str=='fc1':
        model = modelsCNN.FC_l1(nfeatures,fc_dim1,nclasses,regrs=False)
        lossCE=torch.nn.CrossEntropyLoss(torch.tensor(weights_train).cuda().float())
    if model_str=='fc0':
        model = modelsCNN.FC_l0(nfeatures,nclasses,regrs=False)
        lossCE=torch.nn.CrossEntropyLoss(torch.tensor(weights_train).cuda().float())

    if use_cuda:
        model.cuda()

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    train_loss_ep=[None]*epochs
    val_loss_ep=[None]*epochs
    t_ep=time.time()

    for ep in range(epochs):
        train_loss_ep[ep]=train(ep,inputAll[trainIdx],labelsAll[trainIdx])


        if ep%saveFreq == 0 and ep!=0:
            torch.save(model.cpu().state_dict(), os.path.join(modelsavepath,imgNamesAll[sampleIdx]+'_'+str(ep)+'.pt'))
        if use_cuda:
            model.cuda()
            torch.cuda.empty_cache()
    print(' total time: {:.4f}s'.format(time.time() - t_ep))

    with open(os.path.join(logsavepath,imgNamesAll[sampleIdx]+'_train_loss'), 'wb') as output:
        pickle.dump(train_loss_ep, output, pickle.HIGHEST_PROTOCOL)

    model.load_state_dict(torch.load(os.path.join(modelsavepath,imgNamesAll[sampleIdx]+'_'+str(testepoch)+'.pt')))
    with torch.no_grad():
        model.cuda()
        model.eval()
        pred = model(inputAll[[sampleIdx]])
        predtest[sampleIdx]=pred.cpu().detach().numpy()

        loss_test=lossCE(pred,labelsAll[[sampleIdx]]).item()

    print(loss_test)

  inputAll=torch.tensor(inputAll).cuda().float()
  labelsAll=torch.tensor(labelsAll).cuda().long()


Epoch: 0000 loss_train: 2.0710
Epoch: 0500 loss_train: 0.9134
Epoch: 1000 loss_train: 0.7261
Epoch: 1500 loss_train: 0.6266
Epoch: 2000 loss_train: 0.5612
Epoch: 2500 loss_train: 0.5126
Epoch: 3000 loss_train: 0.4734
Epoch: 3500 loss_train: 0.4402
Epoch: 4000 loss_train: 0.4115
Epoch: 4500 loss_train: 0.3861
Epoch: 5000 loss_train: 0.3633
Epoch: 5500 loss_train: 0.3427
 total time: 9.2363s
0.02995162457227707
Epoch: 0000 loss_train: 2.0712
Epoch: 0500 loss_train: 0.9127
Epoch: 1000 loss_train: 0.7258
Epoch: 1500 loss_train: 0.6264
Epoch: 2000 loss_train: 0.5610
Epoch: 2500 loss_train: 0.5123
Epoch: 3000 loss_train: 0.4730
Epoch: 3500 loss_train: 0.4399
Epoch: 4000 loss_train: 0.4111
Epoch: 4500 loss_train: 0.3857
Epoch: 5000 loss_train: 0.3630
Epoch: 5500 loss_train: 0.3423
 total time: 3.2104s
0.4176146984100342
Epoch: 0000 loss_train: 2.0696
Epoch: 0500 loss_train: 0.9109
Epoch: 1000 loss_train: 0.7232
Epoch: 1500 loss_train: 0.6235
Epoch: 2000 loss_train: 0.5578
Epoch: 2500 loss_tra

 total time: 3.0294s
0.014234201051294804
Epoch: 0000 loss_train: 2.0709
Epoch: 0500 loss_train: 0.9129
Epoch: 1000 loss_train: 0.7257
Epoch: 1500 loss_train: 0.6263
Epoch: 2000 loss_train: 0.5607
Epoch: 2500 loss_train: 0.5119
Epoch: 3000 loss_train: 0.4724
Epoch: 3500 loss_train: 0.4391
Epoch: 4000 loss_train: 0.4101
Epoch: 4500 loss_train: 0.3844
Epoch: 5000 loss_train: 0.3613
Epoch: 5500 loss_train: 0.3403
 total time: 3.0087s
3.4868245124816895
Epoch: 0000 loss_train: 2.0711
Epoch: 0500 loss_train: 0.9131
Epoch: 1000 loss_train: 0.7262
Epoch: 1500 loss_train: 0.6267
Epoch: 2000 loss_train: 0.5612
Epoch: 2500 loss_train: 0.5125
Epoch: 3000 loss_train: 0.4731
Epoch: 3500 loss_train: 0.4400
Epoch: 4000 loss_train: 0.4112
Epoch: 4500 loss_train: 0.3858
Epoch: 5000 loss_train: 0.3630
Epoch: 5500 loss_train: 0.3423
 total time: 2.9918s
0.3474266827106476
Epoch: 0000 loss_train: 2.0707
Epoch: 0500 loss_train: 0.9107
Epoch: 1000 loss_train: 0.7233
Epoch: 1500 loss_train: 0.6235
Epoch: 200

Epoch: 5500 loss_train: 0.3429
 total time: 2.9528s
0.0401199571788311
Epoch: 0000 loss_train: 2.0705
Epoch: 0500 loss_train: 0.9094
Epoch: 1000 loss_train: 0.7212
Epoch: 1500 loss_train: 0.6207
Epoch: 2000 loss_train: 0.5543
Epoch: 2500 loss_train: 0.5050
Epoch: 3000 loss_train: 0.4654
Epoch: 3500 loss_train: 0.4322
Epoch: 4000 loss_train: 0.4036
Epoch: 4500 loss_train: 0.3785
Epoch: 5000 loss_train: 0.3560
Epoch: 5500 loss_train: 0.3357
 total time: 3.0855s
12.722186088562012
Epoch: 0000 loss_train: 2.0707
Epoch: 0500 loss_train: 0.9136
Epoch: 1000 loss_train: 0.7265
Epoch: 1500 loss_train: 0.6270
Epoch: 2000 loss_train: 0.5615
Epoch: 2500 loss_train: 0.5127
Epoch: 3000 loss_train: 0.4734
Epoch: 3500 loss_train: 0.4403
Epoch: 4000 loss_train: 0.4115
Epoch: 4500 loss_train: 0.3861
Epoch: 5000 loss_train: 0.3633
Epoch: 5500 loss_train: 0.3427
 total time: 3.1136s
0.00014804698002990335
Epoch: 0000 loss_train: 2.0712
Epoch: 0500 loss_train: 0.9126
Epoch: 1000 loss_train: 0.7259
Epoch: 1

Epoch: 5000 loss_train: 0.3616
Epoch: 5500 loss_train: 0.3410
 total time: 3.1199s
2.9257853031158447
Epoch: 0000 loss_train: 2.0715
Epoch: 0500 loss_train: 0.9146
Epoch: 1000 loss_train: 0.7268
Epoch: 1500 loss_train: 0.6268
Epoch: 2000 loss_train: 0.5610
Epoch: 2500 loss_train: 0.5121
Epoch: 3000 loss_train: 0.4727
Epoch: 3500 loss_train: 0.4395
Epoch: 4000 loss_train: 0.4107
Epoch: 4500 loss_train: 0.3853
Epoch: 5000 loss_train: 0.3625
Epoch: 5500 loss_train: 0.3419
 total time: 3.0540s
2.655111312866211
Epoch: 0000 loss_train: 2.0710
Epoch: 0500 loss_train: 0.9112
Epoch: 1000 loss_train: 0.7234
Epoch: 1500 loss_train: 0.6237
Epoch: 2000 loss_train: 0.5582
Epoch: 2500 loss_train: 0.5096
Epoch: 3000 loss_train: 0.4705
Epoch: 3500 loss_train: 0.4375
Epoch: 4000 loss_train: 0.4090
Epoch: 4500 loss_train: 0.3837
Epoch: 5000 loss_train: 0.3610
Epoch: 5500 loss_train: 0.3403
 total time: 3.2504s
1.5101361274719238
Epoch: 0000 loss_train: 2.0700
Epoch: 0500 loss_train: 0.9061
Epoch: 1000 l

Epoch: 4500 loss_train: 0.3863
Epoch: 5000 loss_train: 0.3634
Epoch: 5500 loss_train: 0.3428
 total time: 2.9180s
0.010241220705211163
Epoch: 0000 loss_train: 2.0715
Epoch: 0500 loss_train: 0.9137
Epoch: 1000 loss_train: 0.7264
Epoch: 1500 loss_train: 0.6268
Epoch: 2000 loss_train: 0.5612
Epoch: 2500 loss_train: 0.5125
Epoch: 3000 loss_train: 0.4733
Epoch: 3500 loss_train: 0.4402
Epoch: 4000 loss_train: 0.4114
Epoch: 4500 loss_train: 0.3860
Epoch: 5000 loss_train: 0.3633
Epoch: 5500 loss_train: 0.3426
 total time: 2.9034s
0.3462642431259155
Epoch: 0000 loss_train: 2.0705
Epoch: 0500 loss_train: 0.8965
Epoch: 1000 loss_train: 0.7098
Epoch: 1500 loss_train: 0.6129
Epoch: 2000 loss_train: 0.5503
Epoch: 2500 loss_train: 0.5037
Epoch: 3000 loss_train: 0.4661
Epoch: 3500 loss_train: 0.4341
Epoch: 4000 loss_train: 0.4062
Epoch: 4500 loss_train: 0.3813
Epoch: 5000 loss_train: 0.3589
Epoch: 5500 loss_train: 0.3386
 total time: 2.9788s
7.873163223266602
Epoch: 0000 loss_train: 2.0705
Epoch: 0500

Epoch: 4000 loss_train: 0.4115
Epoch: 4500 loss_train: 0.3860
Epoch: 5000 loss_train: 0.3632
Epoch: 5500 loss_train: 0.3425
 total time: 3.0582s
0.6365519762039185
Epoch: 0000 loss_train: 2.0696
Epoch: 0500 loss_train: 0.9119
Epoch: 1000 loss_train: 0.7250
Epoch: 1500 loss_train: 0.6256
Epoch: 2000 loss_train: 0.5602
Epoch: 2500 loss_train: 0.5114
Epoch: 3000 loss_train: 0.4721
Epoch: 3500 loss_train: 0.4388
Epoch: 4000 loss_train: 0.4099
Epoch: 4500 loss_train: 0.3843
Epoch: 5000 loss_train: 0.3613
Epoch: 5500 loss_train: 0.3405
 total time: 3.1174s
1.2012877464294434
Epoch: 0000 loss_train: 2.0663
Epoch: 0500 loss_train: 0.9162
Epoch: 1000 loss_train: 0.7284
Epoch: 1500 loss_train: 0.6286
Epoch: 2000 loss_train: 0.5630
Epoch: 2500 loss_train: 0.5142
Epoch: 3000 loss_train: 0.4749
Epoch: 3500 loss_train: 0.4417
Epoch: 4000 loss_train: 0.4129
Epoch: 4500 loss_train: 0.3875
Epoch: 5000 loss_train: 0.3647
Epoch: 5500 loss_train: 0.3440
 total time: 3.1026s
0.05064012482762337
Epoch: 0000

Epoch: 3500 loss_train: 0.4402
Epoch: 4000 loss_train: 0.4115
Epoch: 4500 loss_train: 0.3861
Epoch: 5000 loss_train: 0.3633
Epoch: 5500 loss_train: 0.3426
 total time: 3.0193s
0.24295607209205627
Epoch: 0000 loss_train: 2.0707
Epoch: 0500 loss_train: 0.9123
Epoch: 1000 loss_train: 0.7254
Epoch: 1500 loss_train: 0.6261
Epoch: 2000 loss_train: 0.5608
Epoch: 2500 loss_train: 0.5122
Epoch: 3000 loss_train: 0.4731
Epoch: 3500 loss_train: 0.4400
Epoch: 4000 loss_train: 0.4113
Epoch: 4500 loss_train: 0.3859
Epoch: 5000 loss_train: 0.3632
Epoch: 5500 loss_train: 0.3425
 total time: 3.0158s
0.4757787883281708
Epoch: 0000 loss_train: 2.0706
Epoch: 0500 loss_train: 0.9118
Epoch: 1000 loss_train: 0.7248
Epoch: 1500 loss_train: 0.6254
Epoch: 2000 loss_train: 0.5601
Epoch: 2500 loss_train: 0.5116
Epoch: 3000 loss_train: 0.4725
Epoch: 3500 loss_train: 0.4395
Epoch: 4000 loss_train: 0.4108
Epoch: 4500 loss_train: 0.3855
Epoch: 5000 loss_train: 0.3627
Epoch: 5500 loss_train: 0.3421
 total time: 3.0130s

Epoch: 3000 loss_train: 0.4730
Epoch: 3500 loss_train: 0.4399
Epoch: 4000 loss_train: 0.4112
Epoch: 4500 loss_train: 0.3858
Epoch: 5000 loss_train: 0.3630
Epoch: 5500 loss_train: 0.3424
 total time: 2.9583s
4.768370445162873e-07
Epoch: 0000 loss_train: 2.0702
Epoch: 0500 loss_train: 0.9118
Epoch: 1000 loss_train: 0.7246
Epoch: 1500 loss_train: 0.6251
Epoch: 2000 loss_train: 0.5598
Epoch: 2500 loss_train: 0.5114
Epoch: 3000 loss_train: 0.4724
Epoch: 3500 loss_train: 0.4395
Epoch: 4000 loss_train: 0.4109
Epoch: 4500 loss_train: 0.3856
Epoch: 5000 loss_train: 0.3630
Epoch: 5500 loss_train: 0.3424
 total time: 3.1407s
1.141287088394165
Epoch: 0000 loss_train: 2.0702
Epoch: 0500 loss_train: 0.9114
Epoch: 1000 loss_train: 0.7250
Epoch: 1500 loss_train: 0.6259
Epoch: 2000 loss_train: 0.5606
Epoch: 2500 loss_train: 0.5120
Epoch: 3000 loss_train: 0.4728
Epoch: 3500 loss_train: 0.4398
Epoch: 4000 loss_train: 0.4111
Epoch: 4500 loss_train: 0.3857
Epoch: 5000 loss_train: 0.3630
Epoch: 5500 loss_tr

Epoch: 2500 loss_train: 0.5114
Epoch: 3000 loss_train: 0.4720
Epoch: 3500 loss_train: 0.4389
Epoch: 4000 loss_train: 0.4101
Epoch: 4500 loss_train: 0.3846
Epoch: 5000 loss_train: 0.3618
Epoch: 5500 loss_train: 0.3411
 total time: 2.9106s
3.382819414138794
Epoch: 0000 loss_train: 2.0702
Epoch: 0500 loss_train: 0.9116
Epoch: 1000 loss_train: 0.7248
Epoch: 1500 loss_train: 0.6255
Epoch: 2000 loss_train: 0.5602
Epoch: 2500 loss_train: 0.5116
Epoch: 3000 loss_train: 0.4724
Epoch: 3500 loss_train: 0.4393
Epoch: 4000 loss_train: 0.4107
Epoch: 4500 loss_train: 0.3853
Epoch: 5000 loss_train: 0.3626
Epoch: 5500 loss_train: 0.3420
 total time: 3.6332s
1.142520546913147
Epoch: 0000 loss_train: 2.0704
Epoch: 0500 loss_train: 0.9120
Epoch: 1000 loss_train: 0.7249
Epoch: 1500 loss_train: 0.6254
Epoch: 2000 loss_train: 0.5601
Epoch: 2500 loss_train: 0.5116
Epoch: 3000 loss_train: 0.4726
Epoch: 3500 loss_train: 0.4397
Epoch: 4000 loss_train: 0.4111
Epoch: 4500 loss_train: 0.3858
Epoch: 5000 loss_train:

Epoch: 2000 loss_train: 0.5609
Epoch: 2500 loss_train: 0.5122
Epoch: 3000 loss_train: 0.4730
Epoch: 3500 loss_train: 0.4399
Epoch: 4000 loss_train: 0.4112
Epoch: 4500 loss_train: 0.3858
Epoch: 5000 loss_train: 0.3630
Epoch: 5500 loss_train: 0.3424
 total time: 3.1079s
0.0010189585154876113
Epoch: 0000 loss_train: 2.0702
Epoch: 0500 loss_train: 0.9129
Epoch: 1000 loss_train: 0.7259
Epoch: 1500 loss_train: 0.6264
Epoch: 2000 loss_train: 0.5610
Epoch: 2500 loss_train: 0.5123
Epoch: 3000 loss_train: 0.4730
Epoch: 3500 loss_train: 0.4399
Epoch: 4000 loss_train: 0.4112
Epoch: 4500 loss_train: 0.3858
Epoch: 5000 loss_train: 0.3630
Epoch: 5500 loss_train: 0.3424
 total time: 3.1230s
1.1920927533992653e-07
Epoch: 0000 loss_train: 2.0704
Epoch: 0500 loss_train: 0.9127
Epoch: 1000 loss_train: 0.7257
Epoch: 1500 loss_train: 0.6263
Epoch: 2000 loss_train: 0.5608
Epoch: 2500 loss_train: 0.5121
Epoch: 3000 loss_train: 0.4729
Epoch: 3500 loss_train: 0.4398
Epoch: 4000 loss_train: 0.4111
Epoch: 4500 lo

Epoch: 1500 loss_train: 0.6260
Epoch: 2000 loss_train: 0.5608
Epoch: 2500 loss_train: 0.5123
Epoch: 3000 loss_train: 0.4732
Epoch: 3500 loss_train: 0.4401
Epoch: 4000 loss_train: 0.4114
Epoch: 4500 loss_train: 0.3860
Epoch: 5000 loss_train: 0.3633
Epoch: 5500 loss_train: 0.3426
 total time: 3.0218s
0.13870351016521454
Epoch: 0000 loss_train: 2.0717
Epoch: 0500 loss_train: 0.9128
Epoch: 1000 loss_train: 0.7256
Epoch: 1500 loss_train: 0.6261
Epoch: 2000 loss_train: 0.5607
Epoch: 2500 loss_train: 0.5120
Epoch: 3000 loss_train: 0.4728
Epoch: 3500 loss_train: 0.4397
Epoch: 4000 loss_train: 0.4109
Epoch: 4500 loss_train: 0.3856
Epoch: 5000 loss_train: 0.3629
Epoch: 5500 loss_train: 0.3423
 total time: 3.0214s
0.518147349357605
Epoch: 0000 loss_train: 2.0715
Epoch: 0500 loss_train: 0.9117
Epoch: 1000 loss_train: 0.7251
Epoch: 1500 loss_train: 0.6260
Epoch: 2000 loss_train: 0.5609
Epoch: 2500 loss_train: 0.5124
Epoch: 3000 loss_train: 0.4733
Epoch: 3500 loss_train: 0.4402
Epoch: 4000 loss_trai

Epoch: 5500 loss_train: 0.3376
 total time: 3.0580s
11.373620986938477
Epoch: 0000 loss_train: 2.0712
Epoch: 0500 loss_train: 0.9103
Epoch: 1000 loss_train: 0.7229
Epoch: 1500 loss_train: 0.6234
Epoch: 2000 loss_train: 0.5581
Epoch: 2500 loss_train: 0.5095
Epoch: 3000 loss_train: 0.4703
Epoch: 3500 loss_train: 0.4371
Epoch: 4000 loss_train: 0.4081
Epoch: 4500 loss_train: 0.3824
Epoch: 5000 loss_train: 0.3593
Epoch: 5500 loss_train: 0.3382
 total time: 3.1074s
5.129616737365723
Epoch: 0000 loss_train: 2.0705
Epoch: 0500 loss_train: 0.9129
Epoch: 1000 loss_train: 0.7259
Epoch: 1500 loss_train: 0.6265
Epoch: 2000 loss_train: 0.5610
Epoch: 2500 loss_train: 0.5123
Epoch: 3000 loss_train: 0.4731
Epoch: 3500 loss_train: 0.4400
Epoch: 4000 loss_train: 0.4112
Epoch: 4500 loss_train: 0.3859
Epoch: 5000 loss_train: 0.3631
Epoch: 5500 loss_train: 0.3425
 total time: 3.0423s
0.00273199868388474
Epoch: 0000 loss_train: 2.0706
Epoch: 0500 loss_train: 0.9129
Epoch: 1000 loss_train: 0.7259
Epoch: 1500 

Epoch: 5000 loss_train: 0.3631
Epoch: 5500 loss_train: 0.3425
 total time: 3.0758s
0.001116367639042437
Epoch: 0000 loss_train: 2.0700
Epoch: 0500 loss_train: 0.9128
Epoch: 1000 loss_train: 0.7258
Epoch: 1500 loss_train: 0.6264
Epoch: 2000 loss_train: 0.5609
Epoch: 2500 loss_train: 0.5122
Epoch: 3000 loss_train: 0.4730
Epoch: 3500 loss_train: 0.4399
Epoch: 4000 loss_train: 0.4112
Epoch: 4500 loss_train: 0.3858
Epoch: 5000 loss_train: 0.3631
Epoch: 5500 loss_train: 0.3424
 total time: 3.0893s
0.0001530530134914443
Epoch: 0000 loss_train: 2.0706
Epoch: 0500 loss_train: 0.9129
Epoch: 1000 loss_train: 0.7259
Epoch: 1500 loss_train: 0.6265
Epoch: 2000 loss_train: 0.5610
Epoch: 2500 loss_train: 0.5123
Epoch: 3000 loss_train: 0.4731
Epoch: 3500 loss_train: 0.4400
Epoch: 4000 loss_train: 0.4112
Epoch: 4500 loss_train: 0.3858
Epoch: 5000 loss_train: 0.3631
Epoch: 5500 loss_train: 0.3425
 total time: 3.1657s
0.0029560700058937073
Epoch: 0000 loss_train: 2.0706
Epoch: 0500 loss_train: 0.9130
Epoc

Epoch: 4000 loss_train: 0.4112
Epoch: 4500 loss_train: 0.3858
Epoch: 5000 loss_train: 0.3631
Epoch: 5500 loss_train: 0.3424
 total time: 3.1859s
0.006162925623357296
Epoch: 0000 loss_train: 2.0542
Epoch: 0500 loss_train: 0.9087
Epoch: 1000 loss_train: 0.7225
Epoch: 1500 loss_train: 0.6242
Epoch: 2000 loss_train: 0.5593
Epoch: 2500 loss_train: 0.5110
Epoch: 3000 loss_train: 0.4719
Epoch: 3500 loss_train: 0.4390
Epoch: 4000 loss_train: 0.4105
Epoch: 4500 loss_train: 0.3853
Epoch: 5000 loss_train: 0.3627
Epoch: 5500 loss_train: 0.3421
 total time: 3.0041s
2.084806442260742
Epoch: 0000 loss_train: 2.0705
Epoch: 0500 loss_train: 0.9125
Epoch: 1000 loss_train: 0.7256
Epoch: 1500 loss_train: 0.6262
Epoch: 2000 loss_train: 0.5608
Epoch: 2500 loss_train: 0.5121
Epoch: 3000 loss_train: 0.4729
Epoch: 3500 loss_train: 0.4398
Epoch: 4000 loss_train: 0.4112
Epoch: 4500 loss_train: 0.3858
Epoch: 5000 loss_train: 0.3631
Epoch: 5500 loss_train: 0.3425
 total time: 2.8894s
0.14037643373012543
Epoch: 000

Epoch: 3000 loss_train: 0.4730
Epoch: 3500 loss_train: 0.4399
Epoch: 4000 loss_train: 0.4112
Epoch: 4500 loss_train: 0.3858
Epoch: 5000 loss_train: 0.3630
Epoch: 5500 loss_train: 0.3424
 total time: 3.0813s
2.264974000354414e-06
Epoch: 0000 loss_train: 2.0706
Epoch: 0500 loss_train: 0.9128
Epoch: 1000 loss_train: 0.7259
Epoch: 1500 loss_train: 0.6264
Epoch: 2000 loss_train: 0.5610
Epoch: 2500 loss_train: 0.5123
Epoch: 3000 loss_train: 0.4730
Epoch: 3500 loss_train: 0.4399
Epoch: 4000 loss_train: 0.4112
Epoch: 4500 loss_train: 0.3858
Epoch: 5000 loss_train: 0.3630
Epoch: 5500 loss_train: 0.3424
 total time: 3.1032s
0.05210544541478157
Epoch: 0000 loss_train: 2.0705
Epoch: 0500 loss_train: 0.9130
Epoch: 1000 loss_train: 0.7259
Epoch: 1500 loss_train: 0.6264
Epoch: 2000 loss_train: 0.5610
Epoch: 2500 loss_train: 0.5123
Epoch: 3000 loss_train: 0.4730
Epoch: 3500 loss_train: 0.4399
Epoch: 4000 loss_train: 0.4112
Epoch: 4500 loss_train: 0.3858
Epoch: 5000 loss_train: 0.3630
Epoch: 5500 loss_

Epoch: 2000 loss_train: 0.5610
Epoch: 2500 loss_train: 0.5123
Epoch: 3000 loss_train: 0.4730
Epoch: 3500 loss_train: 0.4399
Epoch: 4000 loss_train: 0.4112
Epoch: 4500 loss_train: 0.3858
Epoch: 5000 loss_train: 0.3631
Epoch: 5500 loss_train: 0.3424
 total time: 3.1119s
1.1444026313256472e-05
Epoch: 0000 loss_train: 2.0705
Epoch: 0500 loss_train: 0.9130
Epoch: 1000 loss_train: 0.7259
Epoch: 1500 loss_train: 0.6265
Epoch: 2000 loss_train: 0.5610
Epoch: 2500 loss_train: 0.5123
Epoch: 3000 loss_train: 0.4731
Epoch: 3500 loss_train: 0.4400
Epoch: 4000 loss_train: 0.4112
Epoch: 4500 loss_train: 0.3858
Epoch: 5000 loss_train: 0.3631
Epoch: 5500 loss_train: 0.3425
 total time: 3.0289s
2.0265558760002023e-06
Epoch: 0000 loss_train: 2.0705
Epoch: 0500 loss_train: 0.9123
Epoch: 1000 loss_train: 0.7252
Epoch: 1500 loss_train: 0.6257
Epoch: 2000 loss_train: 0.5601
Epoch: 2500 loss_train: 0.5113
Epoch: 3000 loss_train: 0.4720
Epoch: 3500 loss_train: 0.4388
Epoch: 4000 loss_train: 0.4100
Epoch: 4500 l

Epoch: 1500 loss_train: 0.6262
Epoch: 2000 loss_train: 0.5608
Epoch: 2500 loss_train: 0.5123
Epoch: 3000 loss_train: 0.4732
Epoch: 3500 loss_train: 0.4402
Epoch: 4000 loss_train: 0.4115
Epoch: 4500 loss_train: 0.3861
Epoch: 5000 loss_train: 0.3633
Epoch: 5500 loss_train: 0.3426
 total time: 3.0298s
0.17439743876457214
Epoch: 0000 loss_train: 2.0689
Epoch: 0500 loss_train: 0.9104
Epoch: 1000 loss_train: 0.7226
Epoch: 1500 loss_train: 0.6226
Epoch: 2000 loss_train: 0.5567
Epoch: 2500 loss_train: 0.5079
Epoch: 3000 loss_train: 0.4686
Epoch: 3500 loss_train: 0.4355
Epoch: 4000 loss_train: 0.4069
Epoch: 4500 loss_train: 0.3817
Epoch: 5000 loss_train: 0.3591
Epoch: 5500 loss_train: 0.3386
 total time: 3.1154s
7.961686134338379
Epoch: 0000 loss_train: 2.0701
Epoch: 0500 loss_train: 0.9114
Epoch: 1000 loss_train: 0.7244
Epoch: 1500 loss_train: 0.6246
Epoch: 2000 loss_train: 0.5589
Epoch: 2500 loss_train: 0.5100
Epoch: 3000 loss_train: 0.4706
Epoch: 3500 loss_train: 0.4373
Epoch: 4000 loss_trai

Epoch: 1000 loss_train: 0.7255
Epoch: 1500 loss_train: 0.6260
Epoch: 2000 loss_train: 0.5605
Epoch: 2500 loss_train: 0.5118
Epoch: 3000 loss_train: 0.4726
Epoch: 3500 loss_train: 0.4395
Epoch: 4000 loss_train: 0.4108
Epoch: 4500 loss_train: 0.3854
Epoch: 5000 loss_train: 0.3627
Epoch: 5500 loss_train: 0.3420
 total time: 3.2345s
0.5404962301254272
Epoch: 0000 loss_train: 2.0710
Epoch: 0500 loss_train: 0.9121
Epoch: 1000 loss_train: 0.7249
Epoch: 1500 loss_train: 0.6254
Epoch: 2000 loss_train: 0.5600
Epoch: 2500 loss_train: 0.5113
Epoch: 3000 loss_train: 0.4721
Epoch: 3500 loss_train: 0.4390
Epoch: 4000 loss_train: 0.4103
Epoch: 4500 loss_train: 0.3849
Epoch: 5000 loss_train: 0.3622
Epoch: 5500 loss_train: 0.3416
 total time: 3.1945s
0.9555199146270752
Epoch: 0000 loss_train: 2.0713
Epoch: 0500 loss_train: 0.9127
Epoch: 1000 loss_train: 0.7257
Epoch: 1500 loss_train: 0.6262
Epoch: 2000 loss_train: 0.5607
Epoch: 2500 loss_train: 0.5120
Epoch: 3000 loss_train: 0.4728
Epoch: 3500 loss_trai

Epoch: 0500 loss_train: 0.9078
Epoch: 1000 loss_train: 0.7205
Epoch: 1500 loss_train: 0.6207
Epoch: 2000 loss_train: 0.5549
Epoch: 2500 loss_train: 0.5059
Epoch: 3000 loss_train: 0.4664
Epoch: 3500 loss_train: 0.4331
Epoch: 4000 loss_train: 0.4041
Epoch: 4500 loss_train: 0.3785
Epoch: 5000 loss_train: 0.3556
Epoch: 5500 loss_train: 0.3348
 total time: 2.8996s
5.008342266082764
Epoch: 0000 loss_train: 2.0708
Epoch: 0500 loss_train: 0.9130
Epoch: 1000 loss_train: 0.7259
Epoch: 1500 loss_train: 0.6263
Epoch: 2000 loss_train: 0.5608
Epoch: 2500 loss_train: 0.5121
Epoch: 3000 loss_train: 0.4728
Epoch: 3500 loss_train: 0.4397
Epoch: 4000 loss_train: 0.4110
Epoch: 4500 loss_train: 0.3856
Epoch: 5000 loss_train: 0.3629
Epoch: 5500 loss_train: 0.3423
 total time: 2.9292s
0.6478436589241028
Epoch: 0000 loss_train: 2.0720
Epoch: 0500 loss_train: 0.9025
Epoch: 1000 loss_train: 0.7127
Epoch: 1500 loss_train: 0.6123
Epoch: 2000 loss_train: 0.5466
Epoch: 2500 loss_train: 0.4979
Epoch: 3000 loss_train

Epoch: 0500 loss_train: 0.9129
Epoch: 1000 loss_train: 0.7259
Epoch: 1500 loss_train: 0.6264
Epoch: 2000 loss_train: 0.5610
Epoch: 2500 loss_train: 0.5123
Epoch: 3000 loss_train: 0.4731
Epoch: 3500 loss_train: 0.4400
Epoch: 4000 loss_train: 0.4112
Epoch: 4500 loss_train: 0.3859
Epoch: 5000 loss_train: 0.3631
Epoch: 5500 loss_train: 0.3425
 total time: 2.9654s
0.00015639036428183317
Epoch: 0000 loss_train: 2.0713
Epoch: 0500 loss_train: 0.9143
Epoch: 1000 loss_train: 0.7272
Epoch: 1500 loss_train: 0.6276
Epoch: 2000 loss_train: 0.5621
Epoch: 2500 loss_train: 0.5133
Epoch: 3000 loss_train: 0.4740
Epoch: 3500 loss_train: 0.4408
Epoch: 4000 loss_train: 0.4121
Epoch: 4500 loss_train: 0.3866
Epoch: 5000 loss_train: 0.3638
Epoch: 5500 loss_train: 0.3431
 total time: 2.9249s
0.011585456319153309
Epoch: 0000 loss_train: 2.0722
Epoch: 0500 loss_train: 0.9146
Epoch: 1000 loss_train: 0.7271
Epoch: 1500 loss_train: 0.6275
Epoch: 2000 loss_train: 0.5619
Epoch: 2500 loss_train: 0.5131
Epoch: 3000 los

Epoch: 0500 loss_train: 0.9157
Epoch: 1000 loss_train: 0.7290
Epoch: 1500 loss_train: 0.6294
Epoch: 2000 loss_train: 0.5639
Epoch: 2500 loss_train: 0.5151
Epoch: 3000 loss_train: 0.4758
Epoch: 3500 loss_train: 0.4425
Epoch: 4000 loss_train: 0.4137
Epoch: 4500 loss_train: 0.3882
Epoch: 5000 loss_train: 0.3653
Epoch: 5500 loss_train: 0.3446
 total time: 3.2601s
0.013173097744584084
Epoch: 0000 loss_train: 2.0700
Epoch: 0500 loss_train: 0.9124
Epoch: 1000 loss_train: 0.7254
Epoch: 1500 loss_train: 0.6261
Epoch: 2000 loss_train: 0.5607
Epoch: 2500 loss_train: 0.5121
Epoch: 3000 loss_train: 0.4729
Epoch: 3500 loss_train: 0.4398
Epoch: 4000 loss_train: 0.4110
Epoch: 4500 loss_train: 0.3856
Epoch: 5000 loss_train: 0.3629
Epoch: 5500 loss_train: 0.3422
 total time: 3.2389s
0.5204038619995117
Epoch: 0000 loss_train: 2.0694
Epoch: 0500 loss_train: 0.9142
Epoch: 1000 loss_train: 0.7269
Epoch: 1500 loss_train: 0.6272
Epoch: 2000 loss_train: 0.5617
Epoch: 2500 loss_train: 0.5129
Epoch: 3000 loss_tr

Epoch: 0500 loss_train: 0.9129
Epoch: 1000 loss_train: 0.7259
Epoch: 1500 loss_train: 0.6264
Epoch: 2000 loss_train: 0.5609
Epoch: 2500 loss_train: 0.5122
Epoch: 3000 loss_train: 0.4730
Epoch: 3500 loss_train: 0.4399
Epoch: 4000 loss_train: 0.4112
Epoch: 4500 loss_train: 0.3858
Epoch: 5000 loss_train: 0.3630
Epoch: 5500 loss_train: 0.3424
 total time: 3.1191s
0.017283186316490173
Epoch: 0000 loss_train: 2.0701
Epoch: 0500 loss_train: 0.9100
Epoch: 1000 loss_train: 0.7233
Epoch: 1500 loss_train: 0.6240
Epoch: 2000 loss_train: 0.5588
Epoch: 2500 loss_train: 0.5104
Epoch: 3000 loss_train: 0.4715
Epoch: 3500 loss_train: 0.4387
Epoch: 4000 loss_train: 0.4103
Epoch: 4500 loss_train: 0.3851
Epoch: 5000 loss_train: 0.3626
Epoch: 5500 loss_train: 0.3421
 total time: 3.0650s
2.1359643936157227
Epoch: 0000 loss_train: 2.0706
Epoch: 0500 loss_train: 0.9127
Epoch: 1000 loss_train: 0.7258
Epoch: 1500 loss_train: 0.6264
Epoch: 2000 loss_train: 0.5610
Epoch: 2500 loss_train: 0.5123
Epoch: 3000 loss_tr

Epoch: 0500 loss_train: 0.9130
Epoch: 1000 loss_train: 0.7260
Epoch: 1500 loss_train: 0.6265
Epoch: 2000 loss_train: 0.5610
Epoch: 2500 loss_train: 0.5123
Epoch: 3000 loss_train: 0.4731
Epoch: 3500 loss_train: 0.4400
Epoch: 4000 loss_train: 0.4112
Epoch: 4500 loss_train: 0.3858
Epoch: 5000 loss_train: 0.3631
Epoch: 5500 loss_train: 0.3424
 total time: 3.1856s
6.437280717364047e-06
Epoch: 0000 loss_train: 2.0703
Epoch: 0500 loss_train: 0.9119
Epoch: 1000 loss_train: 0.7248
Epoch: 1500 loss_train: 0.6252
Epoch: 2000 loss_train: 0.5597
Epoch: 2500 loss_train: 0.5110
Epoch: 3000 loss_train: 0.4717
Epoch: 3500 loss_train: 0.4386
Epoch: 4000 loss_train: 0.4098
Epoch: 4500 loss_train: 0.3843
Epoch: 5000 loss_train: 0.3615
Epoch: 5500 loss_train: 0.3408
 total time: 3.1580s
3.139277696609497
Epoch: 0000 loss_train: 2.0705
Epoch: 0500 loss_train: 0.9130
Epoch: 1000 loss_train: 0.7259
Epoch: 1500 loss_train: 0.6265
Epoch: 2000 loss_train: 0.5610
Epoch: 2500 loss_train: 0.5123
Epoch: 3000 loss_tr

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [40]:
with open(os.path.join(logsavepath,'crossVal_loss'), 'wb') as output:
    pickle.dump(predtest, output, pickle.HIGHEST_PROTOCOL)

In [41]:
predtest_label=np.argmax(predtest,axis=1)

In [42]:
res=pd.DataFrame({'sampleName':imgNamesAll,'true':progUnique[labelsAll.cpu().numpy()],'predicted':progUnique[predtest_label]})
res.to_csv(os.path.join(plotsavepath,'predictions.csv'))

In [None]:
progInclude=np.array(['Hyperplasia','Atypical hyperplasia','DCIS and breast tissue',  'DCIS with early infiltration'])

In [45]:
#plot confusion
def plotCTcomp(labels,ctlist,savepath,savenamecluster,byCT,addname='',order=progInclude):
    res=np.zeros((order.size,order.size))
    for li in range(res.shape[0]):
        l=order[li]
        nl=np.sum(labels==l)
        ctlist_l=ctlist[labels==l]
        for ci in range(res.shape[1]):
            c=order[ci]
            res[li,ci]=np.sum(ctlist_l==c)
#             res[li,ci]=np.sum(ctlist_l==c)/nl
    if not byCT:
        addname+=''
        for li in range(res.shape[0]):
            l=order[li]
            nl=np.sum(labels==l)
            res[li]=res[li]/nl
    else:
        addname+='_normbyCT'
        for ci in range(res.shape[1]):
            c=order[ci]
            nc=np.sum(ctlist==c)
            res[:,ci]=res[:,ci]/nc
    
    fig, ax = plt.subplots(figsize=(10, 10))
    im = ax.imshow(res,cmap='binary',vmin=0,vmax=1)
    fig.colorbar(im)
    ax.set_yticks(np.arange(order.size))
    ax.set_yticklabels(order)
    ax.set_xticks(np.arange(order.size))
    ax.set_xticklabels(order)
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",rotation_mode="anchor")
    fig.tight_layout()
    plt.savefig(os.path.join(savepath,savenamecluster+addname+'.pdf'))
    plt.close()
    
plotCTcomp(progUnique[labelsAll.cpu().numpy()],progUnique[predtest_label],plotsavepath,'confusion'+str(testepoch),False)

In [None]:
plotCTcomp(progUnique[labelsAll.cpu().numpy()][:sidx_start.size],progUnique[predtest_label][:sidx_start.size],plotsavepath,'confusion_excludeValSamples'+str(testepoch),False)

In [39]:
res

Unnamed: 0,sampleName,true,predicted
0,br1003a_1_cytokeratin_555_aSMA_647_hoechst_I1,Breast tissue,Breast tissue
1,br1003a_1_cytokeratin_555_aSMA_647_hoechst_I10,Breast tissue,Breast tissue
2,br1003a_1_cytokeratin_555_aSMA_647_hoechst_I2,Breast tissue,DCIS with early infiltration
3,br1003a_1_cytokeratin_555_aSMA_647_hoechst_I3,Breast tissue,Breast tissue
4,br1003a_1_cytokeratin_555_aSMA_647_hoechst_I7,Breast tissue,Breast tissue
...,...,...,...
407,br8018a_2_cytokeratin_555_ki67_647_hoechst_G6,Invasive ductal carcinoma,Invasive ductal carcinoma
408,br8018a_2_cytokeratin_555_ki67_647_hoechst_H2,Cancer adjacent normal breast tissue,Invasive ductal carcinoma
409,br8018a_2_cytokeratin_555_ki67_647_hoechst_H3,Cancer adjacent normal breast tissue,Cancer adjacent normal breast tissue
410,br8018a_2_cytokeratin_555_ki67_647_hoechst_H4,Cancer adjacent normal breast tissue,Cancer adjacent normal breast tissue
