In [13]:
#BraveNet3D

#Encoder

from turtle import forward
import torch.nn as nn
import torchvision.models
import torch
import torch.nn.functional as F

#kernels and stride to 3 × 3 × 3 and 1 for convolutions and 2 × 2 × 2 and 2 for max-pooling

class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1):
    super(ConvBlock, self).__init__()
    self.conv3D = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
    self.batch_norm = nn.BatchNorm3d(num_features=out_channels)

  def forward(self, x):
    x = F.relu(self.conv3D(x))
    x = self.batch_norm(x)

    return(x)


class Encoder(nn.Module):
  def __init__(self, in_channels, model_depth = 4, pool_size = 2):
    super(Encoder, self).__init__()
    self.root_feat_maps = 16
    self.num_conv_block = 2
    self.module_dict = nn.ModuleDict()

    for depth in range(model_depth):
      feat_map_channels = 2**(depth+1)*self.root_feat_maps #32, 64, 128, 256
      

      for i in range(self.num_conv_block):
        if depth == 0:
          self.conv_block = ConvBlock(in_channels=in_channels, out_channels=feat_map_channels) 
          self.module_dict['conv_{}_{}'.format(depth, i)] = self.conv_block
          if i == 1:
            in_channels, feat_map_channels = feat_map_channels, feat_map_channels*2
          else:
            self.dropout = nn.Dropout3d()
            self.module_dict['dropout{}'.format(depth)] = self.dropout
            in_channels, feat_map_channels = feat_map_channels, feat_map_channels

        else:
          self.conv_block = ConvBlock(in_channels=in_channels, out_channels=feat_map_channels)
          self.module_dict['conv_{}_{}'.format(depth, i)] = self.conv_block
          if i == 1:
            in_channels, feat_map_channels = feat_map_channels, feat_map_channels*2
          else:
            self.dropout = nn.Dropout3d()
            self.module_dict['dropout{}'.format(depth)] = self.dropout
            in_channels, feat_map_channels = feat_map_channels, feat_map_channels

      if depth == model_depth - 1:  #depth = 3
        break
      else:
        nn.MaxPool3d
        self.pooling = nn.MaxPool3d(kernel_size=pool_size, stride=2)
        self.module_dict['max_pooling_{}'.format(depth)] = self.pooling

    self.module_dict2 = self.module_dict

  def forward(self, x, y):
    down_sampling_feature = []
    down_sampling_feature2 = []
    for k, op in self.module_dict.items():
      # print('k: ', k, 'op: ', op)

      if k.startswith('conv'):
        # print ('input', x.shape)
        x = op(x)
        # print ('output', x.shape)
        if k.endswith('1'):
          down_sampling_feature.append(x)

      else:
        # print ('input', x.shape)
        x = op(x)
        # print ('output', x.shape)
        # print(k, x.shape)

    for k, op in self.module_dict2.items():
        if k.startswith('conv'):
            y = op(y)
            if k.endswith('1'):
                down_sampling_feature2.append(y)
        else:
            y = op(y)

    return down_sampling_feature, down_sampling_feature2


if __name__ == '__main__':
    inputs_low = torch.randn(1, 1, 8, 64, 64)
    inputs_low = inputs_low.cuda()

    inputs_high = torch.randn(1, 1, 8, 64, 64)
    inputs_high = inputs_high.cuda()
    # print('the shape of input = ', inputs.shape)

    encoder = Encoder(1)

    # print(encoder)
    encoder.cuda()


    d_low, d_high = encoder(inputs_low, inputs_high)
    # print('the shape of output = ', x_test[0].shape)

In [14]:
# encoder = Encoder(1)
# print(encoder)

In [15]:
#Decoder

class ConvTranspose(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1, output_padding = 1):
    super(ConvTranspose, self).__init__()
    self.conv3d_transpose = nn.ConvTranspose3d(in_channels=in_channels, out_channels=out_channels,kernel_size=kernel_size,
    stride=stride, padding=padding, output_padding = output_padding)

  def forward(self, x):
    return(self.conv3d_transpose(x))

class ConvBlock_One_ReLu(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size = 1, stride = 1, padding = 0):
    super(ConvBlock_One_ReLu, self).__init__()
    self.conv3D = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)

  def forward(self, x):
    x = F.relu(self.conv3D(x))

    return(x)


