In [14]:
#3D

from turtle import forward
import torch.nn as nn
import torchvision.models
import torch
import torch.nn.functional as F

class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1):
    super(ConvBlock, self).__init__()
    self.conv3D = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
    self.batch_norm = nn.BatchNorm3d(num_features=out_channels)

  def forward(self, x):
    x = self.batch_norm(self.conv3D(x))
    x = F.relu(x)

    return(x)


class Encoder(nn.Module):
  def __init__(self, in_channels, model_depth = 4, pool_size = 2):
    super(Encoder, self).__init__()
    self.root_feat_maps = 16
    self.num_conv_block = 2
    self.module_dict = nn.ModuleDict()

    for depth in range(model_depth):
      feat_map_channels = 2**(depth+1)*self.root_feat_maps #2*16, 4*16, 8*16, 16*16

      for i in range(self.num_conv_block):
        if depth == 0:
          self.conv_block = ConvBlock(in_channels=in_channels, out_channels=feat_map_channels) 
          self.module_dict['conv_{}_{}'.format(depth, i)] = self.conv_block
          in_channels, feat_map_channels = feat_map_channels, feat_map_channels*2
        else:
          self.conv_block = ConvBlock(in_channels=in_channels, out_channels=feat_map_channels)
          self.module_dict['conv_{}_{}'.format(depth, i)] = self.conv_block
          in_channels, feat_map_channels = feat_map_channels, feat_map_channels*2

      if depth == model_depth - 1:  #depth = 3
        break
      else:
        self.pooling = nn.MaxPool3d(kernel_size=pool_size, stride=2, padding=0)
        self.module_dict['max_pooling_{}'.format(depth)] = self.pooling

  def forward(self, x):
    down_sampling_feature = []
    for k, op in self.module_dict.items():
      print('k: ', k, 'op: ', op)

      if k.startswith('conv'):
        print ('input', x.shape)
        x = op(x)
        print ('output', x.shape)
        # print (k, x.shape)
        if k.endswith('1'):
          down_sampling_feature.append(x)

      elif k.startswith('max_pooling'):
        print ('input', x.shape)
        x = op(x)
        print ('output', x.shape)
        # print(k, x.shape)

    return x, down_sampling_feature

class ConvTranspose(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1, output_padding = 1):
    super(ConvTranspose, self).__init__()
    self.conv3d_transpose = nn.ConvTranspose3d(in_channels=in_channels, out_channels=out_channels,kernel_size=kernel_size,
    stride=stride, padding=padding, output_padding = output_padding)

  def forward(self, x):
    return(self.conv3d_transpose(x))

class Decoder(nn.Module):
  def __init__(self, out_channels, model_depth = 4):
    super(Decoder, self).__init__()
    self.num_conv_blocks = 2
    self.num_feat_maps = 16
    self.module_dict = nn.ModuleDict()

    for depth in range(model_depth-2, -1, -1): #2, 1, 0
      feat_map_channels = 2**(depth+1)*self.num_feat_maps #128, 64, 32
      self.deconv = ConvTranspose(in_channels=feat_map_channels*4, out_channels=feat_map_channels*4) #512, 256, 128
      self.module_dict['deconv_{}'.format(depth)] = self.deconv
      for i in range(self.num_conv_blocks):
        if i == 0:
          self.conv = ConvBlock(in_channels=feat_map_channels*6, out_channels=feat_map_channels*2) #768/256, 384/128, 192/64
          self.module_dict['conv_{}_{}'.format(depth, i)] = self.conv
        else:
          self.conv = ConvBlock(in_channels=feat_map_channels*2, out_channels=feat_map_channels*2) #256/256, 128/128, 64/64
          self.module_dict['conv_{}_{}'.format(depth, i)] = self.conv
      if depth == 0:
        self.final_conv = ConvBlock(in_channels=feat_map_channels*2, out_channels=out_channels)
        self.module_dict['final_conv'] = self.final_conv

  def forward(self, x, down_sampling_feature):
    for k, op in self.module_dict.items():
      if k.startswith('deconv'):
        x = op(x)
        print ('operation: ',k, 'output shape: ', x.shape) 
        x = torch.cat((down_sampling_feature[int(k[-1])], x), dim = 1)
        print ('concat shape1', down_sampling_feature[int(k[-1])].shape) 
        print ('operation: concat', 'output shape: ', x.shape) 
      elif k.startswith('conv'):
        x = op(x)
        print ('operation: ',k, 'output shape: ', x.shape) 
      else:
        x = op(x)
        print ('operation: ',k, 'output shape: ', x.shape) 
    
    return x


