In [1]:
import subprocess
import shutil
import torch
import os
import numpy as np
from scipy.ndimage import label
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
import nibabel as nib
import argparse
from totalsegmentator.python_api import totalsegmentator

nnUNet_raw is not defined and nnU-Net can only be used on data for which preprocessed files are already present on your system. nnU-Net cannot be used for experiment planning and preprocessing like this. If this is not intended, please read documentation/setting_up_paths.md for information on how to set this up properly.
nnUNet_preprocessed is not defined and nnU-Net can not be used for preprocessing or training. If this is not intended, please read documentation/setting_up_paths.md for information on how to set this up.
nnUNet_results is not defined and nnU-Net cannot be used for training or inference. If this is not intended behavior, please read documentation/setting_up_paths.md for information on how to set this up.


In [2]:
from torchinfo import summary
from torchviz import make_dot
from torchview import draw_graph

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cpu


# 3D UNet

In [4]:
predictor_3d = nnUNetPredictor(
	tile_step_size=0.5,
	use_gaussian=True,
	use_mirroring=True,
	perform_everything_on_device=True,
	device=device,
	verbose=False,
	verbose_preprocessing=False,
	allow_tqdm=True,
)

predictor_3d.initialize_from_trained_model_folder(
	'models/multiclass/nnUNetTrainerDA5__nnUNetResEncUNetLPlans__3d_fullres',
	use_folds=(0, 1, 2, 3, 4),
	checkpoint_name='checkpoint_best.pth',
)

perform_everything_on_device=True is only supported for cuda devices! Setting this to False


In [None]:
model = predictor_3d.network

In [6]:
print(model)

