# Create extra weak labels input channel
Partial labels like boxes and masks are used as extra input channel for inputs to nnUNetv2.
* 2D bbox of key slice
* 3D bbox
* 2 orthogonal 2D bboxes of key slices
* 2D mask of key slice
* 2 orthogonal 2D masks of key slices

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from pathlib import Path

import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

In [None]:
# import cc3d
# labels, n_components = cc3d.connected_components(seg_data, return_N=True)

## 2D bbox of key slice


In [None]:
data_folder = Path("/media/liushifeng/KINGSTON/nnUNet_raw/Dataset001_3dlesion/")
train_images = data_folder / "imagesTr"
train_labels = data_folder / "labelsTr"

train_names = os.listdir(train_images)
train_names = [n for n in train_names if n.endswith("_0000.nii.gz")]

In [None]:
def channel_with_box_on_key_slice(label):
    key_index = label.sum(axis=(1,2)).argmax()
    key_slice = label[key_index]

    y_coords = np.where(key_slice.any(axis=1))[0]
    x_coords = np.where(key_slice.any(axis=0))[0]

    y_max, y_min = y_coords.max(), y_coords.min()
    x_max, x_min = x_coords.max(), x_coords.min()

    mask = np.zeros(label.shape)
    mask[key_index, y_min:y_max+1, x_min:x_max+1] = 1
    return mask

In [None]:
channel_n = 2

for f in tqdm(train_names):
    ct_path = train_images / f
    ct_img = sitk.ReadImage(ct_path)

    label_path = train_labels / f.replace("_0000.nii.gz", ".nii.gz")
    label_img = sitk.ReadImage(label_path)
    label = sitk.GetArrayFromImage(label_img)

    key_slice_box = channel_with_box_on_key_slice(label)
    new_img = sitk.GetImageFromArray(key_slice_box)

    # set original metadata
    new_img.SetOrigin(ct_img.GetOrigin())
    new_img.SetSpacing(ct_img.GetSpacing())
    new_img.SetDirection(ct_img.GetDirection())

    output_path = Path("/media/liushifeng/KINGSTON/nnUNet_raw/Dataset003_3dlesion_2dkeybox/imagesTr/")
    new_path = output_path / f.replace("_0000.nii.gz", f"_000{channel_n}.nii.gz")
    print(new_path)

    # sitk.WriteImage(new_img, new_path)

## 2D bboxes of horizontal and sagittal key slices


In [None]:
def channel_with_box_on_key_sagittal_slice(label):
    """Sagittal slice (dim 2)"""
    key_index = label.sum(axis=(0,1)).argmax()
    key_slice = label[:,:,key_index]

    y_coords = np.where(key_slice.any(axis=1))[0]
    x_coords = np.where(key_slice.any(axis=0))[0]

    y_max, y_min = y_coords.max(), y_coords.min()
    x_max, x_min = x_coords.max(), x_coords.min()

    mask = np.zeros(label.shape)
    mask[y_min:y_max+1, x_min:x_max+1, key_index] = 1
    return mask

In [None]:
data_folder = Path("/media/liushifeng/KINGSTON/nnUNet_raw/Dataset005_3dlesion_2dkeyboxes_orthogonal/")
train_images = data_folder / "imagesTr"
train_labels = data_folder / "labelsTr"
train_names = os.listdir(train_images)
train_names = [n for n in train_names if n.endswith("_0000.nii.gz")]

In [None]:
channel_n = 2

for f in tqdm(train_names):
    ct_path = train_images / f
    ct_img = sitk.ReadImage(ct_path)

    label_path = train_labels / f.replace("_0000.nii.gz", ".nii.gz")
    label_img = sitk.ReadImage(label_path)
    label = sitk.GetArrayFromImage(label_img)

    # get horizontal and vertical boxes and combine them
    horizontal_slice_box = channel_with_box_on_key_slice(label)
    sagittal_slice_box = channel_with_box_on_key_sagittal_slice(label)
    boxes_mask = horizontal_slice_box + sagittal_slice_box
    boxes_mask[boxes_mask > 1] = 1
    new_img = sitk.GetImageFromArray(boxes_mask)

    # set original metadata
    new_img.SetOrigin(ct_img.GetOrigin())
    new_img.SetSpacing(ct_img.GetSpacing())
    new_img.SetDirection(ct_img.GetDirection())

    new_path = train_images / f.replace("_0000.nii.gz", f"_000{channel_n}.nii.gz")
    # print(new_path)
    sitk.WriteImage(new_img, new_path)

