In [None]:
import torch
from munch import munchify

from diffusion_3d.headct.ldm_uncontrolled.dataset import HeadCTDataset
from diffusion_3d.utils.visualize import plot_scans

# Check dataset

In [38]:
image_size = (32, 256, 256)

config = munchify(
    dict(
        csvpath=r"/raid3/arjun/ct_pretraining/csvs/sources.csv",
        datapath=r"/raid3/arjun/ct_pretraining/scans/",
        checkpointspath=r"/raid3/arjun/checkpoints/diffusion_3d/",
        #
        limited_dataset_size=None,
        #
        sources=["vrad", "medall", "fts"],
        bodyparts=["head"],
        allowed_spacings=((0.4, 7), (-1, -1), (-1, -1)),
        allowed_shapes=((64, -1), (-1, -1), (-1, -1)),
        #
        train_augmentations=[
            {
                "__fn_name__": "pad_to_target_shape",
                "target_shape": (36, 512, 512),
                "mode": "random",
            },
            {
                "__fn_name__": "random_crop",
                "target_shape": (-1, 512, 512),
            },
            {
                "__fn_name__": "random_rotate",
                "degrees": 25,
            },
            {
                "__fn_name__": "random_resize",
                "min_shape": image_size,
                "max_shape": (
                    int(image_size[0] * 1.2),
                    min(int(image_size[1] * 1.1), 512),
                    min(int(image_size[2] * 1.1), 512),
                ),
                "interpolation_mode": "trilinear",
            },
            {
                "__fn_name__": "random_crop",
                "target_shape": image_size,
            },
            {
                "__fn_name__": "random_windowing",
                "hotspots_and_stds": [
                    [(80, 40), (7, 2)],  # Brain window
                    [(37, 37), (4, 2)],  # Stroke window
                    [(3400, 650), (360, 35)],  # Bone window
                    [(8, 32), (0.5, 2)],  # Another stroke window
                    [(210, 75), (10, 4)],  # subdural window
                    [(375, 40), (10, 2)],  # Soft tissue window
                ],
                "sampling_probability": [0.4, 0.3, 0.15, 0.05, 0.05, 0.05],
                "normalize_range": (0, 1),
            },
            {
                "__fn_name__": "random_horizontal_flip",
                "probability": 0.5,
            },
            [
                [0.4, 0.3, 0.3],
                [],
                [
                    {
                        "__fn_name__": "random_gaussian_blurring",
                        "sigma_range": (0, 1),
                    }
                ],
                [
                    {
                        "__fn_name__": "random_unsharp_masking",
                        "sigma_range": (0, 1),
                        "alpha_range": (0.5, 2),
                    }
                ],
            ],
        ],
        #
        num_workers=6,
        # batch_size=int(torch.cuda.get_device_properties(0).total_memory // 1.25e9),  # (32, 384, 384) 100M
        # batch_size=int(torch.cuda.get_device_properties(0).total_memory // 3.2e9),  # (48, 384, 384) 100M
        batch_size=int(torch.cuda.get_device_properties(0).total_memory // 4e9),  # (64, 384, 384) 100M
        # train_sample_size=168_000,
        train_sample_size=70_000,  # (64, 384, 384)
        sample_balance_cols=["Source", "BodyPart"],
    )
)

In [None]:
dataset = HeadCTDataset(config, "train")

train:   0%|          | 0/194552 [00:00<?, ?it/s]

No. of train datapoints: 77860


In [40]:
datapoint = dataset[0]

print(datapoint['spacing'])
print(datapoint['uid'])
print(datapoint['index'])
print(datapoint['scan'].shape)
plot_scans(datapoint['scan'][0])

tensor([5.1445, 0.9580, 0.9399], dtype=torch.float16)
1.2.840.113619.2.278.3.2831165736.943.1356932353.328
1
torch.Size([1, 32, 256, 256])


interactive(children=(IntSlider(value=0, description='z', max=31), Output()), _dom_classes=('widget-interact',…