In [None]:
import os
from pathlib import Path

import nibabel as nib
import numpy as np
from skimage.transform import resize
import torch
from torchvision.utils import save_image
import yaml
from utils import get_affine_from_metadata, get_metadata_for_volume, convert_to_hu, resample_volume, crop_or_pad, preprocess_ct_volume, batch_preprocess_ct_rate

# Functions


In [None]:
def load_nii_and_normalize(path):
    img = nib.load(path).get_fdata()
    img = np.clip(img, -1000, 1000)  # Clip Hounsfield units
    img = (img + 1000) / 2000        # Normalize to 0–1
    return img


In [None]:
def extract_patches(volume, depth=5, drop=2, size=128):
    slices = volume.shape[2]
    outputs = []
    idx = 0
    while idx + depth <= slices:
        chunk = volume[:, :, idx:idx+depth]  # Shape: H x W x depth
        chunk = np.transpose(chunk, (2, 0, 1))  # Shape: depth x H x W
        chunk = resize(chunk, (depth, size, size), mode="constant")
        outputs.append(torch.tensor(chunk, dtype=torch.float32))
        idx += depth + drop
    return outputs


In [None]:
def preprocess_dataset(root_dir, output_dir, slice_depth, slice_drop, input_size, save_as_image=False):
    os.makedirs(output_dir, exist_ok=True)
    for split in ["train", "valid"]:
        split_path = Path(root_dir) / split
        for case in split_path.glob("*/*.nii.gz"):
            vol = load_nii_and_normalize(str(case))
            patches = extract_patches(vol, slice_depth, slice_drop, input_size)

            case_name = case.stem
            out_path = Path(output_dir) / split / case_name
            out_path.mkdir(parents=True, exist_ok=True)

            for i, patch in enumerate(patches):
                if save_as_image:
                    save_image(patch, open(out_path / f"{case_name}_patch_{i}.png", "wb"))
                else:
                    torch.save(patch, out_path / f"{case_name}_patch_{i}.pt")


# Process


In [None]:
batch_preprocess_ct_rate(
    input_root=r"D:\Work\QUMLG\text_guided_3D_generation\CT_RATE\dataset",
    output_root=r"D:\Work\QUMLG\text_guided_3D_generation\CT_RATE\reconstructed",
    metadata_csv=r"D:\Work\QUMLG\text_guided_3D_generation\CT_RATE\dataset\train_metadata.csv",
    target_spacing=(0.75, 0.75, 1.5),  
    target_shape=(512, 512, -1),
    hu_range=(-1000, 1000),
    resample_order=1,  
    save_nifti=True,
    save_slices=False,
    slice_axis=2,
    slice_format="png",
)