In [1]:
from datetime import datetime
import json
import numpy as np
import os
import random

import nibabel as nib
import tensorflow as tf

import tf_loader

import csv

In [2]:
source_train_pth = './data/datalist/training_mr.txt'
target_train_pth = './data/datalist/training_ct.txt'
source_val_pth = './data/datalist/validation_mr.txt'
target_val_pth = './data/datalist/validation_ct.txt'

num_cls = 5

In [3]:
class tf2nii:
    """The tfrecord to nii.gz converter"""

    def __init__(self, output_dir):

        current_time = datetime.now().strftime("%Y%m%d-%H%M%S")

        self._source_train_pth = source_train_pth
        self._target_train_pth = target_train_pth
        self._source_val_pth = source_val_pth
        self._target_val_pth = target_val_pth
        self._num_cls = num_cls

        self._output_dir = output_dir
        if not os.path.exists(os.path.join("data", self._output_dir)):
            os.makedirs(os.path.join("data", self._output_dir))
            os.makedirs(os.path.join("data", self._output_dir, "datalist"))
        
        
        # Load Dataset from the dataset folder
        source_slices = tf_loader.load_data(self._source_train_pth)
        target_slices = tf_loader.load_data(self._target_train_pth)

        val_source_slices = tf_loader.load_data(self._source_val_pth)
        val_target_slices = tf_loader.load_data(self._target_val_pth)
        
        # Save Images Counter
        self.count = 0

        self.looper(source_slices, export_dir="mr_train", export_csv="mr_train.csv")
        self.looper(val_source_slices, export_dir="mr_val", export_csv="mr_val.csv")
        self.looper(target_slices, export_dir="ct_train", export_csv="ct_train.csv")
        self.looper(val_target_slices, export_dir="ct_val", export_csv="ct_val.csv")


    def looper(self, slices, export_dir=None, export_csv=None):
        self.count = 0
        self.export_dir = os.path.join("data", self._output_dir, export_dir)
        self.export_csv = os.path.join("data", self._output_dir, "datalist", export_csv)
        
        for image, gt in slices:

            image = image.numpy()
            gt = gt.numpy()
            self.save_file(image, gt)


    def save_file(self, image, label):
        # save to the disk
        self.count += 1

        image_name = "coronal_slice_"+"{:04d}".format(self.count)
        label_name = "coronal_slice_label"+"{:04d}".format(self.count)

        if not os.path.exists(self.export_dir):
            os.makedirs(self.export_dir)
            os.makedirs(os.path.join(self.export_dir, "slices"))
            os.makedirs(os.path.join(self.export_dir, "labels"))

        with open(self.export_csv, 'a', newline='') as csvfile:
            wrt = csv.writer(csvfile, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL)
            wrt.writerow([image_name, label_name])

        nii_image = nib.Nifti1Image(image, np.diag([1,1,1,1]))
        nii_label = nib.Nifti1Image(label, np.diag([1,1,1,1]))

        nib.save(nii_image, os.path.join(self.export_dir, "slices", image_name+".nii.gz"))
        nib.save(nii_label, os.path.join(self.export_dir, "labels", label_name+".nii.gz"))



In [4]:
def main(output_dir):
    dataloader = tf2nii(output_dir)

In [5]:
if __name__ == '__main__':
    output_dir = "ct_mr_dataset" 
    main(output_dir)

9600
{(1, 256, 256)}
8400
{(1, 256, 256)}
2400
{(1, 256, 256)}
1200
{(1, 256, 256)}
