In [1]:
import os
import glob
from PIL import Image

import typing
from typing import Any, Callable, Optional, Sequence, Union, List

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import monai 
from monai.data import Dataset, ArrayDataset, CacheDataset, DataLoader
from monai.data import list_data_collate, decollate_batch
from monai.utils import first, set_determinism, get_seed, MAX_SEED
from monai.transforms import (
    apply_transform, 
    Transform, MapTransform, DivisiblePad,
    AddChanneld,
    Compose, OneOf, 
    LoadImaged, SaveImaged, Spacingd,
    Orientationd, DivisiblePadd, 
    RandFlipd, RandZoomd, RandScaleCropd, RandSpatialCropd,
    RandHistogramShiftd, RandAffined,
    Resized, Rotate90d, Flipd, 
    ScaleIntensityd,
    ScaleIntensityRanged, HistogramNormalized,
    ToTensord,
)
from pytorch_lightning import LightningDataModule

import pytorch3d
from pytorch3d.structures import Volumes
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import (
    camera_position_from_spherical_angles,
    CamerasBase,
    FoVOrthographicCameras,
    FoVPerspectiveCameras,
    get_world_to_view_transform,
    look_at_rotation,
    look_at_view_transform,
    OpenGLOrthographicCameras,
    OpenGLPerspectiveCameras,
    OrthographicCameras,
    PerspectiveCameras,
    SfMOrthographicCameras,
    SfMPerspectiveCameras,
)
from pytorch3d.transforms import Transform3d
from pytorch3d.transforms.rotation_conversions import random_rotations
from pytorch3d.transforms.so3 import so3_exp_map
from pytorch3d.renderer import VolumeRenderer
from pytorch3d.renderer import VolumeSampler
from pytorch3d.renderer import NDCMultinomialRaysampler
from pytorch3d.renderer import ray_bundle_to_ray_points
from pytorch3d.renderer import RayBundle
from pytorch3d.renderer import EmissionAbsorptionRaymarcher

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
class UnpairedDataset(monai.data.Dataset, monai.transforms.Randomizable):
    def __init__(
        self,
        keys: Sequence, 
        data: Sequence, 
        transform: Optional[Callable] = None,
        length: Optional[Callable] = None, 
        batch_size: int = 32, 
    ) -> None:
        self.keys = keys
        self.data = data
        self.length = length
        self.batch_size = batch_size
        self.transform = transform
        
    def __len__(self) -> int:
        if self.length is None:
            return min((len(dataset) for dataset in self.data))
        else: 
            return self.length

    def _transform(self, index: int):
        data = {}
        self.R.seed(index)
        for key, dataset in zip(self.keys, self.data):
            rand_idx = self.R.randint(0, len(dataset)) 
            data[key] = dataset[rand_idx]
        
        if self.transform is not None:
            data = apply_transform(self.transform, data)

        return data

