In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F



from pytorch3d.renderer.cameras import FoVOrthographicCameras, FoVPerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.implicit.utils import ray_bundle_to_ray_points
from pytorch3d.renderer import VolumeRenderer, NDCMultinomialRaysampler, EmissionAbsorptionRaymarcher
from pytorch3d.structures import Pointclouds, Volumes, Meshes
from pytorch3d.ops import add_pointclouds_to_volumes, cubify

def make_cameras_dea(dist: torch.Tensor, elev: torch.Tensor, azim: torch.Tensor, fov: int = 10, znear: int = 18.0, zfar: int = 22.0, is_orthogonal: bool = False):
    assert dist.device == elev.device == azim.device
    _device = dist.device
    R, T = look_at_view_transform(dist=dist.float(), elev=elev.float() * 90, azim=azim.float() * 180)
    if is_orthogonal:
        return FoVOrthographicCameras(R=R, T=T, znear=znear, zfar=zfar).to(_device)
    return FoVPerspectiveCameras(R=R, T=T, fov=fov, znear=znear, zfar=zfar).to(_device)


In [4]:
torch.manual_seed(42)

<torch._C.Generator at 0x7fb8103bb3d0>

In [5]:
_device = torch.device("cpu")

In [6]:
batch_size = 1
n_channels = 4
shape = 8
n_pts_per_ray = 16
dtype=torch.float32

In [7]:
batch_size = 2
n_channels = 4
shape = 8
n_pts_per_ray = 16
dtype=torch.float32

In [8]:
import os
from datamodule import UnpairedDataModule

from typing import NamedTuple, Optional, Union

class Hparams(NamedTuple):
    datadir: str=None
    train_samples: int=100
    val_samples: int=100
    test_samples: int=100
    img_shape: int=256
    vol_shape: int=256
    batch_size: int=1

hparams = Hparams(
    datadir="/home/quantm/data", 
    train_samples=1, 
    val_samples=1, 
    test_samples=1,
    img_shape=256,
    vol_shape=256,
    batch_size=1 
)

# Create data module
train_image3d_folders = [
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/NSCLC/processed/train/images'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-0'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-1'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-2'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-3'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-4'),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/Imagenglab/processed/train/images'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/MELA2022/raw/train/images'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/MELA2022/raw/val/images'),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/AMOS2022/raw/train/images'),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/AMOS2022/raw/val/images'),

    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/Verse2019/raw/train/rawdata/'),
    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/Verse2020/raw/train/rawdata/'),
    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/Verse2019/raw/val/rawdata/'),
    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/Verse2020/raw/val/rawdata/'),
    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/Verse2019/raw/test/rawdata/'),
    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/Verse2020/raw/test/rawdata/'),

    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/UWSpine/processed/train/images'),
    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/UWSpine/processed/test/images/'),
]

train_label3d_folders = [
]

train_image2d_folders = [
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/JSRT/processed/images/'),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/ChinaSet/processed/images/'),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/Montgomery/processed/images/'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/VinDr/v1/processed/train/images/'),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/VinDr/v1/processed/test/images/'),

    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/T62020/20200501/raw/images'),
    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/T62021/20211101/raw/images'),
    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/VinDr/v1/processed/train/images/'),
    # # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/VinDr/v1/processed/test/images/'),
]

train_label2d_folders = [
]

val_image3d_folders = [
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/NSCLC/processed/train/images'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-0'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-1'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-2'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-3'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-4'),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/Imagenglab/processed/train/images'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/MELA2022/raw/train/images'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/MELA2022/raw/val/images'),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/AMOS2022/raw/train/images'),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/AMOS2022/raw/val/images'),

    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/Verse2019/raw/train/rawdata/'),
    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/Verse2020/raw/train/rawdata/'),
    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/Verse2019/raw/val/rawdata/'),
    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/Verse2020/raw/val/rawdata/'),
    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/Verse2019/raw/test/rawdata/'),
    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/Verse2020/raw/test/rawdata/'),

    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/UWSpine/processed/train/images'),
    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/UWSpine/processed/test/images/'),
]

