In [156]:
#BraveNet3D

#Encoder

from turtle import forward
import torch.nn as nn
import torchvision.models
import torch
import torch.nn.functional as F

#kernels and stride to 3 × 3 × 3 and 1 for convolutions and 2 × 2 × 2 and 2 for max-pooling

class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1):
    super(ConvBlock, self).__init__()
    self.conv3D = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
    self.batch_norm = nn.BatchNorm3d(num_features=out_channels)

  def forward(self, x):
    x = F.relu(self.conv3D(x))
    x = self.batch_norm(x)

    return(x)


class Encoder(nn.Module):
  def __init__(self, in_channels, model_depth = 4, pool_size = 2):
    super(Encoder, self).__init__()
    self.root_feat_maps = 16
    self.num_conv_block = 2
    self.module_dict = nn.ModuleDict()

    for depth in range(model_depth):
      feat_map_channels = 2**(depth+1)*self.root_feat_maps #32, 64, 128, 256
      

      for i in range(self.num_conv_block):
        if depth == 0:
          self.conv_block = ConvBlock(in_channels=in_channels, out_channels=feat_map_channels) 
          self.module_dict['conv_{}_{}'.format(depth, i)] = self.conv_block
          if i == 1:
            in_channels, feat_map_channels = feat_map_channels, feat_map_channels*2
          else:
            self.dropout = nn.Dropout3d()
            self.module_dict['dropout{}'.format(depth)] = self.dropout
            in_channels, feat_map_channels = feat_map_channels, feat_map_channels

        else:
          self.conv_block = ConvBlock(in_channels=in_channels, out_channels=feat_map_channels)
          self.module_dict['conv_{}_{}'.format(depth, i)] = self.conv_block
          if i == 1:
            in_channels, feat_map_channels = feat_map_channels, feat_map_channels*2
          else:
            self.dropout = nn.Dropout3d()
            self.module_dict['dropout{}'.format(depth)] = self.dropout
            in_channels, feat_map_channels = feat_map_channels, feat_map_channels

      if depth == model_depth - 1:  #depth = 3
        break
      else:
        nn.MaxPool3d
        self.pooling = nn.MaxPool3d(kernel_size=pool_size, stride=2)
        self.module_dict['max_pooling_{}'.format(depth)] = self.pooling

    self.module_dict2 = self.module_dict

  def forward(self, x, y):
    print ('------Encoding------') 
    down_sampling_feature = []
    down_sampling_feature2 = []
    for k, op in self.module_dict.items():
      if k.startswith('conv'):
        x = op(x)
        print ('operation: ',k, 'output shape: ', x.shape) 
        if k.endswith('1'):
          down_sampling_feature.append(x)

      else:
        x = op(x)
        print ('operation: ',k, 'output shape: ', x.shape) 

    for k, op in self.module_dict2.items():
        if k.startswith('conv'):
            y = op(y)
            if k.endswith('1'):
                down_sampling_feature2.append(y)
        else:
            y = op(y)

    return down_sampling_feature, down_sampling_feature2


# if __name__ == '__main__':
#     inputs_low = torch.randn(1, 1, 8, 64, 64)
#     inputs_low = inputs_low.cuda()

#     inputs_high = torch.randn(1, 1, 8, 64, 64)
#     inputs_high = inputs_high.cuda()
#     # print('the shape of input = ', inputs.shape)

#     encoder = Encoder(1)

#     # print(encoder)
#     encoder.cuda()


#     d_low, d_high = encoder(inputs_low, inputs_high)
#     # print('the shape of output = ', x_test[0].shape)

In [157]:
# encoder = Encoder(1)
# print(encoder)

In [158]:
#Decoder

class ConvTranspose(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1, output_padding = 1):
    super(ConvTranspose, self).__init__()
    self.conv3d_transpose = nn.ConvTranspose3d(in_channels=in_channels, out_channels=out_channels,kernel_size=kernel_size,
    stride=stride, padding=padding, output_padding = output_padding)

  def forward(self, x):
    return(self.conv3d_transpose(x))

