In [14]:
#3D

from turtle import forward
import torch.nn as nn
import torchvision.models
import torch
import torch.nn.functional as F

class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1):
    super(ConvBlock, self).__init__()
    self.conv3D = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
    self.batch_norm = nn.BatchNorm3d(num_features=out_channels)

  def forward(self, x):
    x = self.batch_norm(self.conv3D(x))
    x = F.relu(x)

    return(x)


class Encoder(nn.Module):
  def __init__(self, in_channels, model_depth = 4, pool_size = 2):
    super(Encoder, self).__init__()
    self.root_feat_maps = 16
    self.num_conv_block = 2
    self.module_dict = nn.ModuleDict()

    for depth in range(model_depth):
      feat_map_channels = 2**(depth+1)*self.root_feat_maps #2*16, 4*16, 8*16, 16*16

      for i in range(self.num_conv_block):
        if depth == 0:
          self.conv_block = ConvBlock(in_channels=in_channels, out_channels=feat_map_channels) 
          self.module_dict['conv_{}_{}'.format(depth, i)] = self.conv_block
          in_channels, feat_map_channels = feat_map_channels, feat_map_channels*2
        else:
          self.conv_block = ConvBlock(in_channels=in_channels, out_channels=feat_map_channels)
          self.module_dict['conv_{}_{}'.format(depth, i)] = self.conv_block
          in_channels, feat_map_channels = feat_map_channels, feat_map_channels*2

      if depth == model_depth - 1:  #depth = 3
        break
      else:
        self.pooling = nn.MaxPool3d(kernel_size=pool_size, stride=2, padding=0)
        self.module_dict['max_pooling_{}'.format(depth)] = self.pooling

  def forward(self, x):
    down_sampling_feature = []
    for k, op in self.module_dict.items():
      # print('k: ', k, 'op: ', op)

      if k.startswith('conv'):
        x = op(x)
        # print (k, x.shape)
        if k.endswith('1'):
          down_sampling_feature.append(x)

      elif k.startswith('max_pooling'):
        x = op(x)
        # print(k, x.shape)

    return x, down_sampling_feature

class ConvTranspose(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1, output_padding = 1):
    super(ConvTranspose, self).__init__()
    self.conv3d_transpose = nn.ConvTranspose3d(in_channels=in_channels, out_channels=out_channels,kernel_size=kernel_size,
    stride=stride, padding=padding, output_padding = output_padding)

  def forward(self, x):
    return(self.conv3d_transpose(x))

class Decoder(nn.Module):
  def __init__(self, out_channels, model_depth = 4):
    super(Decoder, self).__init__()
    self.num_conv_blocks = 2
    self.num_feat_maps = 16
    self.module_dict = nn.ModuleDict()

    for depth in range(model_depth-2, -1, -1): #2, 1, 0
      feat_map_channels = 2**(depth+1)*self.num_feat_maps #128, 64, 32
      self.deconv = ConvTranspose(in_channels=feat_map_channels*4, out_channels=feat_map_channels*4) #512, 256, 128
      self.module_dict['deconv_{}'.format(depth)] = self.deconv
      for i in range(self.num_conv_blocks):
        if i == 0:
          self.conv = ConvBlock(in_channels=feat_map_channels*6, out_channels=feat_map_channels*2) #768/256, 384/128, 192/64
          self.module_dict['conv_{}_{}'.format(depth, i)] = self.conv
        else:
          self.conv = ConvBlock(in_channels=feat_map_channels*2, out_channels=feat_map_channels*2) #256/256, 128/128, 64/64
          self.module_dict['conv_{}_{}'.format(depth, i)] = self.conv
      if depth == 0:
        self.final_conv = ConvBlock(in_channels=feat_map_channels*2, out_channels=out_channels)
        self.module_dict['final_conv'] = self.final_conv

  def forward(self, x, down_sampling_feature):
    for k, op in self.module_dict.items():
      if k.startswith('deconv'):
        x = op(x)
        x = torch.cat((down_sampling_feature[int(k[-1])], x), dim = 1)
      elif k.startswith('conv'):
        x = op(x)
      else:
        x = op(x)
    
    return x