val_image2d_folders = [
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/JSRT/processed/images/'),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/ChinaSet/processed/images/'),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/Montgomery/processed/images/'),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/VinDr/v1/processed/train/images/'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/VinDr/v1/processed/test/images/'),
    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/T62020/20200501/raw/images'),
    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/T62021/20211101/raw/images'),
    # # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/VinDr/v1/processed/train/images/'),
    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/VinDr/v1/processed/test/images/'),
]

test_image3d_folders = val_image3d_folders
test_image2d_folders = val_image2d_folders

dm = UnpairedDataModule(
    train_image3d_folders=train_image3d_folders,
    train_image2d_folders=train_image2d_folders,
    val_image3d_folders=val_image3d_folders,
    val_image2d_folders=val_image2d_folders,
    test_image3d_folders=test_image3d_folders,
    test_image2d_folders=test_image2d_folders,
    train_samples=hparams.train_samples,
    val_samples=hparams.val_samples,
    test_samples=hparams.test_samples,
    batch_size=hparams.batch_size,
    img_shape=hparams.img_shape,
    vol_shape=hparams.vol_shape
)
dm.setup()
for idx, batch in enumerate(dm.val_dataloader()):
    image3d = batch["image3d"]
    print(image3d.shape)

2392
['/home/quantm/data/ChestXRLungSegmentation/MELA2022/raw/train/images/mela_0001.nii.gz']
15000
['/home/quantm/data/ChestXRLungSegmentation/VinDr/v1/processed/train/images/000434271f63a053c4128a0ba6352c7f.png']
2392
['/home/quantm/data/ChestXRLungSegmentation/MELA2022/raw/train/images/mela_0001.nii.gz']
3000
['/home/quantm/data/ChestXRLungSegmentation/VinDr/v1/processed/test/images/002a34c58c5b758217ed1f584ccbcfe9.png']
2392
['/home/quantm/data/ChestXRLungSegmentation/MELA2022/raw/train/images/mela_0001.nii.gz']
3000
['/home/quantm/data/ChestXRLungSegmentation/VinDr/v1/processed/test/images/002a34c58c5b758217ed1f584ccbcfe9.png']


monai.transforms.io.dictionary LoadImaged.__init__:image_only: Current default value of argument `image_only=False` has been deprecated since version 1.1. It will be changed to `image_only=True` in version 1.3.
<class 'monai.transforms.utility.dictionary.AddChanneld'>: Class `AddChanneld` has been deprecated since version 0.8. It will be removed in version 1.3. please use MetaTensor data type and monai.transforms.EnsureChannelFirstd instead with `channel_dim='no_channel'`.
monai.transforms.utility.dictionary EnsureChannelFirstd.__init__:meta_keys: Argument `meta_keys` has been deprecated since version 0.9. not needed if image is type `MetaTensor`.


torch.Size([1, 1, 256, 256, 256])


In [9]:
cameras = make_cameras_dea(
    azim = 1.00 * torch.rand(batch_size, device=_device) * 2 - 1,
    elev = 0.00 * torch.ones(batch_size, device=_device),
    dist = 6.00 * torch.ones(batch_size, device=_device) 
) 

print(cameras.R.shape)
print(cameras.T.shape)

torch.Size([2, 3, 3])
torch.Size([2, 3])


In [10]:
raymarcher = EmissionAbsorptionRaymarcher()
raysampler = NDCMultinomialRaysampler(  
    image_width=shape,
    image_height=shape,
    n_pts_per_ray=n_pts_per_ray,  
    min_depth=4.0,
    max_depth=8.0,
)      

renderer = VolumeRenderer(
    raysampler=raysampler,
    raymarcher=raymarcher,
) 