class ConvBlock_One_ReLu(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size = 1, stride = 1, padding = 0):
    super(ConvBlock_One_ReLu, self).__init__()
    self.conv3D = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)

  def forward(self, x):
    x = F.relu(self.conv3D(x))

    return(x)


class ConvBlock_One(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size = 1, stride = 1, padding = 0):
    super(ConvBlock_One, self).__init__()
    self.conv3D = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)

  def forward(self, x):
    x = self.conv3D(x)

    return(x)


# operation:  conv_2_1 output shape:  torch.Size([1, 128, 2, 16, 16]) -> 8*64*64
# operation:  conv_1_1 output shape:  torch.Size([1, 64, 4, 32, 32]) -> 8*64*64
# operation:  final_conv output shape:  torch.Size([1, 1, 8, 64, 64])


class Decoder(nn.Module):
  def __init__(self, out_channels, model_depth = 4):
    super(Decoder, self).__init__()
    self.num_conv_blocks = 2
    self.num_feat_maps = 16
    self.module_dict = nn.ModuleDict()

    for depth in range(model_depth-1, -1, -1): #3, 2, 1, 0

      feat_map_channels = 2**(depth+1)*self.num_feat_maps #256 ,128, 64, 32

      if depth == 3:
        for i in range(self.num_conv_blocks):
          if i == 0:
            self.conv_one = ConvBlock_One_ReLu(feat_map_channels*2, out_channels=feat_map_channels)
            self.module_dict['conv_one_{}_{}'.format(depth, i)] = self.conv_one
          else:
            self.conv_one = ConvBlock_One_ReLu(feat_map_channels, out_channels=feat_map_channels)
            self.module_dict['conv_one_{}_{}'.format(depth, i)] = self.conv_one

        self.deconv = ConvTranspose(in_channels=feat_map_channels, out_channels=feat_map_channels) #256
        self.module_dict['deconv_{}'.format(depth)] = self.deconv

      else:  #depth = 2, 1, 0 feat_map_channels = 128, 64, 32
        for i in range(self.num_conv_blocks):
          if i == 0:
            self.conv_block = ConvBlock(in_channels=feat_map_channels*4, out_channels=feat_map_channels) #512/128, 256/64, 128/32
            self.module_dict['conv_{}_{}'.format(depth, i)] = self.conv_block
            self.dropout = nn.Dropout3d()
            self.module_dict['dropout{}'.format(depth)] = self.dropout
          else:
            self.conv_block = ConvBlock(in_channels=feat_map_channels, out_channels=feat_map_channels) #128/128, 64/64, 32/32
            self.module_dict['conv_{}_{}'.format(depth, i)] = self.conv_block

            if depth in (1,2):
              self.upsample = nn.Upsample(scale_factor=depth*2, mode='trilinear')
              self.module_dict['upsample{}'.format(depth)] = self.upsample
              self.conv_block_sig = ConvBlock_One(in_channels=feat_map_channels, out_channels=1) #128/1, 64/1, 32/1
              self.module_dict['conv_one_output{}'.format(depth)] = self.conv_block_sig

        if depth != 0:
          self.deconv = ConvTranspose(in_channels=feat_map_channels, out_channels=feat_map_channels)
          self.module_dict['deconv_{}'.format(depth)] = self.deconv

      if depth == 0:
        self.final_conv = ConvBlock_One(in_channels=feat_map_channels, out_channels=out_channels)
        self.module_dict['final_conv'] = self.final_conv