In [15]:
class Unet3D(nn.Module):
    def __init__(self, in_channels, out_channels, model_depth = 4, final_activation = 'sigmoid'):
        super(Unet3D, self).__init__()
        self.encoder = Encoder(in_channels=in_channels, model_depth=model_depth)
        self.decoder = Decoder(out_channels=out_channels, model_depth=model_depth)
        if final_activation == 'sigmoid':
            self.f_activation = nn.Sigmoid()
        else:
            self.f_activation = nn.Softmax(dim = 1)

    def forward(self, x):
        x, d_features = self.encoder(x)
        x = self.decoder(x, d_features)
        x = self.f_activation(x)
        print("Final output shape: ", x.shape)

        return x

if __name__ == '__main__':
    inputs = torch.randn(1, 1, 96, 96, 96)
    inputs = inputs.cuda()
    # print('the shape of input = ', inputs.shape)

    model = Unet3D(in_channels=1, out_channels=1)
    model.cuda()

    x_test = model(inputs)

Final output shape:  torch.Size([1, 1, 96, 96, 96])


In [38]:
#BraveNet

#3D

from turtle import forward
import torch.nn as nn
import torchvision.models
import torch
import torch.nn.functional as F

#kernels and stride to 3 × 3 × 3 and 1 for convolutions and 2 × 2 × 2 and 2 for max-pooling

class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1):
    super(ConvBlock, self).__init__()
    self.conv3D = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
    self.batch_norm = nn.BatchNorm3d(num_features=out_channels)

  def forward(self, x):
    x = F.relu(self.conv3D(x))
    x = self.batch_norm(x)

    return(x)

class Encoder(nn.Module):
  def __init__(self, in_channels, model_depth = 4, pool_size = 2):
    super(Encoder, self).__init__()
    self.root_feat_maps = 16
    self.num_conv_block = 2
    self.module_dict = nn.ModuleDict()

    for depth in range(model_depth):
      feat_map_channels = 2**(depth+1)*self.root_feat_maps #32, 64, 128, 256
      

      for i in range(self.num_conv_block):
        if depth == 0:
          self.conv_block = ConvBlock(in_channels=in_channels, out_channels=feat_map_channels) 
          self.module_dict['conv_{}_{}'.format(depth, i)] = self.conv_block
          if i == 1:
            in_channels, feat_map_channels = feat_map_channels, feat_map_channels*2
          else:
            self.dropout = nn.Dropout3d()
            self.module_dict['dropout{}'.format(depth)] = self.dropout
            in_channels, feat_map_channels = feat_map_channels, feat_map_channels

        else:
          self.conv_block = ConvBlock(in_channels=in_channels, out_channels=feat_map_channels)
          self.module_dict['conv_{}_{}'.format(depth, i)] = self.conv_block
          if i == 1:
            in_channels, feat_map_channels = feat_map_channels, feat_map_channels*2
          else:
            self.dropout = nn.Dropout3d()
            self.module_dict['dropout{}'.format(depth)] = self.dropout
            in_channels, feat_map_channels = feat_map_channels, feat_map_channels

      if depth == model_depth - 1:  #depth = 3
        break
      else:
        nn.MaxPool3d
        self.pooling = nn.MaxPool3d(kernel_size=pool_size, stride=2)
        self.module_dict['max_pooling_{}'.format(depth)] = self.pooling

  def forward(self, x):
    down_sampling_feature = []
    for k, op in self.module_dict.items():
      print('k: ', k, 'op: ', op)

      if k.startswith('conv'):
        print ('input', x.shape)
        x = op(x)
        print ('output', x.shape)
        if k.endswith('1'):
          down_sampling_feature.append(x)

      elif k.startswith('max_pooling'):
        print ('input', x.shape)
        x = op(x)
        print ('output', x.shape)
        # print(k, x.shape)

      else:
        print ('input', x.shape)
        x = op(x)
        print ('output', x.shape)

    return x, down_sampling_feature


