# Task: Generate 3D PTV Mask and 2D GTV Mask for All Patients

In [None]:
import numpy as np
import nrrd
import nibabel as nib
from IPython.core.interactiveshell import InteractiveShell
import matplotlib.pyplot as plt
import matplotlib
import sys
import mpl_toolkits.mplot3d
from skimage import measure, morphology
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import scipy
import skimage
import pathlib
from pathlib import Path
from tqdm.notebook import tqdm
from typing import Union, Optional, Tuple, Sized

InteractiveShell.ast_node_interactivity = 'all'
%matplotlib inline

## Check Package Version

In [None]:
f'python: {sys.version}'
f'matplotlib: {matplotlib.__version__}'
f'skimage: {skimage.__version__}'
f'scipy: {scipy.__version__}'

## Util: Plot Function

In [None]:
def plot_3d(image, threshold=0, step_size=1):
    p = image.transpose(2, 1, 0)

    verts, faces, _, _ = measure.marching_cubes_lewiner(
        p,
        threshold,
        step_size=step_size,
        allow_degenerate=False,
    )

    fig = plt.figure(figsize=(10, 10), dpi=200)
    ax = fig.add_subplot(111, projection='3d')

    mesh = Poly3DCollection(verts[faces], alpha=0.10)
    face_color = [0.45, 0.45, 0.75]
    mesh.set_facecolor(face_color)
    ax.add_collection3d(mesh)

    ax.set_xlim(0, p.shape[0])
    ax.set_ylim(0, p.shape[1])
    ax.set_zlim(0, p.shape[2])

    plt.show()
    

def plot_3d_grid(names, imgs, threshold=0, step_size=1, n_col=3):

    fig = plt.figure(figsize=plt.figaspect(1 / 3), dpi=200)

    for i, (name, img) in enumerate(zip(names, imgs)):
        ax = fig.add_subplot(
            len(imgs) // n_col + (1 if len(imgs) % 3 != 0 else 0),
            n_col,
            i + 1,
            projection='3d',
        )

        p = img.transpose(2, 1, 0)

        verts, faces, _, _ = measure.marching_cubes_lewiner(
            p,
            threshold,
            step_size=step_size,
            allow_degenerate=False,
        )

        mesh = Poly3DCollection(verts[faces], alpha=0.10)
        face_color = [0.45, 0.45, 0.75]
        mesh.set_facecolor(face_color)
        ax.add_collection3d(mesh)
        ax.title.set_text(name)
        ax.set_xlim(0, p.shape[0])
        ax.set_ylim(0, p.shape[1])
        ax.set_zlim(0, p.shape[2])

    plt.show()

## Function for One Case (Input Case Name. Output Its Name, 3D PTV and 2D GTV Numpy Array)

In [None]:
def func(
        case_path: Path,
        dialated_size: int = 15,
        need_plot: bool = False,
        return_result: bool = False,
) -> Union[tuple, None]:
    """
    
    :param case_path: Path of One Case.
    :param dialated_size: Dialated size (in mm).
    :param need_plot: Whether to plot 3D results. Default to False.
    :param return_result: Whether to return results or None.
    
    """
    assert case_path.exists()
    assert isinstance(dialated_size, int) and dialated_size > 0
    assert isinstance(need_plot, bool)
    assert isinstance(return_result, bool)

    case_id = case_path.stem
    nii_f = case_path / f'{case_id}.nii'
    label_nrrd_f = case_path / 'Segmentation-label.nrrd'
    seg_nrrd_f = case_path / 'Segmentation.seg.nrrd'

    assert nii_f.exists()
    assert label_nrrd_f.exists()
    assert seg_nrrd_f.exists()

    # Load '*.nii' File
    img = nib.load(str(nii_f)).get_data()
    img = img.astype(np.int16)
    # Load 'Segmentation-label.nrrd' FIle
    gtv_seg, header = nrrd.read(str(label_nrrd_f))
    gtv_seg = gtv_seg.astype(np.bool)
    #     plot_3d(gtv_seg)

    #Step1 Get 3D PTV and save it

    #     f'spacing = {header["space directions"][0, 0]}'
    #     f'n_dialated = {int(round(15 / header["space directions"][0, 0]))}'

    dilated_gtv_seg = scipy.ndimage.binary_dilation(
        input=gtv_seg,
        iterations=int(round(dialated_size /
                             header['space directions'][0, 0])),
    )
    #     plot_3d(dilated_gtv_seg)

    ptv_seg = gtv_seg ^ dilated_gtv_seg
    if need_plot:
        plot_3d(ptv_seg)
    # Save PTV Mask to NRRD File
    ptv_f = case_path / 'ptv_seg.nrrd'
    nrrd.write(filename=str(ptv_f),
               data=ptv_seg.astype(np.int16),
               header=header)

    #Step2 Get 2D GTV Slice and save it
    ## Find Slice with Max Area in Original Label
    #Slice index with max area in original label
    max_area, max_index = 0, 0
    for index in range(gtv_seg.shape[2]):  #(512, 512, 298)
        one_slice = gtv_seg[:, :, index]
        if one_slice.sum() > max_area:
            max_area = one_slice.sum()
            max_index = index
    gtv_slice = gtv_seg[:, :, max_index]
    gtv_slice_seg = np.zeros(gtv_seg.shape, dtype=np.bool)
    gtv_slice_seg[:, :, max_index] = gtv_slice
    if need_plot:
        plot_3d(gtv_slice_seg)

    gtv_slice_f = case_path / 'gtv_slice_seg.nrrd'
    nrrd.write(filename=str(gtv_slice_f),
               data=gtv_slice_seg.astype(np.int16),
               header=header)

    if return_result:
        return case_id, ptv_seg, gtv_slice_seg
    else:
        return None

## Process All Patients

In [None]:
results = []

segmented_path = Path('F:\segmentation_samples')

# for one_path in tqdm(segmented_path.iterdir()):
for one_path in segmented_path.iterdir():
    
    if not one_path.is_dir() or 'checkpoint' in str(one_path):
        continue
    print('one_path =',one_path)
#     results.append(func(case_path=one_path, return_result=True))
    func(case_path=one_path, return_result=False)

## (Optional) Check Results
### Check PTV Mask

In [None]:
plot_3d_grid(
    names=[i[0] for i in results],
    imgs=[i[1] for i in results],
    threshold=0,
    step_size=1,
    n_col=3,
)

### Check 2D GTV Mask

In [None]:
plot_3d_grid(
    names=[i[0] for i in results],
    imgs=[i[2] for i in results],
    threshold=0,
    step_size=1,
    n_col=3,
)