In [1]:
import cv2
import torch
import torchvision
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import Image
from monai.transforms import (
    apply_transform,
    Randomizable,
    Compose,
    OneOf,
    EnsureChannelFirstDict,
    LoadImageDict,
    SpacingDict,
    OrientationDict,
    DivisiblePadDict,
    CropForegroundDict,
    ResizeDict,
    RandZoomDict,
    ZoomDict,
    RandRotateDict,
    HistogramNormalizeDict,
    ScaleIntensityDict,
    ScaleIntensityRangeDict,
    ToTensorDict,
    Transform
)

In [2]:
import os
import glob
from tqdm.auto import tqdm

def glob_files(folders: str = None, extension: str = "*.nii.gz"):
    assert folders is not None
    paths = [
        glob.glob(os.path.join(folder, extension), recursive=True)
        for folder in folders
    ]
    files = sorted([item for sublist in paths for item in sublist])
    print(len(files))
    print(files[:1])
    return files

ct_folders = [
    "/home/quantm/data/ChestXRLungSegmentation/NSCLC/processed/train/images",
]
ct_images = glob_files(
    ct_folders, 
    extension="*.nii.gz"
)

402
['/home/quantm/data/ChestXRLungSegmentation/NSCLC/processed/train/images/LUNG1-001_0000.nii.gz']


In [3]:
from pytorch3d.renderer.cameras import (
    FoVPerspectiveCameras,
    FoVOrthographicCameras,
    look_at_view_transform,
)

def make_cameras_dea(
    dist: torch.Tensor,
    elev: torch.Tensor,
    azim: torch.Tensor,
    fov: int = 40,
    znear: int = 4.0,
    zfar: int = 8.0,
    is_orthogonal: bool = False,
):
    assert dist.device == elev.device == azim.device
    _device = dist.device
    R, T = look_at_view_transform(dist=dist, elev=elev * 90, azim=azim * 180)
    if is_orthogonal:
        return FoVOrthographicCameras(R=R, T=T, znear=znear, zfar=zfar).to(_device)
    return FoVPerspectiveCameras(R=R, T=T, fov=fov, znear=znear, zfar=zfar).to(_device)


In [4]:
# Define the transformation pipeline for "image2d"
val_transforms = Compose(
    [
        LoadImageDict(keys=["image3d"]),
        EnsureChannelFirstDict(keys=["image3d"],),
        SpacingDict(
            keys=["image3d"],
            pixdim=(1.0, 1.0, 1.0),
            mode=["bilinear"],
            align_corners=True,
        ),
        OrientationDict(keys=("image3d"), axcodes="ASL"),
        ScaleIntensityRangeDict(
            keys=["image3d"],
            clip=True,
            a_min=-1024,
            a_max=+3071,
            b_min=0.0,
            b_max=1.0,
        ),
        # ScaleIntensityRangeDict(
        #     keys=["image3d"],
        #     clip=True,
        #     a_min=-1024,
        #     a_max=+3071,
        #     b_min=-1024,
        #     b_max=+3071,
        # ),
        # CropForegroundDict(
        #     keys=["image3d"],
        #     source_key="image3d",
        #     select_fn=(lambda x: x > 0),
        #     margin=0,
        # ),
        # ZoomDict(keys=["image3d"], zoom=0.95, padding_mode="constant", mode=["area"]),
        ResizeDict(
            keys=["image3d"],
            spatial_size=256,
            size_mode="longest",
            mode=["trilinear"],
            align_corners=True,
        ),
        DivisiblePadDict(
            keys=["image3d"],
            k=256,
            mode="constant",
            constant_values=0,
        ),
        ToTensorDict(keys=["image3d"],),
    ]
)

def rescaled(x, val=64, eps=1e-8):
    return (x + eps) / (val + eps)

def minimized(x, eps=1e-8):
    return (x + eps) / (x.max() + eps)

def normalized(x, eps=1e-8):
    return (x - x.min() + eps) / (x.max() - x.min() + eps)

def standardized(x, eps=1e-8):
    return (x - x.mean()) / (x.std() + eps)  # 1e-6 to avoid zero division

def transform_hu_to_density(
        volume, 
        bone_attenuation_multiplier=2, 
        # v_min=-800, 
        # v_max=+350, 
        v_min=-256, 
        v_max=+1024
    ):
    # volume can be loaded as int16, need to convert to float32 to use float bone_attenuation_multiplier
    volume = volume.to(torch.float32)
    air = torch.where(volume <= v_min)
    soft_tissue = torch.where((v_min < volume) & (volume <= v_max))
    bone = torch.where(v_max < volume)

    density = torch.empty_like(volume)
    density[air] = volume[soft_tissue].min()
    density[soft_tissue] = volume[soft_tissue]
    density[bone] = volume[bone] * bone_attenuation_multiplier
    density = normalized(density)
    return density

In [5]:
# Run forward pass
device = torch.device('cuda:0')
B = 1

fov=12.0
znear=7
zfar=9

i = 22

In [6]:
dist_hidden = 8 * torch.ones(B, device=device)
elev_hidden = torch.zeros(B, device=device)
azim_hidden = torch.zeros(B, device=device)
view_hidden = make_cameras_dea(
    dist_hidden, 
    elev_hidden, 
    azim_hidden, 
    fov=fov, 
    znear=znear, 
    zfar=zfar,
).to(device)

dist_lateral = 8 * torch.ones(B, device=device)
elev_lateral = torch.zeros(B, device=device)
azim_lateral = torch.ones(B, device=device) * 0.25
view_lateral = make_cameras_dea(
    dist_lateral, 
    elev_lateral, 
    azim_lateral, 
    fov=fov, 
    znear=znear, 
    zfar=zfar,
).to(device)

