In [None]:
import os
import glob
import pickle
import multiprocessing
from multiprocessing import Pool

import numpy as np
import pandas as pd
from scipy import ndimage
import pydicom
from lungmask.lungmask import mask
import SimpleITK as sitk
from joblib import Parallel, delayed

In [None]:
HU_MIN = -1000
HU_MAX = 780
PIXEL_MEAN = 0.51973975

In [None]:
def get_label(s, data_fields):
    return (
        data_fields[data_fields["SOPInstanceUID"] == s.SOPInstanceUID]
        .iloc[0]
        .at["pe_present_on_image"]
    )


def load_scan(path, data_fields):
    slices = [pydicom.dcmread(os.path.join(path, fname)) for fname in os.listdir(path)]
    slices.sort(key=lambda s: s.ImagePositionPatient[-1])
    slice_thickness = float(slices[0].SliceThickness)
    pixel_spacing = list(slices[0].PixelSpacing)
    try:
        labels = [get_label(s, data_fields) for s in slices]  # train
    except:
        labels = None  # test
    return slices, slice_thickness, pixel_spacing, labels


def get_pixels_hu(slices, apply_window=False):
    if isinstance(slices, list):
        image = np.stack([s.pixel_array for s in slices])
    else:
        image = slices.pixel_array
        slices = [slices]

    if apply_window:
        try:
            window_center = int(slices[0].WindowCenter)
            window_width = int(slices[0].WindowWidth)
        except:
            window_center = int(slices[0].WindowCenter[0])
            window_width = int(slices[0].WindowWidth[0])
        image_min = max(-1000, window_center - window_width // 2)
        image_max = min(1000, window_center + window_width // 2)
        image = np.clip(image, image_min, image_max)

    intercept = slices[0].RescaleIntercept
    slope = slices[0].RescaleSlope
    image = image * slope + intercept
    return np.array(image, dtype=np.int16)


def resample(
    image, labels, slice_thickness, pixel_spacing, output_pixel_spacing=[1, 1, 1]
):
    ### np array is the 3D array of all slices look at the above example
    original_spacing = np.array([slice_thickness] + pixel_spacing, dtype=np.float32)
    resample_rate = original_spacing / output_pixel_spacing
    resampled_image = ndimage.zoom(image, resample_rate, mode="nearest")
    if labels is None:
        return resampled_image, None  # test
    resampled_labels = ndimage.zoom(labels, resample_rate[0], mode="nearest")
    return resampled_image, resampled_labels  # train


def mask_to_bbox(mask):
    D, W, H = mask.shape

    D_min, D_max = 0, D - 1
    while D_min < D - 1 and 1 not in mask[D_min]:
        D_min = D_min + 1
    while D_max > 0 and 1 not in mask[D_max]:
        D_max = D_max - 1

    W_min, W_max = 0, W - 1
    while W_min < W - 1 and 1 not in mask[:, W_min, :]:
        W_min = W_min + 1
    while W_max > 0 and 1 not in mask[:, W_max, :]:
        W_max = W_max - 1

    H_min, H_max = 0, H - 1
    while H_min < H - 1 and 1 not in mask[:, :, H_min]:
        H_min = H_min + 1
    while H_max > 0 and 1 not in mask[:, :, H_max]:
        H_max = H_max - 1

    return D_min, D_max, W_min, W_max, H_min, H_max


def normalize(image):
    image = image.astype(np.float32)
    image = (image - HU_MIN) / (HU_MAX - HU_MIN)
    return np.clip(image, 0.0, 1.0)


def zero_center(image):
    return image - PIXEL_MEAN


def save_slices_as_pickle(path, save_dir, data_fields, remove=False):
    if os.path.exists(os.path.join(save_dir, "num_slices.pkl")):
        return

    slices, slice_thickness, pixel_spacing, labels = load_scan(path, data_fields)
    image = get_pixels_hu(slices, apply_window=False)
    image, labels = resample(image, labels, slice_thickness, pixel_spacing)
    # lung localization
    image_sitk = sitk.GetImageFromArray(image)
    segmentation = (mask.apply(image=image_sitk) > 0).astype("uint8")
    D_min, D_max, W_min, W_max, H_min, H_max = mask_to_bbox(segmentation)
    image = image[D_min : D_max + 1, W_min : W_max + 1, H_min : H_max + 1]
    labels = labels[D_min : D_max + 1]

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    filename = os.path.join(save_dir, "num_slices.pkl")
    if os.path.exists(filename) and remove:
        os.remove(filename)
        with open(filename, "wb") as f:
            pickle.dump(image.shape[0], f)
    if not os.path.exists(filename):
        with open(filename, "wb") as f:
            pickle.dump(image.shape[0], f)

    for i in range(image.shape[0]):
        filename = os.path.join(save_dir, str(i) + ".pkl")
        if os.path.exists(filename) and not remove:
            continue
        if os.path.exists(filename) and remove:
            os.remove(filename)
        with open(filename, "wb") as f:
            if labels is None:
                pickle.dump(image[i], f)  # test
            else:
                pickle.dump((image[i], labels[i]), f)  # train


def check_saved_pickle(save_dir):
    filename = os.path.join(save_dir, "num_slices.pkl")
    if not os.path.exists(filename):
        return save_dir, False
    with open(filename, "rb") as f:
        num_slices = pickle.load(f)
    for i in range(num_slices):
        filename = os.path.join(save_dir, str(i) + ".pkl")
        if not os.path.exists(filename):
            return save_dir, False
    return save_dir, True

In [None]:
root_dir = "data/"
mode = "train"
data_fields = pd.read_csv(os.path.join(root_dir, mode + ".csv"))
paths = glob.glob(root_dir + mode + "/*/*")

In [None]:
outputs = Parallel(n_jobs=10, backend="multiprocessing", verbose=10)(
    delayed(save_slices_as_pickle)(
        path,
        os.path.join(root_dir, "pkl", path[5:]),
        data_fields,
        remove=False,
    )
    for path in paths
)

In [None]:
outputs = Parallel(n_jobs=100, backend="multiprocessing", verbose=10)(
    delayed(check_saved_pickle)(os.path.join(root_dir, "pkl", path[5:]))
    for path in paths
)