In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
from dataset import *
from model import *
import os
import SimpleITK as sitk
import math
from itkwidgets import view 
%matplotlib widget

In [2]:
mode='gpu'

In [3]:
if mode=='gpu':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # after switch device, you need restart the kernel
    torch.cuda.set_device(1)
    torch.set_default_tensor_type('torch.cuda.DoubleTensor')
else:
    device = torch.device('cpu')
    torch.set_default_dtype(torch.float64)

## Testing
### initialization

In [15]:
epoch = 35
output_dir = 'Models/UNet1024'
checkpoint = torch.load(f'{output_dir}/epoch_{epoch}_checkpoint.pth.tar')
model = UNet1024()

model.load_state_dict(checkpoint['model_state_dict'])
net = torch.nn.DataParallel(model, device_ids=[0, 1])

# params 120237649, # conv layers 62


### save files

In [16]:
case_info = []
root_dir = '/home/sci/hdai/Projects/Dataset/LymphNodes'
patch_size = 128
field_list = ['Series UID', 'Collection', '3rd Party Analysis', 
                      'Data Description URI', 'Subject ID', 'Study UID', 
                      'Study Description', 'Study Date', 'Series Description', 
                      'Manufacturer', 'Modality', 'SOP Class Name', 
                      'SOP Class UID', 'Number of Images', 'File Size', 
                      'File Location', 'Download Timestamp']
with open(f'{root_dir}/metadata.csv', mode='r') as infile:
    reader = csv.reader(infile)
    for row in reader:
        case_info.append({field_list[i]:row[i] for i in range(len(row))})

case_info = case_info[87:]
        
for case in tqdm(case_info):         
#         construct 3d CT from dicom folder
        # '/CT Lymph Nodes/ABD_LYMPH_003/09-14-2014-ABDLYMPH003-abdominallymphnodes-39052/abdominallymphnodes-65663'
    relative_ct_folder_path = case['File Location'][1:].replace('\\','/')
        # '/home/sci/hdai/Projects/LymphNodes/CT Lymph Nodes/ABD_LYMPH_003/09-14-2014-ABDLYMPH003-abdominallymphnodes-39052/abdominallymphnodes-65663'
    ct_folder_path = f'{root_dir}{relative_ct_folder_path}'
    slice_name_list = [f for f in os.listdir(ct_folder_path)]
    slice_name_list.sort()
    slice_list = []
    for slice_name in slice_name_list:
        ds = pd.dcmread(f'{ct_folder_path}/{slice_name}')
        slice_list.append(torch.from_numpy(ds.pixel_array.transpose()))
    img = torch.stack(slice_list,-1).to(device)
    
    case_name = case['File Location'][17:30].replace('\\','/')
    mask_path = f'{root_dir}/MED_ABD_LYMPH_MASKS/{case_name}/{case_name}_mask.nii.gz'
    mask = torch.from_numpy(nib.load(mask_path).get_fdata()).to(device)
    mask[mask>1] = 1
    
    half_patch_size = int(patch_size/2)
    idx_x, idx_y, idx_z = torch.where(mask!=0)
    centroid_x, centroid_y, centroid_z = 256, 256, 300
    if int(torch.mean(idx_x.float())) < mask.shape[0]-half_patch_size and int(torch.mean(idx_x.float())) > half_patch_size:
        centroid_x = int(torch.mean(idx_x.float()))
    if int(torch.mean(idx_y.float())) < mask.shape[1]-half_patch_size and int(torch.mean(idx_y.float())) > half_patch_size:
        centroid_y = int(torch.mean(idx_y.float()))
    if int(torch.mean(idx_z.float())) < mask.shape[2]-half_patch_size and int(torch.mean(idx_z.float())) > half_patch_size:
        centroid_z = int(torch.mean(idx_z.float()))
    img = img[centroid_x-half_patch_size:centroid_x+half_patch_size, centroid_y-half_patch_size:centroid_y+half_patch_size, centroid_z-half_patch_size:centroid_z+half_patch_size]
    mask_pred = model(img.unsqueeze(0).unsqueeze(0))
    sigmoid = torch.nn.Sigmoid()
    mask_pred = sigmoid(mask_pred).squeeze()
    
#     segment_depth = 128
#     segment_num = math.ceil(img.shape[2]/segment_depth)
#     mask_pred_segment_list = []
#     for i in range(segment_num):
#         begin = i*segment_depth
#         end = min(i*segment_depth+segment_depth,img.shape[2])
#         mask_pred_segment = model(img[192:320,192:320,begin:end].unsqueeze(0).unsqueeze(0))
#         mask_pred_segment_list.append(mask_pred_segment)
        
