In [67]:
import os
import numpy as np
from random import sample, seed
from tqdm import tqdm
import tensorflow as tf 
os.environ['AUTOGRAPH_VERBOSITY'] = "0"
from tensorflow import keras
from tensorflow.keras import layers, Sequential, Model
from tensorflow.keras.layers import Dense, Conv3D, Conv1D, Conv3DTranspose, Flatten, Reshape, Input, BatchNormalization, GlobalAveragePooling3D, Dropout
from tensorflow.keras.activations import relu, sigmoid
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.optimizers import Adam, SGD
from utils.layers import *

In [33]:
def _buildEncoder(input_shape, filters=[32, 64, 128], last_activation=relu):

    input = Input(shape=input_shape)
    x = Conv3D(filters=filters[0], kernel_size=5, strides=(2,2,2), padding="SAME")(input)
    x = BatchNormalization()(x)
    x = relu(x)
    for i, ft in enumerate(filters[1:]):
        if i == len(filters[1:])-1:
            x = residual_block(x, filters = ft, kernel_size= 3,  
                        strides = (2,2,2), padding = "SAME", activate=relu)
        else:
            x = residual_block(x, filters = ft, kernel_size= 3,  
                        strides = (2,2,2), padding = "SAME", activate=last_activation)

    encoder = Model(inputs=input, outputs=x)         
    return encoder

In [34]:
encoder = _buildEncoder([15,60,60,1], filters=[32, 64, 128])

In [35]:
a = tf.ones(shape=[8,15,60,60,1], dtype=tf.float32)
b = encoder(a)
b.shape

TensorShape([8, 2, 8, 8, 128])

In [49]:
def _buildDecoder(input_shape, filters=[16, 32, 64, 128], last_activation=relu):
    input = Input(shape=input_shape)
    x = Conv3DTranspose(filters=filters[-1], kernel_size=3, strides=(2,)*3, padding="SAME", activation='relu')(input)
    x = BatchNormalization()(x)
    for i, ft in enumerate(filters[-2::-1]):
        if i != len(filters[-2::-1])-1:
            x = resTP_block(x, filters=ft, strides=(2,2,2),padding="SAME")
        else:
            x = resTP_block(x, filters=ft, strides=(2,2,2),padding="SAME", activation=last_activation)
    x = x[:, :15, 2:62, 2:62, :]
    decoder = Model(inputs=input, outputs=x)
    return decoder

In [50]:
encoder.output.shape

TensorShape([None, 2, 8, 8, 128])

In [51]:
decoder = _buildDecoder([1, 4, 4, 128])

In [52]:
decoder.summary()

Model: "functional_17"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_12 (InputLayer)           [(None, 1, 4, 4, 128 0                                            
__________________________________________________________________________________________________
conv3d_transpose_93 (Conv3DTran (None, 2, 8, 8, 128) 442496      input_12[0][0]                   
__________________________________________________________________________________________________
batch_normalization_107 (BatchN (None, 2, 8, 8, 128) 512         conv3d_transpose_93[0][0]        
__________________________________________________________________________________________________
conv3d_transpose_94 (Conv3DTran (None, 4, 16, 16, 64 221248      batch_normalization_107[0][0]    
______________________________________________________________________________________

In [49]:
x.shape

AttributeError: 'NoneType' object has no attribute 'shape'

In [48]:
conv = Conv3DTranspose(12, 3, strides=(2,2,2), padding="SAME")

conv(b).shape

TensorShape([8, 4, 16, 16, 12])

In [68]:
def _buildDiscriminator(input_shape, filters=[16, 32, 64, 128], last_activation=relu):

    input = Input(shape=input_shape)
    x = Conv3D(filters=filters[0], kernel_size=5, strides=(2,2,2), padding="SAME")(input)
    x = BatchNormalization()(x)
    x = relu(x)
    for i, ft in enumerate(filters[1:]):
        if i == len(filters[1:])-1:
            x = residual_block(x, filters = ft, kernel_size= 3,  
                        strides = (2,2,2), padding = "SAME", activate=relu)
        else:
            x = residual_block(x, filters = ft, kernel_size= 3,  
                        strides = (2,2,2), padding = "SAME", activate=last_activation)

    x = GlobalAveragePooling3D()(x)
    x = Flatten()(x)
    x = Dense(128)(x)
    x = Dropout(0.7)(x)
    x = Dense(128)(x)
    x = Dropout(0.7)(x)
    x = Dense(1, activation=sigmoid)(x)
    discriminator = Model(inputs=input, outputs=x) 

    return discriminator

In [69]:
ds = _buildDiscriminator([15, 60, 60, 1])

In [70]:
ds(tf.ones(shape=(2,8,8,1)))

<tf.Tensor: shape=(2, 1), dtype=float32, numpy=
array([[0.5013346],
       [0.5013346]], dtype=float32)>

In [62]:
ds.summary()

Model: "functional_19"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_15 (InputLayer)           [(None, 2, 8, 8, 128 0                                            
__________________________________________________________________________________________________
conv3d_40 (Conv3D)              (None, 1, 4, 4, 16)  256016      input_15[0][0]                   
__________________________________________________________________________________________________
batch_normalization_137 (BatchN (None, 1, 4, 4, 16)  64          conv3d_40[0][0]                  
__________________________________________________________________________________________________
tf_op_layer_Relu_84 (TensorFlow [(None, 1, 4, 4, 16) 0           batch_normalization_137[0][0]    
______________________________________________________________________________________