In [5]:
class ConvBlock(nn.Module):
    """Block of two downsample 3D convolution layers
    
    Attributes:
        in_channels:
        out_channels:

    """
    
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1,
                 use_batch_norm=True, is_bottleneck=False, pool_kernel_size=2, pool_strid=2):
        """Initialized the block"""
        super(ConvBlock, self).__init__()
        mid_channels = out_channels//2
        self.use_batch_norm = use_batch_norm
        self.is_bottleneck = is_bottleneck

        self.conv_1 = nn.Conv3d(in_channels, mid_channels, kernel_size=kernel_size, padding=padding)
        if use_batch_norm: self.bn_1 = nn.BatchNorm3d(mid_channels)
        self.relu_1 = nn.ReLU()
        self.conv_2 = nn.Conv3d(mid_channels, out_channels, kernel_size=kernel_size, padding=padding)
        if use_batch_norm: self.bn_2 = nn.BatchNorm3d(out_channels)
        self.relu_2 = nn.ReLU()
        print(f'in and out: {in_channels, out_channels}')
        if not is_bottleneck: self.pooling = nn.MaxPool3d(pool_kernel_size, stride=pool_strid)
    
    def forward(self, x):
        if self.use_batch_norm:
            res = self.relu_1(self.bn_1(self.conv_1(x)))
            res = self.relu_2(self.bn_2(self.conv_2(res)))
        else:
            res = self.relu_1(self.conv_1(x))
            res = self.relu_2(self.conv_2(res))
        
        if not self.is_bottleneck:
            out = self.pooling(res)
        else:
            out = res

        return(out, res)


class UpsampleBlock(nn.Module):
    """Block of two 3D upsample layers

    Attributes:
        in_channels:
        out_channels:
    """

    def __init__(self, in_channels, res_channels, up_kernel_size=2, up_stride_size=2,
                 kernel_size=3, padding=1, is_output=False, num_classes=2):
        super(UpsampleBlock, self).__init__()
        self.res_channels = res_channels
        mid_channels = in_channels//2
        self.is_output = is_output

        self.conv_trans = nn.ConvTranspose3d(in_channels, in_channels, kernel_size=up_kernel_size, stride=up_stride_size)
        self.conv_1 = nn.Conv3d(in_channels+res_channels, mid_channels, kernel_size=kernel_size, padding=padding)
        self.bn_1 = nn.BatchNorm3d(mid_channels)
        self.relu_1 = nn.ReLU()
        self.conv_2 = nn.Conv3d(mid_channels, mid_channels, kernel_size=kernel_size, padding=padding)
        self.bn_2 = nn.BatchNorm3d(mid_channels)
        self.relu_2 = nn.ReLU()
        print(f'in, res and out: {in_channels, res_channels, mid_channels}')
        if is_output:
            self.conv_3 = nn.Conv3d(mid_channels, num_classes, kernel_size=1)
            print(f'output numbner of classes: {num_classes}')
        
    def forward(self, x, res):
        #assert res.size()[0] == self.res_channels, "residual input channels not equal to res_channels!"
        out = self.conv_trans(x)
        if res is not None: 
            out = torch.cat((out, res), 1)
        out = self.relu_1(self.bn_1(self.conv_1(out)))
        out = self.relu_2(self.bn_2(self.conv_2(out)))
        if self.is_output: out = self.conv_3(out)
        return(out)
    