In [None]:
# plt.imshow(label.sum(axis=(2)))
# plt.imshow(key_slice_box.sum(axis=(2)))

## 2D mask of key slice


In [None]:
data_folder = Path("/media/liushifeng/KINGSTON/nnUNet_raw/Dataset006_3dlesion_2dkeymask")
train_images = data_folder / "imagesTr"
train_labels = data_folder / "labelsTr"

train_names = os.listdir(train_images)
train_names = [n for n in train_names if n.endswith("_0000.nii.gz")]
len(train_names)

In [None]:
def channel_with_mask_on_key_slice(label):
    key_index = label.sum(axis=(1,2)).argmax()
    key_slice = label[key_index]

    mask = np.zeros(label.shape)
    mask[key_index] = key_slice
    return mask

In [None]:
channel_n = 2

for f in tqdm(train_names):
    ct_path = train_images / f
    ct_img = sitk.ReadImage(ct_path)

    label_path = train_labels / f.replace("_0000.nii.gz", ".nii.gz")
    label_img = sitk.ReadImage(label_path)
    label = sitk.GetArrayFromImage(label_img)

    key_slice_mask = channel_with_mask_on_key_slice(label)
    new_img = sitk.GetImageFromArray(key_slice_mask)

    # set original metadata
    new_img.SetOrigin(ct_img.GetOrigin())
    new_img.SetSpacing(ct_img.GetSpacing())
    new_img.SetDirection(ct_img.GetDirection())

    output_path = Path("/media/liushifeng/KINGSTON/nnUNet_raw/Dataset006_3dlesion_2dkeymask/imagesTr")
    new_path = output_path / f.replace("_0000.nii.gz", f"_000{channel_n}.nii.gz")
    # print(new_path)

    sitk.WriteImage(new_img, new_path)

## 2D masks of key horizontal and sagittal slice


In [None]:
data_folder = Path("/media/liushifeng/KINGSTON/nnUNet_raw/Dataset007_3dlesion_2dkeymasks_orthogonal")
train_images = data_folder / "imagesTr"
train_labels = data_folder / "labelsTr"

train_names = os.listdir(train_images)
train_names = [n for n in train_names if n.endswith("_0000.nii.gz")]
len(train_names)

In [None]:
## 2D masks of horizontal and sagittal key slices
def channel_with_masks_on_key_sagittal_slice(label):
    """Sagittal slice (dim 2)"""
    key_index = label.sum(axis=(0, 1)).argmax()
    key_slice = label[:, :, key_index]

    mask = np.zeros(label.shape)
    mask[:,:,key_index] = key_slice
    return mask

In [None]:
channel_n = 2

for f in tqdm(train_names):
    ct_path = train_images / f
    ct_img = sitk.ReadImage(ct_path)

    label_path = train_labels / f.replace("_0000.nii.gz", ".nii.gz")
    label_img = sitk.ReadImage(label_path)
    label = sitk.GetArrayFromImage(label_img)

    # get horizontal and vertical boxes and combine them
    horizontal_slice_mask = channel_with_mask_on_key_slice(label)
    sagittal_slice_mask = channel_with_masks_on_key_sagittal_slice(label)
    mask = horizontal_slice_mask + sagittal_slice_mask
    mask[mask > 1] = 1
    new_img = sitk.GetImageFromArray(mask)

    # set original metadata
    new_img.SetOrigin(ct_img.GetOrigin())
    new_img.SetSpacing(ct_img.GetSpacing())
    new_img.SetDirection(ct_img.GetDirection())

    output_path = data_folder / "imagesTr"
    new_path = output_path / f.replace("_0000.nii.gz", f"_000{channel_n}.nii.gz")
    # print(new_path)

    sitk.WriteImage(new_img, new_path)

## Visualization Code

In [None]:
res = result.squeeze()
for i in range(0, res.shape[-1], 2):
    seg_mask = res[..., i].rot90()
    if (seg_mask > 0).sum() > 0:
        ct_slice = ct_img[:, :, i].rot90()
        fig, axes = plt.subplots(1, 2, figsize=(6, 3))
        axes[0].imshow(ct_slice, cmap="gray")
        axes[1].imshow(seg_mask, vmin=0, vmax=117, cmap="gist_stern")
        plt.show()
        break