class CustomDataModule(LightningDataModule):
    def __init__(self, 
        train_image3d_folders: str = "path/to/folder", 
        train_image2d_folders: str = "path/to/folder", 
        val_image3d_folders: str = "path/to/folder", 
        val_image2d_folders: str = "path/to/folder", 
        test_image3d_folders: str = "path/to/folder", 
        test_image2d_folders: str = "path/to/dir", 
        shape: int = 256,
        batch_size: int = 32
    ):
        super().__init__()

        self.batch_size = batch_size
        self.shape = shape
        # self.setup() 
        self.train_image3d_folders = train_image3d_folders
        self.train_image2d_folders = train_image2d_folders
        self.val_image3d_folders = val_image3d_folders
        self.val_image2d_folders = val_image2d_folders
        self.test_image3d_folders = test_image3d_folders
        self.test_image2d_folders = test_image2d_folders

        # self.setup()
        def glob_files(folders: str=None, extension: str='*.nii.gz'):
            assert folders is not None
            paths = [glob.glob(os.path.join(folder, extension), recursive = True) for folder in folders]
            files = sorted([item for sublist in paths for item in sublist])
            print(len(files))
            print(files[:1])
            return files
            
        self.train_image3d_files = glob_files(folders=train_image3d_folders, extension='**/*.nii.gz')
        self.train_image2d_files = glob_files(folders=train_image2d_folders, extension='**/*.png')
        
        self.val_image3d_files = glob_files(folders=val_image3d_folders, extension='**/*.nii.gz') # TODO
        self.val_image2d_files = glob_files(folders=val_image2d_folders, extension='**/*.png')
        
        self.test_image3d_files = glob_files(folders=test_image3d_folders, extension='**/*.nii.gz') # TODO
        self.test_image2d_files = glob_files(folders=test_image2d_folders, extension='**/*.png')


    def setup(self, stage: Optional[str]=None):
        # make assignments here (val/train/test split)
        # called on every process in DDP
        set_determinism(seed=2222)

    def train_dataloader(self):
        self.train_transforms = Compose(
            [
                LoadImaged(keys=["image3d", "image2d"]),
                AddChanneld(keys=["image3d", "image2d"],),
                Spacingd(keys=["image3d"], pixdim=(1.0, 1.0, 1.0), mode=["bilinear"]),  
                Rotate90d(keys=["image2d"], k=3),
                OneOf([
                    Orientationd( keys=('image3d'), axcodes="ALI"),
                    Orientationd( keys=('image3d'), axcodes="PLI"),
                    Orientationd( keys=('image3d'), axcodes="ARI"),
                    Orientationd( keys=('image3d'), axcodes="PRI"),
                    # Orientationd( keys=["image3d"], axcodes="LPI"),
                    # Orientationd( keys=["image3d"], axcodes="RPI"),
                    # Orientationd( keys=["image3d"], axcodes="LAI"),
                    # Orientationd( keys=["image3d"], axcodes="RAI"),
                    ],
                ),
                ScaleIntensityd(keys=["image2d"], minv=0.0, maxv=1.0,),
                ScaleIntensityRanged(keys=["image3d"], clip=True,  # Full range
                        a_min=-500, #-200, 
                        a_max=3071, #1500,
                        b_min=0.0,
                        b_max=1.0),
                # RandFlipd(keys=["image3d"], prob=0.5, spatial_axis=0),
                RandZoomd(keys=["image3d"], prob=1.0, min_zoom=0.9, max_zoom=1.0, padding_mode='constant', mode=["trilinear"], align_corners=True), 
                RandZoomd(keys=["image2d"], prob=1.0, min_zoom=0.9, max_zoom=1.0, padding_mode='constant', mode=["area"]), 
                RandFlipd(keys=["image2d"], prob=0.5, spatial_axis=1),
                RandScaleCropd(keys=["image3d"], 
                               roi_scale=(0.9, 0.9, 0.8), 
                               max_roi_scale=(1.0, 1.0, 0.8), 
                               random_center=True, 
                               random_size=True),
                RandAffined(keys=["image3d"], rotate_range=None, shear_range=None, translate_range=20, scale_range=None),
                Resized(keys=["image3d"], spatial_size=256, size_mode="longest", mode=["trilinear"], align_corners=True),
                Resized(keys=["image2d"], spatial_size=256, size_mode="longest", mode=["area"]),
                DivisiblePadd(keys=["image3d", "image2d"], k=256, mode="constant", constant_values=0),
                
                ToTensord(keys=["image3d", "image2d"],),
            ]
        )

        self.train_datasets = UnpairedDataset(
            keys=["image3d", "image2d"],
            data=[self.train_image3d_files, self.train_image2d_files], 
            transform=self.train_transforms,
            length=1000,
            batch_size=self.batch_size,
        )

        self.train_loader = DataLoader(
            self.train_datasets, 
            batch_size=self.batch_size, 
            num_workers=8, 
            collate_fn=list_data_collate,
            shuffle=True,
        )
        return self.train_loader

    def val_dataloader(self):
        self.val_transforms = Compose(
            [
                LoadImaged(keys=["image3d", "image2d"]),
                AddChanneld(keys=["image3d", "image2d"],),
                Spacingd(keys=["image3d"], pixdim=(1.0, 1.0, 1.0), mode=["bilinear"]),  
                Rotate90d(keys=["image2d"], k=3),
                OneOf([
                    Orientationd( keys=('image3d'), axcodes="ALI"),
                    Orientationd( keys=('image3d'), axcodes="PLI"),
                    Orientationd( keys=('image3d'), axcodes="ARI"),
                    Orientationd( keys=('image3d'), axcodes="PRI"),
                    # Orientationd( keys=["image3d"], axcodes="LPI"),
                    # Orientationd( keys=["image3d"], axcodes="RPI"),
                    # Orientationd( keys=["image3d"], axcodes="LAI"),
                    # Orientationd( keys=["image3d"], axcodes="RAI"),
                    ],
                ), 
                ScaleIntensityd(keys=["image2d"], minv=0.0, maxv=1.0,),
                ScaleIntensityRanged(keys=["image3d"], clip=True,  # Full range
                        a_min=-500, #-200, 
                        a_max=3071, #1500,
                        b_min=0.0,
                        b_max=1.0),
                Resized(keys=["image3d"], spatial_size=256, size_mode="longest", mode=["trilinear"], align_corners=True),
                Resized(keys=["image2d"], spatial_size=256, size_mode="longest", mode=["area"]),
                DivisiblePadd(keys=["image3d", "image2d"], k=256, mode="constant", constant_values=0),
            
                ToTensord(keys=["image3d", "image2d"],),
            ]
        )

        self.val_datasets = UnpairedDataset(
            keys=["image3d", "image2d"],
            data=[self.val_image3d_files, self.val_image2d_files], 
            transform=self.val_transforms,
            length=200,
            batch_size=self.batch_size,
        )
        
        self.val_loader = DataLoader(
            self.val_datasets, 
            batch_size=self.batch_size, 
            num_workers=4, 
            collate_fn=list_data_collate,
            shuffle=True,
        )
        return self.val_loader


