In [None]:
''' 
This notebook assumes the following package versions:

MONAI version: 0.2.0+166.g12b3fbf
Python version: 3.6.10 |Anaconda, Inc.| (default, May  8 2020, 02:54:21)  [GCC 7.3.0]
Numpy version: 1.19.5
Pytorch version: 1.7.1

Optional dependencies:
Pytorch Ignite version: 0.3.0
Nibabel version: 3.1.1
scikit-image version: 0.15.0
Pillow version: 7.2.0
Tensorboard version: 1.15.0+nv
gdown version: 3.12.2
TorchVision version: 0.8.0a0
ITK version: 5.1.1

Later MONAI/PyTorch versions are likely to have slight changes in syntax.
'''

In [None]:
import logging
import os
import shutil
import sys
import time
import tempfile
from glob import glob

import nibabel as nib
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from matplotlib import pylab as plt

import monai
from monai.networks.layers import Norm
from monai.data import create_test_image_3d, list_data_collate, ITKReader
from monai.inferers import sliding_window_inference
from monai.inferers import SimpleInferer
from monai.metrics import DiceMetric
from monai.transforms import (
    AsChannelFirstd,
    AsChannelLastd,
    AddChanneld,
    RandAdjustContrastd,
    Compose,
    DivisiblePadd,
    LoadNiftid,
    LoadImaged,
    RandCropByPosNegLabeld,
    RandRotated,
    RandZoomd,
    RandFlipd,
    RandShiftIntensityd,
    RandScaleIntensityd,
    RandAffined,
    Rand3DElasticd,
    RandGaussianNoised,
    ScaleIntensityd,
    SpatialPadd,
    ToTensord,
    DataStats,
)
from monai.utils import first, set_determinism
from monai.visualize import plot_2d_or_3d_image
import itk

monai.config.print_config()

In [None]:
class IEVnetInferer():
    '''
        instantiate ievnet=IEVnetInferer()
        call ievnet.predict(filenames_in, filenames_out) for prediction of a list of volumes
        assumes input and writes output volumes of dimensions [200,150,100] at 0.2mm resolution
    '''
    def __init__(self):
        print('Loading IEVNet model...')
        # using MONAI Unet implementation for cuda optimizations
        # parametrize UNet like VNet (see: https://arxiv.org/abs/1606.04797): 
        #   - 4x downsampling
        #   - 16/32/64/128/256 filter channels, 
        #   - down/up-convolutions with stride 2 instead of max-pooling/up-pooling
        #   - number of residual units in each layer: 2 
        #   - Dice loss
        self.model = monai.networks.nets.UNet(
                        dimensions=3,
                        in_channels=1,
                        out_channels=1,
                        channels=(16, 32, 64, 128, 256),
                        strides=(2, 2, 2, 2),
                        dropout=0.5,
                        num_res_units=2).to(device)
        self.model.load_state_dict(torch.load('best_metric_model.pth'))
        self.model.eval()
        
        self.transforms = Compose([LoadNifti(),
                                   AddChannel(),
                                   SpatialPad(spatial_size=[208, 160, 112]),
                                   ScaleIntensity(),
                                   ToTensor(),])
        
        # post-processing: center-cropping back to 200,150,100 voxels
        self.cropper = monai.transforms.CenterSpatialCrop([200, 150, 100])
        
        print('Done loading model - IEVNet ready.')
    
    def predict(self, filenames_in, filenames_out):
        # create dataset and loader on the fly
        dataset = monai.data.ArrayDataset(data=filenames_in, transform=self.transforms)
        loader = DataLoader(dataset, batch_size=1, collate_fn=list_data_collate)
        saver = monai.data.NiftiSaver(resample=False)
        
        with torch.no_grad():
            # run seg inference (only one file in dataset)
            for idx, img in enumerate(dataset): 
                # predict
                pred_raw = self.model(img)
                # apply sigmoid
                pred_sigmoid = torch.nn.Sigmoid(pred_raw)
                # post-processing of results
                pred = self.cropper(pred_sigmoid)
                # export
                # get original ITK image header info
                reader = monai.data.ITKReader()
                img = reader.read(ff_img)
                img_arr, img_hdr = reader.get_data(img)
                meta_data = img_hdr
                meta_data['filename_or_obj'] = filenames_out[idx]
                saver.save(np_sigmoid(pred),meta_data)


In [None]:
# sample inference on two example volumes
ievnet = IEVnetInferer()
filenames_in  = ['/path/to/my/testfile/vol_01.nii.gz', 
                 '/path/to/my/testfile/vol_02.nii.gz']
filenames_out = ['/path/to/my/testfile/pred_01.nii.gz', 
                 '/path/to/my/testfile/pred_02.nii.gz']
ievnet.predict(filenames_in, filenames_out)