In [1]:
import numpy as np
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from torchinfo import summary

In [2]:
from lhunet import LHUNet

In [3]:
crop_size = (16, 160, 160)
input_channels = 3
num_classes = 4
deep_supervision = False


network = LHUNet(
    spatial_shapes=crop_size,
    in_channels=input_channels,
    out_channels=num_classes,
    do_ds=deep_supervision,
    # encoder params
    cnn_kernel_sizes=[3, 3],
    cnn_features=[8, 16],
    cnn_strides=[[1, 2, 2], [1, 2, 2]],
    cnn_maxpools=[True, True],
    cnn_dropouts=0.0,
    cnn_blocks="nn",  # n= resunet, d= deformconv, b= basicunet,
    hyb_kernel_sizes=[3, 3, 3],
    hyb_features=[16, 32, 64],
    hyb_strides=[[1, 2, 2], 2, 2],
    hyb_maxpools=[True, True, True],
    hyb_cnn_dropouts=0.0,
    hyb_tf_proj_sizes=[64, 32, 0],
    hyb_tf_repeats=[1, 1, 1],
    hyb_tf_num_heads=[4, 4, 4],
    hyb_tf_dropouts=0.1,
    hyb_cnn_blocks="nnn",  # n= resunet, d= deformconv, b= basicunet,
    hyb_vit_blocks="MMM",  # s= dlka_special_v2, S= dlka_sp_seq, c= dlka_channel_v2, C= dlka_ch_seq,
    # hyb_vit_sandwich= False,
    hyb_skip_mode="cat",  # "sum" or "cat",
    hyb_arch_mode="residual",  # sequential, residual, parallel, collective,
    hyb_res_mode="sum",  # "sum" or "cat",
    # decoder params
    dec_hyb_tcv_kernel_sizes=[5, 5, 5],
    dec_cnn_tcv_kernel_sizes=[5, 5],
    dec_cnn_blocks=None,
    dec_tcv_bias=False,
    dec_hyb_tcv_bias=False,
    dec_hyb_kernel_sizes=None,
    dec_hyb_features=None,
    dec_hyb_cnn_dropouts=None,
    dec_hyb_tf_proj_sizes=None,
    dec_hyb_tf_repeats=None,
    dec_hyb_tf_num_heads=None,
    dec_hyb_tf_dropouts=None,
    dec_cnn_kernel_sizes=None,
    dec_cnn_features=None,
    dec_cnn_dropouts=None,
    dec_hyb_cnn_blocks=None,
    dec_hyb_vit_blocks=None,
    # dec_hyb_vit_sandwich= None,
    dec_hyb_skip_mode=None,
    dec_hyb_arch_mode="collective",  # sequential, residual, parallel, collective, sequential-lite,
    dec_hyb_res_mode=None,
)

network = network.cuda()

In [4]:
crop_size = (64, 128, 128)
input_channels = 3
num_classes = 4
deep_supervision = False


network = LHUNet(
    spatial_shapes=crop_size,
    in_channels=input_channels,
    out_channels=num_classes,
    do_ds=deep_supervision,
    # encoder params
    cnn_kernel_sizes=[3, 3],
    cnn_features=[8, 16],
    cnn_strides=[[1, 2, 2], 2],
    cnn_maxpools=[True, True],
    cnn_dropouts=0.0,
    cnn_blocks="nn",  # n= resunet, d= deformconv, b= basicunet,
    hyb_kernel_sizes=[3, 3, 3],
    hyb_features=[16, 32, 64],
    hyb_strides=[2, 2, 2],
    hyb_maxpools=[True, True, True],
    hyb_cnn_dropouts=0.0,
    hyb_tf_proj_sizes=[64, 32, 0],
    hyb_tf_repeats=[1, 1, 1],
    hyb_tf_num_heads=[4, 4, 4],
    hyb_tf_dropouts=0.1,
    hyb_cnn_blocks="nnn",  # n= resunet, d= deformconv, b= basicunet,
    hyb_vit_blocks="MMM",  # s= dlka_special_v2, S= dlka_sp_seq, c= dlka_channel_v2, C= dlka_ch_seq,
    # hyb_vit_sandwich= False,
    hyb_skip_mode="cat",  # "sum" or "cat",
    hyb_arch_mode="residual",  # sequential, residual, parallel, collective,
    hyb_res_mode="sum",  # "sum" or "cat",
    # decoder params
    dec_hyb_tcv_kernel_sizes=[5, 5, 5],
    dec_cnn_tcv_kernel_sizes=[5, 5],
    dec_cnn_blocks=None,
    dec_tcv_bias=False,
    dec_hyb_tcv_bias=False,
    dec_hyb_kernel_sizes=None,
    dec_hyb_features=None,
    dec_hyb_cnn_dropouts=None,
    dec_hyb_tf_proj_sizes=None,
    dec_hyb_tf_repeats=None,
    dec_hyb_tf_num_heads=None,
    dec_hyb_tf_dropouts=None,
    dec_cnn_kernel_sizes=None,
    dec_cnn_features=None,
    dec_cnn_dropouts=None,
    dec_hyb_cnn_blocks=None,
    dec_hyb_vit_blocks=None,
    # dec_hyb_vit_sandwich= None,
    dec_hyb_skip_mode=None,
    dec_hyb_arch_mode="collective",  # sequential, residual, parallel, collective, sequential-lite,
    dec_hyb_res_mode=None,
)

network = network.cuda()

In [7]:
x = torch.randn(4, input_channels, *crop_size).to("cuda")
y = network(x)
# print(y.shape)

torch.Size([4, 16, 16, 16, 16])
torch.Size([4, 16, 1, 1, 1])
torch.Size([4, 16, 16, 16, 16])
torch.Size([4, 32, 8, 8, 8])
torch.Size([4, 32, 1, 1, 1])
torch.Size([4, 32, 8, 8, 8])
torch.Size([4, 64, 4, 4, 4])
torch.Size([4, 64, 1, 1, 1])
torch.Size([4, 64, 4, 4, 4])
torch.Size([4, 32, 8, 8, 8])
torch.Size([4, 32, 1, 1, 1])
torch.Size([4, 32, 8, 8, 8])
torch.Size([4, 16, 16, 16, 16])
torch.Size([4, 16, 1, 1, 1])
torch.Size([4, 16, 16, 16, 16])
torch.Size([4, 16, 32, 32, 32])
torch.Size([4, 16, 1, 1, 1])
torch.Size([4, 16, 32, 32, 32])


In [6]:
model_parameters = filter(lambda p: p.requires_grad, network.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(params)

1212271
