In [1]:
import cv2
import torch
import torchvision
from tqdm import tqdm
import numpy as np
from monai.transforms import (
    apply_transform,
    Randomizable,
    Compose,
    OneOf,
    EnsureChannelFirstDict,
    LoadImageDict,
    SpacingDict,
    OrientationDict,
    DivisiblePadDict,
    CropForegroundDict,
    ResizeDict,
    RandZoomDict,
    ZoomDict,
    RandRotateDict,
    HistogramNormalizeDict,
    ScaleIntensityDict,
    ScaleIntensityRangeDict,
    ToTensorDict,
    Transform
)

In [2]:
%%bash
mkdir -p demo/ct/

In [3]:
import os
import glob
from tqdm.auto import tqdm

In [4]:
def glob_files(folders: str = None, extension: str = "*.nii.gz"):
    assert folders is not None
    paths = [
        glob.glob(os.path.join(folder, extension), recursive=True)
        for folder in folders
    ]
    files = sorted([item for sublist in paths for item in sublist])
    print(len(files))
    print(files[:1])
    return files

In [5]:
ct_folders = [
    "data/ChestXRLungSegmentation/NSCLC/processed/train/images",
    "data/ChestXRLungSegmentation/MOSMED/processed/train/images/CT-0",
    "data/ChestXRLungSegmentation/MOSMED/processed/train/images/CT-1",
    "data/ChestXRLungSegmentation/MOSMED/processed/train/images/CT-2",
    "data/ChestXRLungSegmentation/MOSMED/processed/train/images/CT-3",
    "data/ChestXRLungSegmentation/MOSMED/processed/train/images/CT-4",
    "data/ChestXRLungSegmentation/Imagenglab/processed/train/images",
]
ct_images = glob_files(
    ct_folders, 
    extension="*.nii.gz"
)

1512
['data/ChestXRLungSegmentation/MOSMED/processed/train/images/CT-0/study_0001.nii.gz']


In [6]:
class UnSqueezeDim(Transform):
    """
    Squeeze unnecessary unitary dimensions
    """

    def __init__(self, dim=None):
        """
        Args:
            dim (int): dimension to be squeezed.
                Default: None (all dimensions of size 1 will be removed)
        """
        if dim is not None:
            assert isinstance(dim, int) and dim >= -1, 'invalid channel dimension.'
        self.dim = dim

    def __call__(self, img):
        """
        Args:
            data (dict): dictionary of numpy arrays with dim removed,
        """
        return np.expand_dims(img, self.dim)

class SqueezeDim(Transform):
    """
    Squeeze unnecessary unitary dimensions
    """

    def __init__(self, dim=None):
        """
        Args:
            dim (int): dimension to be squeezed.
                Default: None (all dimensions of size 1 will be removed)
        """
        if dim is not None:
            assert isinstance(dim, int) and dim >= -1, 'invalid channel dimension.'
        self.dim = dim

    def __call__(self, img):
        """
        Args:
            data (dict): dictionary of numpy arrays with dim removed,
        """
        return np.squeeze(img, self.dim)

In [7]:
# Define the transformation pipeline for "image2d"
val_transforms = Compose(
    [
        LoadImageDict(keys=["image3d"]),
        EnsureChannelFirstDict(keys=["image3d"],),
        SpacingDict(
            keys=["image3d"],
            pixdim=(1.0, 1.0, 1.0),
            mode=["bilinear"],
            align_corners=True,
        ),
        OrientationDict(keys=("image3d"), axcodes="ASL"),
        ScaleIntensityRangeDict(
            keys=["image3d"],
            clip=True,
            a_min=-1024,
            a_max=+3071,
            b_min=0.0,
            b_max=1.0,
        ),
        CropForegroundDict(
            keys=["image3d"],
            source_key="image3d",
            select_fn=(lambda x: x > 0),
            margin=0,
        ),
        ZoomDict(keys=["image3d"], zoom=0.95, padding_mode="constant", mode=["area"]),
        ResizeDict(
            keys=["image3d"],
            spatial_size=256,
            size_mode="longest",
            mode=["trilinear"],
            align_corners=True,
        ),
        DivisiblePadDict(
            keys=["image3d"],
            k=256,
            mode="constant",
            constant_values=0,
        ),
        ToTensorDict(keys=["image3d"],),
    ]
)



In [8]:
from main_frustuminv_xray import NVLightningModule, make_cameras_dea

# Run forward pass
device = torch.device('cuda:0')
B = 1
checkpoint_path = "logs/diffusion/version_4/checkpoints/last.ckpt"
model = NVLightningModule.load_from_checkpoint(checkpoint_path, strict=False).to(device)
dist_hidden = 8 * torch.ones(B, device=device)
elev_hidden = torch.zeros(B, device=device)
azim_hidden = torch.zeros(B, device=device)
view_hidden = make_cameras_dea(
    dist_hidden, 
    elev_hidden, 
    azim_hidden, 
    fov=16.0, 
    znear=6.1, 
    zfar=9.9,
).to(device)


(1048576, 1048576)


Seed set to 21


In [9]:

for image in tqdm(ct_images):
    data = {}
    data["image3d"] = image
    sample = val_transforms(data)
    # print(sample["image3d"].shape)

    sample = sample["image3d"].to(device)
    sample = sample.unsqueeze(0)
    # print(sample.shape)

    output = model.forward_screen(
        image3d=sample, 
        cameras=view_hidden, 
    ).clamp_(0, 1).squeeze().detach().cpu()
    output = (255*output).astype(torch.uint8).unsqueeze(0).transpose(-1, -2)
    # print(output.shape)
    
    filename = image.replace("/", "_").replace(".nii.gz", ".png")
    torchvision.io.write_png(output, f"demo/ct/{filename}")
    
    

  0%|          | 0/1512 [00:00<?, ?it/s]

torch.Size([1, 256, 256, 256])
torch.Size([1, 256, 256, 256])
torch.Size([1, 256, 256, 256])
torch.Size([1, 256, 256, 256])