In [7]:
from dvr.renderer import ObjectCentricXRayVolumeRenderer


# Initialize the renderer outside the loop to avoid reinitialization
fwd_renderer = ObjectCentricXRayVolumeRenderer(
    image_width=256,
    image_height=256,
    n_pts_per_ray=800,
    min_depth=znear,
    max_depth=zfar,
    ndc_extent=1.0,
)

# Loop through v_min values with tqdm
for v_min in tqdm(range(-1024, 3071, 50), desc="Outer Loop (v_min)"): 
    # Loop through v_max values from v_min + 1 to 3071 with tqdm
    for v_max in tqdm(range(v_min + 1, 3072, 50), desc="Inner Loop (v_max)", leave=False):
        try:
            # Prepare data for rendering
            data = {}
            data["image3d"] = ct_images[i]
            image3d = val_transforms(data)

            image3d = image3d["image3d"].to(device)
            image3d = image3d.unsqueeze(0)

            # Scale from 0 1 to -1024 4095
            image3d *= 4095
            image3d -= 1024  
            
            # Transform the image using the current v_min and v_max
            image3d = transform_hu_to_density(image3d, v_min=v_min, v_max=v_max)

            # Render frontal view
            output_frontal = fwd_renderer.forward(
                image3d=image3d, 
                cameras=view_hidden, 
            ).clamp_(0, 1).squeeze().detach().cpu()
            
            output_frontal = (255 * output_frontal).astype(torch.uint8).unsqueeze(0)
            
            # Save frontal view image
            filename_frontal = f"win/ObjectCentricXRayFrontalImage_vmin_{v_min}_vmax_{v_max}.jpg"
            torchvision.io.write_jpeg(output_frontal, filename_frontal)

            # Render lateral view
            output_lateral = fwd_renderer.forward(
                image3d=image3d, 
                cameras=view_lateral, 
            ).clamp_(0, 1).squeeze().detach().cpu()
            
            output_lateral = (255 * output_lateral).astype(torch.uint8).unsqueeze(0)
            
            # Save lateral view image
            filename_lateral = f"win/ObjectCentricXRayLateralImage_vmin_{v_min}_vmax_{v_max}.jpg"
            torchvision.io.write_jpeg(output_lateral, filename_lateral)

            del data 
        except RuntimeError as e:
            if "Expected reduction dim to be specified" in str(e):
                print(f"Skipping due to RuntimeError: {e}")
                continue  # Skip this iteration and continue with the next one
            else:
                raise  # Re-raise the exception if it's not the one we're handling
        
print("Rendering complete for all specified v_min and v_max values.")

Outer Loop (v_min):   0%|          | 0/82 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/82 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/81 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/80 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/79 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/78 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/77 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/76 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/75 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/74 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/73 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/72 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/71 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/70 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/69 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/68 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/67 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/66 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/65 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/64 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/63 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/62 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/61 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/60 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/59 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/58 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/57 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/56 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/55 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/54 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/53 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/52 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/51 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/50 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/49 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/48 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/47 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/46 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/45 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/44 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/43 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/42 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/41 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/40 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/39 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/38 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/37 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/36 [00:00<?, ?it/s]

Skipping due to RuntimeError: min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.


Inner Loop (v_max):   0%|          | 0/35 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/34 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/33 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/32 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/31 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/30 [00:00<?, ?it/s]

Skipping due to RuntimeError: min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.


Inner Loop (v_max):   0%|          | 0/29 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/28 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/27 [00:00<?, ?it/s]

Skipping due to RuntimeError: min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.


Inner Loop (v_max):   0%|          | 0/26 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/25 [00:00<?, ?it/s]

Skipping due to RuntimeError: min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.


Inner Loop (v_max):   0%|          | 0/24 [00:00<?, ?it/s]

Skipping due to RuntimeError: min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.


Inner Loop (v_max):   0%|          | 0/23 [00:00<?, ?it/s]

Skipping due to RuntimeError: min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.


Inner Loop (v_max):   0%|          | 0/22 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/21 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/20 [00:00<?, ?it/s]

Skipping due to RuntimeError: min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.


Inner Loop (v_max):   0%|          | 0/19 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/18 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/17 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/16 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/15 [00:00<?, ?it/s]

Skipping due to RuntimeError: min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.


Inner Loop (v_max):   0%|          | 0/14 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/13 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/12 [00:00<?, ?it/s]

Skipping due to RuntimeError: min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.


Inner Loop (v_max):   0%|          | 0/11 [00:00<?, ?it/s]

Skipping due to RuntimeError: min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.


Inner Loop (v_max):   0%|          | 0/10 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/9 [00:00<?, ?it/s]

Skipping due to RuntimeError: min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.


Inner Loop (v_max):   0%|          | 0/8 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/7 [00:00<?, ?it/s]

Skipping due to RuntimeError: min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.


Inner Loop (v_max):   0%|          | 0/6 [00:00<?, ?it/s]

Skipping due to RuntimeError: min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.


Inner Loop (v_max):   0%|          | 0/5 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/4 [00:00<?, ?it/s]

Skipping due to RuntimeError: min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.


Inner Loop (v_max):   0%|          | 0/3 [00:00<?, ?it/s]

Inner Loop (v_max):   0%|          | 0/2 [00:00<?, ?it/s]

Skipping due to RuntimeError: min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.


Inner Loop (v_max):   0%|          | 0/1 [00:00<?, ?it/s]

Rendering complete for all specified v_min and v_max values.