#     mask_pred = torch.stack(mask_pred_segment_list,-1)
    print(case_name)
    mask_path = f'{output_dir}/PredResult/{case_name}_pred_mask.nii.gz'
    nib.save(nib.Nifti1Image(mask_pred.cpu().detach().numpy(), None), mask_path)

  0%|          | 0/88 [00:00<?, ?it/s]

MED_LYMPH_001


  1%|          | 1/88 [00:04<07:14,  5.00s/it]

MED_LYMPH_002


  2%|▏         | 2/88 [00:09<07:05,  4.95s/it]

MED_LYMPH_003


  3%|▎         | 3/88 [00:15<07:25,  5.25s/it]

MED_LYMPH_004


  5%|▍         | 4/88 [00:20<07:23,  5.28s/it]

MED_LYMPH_005


  6%|▌         | 5/88 [00:25<07:05,  5.13s/it]

MED_LYMPH_006


  7%|▋         | 6/88 [00:30<06:50,  5.01s/it]

MED_LYMPH_007


  8%|▊         | 7/88 [00:35<06:49,  5.05s/it]

MED_LYMPH_008


  9%|▉         | 8/88 [00:40<06:43,  5.04s/it]

MED_LYMPH_009


 10%|█         | 9/88 [00:45<06:37,  5.03s/it]

MED_LYMPH_010


 11%|█▏        | 10/88 [00:50<06:35,  5.07s/it]

MED_LYMPH_011


 12%|█▎        | 11/88 [00:55<06:20,  4.94s/it]

MED_LYMPH_012


 14%|█▎        | 12/88 [00:59<06:02,  4.76s/it]

MED_LYMPH_013


 15%|█▍        | 13/88 [01:04<05:56,  4.75s/it]

MED_LYMPH_014


 16%|█▌        | 14/88 [01:09<05:49,  4.72s/it]

MED_LYMPH_015


 17%|█▋        | 15/88 [01:13<05:45,  4.74s/it]

MED_LYMPH_016


 18%|█▊        | 16/88 [01:17<05:24,  4.50s/it]

MED_LYMPH_017


 19%|█▉        | 17/88 [01:21<05:10,  4.37s/it]

MED_LYMPH_018


 20%|██        | 18/88 [01:26<05:11,  4.46s/it]

MED_LYMPH_019


 22%|██▏       | 19/88 [01:31<05:14,  4.56s/it]

MED_LYMPH_020


 23%|██▎       | 20/88 [01:36<05:14,  4.62s/it]

MED_LYMPH_022


 24%|██▍       | 21/88 [01:40<05:05,  4.56s/it]

MED_LYMPH_023


 25%|██▌       | 22/88 [01:45<05:01,  4.57s/it]

MED_LYMPH_024


 26%|██▌       | 23/88 [01:49<04:54,  4.54s/it]

MED_LYMPH_025


 27%|██▋       | 24/88 [01:54<04:56,  4.63s/it]

MED_LYMPH_026


 28%|██▊       | 25/88 [01:59<04:52,  4.64s/it]

MED_LYMPH_027


 30%|██▉       | 26/88 [02:03<04:40,  4.53s/it]

MED_LYMPH_028


 31%|███       | 27/88 [02:08<04:43,  4.64s/it]

MED_LYMPH_029


 32%|███▏      | 28/88 [02:13<04:43,  4.72s/it]

MED_LYMPH_030


 33%|███▎      | 29/88 [02:17<04:34,  4.66s/it]

MED_LYMPH_031


 34%|███▍      | 30/88 [02:22<04:28,  4.63s/it]

MED_LYMPH_032


 35%|███▌      | 31/88 [02:26<04:15,  4.48s/it]

MED_LYMPH_033


 36%|███▋      | 32/88 [02:30<04:10,  4.48s/it]

MED_LYMPH_034


 38%|███▊      | 33/88 [02:35<04:10,  4.56s/it]

MED_LYMPH_035


 39%|███▊      | 34/88 [02:40<04:05,  4.55s/it]

MED_LYMPH_036


 40%|███▉      | 35/88 [02:44<03:49,  4.33s/it]

MED_LYMPH_037


 41%|████      | 36/88 [02:48<03:47,  4.38s/it]

MED_LYMPH_038


 42%|████▏     | 37/88 [02:53<03:47,  4.46s/it]

MED_LYMPH_039


 43%|████▎     | 38/88 [02:56<03:33,  4.27s/it]