# operation:  conv_2_1 output shape:  torch.Size([1, 128, 2, 16, 16]) -> 1*128*8*64*64 -> 1*1*8*64*64
# operation:  conv_1_1 output shape:  torch.Size([1, 64, 4, 32, 32]) -> 1*64*8*64*64 -> 1*1*8*64*64
# operation:  final_conv output shape:  torch.Size([1, 1, 8, 64, 64])

  def forward(self, down_sampling_feature, down_sampling_feature2):

    final_feature = {}
    print ('------Decoding------') 

    x = torch.cat((down_sampling_feature[3], down_sampling_feature2[3]), dim = 1)
    print ('operation: original input concatenate', 'output shape: ', x.shape) 
    x = self.module_dict['conv_one_3_0'](x)
    print ('operation: conv_one_3_0', 'output shape: ', x.shape) 
    x = self.module_dict['conv_one_3_1'](x)
    print ('operation: conv_one_3_1', 'output shape: ', x.shape) 

    x = self.module_dict['deconv_3'](x)
    print ('operation: deconv_3', 'output shape: ', x.shape) 
    y = torch.cat((down_sampling_feature[2], down_sampling_feature2[2]), dim = 1) #2, 1, 0
    print ('operation: input concatenate', 'output shape: ', y.shape) 
    x = torch.cat((x, y), dim = 1)
    print ('operation: concatenate', 'output shape: ', x.shape) 
    x = self.module_dict['conv_2_0'](x)
    print ('operation: conv_2_0', 'output shape: ', x.shape) 
    x = self.module_dict['dropout2'](x)
    print ('operation: dropout2', 'output shape: ', x.shape) 
    x = self.module_dict['conv_2_1'](x)
    print ('operation: conv_2_1', 'output shape: ', x.shape) 

    #out1
    print ('-------------------------------------------------') 
    x1 = self.module_dict['upsample2'](x)
    print ('operation: upsample2', 'output shape: ', x1.shape) 
    x1 = self.module_dict['conv_one_output2'](x1)
    print('Shape of output: ', x1.shape)
    print ('-------------------------------------------------') 

    #dev
    x = self.module_dict['deconv_2'](x)
    print ('operation: deconv_2', 'output shape: ', x.shape) 
    y = torch.cat((down_sampling_feature[1], down_sampling_feature2[1]), dim = 1) #2, 1, 0
    print ('operation: input concatenate', 'output shape: ', y.shape) 
    x = torch.cat((x, y), dim = 1)
    print ('operation: concatenate', 'output shape: ', x.shape) 
    x = self.module_dict['conv_1_0'](x)
    print ('operation: conv_1_0', 'output shape: ', x.shape) 
    x = self.module_dict['dropout1'](x)
    print ('operation: dropout1', 'output shape: ', x.shape) 
    x = self.module_dict['conv_1_1'](x)
    print ('operation: conv_1_1', 'output shape: ', x.shape) 

    #out2
    print ('-------------------------------------------------') 
    x2 = self.module_dict['upsample1'](x)
    print ('operation: upsample1', 'output shape: ', x2.shape) 
    x2 = self.module_dict['conv_one_output1'](x2)
    print('Shape of output: ', x2.shape)
    print ('-------------------------------------------------') 

    #dev
    x = self.module_dict['deconv_1'](x)
    print ('operation: deconv_1', 'output shape: ', x.shape) 
    y = torch.cat((down_sampling_feature[0], down_sampling_feature2[0]), dim = 1) #2, 1, 0
    print ('operation: input concatenate', 'output shape: ', y.shape) 
    x = torch.cat((x, y), dim = 1)
    print ('operation: concatenate', 'output shape: ', x.shape)
    x = self.module_dict['conv_0_0'](x)
    print ('operation: conv_0_0', 'output shape: ', x.shape)
    x = self.module_dict['dropout0'](x)
    print ('operation: dropout0', 'output shape: ', x.shape)
    x = self.module_dict['conv_0_1'](x)
    print ('operation: conv_0_1', 'output shape: ', x.shape)

    x = self.module_dict['final_conv'](x)
    print ('-------------------------------------------------') 
    print('Shape of output: ', x.shape)

    pred1 = torch.sigmoid(x)
    pred2 = torch.sigmoid(x1)
    pred3 = torch.sigmoid(x2)

    return pred1, pred2, pred3

