In [1]:
import logging
from typing import Tuple

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import torch
import tqdm
from IPython.display import HTML
from omegaconf import OmegaConf
from PIL import Image
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase, ImplicitronRayBundle
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
from pytorch3d.implicitron.tools.config import get_default_args, registry, remove_unused_components
from pytorch3d.renderer.implicit.renderer import VolumeSampler
from pytorch3d.structures import Volumes
from pytorch3d.vis.plotly_vis import plot_batch_individually, plot_scene


In [2]:
output_resolution = 256

In [3]:
CONSTRUCT_MODEL_FROM_CONFIG = True
 # constructing GenericModel directly
gm = GenericModel(
    image_feature_extractor_class_type="ResNetFeatureExtractor",
    implicit_function_class_type="VoxelGridImplicitFunction",
    render_image_height=output_resolution,
    render_image_width=output_resolution,
    loss_weights={"loss_rgb_mse": 1.0},
    tqdm_trigger_threshold=19000,
    raysampler_AdaptiveRaySampler_args = {"scene_extent": 4.0}
)

# In this case we can get the equivalent DictConfig cfg object to the way gm is configured as follows
cfg = OmegaConf.structured(gm)

In [8]:
# We can display the configuration in use as follows.
remove_unused_components(cfg)
yaml = OmegaConf.to_yaml(cfg, sort_keys=False)
# %page -r yaml

In [9]:
device = torch.device("cuda:0")
gm.to(device)
assert next(gm.parameters()).is_cuda


In [10]:
gm.train()
optimizer = torch.optim.Adam(gm.parameters(), lr=0.1)


