In [1]:
# Reads CT test images and labels, and saves them in tf-record format similar to the training/validation images
# The data will be saved under 'ct_test_tfs', and the list of slices will be under 'ct_test_list'
# Follows the code at https://github.com/cchen-cc/SIFA/blob/SIFA-v1/evaluate.py

In [2]:
import numpy as np
import tensorflow as tf
import nibabel as nib

import io_utils as io
import importlib

In [3]:
test_img_dir = "./data/mmwhs/test_ct_image&labels/"
files = ["1003", "1008", "1014", "1019"]

data_dir = "./data/mmwhs/PnpAda_release_data/train&val/"
test_file_list = "./data/mmwhs/PnpAda_release_data/train&val/ct_test_list"

In [4]:
test_files = []
fn_idx = 0

for scan in files:
    slices = nib.load(test_img_dir + "image_ct_" + scan + ".nii")
    labels = nib.load(test_img_dir + "gth_ct_" + scan + ".nii")
    
    slices = slices.get_fdata() * 1.
    labels = labels.get_fdata()
    
    # Flip data, per https://github.com/cchen-cc/SIFA/blob/master/evaluate.py line 160
    slices = np.flip(slices, axis=0)
    slices = np.flip(slices, axis=1)
    labels = np.flip(labels, axis=0)
    labels = np.flip(labels, axis=1)
    
    print(np.min(slices), np.max(slices), np.mean(slices), np.var(slices))
    
    for idx in range(slices.shape[2]):
        X = np.zeros((256,256,3), dtype=np.float32)
        Y = np.zeros((256,256,3), dtype=np.float32)
        
        # Compute the default image
        for channel_idx in range(3):
            i = idx + channel_idx - 1
            i = max(i, 0)
            i = min(i, slices.shape[0] - 1)
            
            X[..., channel_idx] = np.copy(slices[...,i])
            Y[..., channel_idx] = np.copy(labels[...,i])
            
        fn = "ct_test_tfs/ct_test_slice" + str(fn_idx) + ".tfrecords"
        test_files.append(fn)
        io.to_tfrecord(X, Y, data_dir + fn)
        fn_idx += 1

-1.508024935050452 2.368554272081745 -3.5548422022236e-05 0.999788836343351
-1.763460938640936 1.3978457339762702 -1.9370614215015328e-05 0.9999155158331816
-1.1669894842597788 2.2552393552573076 6.939882422179848e-05 0.9997865691462293
-1.4178956401328702 1.6743763779368976 6.119588110340898e-05 0.9999489784833527


In [5]:
with open(test_file_list, 'w') as fct_list:
    for item in test_files:
        fct_list.write("%s\n" % item)

In [6]:
importlib.reload(io)

im,la = io.sample_batch("./data/mmwhs/PnpAda_release_data/train&val/", \
                     test_files, \
                     data_type='mr', \
                     batch_size=3)