In [1]:
import numpy as np
import pandas as pd
from glob import glob
import os
import torch
import SimpleITK as sitk
from SUMNet_bn import SUMNet
from torchvision import transforms
import torch.nn.functional as F
import cv2
from tqdm import tqdm_notebook as tq

In [2]:
def load_itk_image(filename):
    itkimage = sitk.ReadImage(filename)
    numpyImage = sitk.GetArrayFromImage(itkimage)
   
    numpyOrigin = np.array(list(reversed(itkimage.GetOrigin())))
    numpySpacing = np.array(list(reversed(itkimage.GetSpacing())))
    return numpyImage, numpyOrigin, numpySpacing

In [3]:
seg_model_loadPath = '../train_codes/Results/SUMNet_new/Adam_1e-4_ep100_CE+Lov/'
netS = SUMNet(in_ch=1,out_ch=2)
netS.load_state_dict(torch.load(seg_model_loadPath+'sumnet_best.pt'))
netS = netS.cuda()
apply_norm = transforms.Normalize([-460.466],[444.421]) 

In [4]:
cand_path = "../dataset/candidates.csv"
b_sz = 8
df_node = pd.read_csv(cand_path)
subset = ['1']#,'5']
running_correct = 0
count = 0

orig_list = []
pred_list = []
for s in subset:
    print('Subset:',s)
    luna_subset_path = '../dataset/subset'+str(s)+'/'    
    all_files = os.listdir(luna_subset_path)
    mhd_files = []
    for f in all_files:
        if '.mhd' in f:
            mhd_files.append(f)
    count = 0
    for m in tq(mhd_files):   
        print(m)
        mini_df = df_node[df_node["seriesuid"]==m[:-4]]
        itk_img = sitk.ReadImage(luna_subset_path+m) 
        img_array = sitk.GetArrayFromImage(itk_img)
        origin = np.array(itk_img.GetOrigin())      # x,y,z  Origin in world coordinates (mm)
        spacing = np.array(itk_img.GetSpacing())   
        slice_list = []
        if len(mini_df)>0:
            for i in range(len(mini_df)):
                fName = mini_df['seriesuid'].values[i]
                z_coord = mini_df['coordZ'].values[i]
                print(z_coord)
                orig_class = mini_df['class'].values[i]
                print(orig_class)
                pred = 0
                v_center =np.rint((z_coord-origin[2])/spacing[2])   
                img_slice = img_array[int(v_center)]
                mid_mean = img_slice[100:400,100:400].mean()    
                img_slice[img_slice==img_slice.min()] = mid_mean
                img_slice[img_slice==img_slice.max()] = mid_mean
                img_slice_tensor = torch.from_numpy(img_slice).unsqueeze(0).float()
                img_slice_norm = apply_norm(img_slice_tensor).unsqueeze(0)
                
                out = F.softmax(netS(img_slice_norm.cuda()),dim=1)
                out_np = np.asarray(out[0,1].squeeze(0).detach().cpu().numpy()*255,dtype=np.uint8)

                ret, thresh = cv2.threshold(out_np,0,1,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
                connectivity = 4  
                output = cv2.connectedComponentsWithStats(thresh, connectivity, cv2.CV_32S)
                stats = output[2]
                temp = stats[1:, cv2.CC_STAT_AREA]
                if len(temp)>0:
                    largest_label = 1 + np.argmax(temp)    
                    areas = stats[1:, cv2.CC_STAT_AREA]
                    max_area = np.max(areas)
                    if max_area>150:
                        pred = 1
                if pred == orig_class:                    
                    running_correct += 1
                pred_list.append(pred)
                orig_list.append(orig_class)
                count += 1                                                                      

Subset: 1


HBox(children=(IntProgress(value=0, max=89), HTML(value='')))

1.3.6.1.4.1.14519.5.2.1.6279.6001.315214756157389122376518747372.mhd
1.3.6.1.4.1.14519.5.2.1.6279.6001.193408384740507320589857096592.mhd
1.3.6.1.4.1.14519.5.2.1.6279.6001.259543921154154401875872845498.mhd


KeyboardInterrupt: 

In [None]:
print('Accuarcy:',(running_correct/count)*100)

In [None]:
from sklearn.metrics import confusion_matrix

In [None]:
cf = confusion_matrix(orig_list, pred_list)
tn, fp, fn, tp = cf.ravel()

In [None]:
cf

In [None]:
sensitivity = tp/(tp+fn)
print('Sensitivity:',sensitivity)

In [None]:
specificity = tn/(tn+fp)
print('Specificity:',specificity)

In [None]:
precision = tp/(tp+fp)
print('Precision:',precision)