In [24]:
def get_frustum(cameras=None):
    batch_size = cameras.R.shape[0]
    features = torch.rand(
        size=[batch_size, n_channels, shape, shape, shape],
        device=_device,
        dtype=torch.float32,
    )

    densities = torch.ones(
        size=[batch_size, 1, shape, shape, shape],
        device=_device,
        dtype=torch.float32,
    )

    volumes = Volumes(
        features=features,
        densities=densities, 
        voxel_size= 2.0 / shape,
        volume_translation = [0, 0, 0],
    )
    
    _, ray_bundle = renderer(cameras=cameras, volumes=volumes) # [...,:3]
    ray_bundle = raysampler.forward(cameras=cameras, n_pts_per_ray=n_pts_per_ray)
    ray_points = ray_bundle_to_ray_points(ray_bundle).view(batch_size, -1, 3)  
    # print(ray_points.shape)
    ray_frustum = Pointclouds(ray_points)
    return ray_frustum

frustum = get_frustum(cameras=cameras)

In [13]:
# Some of these imports are only needed for testing code coverage
from pytorch3d.vis.plotly_vis import (  # noqa: F401
    get_camera_wireframe,  # noqa: F401
    plot_batch_individually,  # noqa: F401
    plot_scene, 
    Lighting
)



In [27]:
from pytorch3d.transforms import random_rotations
from pytorch3d.renderer import HeterogeneousRayBundle, PerspectiveCameras, RayBundle
from pytorch3d.structures import Meshes, Pointclouds
B = 1
n_rays = 128
n_pts_per_ray = 32
n_verts = 32
n_edges = 64
n_pts = 256
fig = plot_scene({
    "scene": {
        # "ray_bundle": RayBundle(
        #     origins=torch.randn(B, n_rays, 3, device=_device),
        #     xys=torch.randn(B, n_rays, 2, device=_device),
        #     directions=torch.randn(B, n_rays, 3, device=_device),
        #     lengths=torch.randn(
        #         B, n_rays, n_pts_per_ray, device=_device
        #     ),
        # ),
        # "heterogeneous_ray_bundle": HeterogeneousRayBundle(
        #     origins=torch.randn(B * n_rays, 3, device=_device),
        #     xys=torch.randn(B * n_rays, 2, device=_device),
        #     directions=torch.randn(B * n_rays, 3, device=_device),
        #     lengths=torch.randn(
        #         B * n_rays, n_pts_per_ray, device=_device
        #     ),
        #     camera_ids=torch.randint(
        #         low=0, high=B, size=(B * n_rays,), device=_device
        #     ),
        # ),
        "camera": PerspectiveCameras(
            R=random_rotations(B, device=_device),
            T=torch.randn(B, 3, device=_device),
        ),
        # "mesh": Meshes(
        #     verts=torch.randn(B, n_verts, 3, device=_device),
        #     faces=torch.randint(
        #         low=0, high=n_verts, size=(B, n_edges, 3), device=_device
        #     ),
        # ),
        # "point_clouds": Pointclouds(
        #     points=torch.randn(B, n_pts, 3, device=_device),
        # ),
    }
})
fig.show()

In [28]:
cam_dict = {f"cam{idx}": cameras[[idx]] for idx in range(batch_size)}
ray_dict = {f"ray{idx}": frustum[[idx]] for idx in range(batch_size)}

fig = plot_scene(
    plots = {
        "Field of View": {
            **cam_dict,
            **ray_dict
        },
    }, 
    lighting=Lighting(
        ambient=0.1,
        diffuse=0.5,
        fresnel=0.0,
        specular=1.0,
        roughness=0.9,
        facenormalsepsilon=1e-6,
        vertexnormalsepsilon=1e-12,
    ),
    xaxis = {
        'range': [-10, 10],
        'showgrid': False, # thin lines in the background
        'zeroline': False, # thick line at x=0
        'visible': False,  # numbers below
    }, 
    yaxis = {
        'range': [-10, 10],
        'showgrid': False, # thin lines in the background
        'zeroline': False, # thick line at x=0
        'visible': False,  # numbers below
    }, 
    zaxis = {
        'range': [-10, 10],
        'showgrid': False, # thin lines in the background
        'zeroline': False, # thick line at x=0
        'visible': False,  # numbers below
    }, 
    pointcloud_marker_size = 1,
)
fig.update_layout(
    autosize=True,
    width=800,
    height=800,
)
fig.show()