In [1]:
! python -c "import onnxruntime" || pip install -q onnxruntime
import onnxruntime
import skimage.io
import skimage.transform
import torch
import numpy as np

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule
from pytorch_lightning import LightningDataModule
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

from main_nerv_inv import make_cameras_dea, DXRLightningModule

In [2]:
import os
from datamodule import UnpairedDataModule

from typing import NamedTuple, Optional, Union

class Hparams(NamedTuple):
    lr: float=1e-4
    ckpt: str="temp"
    datadir: str=None
    strict: str=None
    alpha: float=1.0
    gamma: float=1.0
    delta: float=1.0
    theta: float=1.0
    omega: float=1.0
    lamda: float=1.0
    timesteps: int=1000
    logsdir: str=None
    sh: int=0
    pe: int=0

    n_pts_per_ray: int=400
    weight_decay: float=1e-2
    devices: torch.device=torch.device('cpu')
    backbone: str="efficientnet-b7"
    train_samples: int=100
    val_samples: int=100
    test_samples: int=100
    img_shape: int=256
    vol_shape: int=256
    fov_depth: int=256
    batch_size: int=1

hparams = Hparams(
    ckpt=None,
    datadir="/home/quantm/data", 
    train_samples=1, 
    val_samples=1, 
    test_samples=1,
    img_shape=256,
    vol_shape=256,
    batch_size=1 
)

# Create data module
train_image3d_folders = [
    os.path.join(hparams.datadir, "ChestXRLungSegmentation/NSCLC/processed/train/images"),
    os.path.join(hparams.datadir, "ChestXRLungSegmentation/MOSMED/processed/train/images/CT-0",),
    os.path.join(hparams.datadir, "ChestXRLungSegmentation/MOSMED/processed/train/images/CT-1",),
    os.path.join(hparams.datadir, "ChestXRLungSegmentation/MOSMED/processed/train/images/CT-2",),
    os.path.join(hparams.datadir, "ChestXRLungSegmentation/MOSMED/processed/train/images/CT-3",),
    os.path.join(hparams.datadir, "ChestXRLungSegmentation/MOSMED/processed/train/images/CT-4",),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/Imagenglab/processed/train/images'),
    os.path.join(hparams.datadir, "ChestXRLungSegmentation/MELA2022/raw/train/images"),
    os.path.join(hparams.datadir, "ChestXRLungSegmentation/MELA2022/raw/val/images"),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/AMOS2022/raw/train/images'),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/AMOS2022/raw/val/images'),
]

train_label3d_folders = []

train_image2d_folders = [
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/JSRT/processed/images/'),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/ChinaSet/processed/images/'),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/Montgomery/processed/images/'),
    os.path.join(hparams.datadir, "ChestXRLungSegmentation/VinDr/v1/processed/train/images/"),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/VinDr/v1/processed/test/images/'),
]

train_label2d_folders = []

val_image3d_folders = [
    # os.path.join(hparams.datadir, "ChestXRLungSegmentation/NSCLC/processed/train/images"),
    # os.path.join(hparams.datadir, "ChestXRLungSegmentation/MOSMED/processed/train/images/CT-0",),
    # os.path.join(hparams.datadir, "ChestXRLungSegmentation/MOSMED/processed/train/images/CT-1",),
    # os.path.join(hparams.datadir, "ChestXRLungSegmentation/MOSMED/processed/train/images/CT-2",),
    # os.path.join(hparams.datadir, "ChestXRLungSegmentation/MOSMED/processed/train/images/CT-3",),
    # os.path.join(hparams.datadir, "ChestXRLungSegmentation/MOSMED/processed/train/images/CT-4",),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/Imagenglab/processed/train/images'),
    # os.path.join(hparams.datadir, "ChestXRLungSegmentation/MELA2022/raw/train/images"),
    # os.path.join(hparams.datadir, "ChestXRLungSegmentation/MELA2022/raw/val/images"),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/AMOS2022/raw/train/images'),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/AMOS2022/raw/val/images'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/TCIA/CT-Covid-19-2020/images'),
    os.path.join(hparams.datadir, 'ChestXRLungSegmentation/TCIA/CT-Covid-19-2021/images'),
]

val_image2d_folders = [
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/JSRT/processed/images/'),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/ChinaSet/processed/images/'),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/Montgomery/processed/images/'),
    # os.path.join(hparams.datadir, 'ChestXRLungSegmentation/VinDr/v1/processed/train/images/'),
    os.path.join(hparams.datadir, "ChestXRLungSegmentation/VinDr/v1/processed/test/images/"),
]

test_image3d_folders = val_image3d_folders
test_image2d_folders = val_image2d_folders


dm = UnpairedDataModule(
    train_image3d_folders=train_image3d_folders,
    train_image2d_folders=train_image2d_folders,
    val_image3d_folders=val_image3d_folders,
    val_image2d_folders=val_image2d_folders,
    test_image3d_folders=test_image3d_folders,
    test_image2d_folders=test_image2d_folders,
    train_samples=hparams.train_samples,
    val_samples=hparams.val_samples,
    test_samples=hparams.test_samples,
    batch_size=hparams.batch_size,
    img_shape=hparams.img_shape,
    vol_shape=hparams.vol_shape
)
dm.setup()
# for idx, batch in enumerate(dm.val_dataloader()):
#     image3d = batch["image3d"]
#     print(image3d.shape)

2392
['/home/quantm/data/ChestXRLungSegmentation/MELA2022/raw/train/images/mela_0001.nii.gz']
15000
['/home/quantm/data/ChestXRLungSegmentation/VinDr/v1/processed/train/images/000434271f63a053c4128a0ba6352c7f.png']
771
['/home/quantm/data/ChestXRLungSegmentation/TCIA/CT-Covid-19-2020/images/volume-covid19-A-0000.nii.gz']
3000
['/home/quantm/data/ChestXRLungSegmentation/VinDr/v1/processed/test/images/002a34c58c5b758217ed1f584ccbcfe9.png']
771
['/home/quantm/data/ChestXRLungSegmentation/TCIA/CT-Covid-19-2020/images/volume-covid19-A-0000.nii.gz']
3000
['/home/quantm/data/ChestXRLungSegmentation/VinDr/v1/processed/test/images/002a34c58c5b758217ed1f584ccbcfe9.png']


In [4]:
model = DXRLightningModule(hparams=hparams)
ckptpath = '/home/quantm/epoch=99-step=12500.ckpt'
model = model.load_from_checkpoint(
    checkpoint_path=ckptpath
)

image2d = torch.randn((1, 1, 256, 256))
dist_hidden = 6.0 * torch.ones(1, device=torch.device('cpu'))
elev_hidden = torch.zeros(1, device=torch.device('cpu'))
azim_hidden = torch.zeros(1, device=torch.device('cpu'))
view_hidden = make_cameras_dea(dist_hidden, elev_hidden, azim_hidden, fov=18, znear=4, zfar=8)

onnxpath = ckptpath.replace('.ckpt', '.onnx')
model.to_onnx(onnxpath, image2d, export_params=True)

verbose: False, log level: Level.ERROR



NotImplementedError: Module [DXRLightningModule] is missing the required "forward" function