In [6]:
import os
from pathlib import Path

import nibabel as nib
import numpy as np
from skimage.transform import resize
import torch

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image

In [7]:
# Parameters to customize
SLICE_DEPTH = 1          # Number of consecutive slices per input (channel size)
SLICE_DROP = 5           # Number of slices to drop between input tensors
INPUT_SIZE = 128         # Final image size
OUTPUT_FOLDER = "../processed_data"


In [8]:
def load_nii_and_normalize(path):
    img = nib.load(path).get_fdata()
    img = np.clip(img, -1000, 1000)  # Clip Hounsfield units
    img = (img + 1000) / 2000        # Normalize to 0–1
    return img


In [9]:
def extract_patches(volume, depth=5, drop=2, size=128):
    slices = volume.shape[2]
    outputs = []
    idx = 0
    while idx + depth <= slices:
        chunk = volume[:, :, idx:idx+depth]  # Shape: H x W x depth
        chunk = np.transpose(chunk, (2, 0, 1))  # Shape: depth x H x W
        chunk = resize(chunk, (depth, size, size), mode='constant')
        outputs.append(torch.tensor(chunk, dtype=torch.float32))
        idx += depth + drop
    return outputs


In [None]:
def preprocess_dataset(root_dir, output_dir, save_as_image=False):
    os.makedirs(output_dir, exist_ok=True)
    for split in ["train", "valid"]:
        split_path = Path(root_dir) / split
        for case in split_path.glob("*/*.nii.gz"):
            vol = load_nii_and_normalize(str(case))
            patches = extract_patches(vol, SLICE_DEPTH, SLICE_DROP, INPUT_SIZE)

            case_name = case.stem
            out_path = Path(output_dir) / split / case_name
            out_path.mkdir(parents=True, exist_ok=True)

            for i, patch in enumerate(patches):
                if save_as_image:
                    save_image(patch, open(out_path / f"{case_name}_patch_{i}.png", "wb"))
                else:
                    torch.save(patch, out_path / f"{case_name}_patch_{i}.pt")


In [13]:
preprocess_dataset("../CT_rate/dataset", "../sample_outputs", True)