if __name__ == '__main__':
    inputs = torch.randn(1, 1, 8, 64, 64)
    inputs = inputs.cuda()
    # print('the shape of input = ', inputs.shape)

    encoder = Encoder(1)
    # print(encoder)
    encoder.cuda()

    x_test = encoder(inputs)
    print('the shape of output = ', x_test[0].shape)

# encoder = Encoder(1)
# print(encoder)

k:  conv_0_0 op:  ConvBlock(
  (conv3D): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (batch_norm): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
input torch.Size([1, 1, 8, 64, 64])
output torch.Size([1, 32, 8, 64, 64])
k:  dropout0 op:  Dropout3d(p=0.5, inplace=False)
input torch.Size([1, 32, 8, 64, 64])
output torch.Size([1, 32, 8, 64, 64])
k:  conv_0_1 op:  ConvBlock(
  (conv3D): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (batch_norm): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
input torch.Size([1, 32, 8, 64, 64])
output torch.Size([1, 32, 8, 64, 64])
k:  max_pooling_0 op:  MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
input torch.Size([1, 32, 8, 64, 64])
output torch.Size([1, 32, 4, 32, 32])
k:  conv_1_0 op:  ConvBlock(
  (conv3D): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (batch_n

In [54]:
#test

#BraveNet

#3D

from turtle import forward
import torch.nn as nn
import torchvision.models
import torch
import torch.nn.functional as F

#kernels and stride to 3 × 3 × 3 and 1 for convolutions and 2 × 2 × 2 and 2 for max-pooling

class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1):
    super(ConvBlock, self).__init__()
    self.conv3D = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
    self.batch_norm = nn.BatchNorm3d(num_features=out_channels)

  def forward(self, x):
    x = F.relu(self.conv3D(x))
    x = self.batch_norm(x)

    return(x)

class Encoder(nn.Module):
  def __init__(self, in_channels, model_depth = 4, pool_size = 2):
    super(Encoder, self).__init__()
    self.root_feat_maps = 16
    self.num_conv_block = 2
    self.module_dict = nn.ModuleDict()

    for depth in range(model_depth):
      feat_map_channels = 2**(depth+1)*self.root_feat_maps #32, 64, 128, 256
      

      for i in range(self.num_conv_block):
        if depth == 0:
          self.conv_block = ConvBlock(in_channels=in_channels, out_channels=feat_map_channels) 
          self.module_dict['conv_{}_{}'.format(depth, i)] = self.conv_block
          if i == 1:
            in_channels, feat_map_channels = feat_map_channels, feat_map_channels*2
          else:
            self.dropout = nn.Dropout3d()
            self.module_dict['dropout{}'.format(depth)] = self.dropout
            in_channels, feat_map_channels = feat_map_channels, feat_map_channels

        else:
          self.conv_block = ConvBlock(in_channels=in_channels, out_channels=feat_map_channels)
          self.module_dict['conv_{}_{}'.format(depth, i)] = self.conv_block
          if i == 1:
            in_channels, feat_map_channels = feat_map_channels, feat_map_channels*2
          else:
            self.dropout = nn.Dropout3d()
            self.module_dict['dropout{}'.format(depth)] = self.dropout
            in_channels, feat_map_channels = feat_map_channels, feat_map_channels

      if depth == model_depth - 1:  #depth = 3
        break
      else:
        nn.MaxPool3d
        self.pooling = nn.MaxPool3d(kernel_size=pool_size, stride=2)
        self.module_dict['max_pooling_{}'.format(depth)] = self.pooling

  def forward(self, x):
    down_sampling_feature = []
    for k, op in self.module_dict.items():
      # print('k: ', k, 'op: ', op)

      if k.startswith('conv'):
        # print ('input', x.shape)
        x = op(x)
        # print ('output', x.shape)
        if k.endswith('1'):
          down_sampling_feature.append(x)

      else:
        # print ('input', x.shape)
        x = op(x)
        # print ('output', x.shape)
        # print(k, x.shape)

    return x, down_sampling_feature


if __name__ == '__main__':
    inputs_low = torch.randn(1, 1, 8, 64, 64)
    inputs_low = inputs.cuda()

    inputs_high = torch.randn(1, 1, 8, 64, 64)
    inputs_high = inputs.cuda()
    # print('the shape of input = ', inputs.shape)

    encoder1 = Encoder(1)
    encoder2 = Encoder(1)
    # print(encoder)
    encoder1.cuda()
    encoder2.cuda()

    x_low, d_low = encoder1(inputs_low)
    x_high, d_high = encoder2(inputs_high)
    # print('the shape of output = ', x_test[0].shape)

# encoder = Encoder(1)
# print(encoder)

In [55]:
class ConvTranspose(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1, output_padding = 1):
    super(ConvTranspose, self).__init__()
    self.conv3d_transpose = nn.ConvTranspose3d(in_channels=in_channels, out_channels=out_channels,kernel_size=kernel_size,
    stride=stride, padding=padding, output_padding = output_padding)

  def forward(self, x):
    return(self.conv3d_transpose(x))

class ConvBlock_One(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size = 1, stride = 1, padding = 1):
    super(ConvBlock_One, self).__init__()
    self.conv3D = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)

  def forward(self, x):
    x = F.relu(self.conv3D(x))

    return(x)


class Decoder(nn.Module):
  def __init__(self, out_channels, model_depth = 4):
    super(Decoder, self).__init__()
    self.num_conv_blocks = 2
    self.num_feat_maps = 16
    self.module_dict = nn.ModuleDict()

    for depth in range(model_depth-1, -1, -1): #3, 2, 1, 0

      feat_map_channels = 2**(depth+1)*self.num_feat_maps #256 ,128, 64, 32

      if depth == 3:
        for i in range(self.num_conv_blocks):
          if i == 0:
            self.conv_one = ConvBlock_One(feat_map_channels*2, out_channels=feat_map_channels)
            self.module_dict['conv_one_{}_{}'.format(depth, i)] = self.conv_one
          else:
            self.conv_one = ConvBlock_One(feat_map_channels, out_channels=feat_map_channels)
            self.module_dict['conv_one_{}_{}'.format(depth, i)] = self.conv_one

        self.deconv = ConvTranspose(in_channels=feat_map_channels, out_channels=feat_map_channels) #256
        self.module_dict['deconv_{}'.format(depth)] = self.deconv

      else:  #depth = 2, 1, 0 feat_map_channels = 128, 64, 32
        for i in range(self.num_conv_blocks):
          if i == 0:
            self.conv_block = ConvBlock(in_channels=feat_map_channels*4, out_channels=feat_map_channels) #512/128, 256/64, 128/32
            self.module_dict['conv_{}_{}'.format(depth, i)] = self.conv_block
            self.dropout = nn.Dropout3d()
            self.module_dict['dropout{}'.format(depth)] = self.dropout
          else:
            self.conv_block = ConvBlock(in_channels=feat_map_channels, out_channels=feat_map_channels) #128/128, 64/64, 32/32
            self.module_dict['conv_{}_{}'.format(depth, i)] = self.conv_block

        if depth != 0:
          self.deconv = ConvTranspose(in_channels=feat_map_channels, out_channels=feat_map_channels)
          self.module_dict['deconv_{}'.format(depth)] = self.deconv

      if depth == 0:
        self.final_conv = ConvBlock_One(in_channels=feat_map_channels, out_channels=out_channels)
        self.module_dict['final_conv'] = self.final_conv

  def forward(self, x, down_sampling_feature):
    for k, op in self.module_dict.items():
      if k.startswith('deconv'):
        x = op(x)
        print ('operation: ',k, 'output shape: ', x.shape) 
        x = torch.cat((down_sampling_feature[int(k[-1])], x), dim = 1)
        print ('operation: concat','output shape: ', x.shape) 
      elif k.startswith('conv'):
        x = op(x)
        print ('operation: ',k, 'output shape: ', x.shape) 
      else:
        x = op(x)
        print ('operation: ',k, 'output shape: ', x.shape) 
    
    return x


decoder = Decoder(1)
decoder

SyntaxError: duplicate argument 'down_sampling_feature' in function definition (1191465602.py, line 63)