MED_LYMPH_040


 44%|████▍     | 39/88 [03:01<03:30,  4.30s/it]

MED_LYMPH_041


 45%|████▌     | 40/88 [03:05<03:30,  4.39s/it]

MED_LYMPH_042


 47%|████▋     | 41/88 [03:10<03:25,  4.37s/it]

MED_LYMPH_043


 48%|████▊     | 42/88 [03:14<03:17,  4.29s/it]

MED_LYMPH_044


 49%|████▉     | 43/88 [03:18<03:08,  4.18s/it]

MED_LYMPH_045


 50%|█████     | 44/88 [03:22<03:09,  4.30s/it]

MED_LYMPH_046


 51%|█████     | 45/88 [03:27<03:10,  4.44s/it]

MED_LYMPH_047


 52%|█████▏    | 46/88 [03:32<03:09,  4.52s/it]

MED_LYMPH_048


 53%|█████▎    | 47/88 [03:36<02:59,  4.38s/it]

MED_LYMPH_049


 55%|█████▍    | 48/88 [03:40<02:50,  4.25s/it]

MED_LYMPH_050


 56%|█████▌    | 49/88 [03:44<02:44,  4.22s/it]

MED_LYMPH_051


 57%|█████▋    | 50/88 [03:48<02:37,  4.15s/it]

MED_LYMPH_052


 58%|█████▊    | 51/88 [03:52<02:32,  4.11s/it]

MED_LYMPH_053


 59%|█████▉    | 52/88 [03:56<02:27,  4.11s/it]

MED_LYMPH_054


 60%|██████    | 53/88 [04:00<02:21,  4.05s/it]

MED_LYMPH_055


 61%|██████▏   | 54/88 [04:04<02:20,  4.14s/it]

MED_LYMPH_056


 62%|██████▎   | 55/88 [04:09<02:22,  4.32s/it]

MED_LYMPH_057


 64%|██████▎   | 56/88 [04:14<02:21,  4.41s/it]

MED_LYMPH_058


 65%|██████▍   | 57/88 [04:18<02:18,  4.46s/it]

MED_LYMPH_059


 66%|██████▌   | 58/88 [04:23<02:13,  4.46s/it]

MED_LYMPH_060


 67%|██████▋   | 59/88 [04:27<02:10,  4.51s/it]

MED_LYMPH_061


 68%|██████▊   | 60/88 [04:32<02:07,  4.54s/it]

MED_LYMPH_062


 69%|██████▉   | 61/88 [04:37<02:05,  4.64s/it]

MED_LYMPH_063


 70%|███████   | 62/88 [04:42<02:02,  4.71s/it]

MED_LYMPH_064


 72%|███████▏  | 63/88 [04:46<01:53,  4.53s/it]

MED_LYMPH_065


 73%|███████▎  | 64/88 [04:50<01:45,  4.38s/it]

MED_LYMPH_066


 74%|███████▍  | 65/88 [04:54<01:40,  4.36s/it]

MED_LYMPH_067


 75%|███████▌  | 66/88 [04:59<01:37,  4.42s/it]

MED_LYMPH_068


 76%|███████▌  | 67/88 [05:03<01:30,  4.33s/it]

MED_LYMPH_069


 77%|███████▋  | 68/88 [05:08<01:28,  4.44s/it]

MED_LYMPH_070


 78%|███████▊  | 69/88 [05:12<01:23,  4.42s/it]

MED_LYMPH_071


 80%|███████▉  | 70/88 [05:17<01:21,  4.51s/it]

MED_LYMPH_072


 81%|████████  | 71/88 [05:21<01:17,  4.55s/it]

MED_LYMPH_074


 82%|████████▏ | 72/88 [05:25<01:09,  4.36s/it]

MED_LYMPH_075


 83%|████████▎ | 73/88 [05:29<01:02,  4.17s/it]

MED_LYMPH_076


 84%|████████▍ | 74/88 [05:33<00:59,  4.26s/it]

MED_LYMPH_077


 85%|████████▌ | 75/88 [05:38<00:56,  4.35s/it]

MED_LYMPH_078


 86%|████████▋ | 76/88 [05:43<00:52,  4.39s/it]

MED_LYMPH_079


 88%|████████▊ | 77/88 [05:47<00:48,  4.43s/it]

MED_LYMPH_080


 89%|████████▊ | 78/88 [05:51<00:44,  4.44s/it]

MED_LYMPH_081


 90%|████████▉ | 79/88 [05:56<00:40,  4.46s/it]

MED_LYMPH_082


 91%|█████████ | 80/88 [06:00<00:35,  4.45s/it]

