In [1]:
import os
from pathlib import Path

import nibabel as nib
import numpy as np
from skimage.transform import resize
import torch
from torchvision.utils import save_image
import yaml
from utils import get_affine_from_metadata, get_metadata_for_volume, convert_to_hu, resample_volume, crop_or_pad, preprocess_ct_volume, batch_preprocess_ct_rate

# Functions


In [2]:
def load_nii_and_normalize(path):
    img = nib.load(path).get_fdata()
    img = np.clip(img, -1000, 1000)  # Clip Hounsfield units
    img = (img + 1000) / 2000        # Normalize to 0–1
    return img


In [3]:
def extract_patches(volume, depth=5, drop=2, size=128):
    slices = volume.shape[2]
    outputs = []
    idx = 0
    while idx + depth <= slices:
        chunk = volume[:, :, idx:idx+depth]  # Shape: H x W x depth
        chunk = np.transpose(chunk, (2, 0, 1))  # Shape: depth x H x W
        chunk = resize(chunk, (depth, size, size), mode="constant")
        outputs.append(torch.tensor(chunk, dtype=torch.float32))
        idx += depth + drop
    return outputs


In [4]:
def preprocess_dataset(root_dir, output_dir, slice_depth, slice_drop, input_size, save_as_image=False):
    os.makedirs(output_dir, exist_ok=True)
    for split in ["train", "valid"]:
        split_path = Path(root_dir) / split
        for case in split_path.glob("*/*.nii.gz"):
            vol = load_nii_and_normalize(str(case))
            patches = extract_patches(vol, slice_depth, slice_drop, input_size)

            case_name = case.stem
            out_path = Path(output_dir) / split / case_name
            out_path.mkdir(parents=True, exist_ok=True)

            for i, patch in enumerate(patches):
                if save_as_image:
                    save_image(patch, open(out_path / f"{case_name}_patch_{i}.png", "wb"))
                else:
                    torch.save(patch, out_path / f"{case_name}_patch_{i}.pt")


# Process


In [None]:
batch_preprocess_ct_rate(
    input_root=r"D:\Work\QUMLG\text_guided_3D_generation\CT_RATE\dataset",
    output_root=r"D:\Work\QUMLG\text_guided_3D_generation\CT_RATE\reconstructed",
    metadata_csv=r"D:\Work\QUMLG\text_guided_3D_generation\CT_RATE\dataset\train_metadata.csv",
    target_spacing=(0.75, 0.75, 1.5),  
    target_shape=(512, 512, -1),
    hu_range=(-1000, 1000),
    resample_order=1,  
    save_nifti=True,
    save_slices=False,
    slice_axis=2,
    slice_format="png",
)

# Demo test

In [2]:
from ct_dataset_esrgan import CTRateDatasetBase, RealESRGANCustomCTDataset
import pandas as pd


In [3]:
opt = yaml.load(open("finetune_realesrgan_x4plus_ct.yml"), Loader=yaml.CLoader)

In [4]:
opt["datasets"]["train"]["dataroot_gt"] = "/home/promit/Promit/text_guided_3d_generation/CT_RATE/dataset/train"
opt["datasets"]["train"]["metadata_csv"] = "/home/promit/Promit/text_guided_3d_generation/CT_RATE/dataset/train_metadata.csv"

opt["datasets"]["val"]["dataroot_gt"] = "/home/promit/Promit/text_guided_3d_generation/CT_RATE/dataset/valid"
opt["datasets"]["val"]["metadata_csv"] = "/home/promit/Promit/text_guided_3d_generation/CT_RATE/dataset/validation_metadata.csv"

# opt["datasets"]["train"]["depth"] = 1
# opt["datasets"]["val"]["depth"] = 1


In [5]:
ds = RealESRGANCustomCTDataset({**opt, **opt["datasets"]["train"]})

Filtering Samples: 100%|██████████| 9/9 [00:00<00:00, 2525.34it/s]
Scanning Volumes for Patch Counts: 100%|██████████| 9/9 [00:00<00:00, 16355.60it/s]


In [6]:
for i in range(len(ds.base_dataset)):
    print(ds.base_dataset.get_volume_slice_count(i))
print(sum([ds.base_dataset.get_volume_slice_count(i) for i in range(len(ds.base_dataset))]))

192
192
237
473
210
303
303
290
290
2490


In [7]:
ds.total_patches, ds.base_dataset.num_slices_total, len(ds.base_dataset), ds.base_dataset.target_shape

(2490, 2490, 9, (512, 512, -1))

In [8]:
len(ds)

2490

2489: (Vol 8, Slice 289)  
2480: (Vol 8, Slice 280)  
2470: (Vol 8, Slice 270)  
2460: (Vol 8, Slice 260)  
2450: (Vol 8, Slice 250)   
2440: (Vol 8, Slice 240) 
2439: (Vol 8, Slice 239)  
2437: (Vol 8, Slice 237)  
2436: (Vol 8, Slice 236)  
2435: (Vol 8, Slice 235) ok  
2430: (Vol 8, Slice 230) ok  


In [9]:
ds.base_dataset.samples[8]

PosixPath('/home/promit/Promit/text_guided_3d_generation/CT_RATE/dataset/train/train_10000_a/train_10000_a_2.nii.gz')

In [10]:
vol = nib.load(ds.base_dataset.samples[8])
vol_data = vol.get_fdata()
vol_data.shape

(512, 512, 290)

In [11]:
_ = ds[2435]

In [12]:
_ = ds[2436]