In [1]:
from glob import glob
import tifffile
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import gc
import time
import shutil
from itertools import product
import h5py
import torch
from src.dataset.df import df_dataset
import cv2
from joblib import Parallel, delayed

In [2]:
class cfg:
    train_ver = "03"
    dir_raw = "/kaggle/working/dataset/stack_raw"
    dir_clipped = f"/kaggle/working/dataset/stack_train{train_ver}"
    size_xy = 256
    offset_xy = 128
    size_z = 1
    offset_z = 1
    dataset_path = f"/kaggle/working/dataset/train{train_ver}_xy_{size_xy}_{offset_xy}_z_{size_z}_{offset_z}/"

In [3]:
def pad(shape, size, offset):
    while (shape - size) % offset != 0:
        shape += 1
    return shape


def pad_npy(npy, size_xy, size_z, offset_xy, offset_z):
    shape_x, shape_y, shape_z = npy.shape
    shape_padded_x = pad(shape_x, size_xy, offset_xy)
    shape_padded_y = pad(shape_y, size_xy, offset_xy)
    shape_padded_z = pad(shape_z, size_z, offset_z)

    pad_x = shape_padded_x - shape_x
    pad_y = shape_padded_y - shape_y
    pad_z = shape_padded_z - shape_z

    npy = np.pad(npy, ((0, pad_x), (0, pad_y), (0, pad_z)), mode="constant")
    return npy


def shift_axis(array: (np.ndarray | torch.Tensor), axis: int) -> np.ndarray | torch.Tensor:
    perm = [axis, (axis + 1) % 3, (axis + 2) % 3]  # 軸の順番をシフト
    if isinstance(array, np.ndarray):
        array = array.transpose(*perm)
    if isinstance(array, torch.Tensor):
        array = array.permute(*perm)
    return array


def gen_fname(image, label, save_dir, kidney, axis, coords):
    x_, y_, z_ = coords

    image_std = str(int(image.std() * 1000)).zfill(4)
    label_sum = str(int(label.sum()))

    save_fname = f"x{x_}_y{y_}_z{z_}_std{image_std}_sum{label_sum}.npy"

    image_path = f"{save_dir}/image/{kidney}_axis{axis}/{save_fname}"
    label_path = f"{save_dir}/label/{kidney}_axis{axis}/{save_fname}"

    return image_path, label_path


def remove_large_contours(image, thresh):
    """
    大きいラベル領域を削除
    """
    if image.sum() == 0:
        return image
    result_image = image.copy()
    contours, hierarchies = cv2.findContours(result_image, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    for contour, hierarchy in zip(contours, hierarchies[0]):
        area = cv2.contourArea(contour)
        if area > thresh:
            cv2.drawContours(result_image, [contour], 0, 0, -1)

    return result_image


def crop_and_save(image, label, coords, image_shape, size_xy, size_z, kidney, axis, save_dir):
    shape_x, shape_y, shape_z = image_shape
    x_, y_, z_ = coords

    if (x_ + size_xy > shape_x) or (y_ + size_xy > shape_y) or (z_ + size_z > shape_z):
        return

    image = image[x_ : (x_ + size_xy), y_ : (y_ + size_xy), z_ : (z_ + size_z)]
    label = label[x_ : (x_ + size_xy), y_ : (y_ + size_xy), z_ : (z_ + size_z)]
    image_path, label_path = gen_fname(image, label, save_dir, kidney, axis, coords)

    os.makedirs(os.path.dirname(image_path), exist_ok=True)
    os.makedirs(os.path.dirname(label_path), exist_ok=True)

    np.save(image_path, image)
    np.save(label_path, label)


def npys_to_hdf5(save_dir):
    hdf_path = f"{save_dir}/dataset.hdf5"

    f = h5py.File(hdf_path, mode="w")

    group = f.create_group("/data")

    file_list = glob(f"{save_dir}/*/*/*.npy")

    def add_to_hdf5(file, group):
        arr = np.load(file)
        dataset = group.create_dataset(name=file, shape=arr.shape, dtype=arr.dtype)
        dataset[...] = arr

    _ = Parallel(n_jobs=2, backend="threading")(delayed(add_to_hdf5)(file, group) for file in tqdm(sorted(file_list), total=len(file_list)))
    f.close()

In [4]:
kidneys = [kidney.replace("_labels.npy", "") for kidney in os.listdir(cfg.dir_clipped) if "_label" in kidney]
for kidney in kidneys:
    if kidney == "kidney_1_dense":
        continue
    for axis in [0, 1, 2]:
        image_path = f"{cfg.dir_clipped}/{kidney}_images.npy"
        label_path = f"{cfg.dir_clipped}/{kidney}_labels.npy"

        image = np.load(image_path)
        image = shift_axis(image, axis)
        image = pad_npy(image, cfg.size_xy, cfg.size_z, cfg.offset_xy, cfg.offset_z).astype(np.float32)

        if os.path.exists(label_path):
            label = np.load(label_path)
            label = shift_axis(label, axis)
            label = pad_npy(label, cfg.size_xy, cfg.size_z, cfg.offset_xy, cfg.offset_z).astype(np.bool_)
        else:
            label = None

        shape_x, shape_y, shape_z = image.shape
        print(f"{kidney}_axis_{axis}{image.shape}")

        iter = list(product(range(0, shape_x, cfg.offset_xy), range(0, shape_y, cfg.offset_xy), range(0, shape_z, cfg.offset_z)))

        _ = Parallel(n_jobs=4, backend="threading")(
            delayed(crop_and_save)(image, label, coords, image.shape, cfg.size_xy, cfg.size_z, kidney, axis, cfg.dataset_path)
            for coords in tqdm(iter, total=len(iter))
        )
        del image, label
        gc.collect()

# npys_to_hdf5(cfg.dataset_path)
df_dataset(cfg, f"{cfg.dir_clipped}/fold.json").to_csv(f"{cfg.dataset_path}/dataset.csv", index=False)
# shutil.rmtree(f"{cfg.dataset_path}/image")
# shutil.rmtree(f"{cfg.dataset_path}/label")