In [None]:
class Hparams(object):
    def __init__(self, datadir, shape=256, batch_size=1):
        self.datadir = datadir
        self.shape = shape
        self.batch_size = batch_size
        
hparams = Hparams(datadir='data')

In [None]:
# Create data module
train_image3d_folders = [
    os.path.join(hparams.datadir, 'SpineXRVertSegmentation/Verse2019/raw/train/rawdata/'),
    os.path.join(hparams.datadir, 'SpineXRVertSegmentation/Verse2020/raw/train/rawdata/'),
    os.path.join(hparams.datadir, 'SpineXRVertSegmentation/Verse2019/raw/val/rawdata/'),
    os.path.join(hparams.datadir, 'SpineXRVertSegmentation/Verse2020/raw/val/rawdata/'),
    os.path.join(hparams.datadir, 'SpineXRVertSegmentation/Verse2019/raw/test/rawdata/'),
    os.path.join(hparams.datadir, 'SpineXRVertSegmentation/Verse2020/raw/test/rawdata/'),

    os.path.join(hparams.datadir, 'SpineXRVertSegmentation/UWSpine/processed/train/images'),
    os.path.join(hparams.datadir, 'SpineXRVertSegmentation/UWSpine/processed/test/images/'),

    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/NSCLC/processed/train/images'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-0'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-1'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-2'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-3'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-4'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/Imagenglab/processed/train/images'),
]
train_label3d_folders = [

]

train_image2d_folders = [
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/JSRT/processed/images/'), 
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/ChinaSet/processed/images/'), 
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/Montgomery/processed/images/'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/VinDr/v1/processed/train/images/'), 
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/VinDr/v1/processed/test/images/'), 
    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/T62020/20200501/raw/images'), 
    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/T62021/20211101/raw/images'), 
    os.path.join(hparams.datadir, 'SpineXRVertSegmentation/VinDr/v1/processed/train/images/'), 
    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/VinDr/v1/processed/test/images/'), 
]
train_label2d_folders = [
]

val_image3d_folders = train_image3d_folders
val_image2d_folders = [
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/JSRT/processed/images/'), 
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/ChinaSet/processed/images/'), 
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/Montgomery/processed/images/'),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/VinDr/v1/processed/train/images/'), 
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/VinDr/v1/processed/test/images/'), 
    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/T62020/20200501/raw/images'), 
    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/T62021/20211101/raw/images'), 
    # os.path.join(hparams.datadir, 'SpineXRVertSegmentation/VinDr/v1/processed/train/images/'), 
    os.path.join(hparams.datadir, 'SpineXRVertSegmentation/VinDr/v1/processed/test/images/'), 
]

test_image3d_folders = val_image3d_folders
test_image2d_folders = val_image2d_folders