ResidualEncoderUNet(
  (encoder): ResidualEncoder(
    (stem): StackedConvBlocks(
      (convs): Sequential(
        (0): ConvDropoutNormReLU(
          (conv): Conv3d(2, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
          (norm): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
          (all_modules): Sequential(
            (0): Conv3d(2, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
            (1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (2): LeakyReLU(negative_slope=0.01, inplace=True)
          )
        )
      )
    )
    (stages): Sequential(
      (0): StackedResidualBlocks(
        (blocks): Sequential(
          (0): BasicBlockD(
            (conv1): ConvDropoutNormReLU(
              (conv): Conv3d(32, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
              

In [7]:
summary(model)

Layer (type:depth-idx)                                                      Param #
ResidualEncoderUNet                                                         --
├─ResidualEncoder: 1-1                                                      --
│    └─StackedConvBlocks: 2-1                                               --
│    │    └─Sequential: 3-1                                                 672
│    └─Sequential: 2-2                                                      --
│    │    └─StackedResidualBlocks: 3-2                                      18,624
│    │    └─StackedResidualBlocks: 3-3                                      206,080
│    │    └─StackedResidualBlocks: 3-4                                      3,329,280
│    │    └─StackedResidualBlocks: 3-5                                      20,391,424
│    │    └─StackedResidualBlocks: 3-6                                      32,718,720
│    │    └─StackedResidualBlocks: 3-7                                      33,189,120
│    │

In [None]:
x = torch.randn(1, 2, 32, 512, 512)  # adjust input size
y = model(x)

dot = make_dot(y, params=dict(model.named_parameters()))
dot.render('model_graph', format='png')

In [10]:
x = torch.randn(1, 1, 32, 128, 128)

graph = draw_graph(model, input_data=x, expand_nested=True)
graph.visual_graph.render('unet_torchview', format='png')




'unet_torchview.png'

In [34]:
model.decoder

UNetDecoder(
  (encoder): PlainConvEncoder(
    (stages): Sequential(
      (0): Sequential(
        (0): StackedConvBlocks(
          (convs): Sequential(
            (0): ConvDropoutNormReLU(
              (conv): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
              (all_modules): Sequential(
                (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
                (2): LeakyReLU(negative_slope=0.01, inplace=True)
              )
            )
            (1): ConvDropoutNormReLU(
              (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_runni

In [15]:
import inspect

In [18]:
function_source = inspect.getsource(model.forward)
print('Source code of model.forward:')
print(function_source)


function_source = inspect.getsource(model.encoder.forward)
print('Source code of model.encoder.forward:')
print(function_source)


function_source = inspect.getsource(model.decoder.forward)
print('Source code of model.decoder.forward:')
print(function_source)

Source code of model.forward:
    def forward(self, x):
        skips = self.encoder(x)
        return self.decoder(skips)

Source code of model.encoder.forward:
    def forward(self, x):
        ret = []
        for s in self.stages:
            x = s(x)
            ret.append(x)
        if self.return_skips:
            return ret
        else:
            return ret[-1]

Source code of model.decoder.forward:
    def forward(self, skips):
        """
        we expect to get the skips in the order they were computed, so the bottleneck should be the last entry
        :param skips:
        :return:
        """
        lres_input = skips[-1]
        seg_outputs = []
        for s in range(len(self.stages)):
            x = self.transpconvs[s](lres_input)
            x = torch.cat((x, skips[-(s+2)]), 1)
            x = self.stages[s](x)
            if self.deep_supervision:
                seg_outputs.append(self.seg_layers[s](x))
            elif s == (len(self.stages) - 1):
          

# 2D UNet

In [19]:
predictor_2d = nnUNetPredictor(
	tile_step_size=0.5,
	use_gaussian=True,
	use_mirroring=True,
	perform_everything_on_device=True,
	device=device,
	verbose=False,
	verbose_preprocessing=False,
	allow_tqdm=True,
)
# CHARITE 2d
predictor_2d.initialize_from_trained_model_folder(
	'models/prelabeling/nnUNetTrainer__nnUNetPlans__2d',
	use_folds=(0,),
	checkpoint_name='checkpoint_best.pth',
)
model = predictor_2d.network

perform_everything_on_device=True is only supported for cuda devices! Setting this to False




In [20]:
print(model)

PlainConvUNet(
  (encoder): PlainConvEncoder(
    (stages): Sequential(
      (0): Sequential(
        (0): StackedConvBlocks(
          (convs): Sequential(
            (0): ConvDropoutNormReLU(
              (conv): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
              (all_modules): Sequential(
                (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
                (2): LeakyReLU(negative_slope=0.01, inplace=True)
              )
            )
            (1): ConvDropoutNormReLU(
              (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_run

In [21]:
summary(model)

Layer (type:depth-idx)                                                 Param #
PlainConvUNet                                                          --
├─PlainConvEncoder: 1-1                                                --
│    └─Sequential: 2-1                                                 --
│    │    └─Sequential: 3-1                                            9,696
│    │    └─Sequential: 3-2                                            55,680
│    │    └─Sequential: 3-3                                            221,952
│    │    └─Sequential: 3-4                                            886,272
│    │    └─Sequential: 3-5                                            3,542,016
│    │    └─Sequential: 3-6                                            4,721,664
│    │    └─Sequential: 3-7                                            4,721,664
│    │    └─Sequential: 3-8                                            4,721,664
├─UNetDecoder: 1-2                                            

In [27]:
x = torch.randn(1, 1, 512, 512)  # adjust input size
y = model(x)

dot = make_dot(y, params=dict(model.named_parameters()))
dot.render('model_graph_2d', format='png')

'model_graph_2d.png'

In [29]:
x = torch.randn(1, 1, 512, 512)

graph = draw_graph(model, input_data=x, expand_nested=True)
graph.visual_graph.render('unet_torchview_2d', format='png')




'unet_torchview_2d.png'

In [33]:
model.decoder

UNetDecoder(
  (encoder): PlainConvEncoder(
    (stages): Sequential(
      (0): Sequential(
        (0): StackedConvBlocks(
          (convs): Sequential(
            (0): ConvDropoutNormReLU(
              (conv): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
              (all_modules): Sequential(
                (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
                (2): LeakyReLU(negative_slope=0.01, inplace=True)
              )
            )
            (1): ConvDropoutNormReLU(
              (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_runni

In [31]:
import inspect

In [32]:
function_source = inspect.getsource(model.forward)
print('Source code of model.forward:')
print(function_source)


function_source = inspect.getsource(model.encoder.forward)
print('Source code of model.encoder.forward:')
print(function_source)


function_source = inspect.getsource(model.decoder.forward)
print('Source code of model.decoder.forward:')
print(function_source)

Source code of model.forward:
    def forward(self, x):
        skips = self.encoder(x)
        return self.decoder(skips)

Source code of model.encoder.forward:
    def forward(self, x):
        ret = []
        for s in self.stages:
            x = s(x)
            ret.append(x)
        if self.return_skips:
            return ret
        else:
            return ret[-1]

Source code of model.decoder.forward:
    def forward(self, skips):
        """
        we expect to get the skips in the order they were computed, so the bottleneck should be the last entry
        :param skips:
        :return:
        """
        lres_input = skips[-1]
        seg_outputs = []
        for s in range(len(self.stages)):
            x = self.transpconvs[s](lres_input)
            x = torch.cat((x, skips[-(s+2)]), 1)
            x = self.stages[s](x)
            if self.deep_supervision:
                seg_outputs.append(self.seg_layers[s](x))
            elif s == (len(self.stages) - 1):
          