<a href="https://colab.research.google.com/github/vantainguyen/A-B-Testing/blob/main/V_Net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
# Parameters for the two model we want to build
OUTPUT_CHANNELS = 3

learning_rate1=0.001
learning_rate2=0.01

optimizer1 = tf.keras.optimizers.Adam(learning_rate=learning_rate1)
optimizer2 = tf.keras.optimizers.Adam(learning_rate=learning_rate2)

batch_size1 = 8
batch_size2 = 8

kernel_size1 = 3
kernel_size2 = 3

filter_base1 = 16
filter_base2 = 32

#loss1 = 'mean_squared_error'
#loss2 = 'mean_squared_error'

self_defined_loss = False # To get the root mean square error

if self_defined_loss:

  def loss_func(y_true, y_pred): 
      
      # Root mean square loss
      squared_difference = tf.square(y_true - y_pred)
      mean_square_loss = tf.reduce_mean(squared_difference, axis=-1)
      root_mean_square_loss = tf.sqrt(mean_square_loss + 1e-20)
      
      return root_mean_square_loss

else:

  loss_func = 'mean_squared_error'



epochs_train = 300

save_period = epochs_train

initializer = tf.random_normal_initializer(0., 0.02)

class V_Net(tf.keras.Model):

  self.convD1 = tf.keras.layers.Conv3D(filter_base2, kernel_size=kernel_size1, strides=(2,2,2), padding='same', 
                                      kernel_initializer=initializer, use_bias=False)
  self.convD2 = tf.keras.layers.Conv3D(filter_base2*2, kernel_size=kernel_size1, strides=(2,2,2), padding='same', 
                                      kernel_initializer=initializer, use_bias=False)
  self.convD3 = tf.keras.layers.Conv3D(filter_base2*4, kernel_size=kernel_size1, strides=(2,2,2), padding='same', 
                                      kernel_initializer=initializer, use_bias=False)
  self.convD4 = tf.keras.layers.Conv3D(filter_base2*8, kernel_size=kernel_size1, strides=(2,2,2), padding='same', 
                                      kernel_initializer=initializer, use_bias=False)
  self.convD5 = tf.keras.layers.Conv3D(filter_base2*16, kernel_size=kernel_size1, strides=(2,2,2), padding='same', 
                                      kernel_initializer=initializer, use_bias=False)
  self.convD6 = tf.keras.layers.Conv3D(filter_base2*32, kernel_size=kernel_size1, strides=(2,2,2), padding='same', 
                                      kernel_initializer=initializer, use_bias=False)
  
  self.conv_batch_norm = tf.keras.layers.BatchNormalization()
  self.conv_activation = tf.keras.layers.LeakyReLU()


  self.convU1 = tf.keras.layers.Conv3DTranspose(filter_base2*32, kernel_size=kernel_size1, strides=2, padding='same', 
                                      kernel_initializer=initializer, use_bias=False)
  self.convU2 = tf.keras.layers.Conv3DTranspose(filter_base2*16, kernel_size=kernel_size1, strides=2, padding='same', 
                                      kernel_initializer=initializer, use_bias=False)
  self.convU3 = tf.keras.layers.Conv3DTranspose(filter_base2*8, kernel_size=kernel_size1, strides=2, padding='same', 
                                      kernel_initializer=initializer, use_bias=False)
  self.convU4 = tf.keras.layers.Conv3DTranspose(filter_base2*4, kernel_size=kernel_size1, strides=2, padding='same', 
                                      kernel_initializer=initializer, use_bias=False)
  self.convU5 = tf.keras.layers.Conv3DTranspose(filter_base2*2, kernel_size=kernel_size1, strides=2, padding='same', 
                                      kernel_initializer=initializer, use_bias=False)
  
  self.concat = tf.keras.layers.Concatenate()

  self.last = tf.keras.layers.Conv3DTranspose(OUTPUT_CHANNELS, kernel_size=kernel_size1, strides=2, padding='same', 
                                      kernel_initializer=initializer, use_bias=False)
  
  # self.convU1_batch_norm = tf.keras.layers.BatchNormalization()
  # self.convU1_activation = tf.keras.layers.LeakyReLU()

  def call(self, inputs):

    x1 = self.convD1(inputs)
    x1 = self.conv_batch_norm(x1)
    x1 = self.conv_activation(x1)

    x2 = self.convD2(x1)
    x2 = self.conv_batch_norm(x2)
    x2 = self.conv_activation(x2)

    x3 = self.convD3(x2)
    x3 = self.conv_batch_norm(x3)
    x3 = self.conv_activation(x3)

    x4 = self.convD4(x3)
    x4 = self.conv_batch_norm(x4)
    x4 = self.conv_activation(x4)

    x5 = self.convD5(x4)
    x5 = self.conv_batch_norm(x5)
    x5 = self.conv_activation(x5)

    x6 = self.convD6(x5)
    x6 = self.conv_batch_norm(x6)
    x6 = self.conv_activation(x6)

    x7 = self.convU1(x6)
    x7 = self.conv_batch_norm(x7)
    x7 = self.conv_batch_norm(x7)

    
    x8 = self.concat([x7, x6])



NameError: ignored