datamodule = CustomDataModule(
    train_image3d_folders = train_image3d_folders, 
    train_image2d_folders = train_image2d_folders, 
    val_image3d_folders = val_image3d_folders, 
    val_image2d_folders = val_image2d_folders, 
    test_image3d_folders = test_image3d_folders, 
    test_image2d_folders = test_image2d_folders, 
    batch_size = hparams.batch_size, 
    shape = hparams.shape
)
datamodule.setup()


In [None]:
def RandomCameras(
        dist: float = 3.0, 
        elev: float = 0.0, 
        azim: float = 0.0, 
        batch_size: int = 32, 
        random: bool = False, 
        dtype: torch.Tensor = torch.float32,
    ): 
    # R0, T = look_at_view_transform(dist, elev, azim)
    if random:
        rand = 180*np.random.randint(0, 2)
        elev = np.random.uniform( -5,  5) + 180*rand
        azim = np.random.uniform( -5,  5) + 180*rand

    R, T = look_at_view_transform(dist, elev, azim)
    R = R.repeat(batch_size, 1, 1)
    T = T.repeat(batch_size, 1)

    znear = 0.1 * torch.ones(batch_size) if random else \
            0.1 * torch.ones(batch_size) # 0.5*torch.ones(batch_size) * 10 + 0.1
    zfar = 3.5 * torch.ones(batch_size) if random else \
           3.5 * torch.ones(batch_size) # 0.5*torch.ones(batch_size) * 4 + 1 + znear
    fov = torch.ones(batch_size) * 60 + (torch.randn(batch_size)) * 5 if random else \
          torch.ones(batch_size) * 60 + 0
    aspect_ratio = 1.15 * torch.ones(batch_size) + (torch.randn(batch_size)) * 0.1  if random else \
                   1.15 * torch.ones(batch_size)
                   
    return FoVPerspectiveCameras(R=R.type(dtype), 
                                 T=T.type(dtype), 
                                 znear=znear.type(dtype), 
                                 zfar=zfar.type(dtype), 
                                 fov=fov.type(dtype), 
                                 aspect_ratio=aspect_ratio.type(dtype))


In [None]:
class VolumeModel(torch.nn.Module):
    def __init__(self, renderer, volume_size=[256] * 3, voxel_size=0.1):
        super().__init__()
        # # After evaluating torch.sigmoid(self.log_colors), we get 
        # # densities close to zero.
        # self.log_densities = torch.nn.Parameter(-4.0 * torch.ones(1, *volume_size))
        # # After evaluating torch.sigmoid(self.log_colors), we get 
        # # a neutral gray color everywhere.
        # self.log_colors = torch.nn.Parameter(torch.zeros(3, *volume_size))

        self._voxel_size = voxel_size
        # Store the renderer module as well.
        self._renderer = renderer
        
    def forward(self, cameras, volumes):
        batch_size = cameras.R.shape[0]

        # # Convert the log-space values to the densities/colors
        # densities = torch.sigmoid(self.log_densities)
        # colors = torch.sigmoid(self.log_colors)
        
        # # Instantiate the Volumes object, making sure
        # # the densities and colors are correctly
        # # expanded batch_size-times.
        # volumes = Volumes(
        #     densities = densities[None].expand(
        #         batch_size, *self.log_densities.shape),
        #     features = colors[None].expand(
        #         batch_size, *self.log_colors.shape),
        #     voxel_size=self._voxel_size,
        # )
        
        # Given cameras and volumes, run the renderer
        # and return only the first output value 
        # (the 2nd output is a re.5presentation of the sampled
        # rays which can be omitted for our purpose).
        # return self._renderer(cameras=cameras, volumes=volumes)[0]
        # screen_RGBA = screen_RGBA.reshape(B, self.shape, self.shape, 4).permute(0,3,2,1) # 3 for NeRF
        screen_RGBA, _ = self._renderer(cameras=cameras, volumes=volumes) #[...,:3]
        screen_RGBA = screen_RGBA.permute(0,3,2,1) # 3 for NeRF
        screen_RGB = screen_RGBA[:,:3].mean(dim=1, keepdim=True)
        normalized = lambda x: (x - x.min())/(x.max() - x.min())
        screen_RGB = normalized(screen_RGB)
        return screen_RGB