class ConvBlock_One(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size = 1, stride = 1, padding = 0):
    super(ConvBlock_One, self).__init__()
    self.conv3D = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)

  def forward(self, x):
    x = self.conv3D(x)

    return(x)


class Decoder(nn.Module):
  def __init__(self, out_channels, model_depth = 4):
    super(Decoder, self).__init__()
    self.num_conv_blocks = 2
    self.num_feat_maps = 16
    self.module_dict = nn.ModuleDict()

    for depth in range(model_depth-1, -1, -1): #3, 2, 1, 0

      feat_map_channels = 2**(depth+1)*self.num_feat_maps #256 ,128, 64, 32

      if depth == 3:
        for i in range(self.num_conv_blocks):
          if i == 0:
            self.conv_one = ConvBlock_One_ReLu(feat_map_channels*2, out_channels=feat_map_channels)
            self.module_dict['conv_one_{}_{}'.format(depth, i)] = self.conv_one
          else:
            self.conv_one = ConvBlock_One_ReLu(feat_map_channels, out_channels=feat_map_channels)
            self.module_dict['conv_one_{}_{}'.format(depth, i)] = self.conv_one

        self.deconv = ConvTranspose(in_channels=feat_map_channels, out_channels=feat_map_channels) #256
        self.module_dict['deconv_{}'.format(depth)] = self.deconv

      else:  #depth = 2, 1, 0 feat_map_channels = 128, 64, 32
        for i in range(self.num_conv_blocks):
          if i == 0:
            self.conv_block = ConvBlock(in_channels=feat_map_channels*4, out_channels=feat_map_channels) #512/128, 256/64, 128/32
            self.module_dict['conv_{}_{}'.format(depth, i)] = self.conv_block
            self.dropout = nn.Dropout3d()
            self.module_dict['dropout{}'.format(depth)] = self.dropout
          else:
            self.conv_block = ConvBlock(in_channels=feat_map_channels, out_channels=feat_map_channels) #128/128, 64/64, 32/32
            self.module_dict['conv_{}_{}'.format(depth, i)] = self.conv_block

        if depth != 0:
          self.deconv = ConvTranspose(in_channels=feat_map_channels, out_channels=feat_map_channels)
          self.module_dict['deconv_{}'.format(depth)] = self.deconv

      if depth == 0:
        self.final_conv = ConvBlock_One(in_channels=feat_map_channels, out_channels=out_channels)
        self.module_dict['final_conv'] = self.final_conv

  def forward(self, down_sampling_feature, down_sampling_feature2):

    for k, op in self.module_dict.items():
        if k.startswith('conv'):
            if k.endswith('3_0'):
                x = torch.cat((down_sampling_feature[3], down_sampling_feature2[3]), dim = 1)
                print ('first concat_x.shape', x.shape) 
                x = op(x)
                print ('operation: ',k, 'output shape: ', x.shape) 
            else:
                x = op(x)
                print ('operation: ',k, 'output shape: ', x.shape) 
        elif k.startswith('deconv'):  #deconv_3, deconv_2, deconv_1
            x = op(x)
            print ('operation: ',k, 'output shape: ', x.shape) 
            # print ('int(k[-1])', int(k[-1])) 
            layer = int(k[-1])-1 #2, 1, 0
            y = torch.cat((down_sampling_feature[layer], down_sampling_feature2[layer]), dim = 1)
            print ('concat_shape_1', y.shape) 
            x = torch.cat((x, y), dim = 1)
            print ('x.concat shape', x.shape) 
        else:
            x = op(x)
            print ('operation: ',k, 'output shape: ', x.shape) 

    return x

# if __name__ == '__main__':


#     decoder = Decoder(1)
#     decoder.cuda()
#     x = decoder(d_low, d_high)
    
#     print('the shape of output = ', x .shape)

In [16]:
decoder = Decoder(1)
decoder