In [10]:
encoder = Encoder(1)
print(encoder)

# if __name__ == '__main__':
#     inputs = torch.randn(1, 1, 8, 64, 64)
#     inputs = inputs.cuda()
#     # print('the shape of input = ', inputs.shape)

#     encoder = Encoder(1)
#     # print(encoder)
#     encoder.cuda()

#     x_test = encoder(inputs)
#     # print('the shape of output = ', x_test.shape)

Encoder(
  (module_dict): ModuleDict(
    (conv_0_0): ConvBlock(
      (conv3D): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (batch_norm): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv_0_1): ConvBlock(
      (conv3D): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (batch_norm): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (max_pooling_0): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv_1_0): ConvBlock(
      (conv3D): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (batch_norm): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv_1_1): ConvBlock(
      (conv3D): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (batch_norm): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, t

In [2]:
# if __name__ == '__main__':
#     x_test = torch.randn(1, 1, 96, 96, 96)
#     x_test = x_test.cuda()
#     print('the shape of input = ', x_test.shape)

#     encoder = Encoder(in_channels=1)
#     encoder.cuda()
#     print(encoder)
#     x_test, h = encoder(x_test)

#     db = Decoder(out_channels=1)
#     db.cuda()
#     x_test = db(x_test, h)

#     print('the shape of output = ', x_test.shape)

In [15]:
class Unet3D(nn.Module):
    def __init__(self, in_channels, out_channels, model_depth = 4, final_activation = 'sigmoid'):
        super(Unet3D, self).__init__()
        self.encoder = Encoder(in_channels=in_channels, model_depth=model_depth)
        self.decoder = Decoder(out_channels=out_channels, model_depth=model_depth)
        if final_activation == 'sigmoid':
            self.f_activation = nn.Sigmoid()
        else:
            self.f_activation = nn.Softmax(dim = 1)

    def forward(self, x):
        x, d_features = self.encoder(x)
        x = self.decoder(x, d_features)
        x = self.f_activation(x)
        print("Final output shape: ", x.shape)

        return x

if __name__ == '__main__':
    inputs = torch.randn(1, 1, 96, 96, 96)
    inputs = inputs.cuda()
    # print('the shape of input = ', inputs.shape)

    model = Unet3D(in_channels=1, out_channels=1)
    model.cuda()

    x_test = model(inputs)

k:  conv_0_0 op:  ConvBlock(
  (conv3D): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (batch_norm): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
input torch.Size([1, 1, 96, 96, 96])
output torch.Size([1, 32, 96, 96, 96])
k:  conv_0_1 op:  ConvBlock(
  (conv3D): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (batch_norm): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
input torch.Size([1, 32, 96, 96, 96])
output torch.Size([1, 64, 96, 96, 96])
k:  max_pooling_0 op:  MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
input torch.Size([1, 64, 96, 96, 96])
output torch.Size([1, 64, 48, 48, 48])
k:  conv_1_0 op:  ConvBlock(
  (conv3D): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (batch_norm): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
input torch.Size([1, 64, 48, 48

In [4]:
# # _*_ coding: utf-8 _*_
# # Author: Jielong
# # @Time: 21/08/2019 15:52
# import numpy as np
# import torch
# import torch.nn as nn
# import torch.nn.functional as F


# class ConvBlock(nn.Module):
#     def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1):
#         super(ConvBlock, self).__init__()
#         self.conv3d = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size,
#                                 stride=stride, padding=padding)
#         self.batch_norm = nn.BatchNorm3d(num_features=out_channels)

#     def forward(self, x):
#         x = self.batch_norm(self.conv3d(x))
#         # x = self.conv3d(x)
#         x = F.elu(x)
#         return x


# class EncoderBlock(nn.Module):
#     def __init__(self, in_channels, model_depth=4, pool_size=2):
#         super(EncoderBlock, self).__init__()
#         self.root_feat_maps = 16
#         self.num_conv_blocks = 2
#         # self.module_list = nn.ModuleList()
#         self.module_dict = nn.ModuleDict()
#         for depth in range(model_depth):
#             feat_map_channels = 2 ** (depth + 1) * self.root_feat_maps
#             for i in range(self.num_conv_blocks):
#                 # print("depth {}, conv {}".format(depth, i))
#                 if depth == 0:
#                     # print(in_channels, feat_map_channels)
#                     self.conv_block = ConvBlock(in_channels=in_channels, out_channels=feat_map_channels)
#                     self.module_dict["conv_{}_{}".format(depth, i)] = self.conv_block
#                     in_channels, feat_map_channels = feat_map_channels, feat_map_channels * 2
#                 else:
#                     # print(in_channels, feat_map_channels)
#                     self.conv_block = ConvBlock(in_channels=in_channels, out_channels=feat_map_channels)
#                     self.module_dict["conv_{}_{}".format(depth, i)] = self.conv_block
#                     in_channels, feat_map_channels = feat_map_channels, feat_map_channels * 2
#             if depth == model_depth - 1:
#                 break
#             else:
#                 self.pooling = nn.MaxPool3d(kernel_size=pool_size, stride=2, padding=0)
#                 self.module_dict["max_pooling_{}".format(depth)] = self.pooling

#     def forward(self, x):
#         down_sampling_features = []
#         for k, op in self.module_dict.items():
#             if k.startswith("conv"):
#                 x = op(x)
#                 print(k, x.shape)
#                 if k.endswith("1"):
#                     down_sampling_features.append(x)
#             elif k.startswith("max_pooling"):
#                 x = op(x)
#                 print(k, x.shape)

#         return x, down_sampling_features


# class ConvTranspose(nn.Module):
#     def __init__(self, in_channels, out_channels, k_size=3, stride=2, padding=1, output_padding=1):
#         super(ConvTranspose, self).__init__()
#         self.conv3d_transpose = nn.ConvTranspose3d(in_channels=in_channels,
#                                                    out_channels=out_channels,
#                                                    kernel_size=k_size,
#                                                    stride=stride,
#                                                    padding=padding,
#                                                    output_padding=output_padding)

#     def forward(self, x):
#         return self.conv3d_transpose(x)


# class DecoderBlock(nn.Module):
#     def __init__(self, out_channels, model_depth=4):
#         super(DecoderBlock, self).__init__()
#         self.num_conv_blocks = 2
#         self.num_feat_maps = 16
#         # user nn.ModuleDict() to store ops
#         self.module_dict = nn.ModuleDict()

#         for depth in range(model_depth - 2, -1, -1):
#             # print(depth)
#             feat_map_channels = 2 ** (depth + 1) * self.num_feat_maps
#             # print(feat_map_channels * 4)
#             self.deconv = ConvTranspose(in_channels=feat_map_channels * 4, out_channels=feat_map_channels * 4)
#             self.module_dict["deconv_{}".format(depth)] = self.deconv
#             for i in range(self.num_conv_blocks):
#                 if i == 0:
#                     self.conv = ConvBlock(in_channels=feat_map_channels * 6, out_channels=feat_map_channels * 2)
#                     self.module_dict["conv_{}_{}".format(depth, i)] = self.conv
#                 else:
#                     self.conv = ConvBlock(in_channels=feat_map_channels * 2, out_channels=feat_map_channels * 2)
#                     self.module_dict["conv_{}_{}".format(depth, i)] = self.conv
#             if depth == 0:
#                 self.final_conv = ConvBlock(in_channels=feat_map_channels * 2, out_channels=out_channels)
#                 self.module_dict["final_conv"] = self.final_conv

#     def forward(self, x, down_sampling_features):
#         """
#         :param x: inputs
#         :param down_sampling_features: feature maps from encoder path
#         :return: output
#         """
#         for k, op in self.module_dict.items():
#             if k.startswith("deconv"):
#                 x = op(x)
#                 x = torch.cat((down_sampling_features[int(k[-1])], x), dim=1)
#             elif k.startswith("conv"):
#                 x = op(x)
#             else:
#                 x = op(x)
#         return x


# if __name__ == "__main__":
#     # x has shape of (batch_size, channels, depth, height, width)
#     x_test = torch.randn(1, 1, 96, 96, 96)
#     x_test = x_test.cuda()
#     print("The shape of input: ", x_test.shape)

#     encoder = EncoderBlock(in_channels=1)
#     encoder.cuda()
#     print(encoder)
#     x_test, h = encoder(x_test)

#     db = DecoderBlock(out_channels=1)
#     db.cuda()
#     x_test = db(x_test, h)

#     print('the shape of output = ', x_test.shape)


In [5]:
for x in range(4-2, -1, -1):
    print(x)

2
1
0