In [13]:
import os
from datamodule import UnpairedDataModule
datadir = "data"
# Create data module
train_image3d_folders = [
    os.path.join(datadir, 'ChestXRLungSegmentation/NSCLC/processed/train/images'),
    os.path.join(datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-0'),
    os.path.join(datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-1'),
    os.path.join(datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-2'),
    os.path.join(datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-3'),
    os.path.join(datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-4'),
    # os.path.join(datadir, 'ChestXRLungSegmentation/Imagenglab/processed/train/images'),
    os.path.join(datadir, 'ChestXRLungSegmentation/MELA2022/raw/train/images'),
    os.path.join(datadir, 'ChestXRLungSegmentation/MELA2022/raw/val/images'),
]

train_label3d_folders = [
]

train_image2d_folders = [
    os.path.join(datadir, 'ChestXRLungSegmentation/JSRT/processed/images/'),
    os.path.join(datadir, 'ChestXRLungSegmentation/ChinaSet/processed/images/'),
    os.path.join(datadir, 'ChestXRLungSegmentation/Montgomery/processed/images/'),
    os.path.join(datadir, 'ChestXRLungSegmentation/VinDr/v1/processed/train/images/'),
    # os.path.join(datadir, 'ChestXRLungSegmentation/VinDr/v1/processed/test/images/'),

    # os.path.join(datadir, 'SpineXRVertSegmentation/T62020/20200501/raw/images'),
    # os.path.join(datadir, 'SpineXRVertSegmentation/T62021/20211101/raw/images'),
    # os.path.join(datadir, 'SpineXRVertSegmentation/VinDr/v1/processed/train/images/'),
    # # os.path.join(datadir, 'SpineXRVertSegmentation/VinDr/v1/processed/test/images/'),
]

train_label2d_folders = [
]

val_image3d_folders = [
    os.path.join(datadir, 'ChestXRLungSegmentation/NSCLC/processed/train/images'),
    os.path.join(datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-0'),
    os.path.join(datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-1'),
    os.path.join(datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-2'),
    os.path.join(datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-3'),
    os.path.join(datadir, 'ChestXRLungSegmentation/MOSMED/processed/train/images/CT-4'),
    # os.path.join(datadir, 'ChestXRLungSegmentation/Imagenglab/processed/train/images'),
    os.path.join(datadir, 'ChestXRLungSegmentation/MELA2022/raw/train/images'),
    os.path.join(datadir, 'ChestXRLungSegmentation/MELA2022/raw/val/images'),
    # os.path.join(datadir, 'ChestXRLungSegmentation/AMOS2022/raw/train/images'),
    # os.path.join(datadir, 'ChestXRLungSegmentation/AMOS2022/raw/val/images'),

    # os.path.join(datadir, 'SpineXRVertSegmentation/Verse2019/raw/train/rawdata/'),
    # os.path.join(datadir, 'SpineXRVertSegmentation/Verse2020/raw/train/rawdata/'),
    # os.path.join(datadir, 'SpineXRVertSegmentation/Verse2019/raw/val/rawdata/'),
    # os.path.join(datadir, 'SpineXRVertSegmentation/Verse2020/raw/val/rawdata/'),
    # os.path.join(datadir, 'SpineXRVertSegmentation/Verse2019/raw/test/rawdata/'),
    # os.path.join(datadir, 'SpineXRVertSegmentation/Verse2020/raw/test/rawdata/'),

    # os.path.join(datadir, 'SpineXRVertSegmentation/UWSpine/processed/train/images'),
    # os.path.join(datadir, 'SpineXRVertSegmentation/UWSpine/processed/test/images/'),
]

val_image2d_folders = [
    os.path.join(datadir, 'ChestXRLungSegmentation/JSRT/processed/images/'),
    os.path.join(datadir, 'ChestXRLungSegmentation/ChinaSet/processed/images/'),
    os.path.join(datadir, 'ChestXRLungSegmentation/Montgomery/processed/images/'),
    # os.path.join(datadir, 'ChestXRLungSegmentation/VinDr/v1/processed/train/images/'),
    os.path.join(datadir, 'ChestXRLungSegmentation/VinDr/v1/processed/test/images/'),
    # os.path.join(datadir, 'SpineXRVertSegmentation/T62020/20200501/raw/images'),
    # os.path.join(datadir, 'SpineXRVertSegmentation/T62021/20211101/raw/images'),
    # # os.path.join(datadir, 'SpineXRVertSegmentation/VinDr/v1/processed/train/images/'),
    # os.path.join(datadir, 'SpineXRVertSegmentation/VinDr/v1/processed/test/images/'),
]

test_image3d_folders = val_image3d_folders
test_image2d_folders = val_image2d_folders

datamodule = UnpairedDataModule(
    train_image3d_folders=train_image3d_folders,
    train_image2d_folders=train_image2d_folders,
    val_image3d_folders=val_image3d_folders,
    val_image2d_folders=val_image2d_folders,
    test_image3d_folders=test_image3d_folders,
    test_image2d_folders=test_image2d_folders,
    train_samples=8,
    val_samples=8,
    test_samples=8,
    batch_size=8,
    img_shape=256,
    vol_shape=256
)
datamodule.setup()

2392
['data/ChestXRLungSegmentation/MELA2022/raw/train/images/mela_0001.nii.gz']
15951
['data/ChestXRLungSegmentation/ChinaSet/processed/images/CHNCXR_0001_0.png']
2392
['data/ChestXRLungSegmentation/MELA2022/raw/train/images/mela_0001.nii.gz']
3951
['data/ChestXRLungSegmentation/ChinaSet/processed/images/CHNCXR_0001_0.png']
2392
['data/ChestXRLungSegmentation/MELA2022/raw/train/images/mela_0001.nii.gz']
3951
['data/ChestXRLungSegmentation/ChinaSet/processed/images/CHNCXR_0001_0.png']


In [None]:
iterator = tqdm.tqdm(range(2000))
for n_batch in iterator:
    optimizer.zero_grad()

    frame = train_data_collated[n_batch % len(dataset_map.train)]
    out = gm(**frame, evaluation_mode=EvaluationMode.TRAINING)
    out["objective"].backward()
    if n_batch % 100 == 0:
        iterator.set_postfix_str(f"loss: {float(out['objective']):.5f}")
    optimizer.step()