MED_LYMPH_083


 92%|█████████▏| 81/88 [06:05<00:31,  4.56s/it]

MED_LYMPH_084


 93%|█████████▎| 82/88 [06:10<00:26,  4.48s/it]

MED_LYMPH_085


 94%|█████████▍| 83/88 [06:13<00:21,  4.32s/it]

MED_LYMPH_086


 95%|█████████▌| 84/88 [06:18<00:17,  4.45s/it]

MED_LYMPH_087


 97%|█████████▋| 85/88 [06:23<00:13,  4.64s/it]

MED_LYMPH_088


 98%|█████████▊| 86/88 [06:29<00:09,  4.90s/it]

MED_LYMPH_089


 99%|█████████▉| 87/88 [06:34<00:04,  4.86s/it]

MED_LYMPH_090


100%|██████████| 88/88 [06:38<00:00,  4.53s/it]


## Visualization

In [4]:
case_info = []
root_dir = '/home/sci/hdai/Projects/Dataset/LymphNodes'
patch_size = 128
field_list = ['Series UID', 'Collection', '3rd Party Analysis', 
                      'Data Description URI', 'Subject ID', 'Study UID', 
                      'Study Description', 'Study Date', 'Series Description', 
                      'Manufacturer', 'Modality', 'SOP Class Name', 
                      'SOP Class UID', 'Number of Images', 'File Size', 
                      'File Location', 'Download Timestamp']
with open(f'{root_dir}/metadata.csv', mode='r') as infile:
    reader = csv.reader(infile)
    for row in reader:
        case_info.append({field_list[i]:row[i] for i in range(len(row))})

case_info = case_info[87:]

In [15]:
idx = 50 #1

relative_ct_folder_path = case_info[idx]['File Location'][1:].replace('\\','/')
# '/home/sci/hdai/Projects/LymphNodes/CT Lymph Nodes/ABD_LYMPH_003/09-14-2014-ABDLYMPH003-abdominallymphnodes-39052/abdominallymphnodes-65663'
ct_folder_path = f'{root_dir}{relative_ct_folder_path}'
slice_name_list = [f for f in os.listdir(ct_folder_path)]
slice_name_list.sort()
slice_list = []
for slice_name in slice_name_list:
    ds = pd.dcmread(f'{ct_folder_path}/{slice_name}')
    slice_list.append(torch.from_numpy(ds.pixel_array.transpose()))
img = torch.stack(slice_list,-1)

case_name = case_info[idx]['File Location'][17:30].replace('\\','/')
mask_path = f'/home/sci/hdai/Projects/Dataset/LymphNodes/MED_ABD_LYMPH_MASKS/{case_name}/{case_name}_mask.nii.gz'
mask = torch.from_numpy(nib.load(mask_path).get_fdata())
mask[mask>1] = 1

mask_pred_path = f'/home/sci/hdai/Projects/LnSeg/Models/UNet1024/PredResult/{case_name}_pred_mask.nii.gz'
mask_pred = torch.from_numpy(nib.load(mask_pred_path).get_fdata())
# mask_pred[mask_pred>=0.5] = 1
# mask_pred[mask_pred<0.5] = 0

In [16]:
half_patch_size = int(patch_size/2)
idx_x, idx_y, idx_z = torch.where(mask!=0)
centroid_x, centroid_y, centroid_z = 256, 256, 300
if int(torch.mean(idx_x.float())) < mask.shape[0]-half_patch_size and int(torch.mean(idx_x.float())) > half_patch_size:
    centroid_x = int(torch.mean(idx_x.float()))
if int(torch.mean(idx_y.float())) < mask.shape[1]-half_patch_size and int(torch.mean(idx_y.float())) > half_patch_size:
    centroid_y = int(torch.mean(idx_y.float()))
if int(torch.mean(idx_z.float())) < mask.shape[2]-half_patch_size and int(torch.mean(idx_z.float())) > half_patch_size:
    centroid_z = int(torch.mean(idx_z.float()))
img = img[centroid_x-half_patch_size:centroid_x+half_patch_size, centroid_y-half_patch_size:centroid_y+half_patch_size, centroid_z-half_patch_size:centroid_z+half_patch_size]
mask = mask[centroid_x-half_patch_size:centroid_x+half_patch_size, centroid_y-half_patch_size:centroid_y+half_patch_size, centroid_z-half_patch_size:centroid_z+half_patch_size]

In [17]:
view(img)

Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageSS3; pr…

In [18]:
view(mask)

Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageD3; pro…

In [19]:
view(mask_pred)

Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageD3; pro…