Decoder(
  (module_dict): ModuleDict(
    (conv_one_3_0): ConvBlock_One_ReLu(
      (conv3D): Conv3d(512, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    )
    (conv_one_3_1): ConvBlock_One_ReLu(
      (conv3D): Conv3d(256, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    )
    (deconv_3): ConvTranspose(
      (conv3d_transpose): ConvTranspose3d(256, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1))
    )
    (conv_2_0): ConvBlock(
      (conv3D): Conv3d(512, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (batch_norm): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (dropout2): Dropout3d(p=0.5, inplace=False)
    (conv_2_1): ConvBlock(
      (conv3D): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (batch_norm): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (deconv_2): ConvTranspose(
      (conv3d_

In [17]:
class ContextUnet(nn.Module):
    def __init__(self, in_channels, out_channels, model_depth = 4, final_activation = 'sigmoid'):
        super(ContextUnet, self).__init__()
        self.encoder = Encoder(in_channels=in_channels, model_depth=model_depth)
        self.decoder = Decoder(out_channels=out_channels, model_depth=model_depth)
        if final_activation == 'sigmoid':
            self.f_activation = nn.Sigmoid()
        else:
            self.f_activation = nn.Softmax(dim = 1)

    def forward(self, x, y):
        d_features, d_features2 = self.encoder(x, y)
        x = self.decoder(d_features, d_features2)
        x = self.f_activation(x)
        print("Final output shape: ", x.shape)

        return x

if __name__ == '__main__':
    inputs_low = torch.randn(1, 1, 8, 64, 64)
    inputs_low = inputs_low.cuda()

    inputs_high = torch.randn(1, 1, 8, 64, 64)
    inputs_high = inputs_high.cuda()

    # print('the shape of input = ', inputs.shape)

    model = ContextUnet(in_channels=1, out_channels=1)
    model.cuda()

    x_test = model(inputs_low, inputs_high)
    

first concat_x.shape torch.Size([1, 512, 1, 8, 8])
operation:  conv_one_3_0 output shape:  torch.Size([1, 256, 1, 8, 8])
operation:  conv_one_3_1 output shape:  torch.Size([1, 256, 1, 8, 8])
operation:  deconv_3 output shape:  torch.Size([1, 256, 2, 16, 16])
concat_shape_1 torch.Size([1, 256, 2, 16, 16])
x.concat shape torch.Size([1, 512, 2, 16, 16])
operation:  conv_2_0 output shape:  torch.Size([1, 128, 2, 16, 16])
operation:  dropout2 output shape:  torch.Size([1, 128, 2, 16, 16])
operation:  conv_2_1 output shape:  torch.Size([1, 128, 2, 16, 16])
operation:  deconv_2 output shape:  torch.Size([1, 128, 4, 32, 32])
concat_shape_1 torch.Size([1, 128, 4, 32, 32])
x.concat shape torch.Size([1, 256, 4, 32, 32])
operation:  conv_1_0 output shape:  torch.Size([1, 64, 4, 32, 32])
operation:  dropout1 output shape:  torch.Size([1, 64, 4, 32, 32])
operation:  conv_1_1 output shape:  torch.Size([1, 64, 4, 32, 32])
operation:  deconv_1 output shape:  torch.Size([1, 64, 8, 64, 64])
concat_shape_

In [18]:
x_test

tensor([[[[[0.4475, 0.5440, 0.5654,  ..., 0.3638, 0.6241, 0.4367],
           [0.5357, 0.4622, 0.3724,  ..., 0.4390, 0.6015, 0.2745],
           [0.3636, 0.4844, 0.5065,  ..., 0.4693, 0.6628, 0.3861],
           ...,
           [0.4565, 0.5578, 0.5475,  ..., 0.5738, 0.5020, 0.5477],
           [0.5913, 0.5395, 0.3446,  ..., 0.4422, 0.6531, 0.4209],
           [0.4872, 0.4813, 0.4483,  ..., 0.2595, 0.3753, 0.5534]],

          [[0.3793, 0.3781, 0.4511,  ..., 0.4986, 0.3813, 0.5117],
           [0.3602, 0.4761, 0.6450,  ..., 0.4184, 0.3597, 0.5232],
           [0.6027, 0.7492, 0.6268,  ..., 0.5129, 0.5211, 0.5568],
           ...,
           [0.5250, 0.5006, 0.3647,  ..., 0.2646, 0.6313, 0.5253],
           [0.6121, 0.2297, 0.6982,  ..., 0.5259, 0.4092, 0.5431],
           [0.5163, 0.4656, 0.4025,  ..., 0.4765, 0.4122, 0.3629]],

          [[0.5087, 0.3745, 0.3668,  ..., 0.2743, 0.3800, 0.3814],
           [0.5692, 0.3851, 0.3346,  ..., 0.4671, 0.4983, 0.4166],
           [0.4602, 0.5352