class UNet3D(nn.Module):
    """3D U-Net model
    
    Dynamic 3D U-Net model for semantic segmentation
    will auto-adjust depth and size given different block_channels.
    
    Attributes:
        in_channels: number of channels for input data.
        num_classes: number of classes to indentify
        block_channels: list or tuple, numbers of channels during downsampleing
          numbers of channels during upsampling are reversed of this list/tuple
          default [64, 128, 256, 512]
    """
    
    def __init__(self, in_channels, num_classes, block_channels=[64, 128, 256, 512]):
        super(UNet3D, self).__init__()
        self.conv_blocks = []
        
        # add conv blocks
        self.conv_blocks.append(ConvBlock(in_channels, block_channels[0])) # first layer
        for i in range(len(block_channels)-2):
            self.conv_blocks.append(ConvBlock(block_channels[i], block_channels[i+1]))
        self.conv_blocks.append(ConvBlock(block_channels[-2], block_channels[-1], is_bottleneck=True)) # bottlenect block, no pooling
        
        # add upsample blocks
        self.upsample_blocks = []
        for i in range(len(block_channels)-1, 1, -1):
            self.upsample_blocks.append(UpsampleBlock(block_channels[i], block_channels[i-1]))
        self.upsample_blocks.append(UpsampleBlock(block_channels[1], block_channels[0], is_output=True, num_classes=num_classes)) # output block

    def forward(self, input):
        out = input
        res_list = []
        for block in self.conv_blocks[:-1]:
            out, res = block(out)
            res_list.append(res)
        out = self.conv_blocks[-1](out)[0] # bottleneck block, no maxpool, res is out     
        assert len(self.upsample_blocks) == len(res_list), "number of upsample blocks and number of residuals don't match!"
        for block, res in zip(self.upsample_blocks, reversed(res_list)):
            out = block(out, res)
        
        return(out)



In [6]:
class UNet3DFixed(nn.Module):
    """
    The 3D UNet model
    -- __init__()
    :param in_channels -> number of input channels
    :param num_classes -> specifies the number of output channels or masks for different classes
    :param level_channels -> the number of channels at each level (count top-down)
    :param bottleneck_channel -> the number of bottleneck channels 
    :param device -> the device on which to run the model
    -- forward()
    :param input -> input Tensor
    :return -> Tensor
    """
    
    def __init__(self, in_channels, num_classes, level_channels=[64, 128, 256], bottleneck_channel=512) -> None:
        super(UNet3DFixed, self).__init__()
        level_1_chnls, level_2_chnls, level_3_chnls = level_channels[0], level_channels[1], level_channels[2]
        self.a_block1 = ConvBlock(in_channels=in_channels, out_channels=level_1_chnls)
        self.a_block2 = ConvBlock(in_channels=level_1_chnls, out_channels=level_2_chnls)
        self.a_block3 = ConvBlock(in_channels=level_2_chnls, out_channels=level_3_chnls)
        self.bottleNeck = ConvBlock(in_channels=level_3_chnls, out_channels=bottleneck_channel, is_bottleneck=True)
        self.s_block3 = UpsampleBlock(in_channels=bottleneck_channel, res_channels=level_3_chnls)
        self.s_block2 = UpsampleBlock(in_channels=level_3_chnls, res_channels=level_2_chnls)
        self.s_block1 = UpsampleBlock(in_channels=level_2_chnls, res_channels=level_1_chnls, num_classes=num_classes, is_output=True)

    
    def forward(self, input):
        #Analysis path forward feed
        out, residual_level1 = self.a_block1(input)
        out, residual_level2 = self.a_block2(out)
        out, residual_level3 = self.a_block3(out)
        out, _ = self.bottleNeck(out)

        #Synthesis path forward feed
        out = self.s_block3(out, residual_level3)
        out = self.s_block2(out, residual_level2)
        out = self.s_block1(out, residual_level1)
        return out

In [7]:
model = UNet3D(in_channels=3, num_classes=1)
model_fixed = UNet3DFixed(in_channels=3, num_classes=1)

in and out: (3, 64)
in and out: (64, 128)
in and out: (128, 256)
in and out: (256, 512)
in, res and out: (512, 256, 256)
in, res and out: (256, 128, 128)
in, res and out: (128, 64, 64)
output numbner of classes: 1
in and out: (3, 64)
in and out: (64, 128)
in and out: (128, 256)
in and out: (256, 512)
in, res and out: (512, 256, 256)
in, res and out: (256, 128, 128)
in, res and out: (128, 64, 64)
output numbner of classes: 1