# if __name__ == '__main__':


#     decoder = Decoder(1)
#     decoder.cuda()
#     x = decoder(d_low, d_high)
    
#     print('the shape of output = ', x)

In [159]:
# k = 'conv_2_1'
# k in ('conv_2_1','conv_1_1')
# 'layer_{}'.format(k)

In [160]:
decoder = Decoder(1)
decoder

Decoder(
  (module_dict): ModuleDict(
    (conv_one_3_0): ConvBlock_One_ReLu(
      (conv3D): Conv3d(512, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    )
    (conv_one_3_1): ConvBlock_One_ReLu(
      (conv3D): Conv3d(256, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    )
    (deconv_3): ConvTranspose(
      (conv3d_transpose): ConvTranspose3d(256, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1))
    )
    (conv_2_0): ConvBlock(
      (conv3D): Conv3d(512, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (batch_norm): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (dropout2): Dropout3d(p=0.5, inplace=False)
    (conv_2_1): ConvBlock(
      (conv3D): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (batch_norm): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (upsample2): Upsample(scale_factor=4.0, m

In [161]:
class BraveNet(nn.Module):
    def __init__(self, in_channels, out_channels, model_depth = 4):
        super(BraveNet, self).__init__()
        self.encoder = Encoder(in_channels=in_channels, model_depth=model_depth)
        self.decoder = Decoder(out_channels=out_channels, model_depth=model_depth)
        # if final_activation == 'sigmoid':
        #     self.f_activation = nn.Sigmoid()
        # else:
        #     self.f_activation = nn.Softmax(dim = 1)

    def forward(self, x, y):
        d_features, d_features2 = self.encoder(x, y)
        p1, p2, p3 = self.decoder(d_features, d_features2)
        # x = self.f_activation(x)
        # print("Final output shape: ", x.shape)

        return p1, p2, p3

if __name__ == '__main__':
    inputs_low = torch.randn(1, 1, 8, 64, 64)
    inputs_low = inputs_low.cuda()

    inputs_high = torch.randn(1, 1, 8, 64, 64)
    inputs_high = inputs_high.cuda()

    # print('the shape of input = ', inputs.shape)

    model = BraveNet(in_channels=1, out_channels=1)
    model.cuda()

    p1, p2, p3 = model(inputs_low, inputs_high)


------Encoding------
operation:  conv_0_0 output shape:  torch.Size([1, 32, 8, 64, 64])
operation:  dropout0 output shape:  torch.Size([1, 32, 8, 64, 64])
operation:  conv_0_1 output shape:  torch.Size([1, 32, 8, 64, 64])
operation:  max_pooling_0 output shape:  torch.Size([1, 32, 4, 32, 32])
operation:  conv_1_0 output shape:  torch.Size([1, 64, 4, 32, 32])
operation:  dropout1 output shape:  torch.Size([1, 64, 4, 32, 32])
operation:  conv_1_1 output shape:  torch.Size([1, 64, 4, 32, 32])
operation:  max_pooling_1 output shape:  torch.Size([1, 64, 2, 16, 16])
operation:  conv_2_0 output shape:  torch.Size([1, 128, 2, 16, 16])
operation:  dropout2 output shape:  torch.Size([1, 128, 2, 16, 16])
operation:  conv_2_1 output shape:  torch.Size([1, 128, 2, 16, 16])
operation:  max_pooling_2 output shape:  torch.Size([1, 128, 1, 8, 8])
operation:  conv_3_0 output shape:  torch.Size([1, 256, 1, 8, 8])
operation:  dropout3 output shape:  torch.Size([1, 256, 1, 8, 8])
operation:  conv_3_1 outpu