In [26]:
import datetime
import matplotlib.pyplot as plt
import numpy as np
import scipy
import sys
import os

from data_helper import predict_15k, save_hist, save_model

from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, History

In [None]:
import keras
from keras import backend as K
#from keras.datasets import mnist
from keras_contrib.layers.normalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate, Add, Lambda
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

def lambda_output(input_shape):
    return input_shape[:2]

def minb_disc(x):
    diffs = K.expand_dims(x, 3) - K.expand_dims(K.permute_dimensions(x, [1, 2, 0]), 0)
    abs_diffs = K.sum(K.abs(diffs), 2)
    x = K.sum(K.exp(-abs_diffs), 2)

    return x

def generate_patch_gan_loss(last_disc_conv_layer, patch_dim, input_layer, nb_patches):

    # generate a list of inputs for the different patches to the network
    list_input = [Input(shape=patch_dim, name="patch_gan_input_%s" % i) for i in range(nb_patches)]

    # get an activation
    x_flat = Flatten()(last_disc_conv_layer)
    x = Dense(2, activation='softmax', name="disc_dense")(x_flat)

    patch_gan = Model(inputs=[input_layer], outputs=[x, x_flat], name="patch_gan")

    # generate individual losses for each patch
    x = [patch_gan(patch)[0] for patch in list_input]
    x_mbd = [patch_gan(patch)[1] for patch in list_input]

    # merge layers if have multiple patches (aka perceptual loss)
    if len(x) > 1:
        #x = merge(x, mode="concat", name="merged_features")
        x = Concatenate(name="merged_features")(x)
    else:
        x = x[0]

    # merge mbd if needed
    # mbd = mini batch discrimination
    # https://arxiv.org/pdf/1606.03498.pdf
    if len(x_mbd) > 1:
        #x_mbd = merge(x_mbd, mode="concat", name="merged_feature_mbd")
        x_mbd = Concatenate(name="merged_feature_mbd")(x_mbd)
    else:
        x_mbd = x_mbd[0]

    num_kernels = 100
    dim_per_kernel = 5

    M = Dense(num_kernels * dim_per_kernel, use_bias=False, activation=None)
    MBD = Lambda(minb_disc, output_shape=lambda_output)

    x_mbd = M(x_mbd)
    x_mbd = Reshape((num_kernels, dim_per_kernel))(x_mbd)
    x_mbd = MBD(x_mbd)
    
    #x = merge([x, x_mbd], mode='concat')
    x = Concatenate()([x, x_mbd])

    x_out = Dense(2, activation="softmax", name="disc_output")(x)

    discriminator = Model(inputs=list_input, outputs=[x_out], name='discriminator_nn')
    return discriminator

def res_block(x, nb_filters, strides):
    res_path = BatchNormalization()(x)
    res_path = Activation(activation='relu')(res_path)
    
    res_path = Conv2D(filters=nb_filters[0], kernel_size=(3, 3), padding='same', strides=strides[0])(res_path)
    res_path = BatchNormalization()(res_path)
    res_path = Activation(activation='relu')(res_path)
    res_path = Conv2D(filters=nb_filters[1], kernel_size=(3, 3), padding='same', strides=strides[1])(res_path)

    shortcut = Conv2D(nb_filters[1], kernel_size=(1, 1), strides=strides[0])(x)
    shortcut = BatchNormalization()(shortcut)

    res_path = Add()([shortcut, res_path])
    return res_path

def decoder(x, from_encoder):
    main_path = UpSampling2D(size=(2, 2))(x)
    main_path = Concatenate(axis=3)([main_path, from_encoder[2]])
    main_path = res_block(main_path, [128, 128], [(1, 1), (1, 1)])

    main_path = UpSampling2D(size=(2, 2))(main_path) 
    main_path = Concatenate(axis=3)([main_path, from_encoder[1]])
    main_path = res_block(main_path, [64, 64], [(1, 1), (1, 1)])

    main_path = UpSampling2D(size=(2, 2))(main_path)
    main_path = Concatenate(axis=3)([main_path, from_encoder[0]])
    main_path = res_block(main_path, [32, 32], [(1, 1), (1, 1)])

    return main_path

def encoder(x):
    to_decoder = []

    main_path = Conv2D(filters=32, kernel_size=(3, 3), padding='same', strides=(1, 1))(x)
    main_path = BatchNormalization()(main_path)
    main_path = Activation(activation='relu')(main_path)
    main_path = Conv2D(filters=32, kernel_size=(3, 3), padding='same', strides=(1, 1))(main_path)

    shortcut = Conv2D(filters=32, kernel_size=(1, 1), strides=(1, 1))(x)
    shortcut = BatchNormalization()(shortcut)

    main_path = Add()([shortcut, main_path])
    # first branching to decoder
    to_decoder.append(main_path)

    main_path = res_block(main_path, [64, 64], [(2, 2), (1, 1)])
    to_decoder.append(main_path)

    main_path = res_block(main_path, [128, 128], [(2, 2), (1, 1)])
    to_decoder.append(main_path)

    return to_decoder


def build_res_unet(input_shape):
    inputs = Input(shape=input_shape)

    to_decoder = encoder(inputs)

    path = res_block(to_decoder[2], [256, 256], [(2, 2), (1, 1)]) # 3x
    
    path = res_block(path, [256, 256], [(1, 1), (1, 1)]) # Yu.add - in 2018-12-02 16-09-04_15 only once

    path = decoder(path, from_encoder=to_decoder)
    
    path = Conv2D(filters=1, kernel_size=(1, 1), activation='sigmoid')(path) 

    return Model(input=inputs, output=path)

class EL_GAN(): # Based on pix2pix
    def __init__(self):

        # Input shape
        self.img_rows = 128
        self.img_cols = 128
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        # Configure data loader
        self.dataset_name = 'mapgen'

        # Calculate output shape of D (PatchGAN)
        patch = int(self.img_rows / 2**4)
        self.disc_patch = (patch, patch, 1)
        
        # Calculate output shape of D (PatchGAN) better version
        self.patch_size = 32
        self.nb_patches = int((self.img_rows / self.patch_size) * (self.img_cols / self.patch_size))
        self.patch_gan_dim = (self.patch_size, self.patch_size, self.channels)
        
        # Number of filters in the first layer of G and D
        self.gf = 64
        self.df = 64

        optimizer = Adam(0.0002, 0.5)
        #optimizer = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # An old version of Pix2pix

        #-------------------------
        # Construct Computational
        #   Graph of Generator
        #-------------------------

        # Build the generator
        #self.generator = self.build_generator() # Old generator from 
        self.generator = self.build_res_unet_generator()

        # Input images and their conditioning images
        img_A = Input(shape=self.img_shape) # Target
        img_B = Input(shape=self.img_shape) # Input

        # By conditioning on B generate a fake version of A
        fake_A = self.generator(img_B)
        
        # Build and compile the discriminator
        #self.discriminator = self.build_discriminator()
        self.discriminator = self.build_2head_discriminator
        self.discriminator.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
        
        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # Discriminators determines validity of translated images / condition pairs
        #valid = self.discriminator([fake_A, img_B])
        valid = self.discriminator([fake_A])

        self.combined = Model(inputs=[img_A, img_B], outputs=[valid, fake_A])
        
        # Original Pix2Pix - low weight for discriminator
        self.combined.compile(loss=['mse', 'mae'],
                              loss_weights=[1, 100],
                              optimizer=optimizer, metrics=['accuracy'])

    
    def build_res_unet_generator(self):
        """Residual U-Net Generator"""
        
        inputs = Input(shape=self.img_shape)
        to_decoder = encoder(inputs)
        path = res_block(to_decoder[2], [256, 256], [(2, 2), (1, 1)]) # 3x
        path = res_block(path, [256, 256], [(1, 1), (1, 1)])
        path = decoder(path, from_encoder=to_decoder)
        path = Conv2D(filters=1, kernel_size=(1, 1), activation='sigmoid')(path) 

        return Model(input=inputs, output=path)
        
    def build_generator(self):
        """U-Net Generator"""

        def conv2d(layer_input, filters, f_size=4, bn=True):
            """Layers used during downsampling"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d

        def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
            """Layers used during upsampling"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
            if dropout_rate:
                u = Dropout(dropout_rate)(u)
            u = BatchNormalization(momentum=0.8)(u)
            u = Concatenate()([u, skip_input])
            return u

        # Image input
        d0 = Input(shape=self.img_shape)

        # Downsampling
        d1 = conv2d(d0, self.gf, bn=False)
        d2 = conv2d(d1, self.gf*2)
        d3 = conv2d(d2, self.gf*4)
        d4 = conv2d(d3, self.gf*8)
        d5 = conv2d(d4, self.gf*8)
        d6 = conv2d(d5, self.gf*8)
        d7 = conv2d(d6, self.gf*8)

        # Upsampling
        u1 = deconv2d(d7, d6, self.gf*8)
        u2 = deconv2d(u1, d5, self.gf*8)
        u3 = deconv2d(u2, d4, self.gf*8)
        u4 = deconv2d(u3, d3, self.gf*4)
        u5 = deconv2d(u4, d2, self.gf*2)
        u6 = deconv2d(u5, d1, self.gf)

        u7 = UpSampling2D(size=2)(u6)
        output_img = Conv2D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u7)

        return Model(d0, output_img)
    
    def build_PatchGanDiscriminator(self):
        """
        Creates the generator according to the specs in the paper below.
        [https://arxiv.org/pdf/1611.07004v1.pdf][5. Appendix]

        PatchGAN only penalizes structure at the scale of patches. This
        discriminator tries to classify if each N x N patch in an
        image is real or fake. We run this discriminator convolutationally
        across the image, averaging all responses to provide
        the ultimate output of D.

        The discriminator has two parts. First part is the actual discriminator
        seconds part we make it a PatchGAN by running each image patch through the model
        and then we average the responses

        Discriminator does the following:
        1. Runs many pieces of the image through the network
        2. Calculates the cost for each patch
        3. Returns the avg of the costs as the output of the network

        :param patch_dim: (channels, width, height) T
        :param nb_patches:
        :return:
        """
        # -------------------------------
        # DISCRIMINATOR
        # C64-C128-C256-C512-C512-C512 (for 256x256)
        # otherwise, it scales from 64
        # 1 layer block = Conv - BN - LeakyRelu
        # -------------------------------
        
        output_img_dim = self.img_shape
        patch_dim = self.patch_gan_dim
        input_layer = Input(shape=patch_dim)
        
        # We have to build the discriminator dinamically because
        # the size of the disc patches is dynamic
        num_filters_start = self.gf
        nb_conv = int(np.floor(np.log(output_img_dim[1]) / np.log(2)))
        filters_list = [num_filters_start * min(8, (2 ** i)) for i in range(nb_conv)]
        
        # CONV 1
        # Do first conv bc it is different from the rest
        # paper skips batch norm for first layer
        disc_out = Conv2D(filters=64, kernel_size=(4, 4), padding='same', strides=(2, 2), name='disc_conv_1')(input_layer)
        disc_out = LeakyReLU(alpha=0.2)(disc_out)
        
        # CONV 2 - CONV N
        # do the rest of the convs based on the sizes from the filters
        for i, filter_size in enumerate(filters_list[1:]):
            name = 'disc_conv_{}'.format(i+2)

            disc_out = Conv2D(filters=filter_size, kernel_size=(4, 4), padding='same', strides=(2, 2), name=name)(disc_out)
            disc_out = BatchNormalization(name=name + '_bn')(disc_out)
            disc_out = LeakyReLU(alpha=0.2)(disc_out)
        
        # ------------------------
        # BUILD PATCH GAN
        # this is where we evaluate the loss over each sublayer of the input
        # ------------------------
        patch_gan_discriminator = generate_patch_gan_loss(last_disc_conv_layer=disc_out,
                                                          patch_dim=patch_dim,
                                                          input_layer=input_layer,
                                                          nb_patches=nb_patches)
        return patch_gan_discriminator
    
    def build_2head_discriminator(self):
        
        def d_layer(img_A, img_B, filters, f_size=3, bn=True, cname = 'c1'): # Chnaged here for the order of bn and activation
            """Discriminator layer"""
            
            conv1 = Conv2D(filters, kernel_size=3, strides=2, padding='same', name = cname)
            batch1 = BatchNormalization(momentum=0.8)
            act1 = Activation(activation='relu')

            c11 = conv1(img_A)
            if bn:
                c11 = batch1(c11)
            c11 = act1(c11)

            s11 = conv1(img_B)
            if bn:
                s11 = batch1(s11)
            s11 = act1(s11)

            return c11, s11
        
        df = self.df
        
        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)
        
        c11, s11 = d_layer(img_A, img_B, df, bn=False, cname = 'c1')
        c22, s22 = d_layer(c11, s11, df*2, cname = 'c2')
        c33, s33 = d_layer(c22, s22, df*4, cname = 'c3')
        c44, s44 = d_layer(c33, s33, df*8, cname = 'c4')
        
        emb_A = Flatten(name = 'embeddingA')(c44)
        prediction_A = Dense(1, activation='softmax')(emb_A)
        validity_A = Conv2D(1, kernel_size=3, strides=1, padding='same')(c44)
        
        emb_B = Flatten(name = 'embeddingB')(s44)
        prediction_B = Dense(1, activation='softmax')(emb_B)
        validity_B = Conv2D(1, kernel_size=3, strides=1, padding='same')(s44)
        
        model = Model(inputs= [img_A, img_B], outputs=[validity_A, validity_B, prediction_A, prediction_B])
        
        return model
        
    
    def build_discriminator(self):

        def d_layer(layer_input, filters, f_size=3, bn=True): # Chnaged here for the order of bn and activation
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            d = Activation(activation='relu')(d)
            #d = LeakyReLU(alpha=0.2)(d)
            return d

        img_A = Input(shape=self.img_shape)
        #img_B = Input(shape=self.img_shape)

        ## Concatenate image and conditioning image by channels to produce input
        #combined_imgs = Concatenate(axis=-1)([img_A, img_B])

        #d1 = d_layer(combined_imgs, self.df, bn=False)
        
        d1 = d_layer(img_A, self.df, bn=False)
        d2 = d_layer(d1, self.df*2)
        d3 = d_layer(d2, self.df*4)
        d4 = d_layer(d3, self.df*8)

        validity = Conv2D(1, kernel_size=3, strides=1, padding='same')(d4)

        return Model([img_A], validity)
    
    def train_generator_only(self, x_train_sim, y_train_sim, x_test_sim, y_test_sim, outPath):
        
        start_time = datetime.datetime.now()
        
        data_gen_args = dict(rotation_range=180.)
        image_datagen = ImageDataGenerator(**data_gen_args)
        mask_datagen = ImageDataGenerator(**data_gen_args)
        
        seed = 1
        BATCH_SIZE = 16
        result_generator = zip(image_datagen.flow(x_train_sim, batch_size=BATCH_SIZE, seed=seed), 
                               mask_datagen.flow(y_train_sim, batch_size=BATCH_SIZE, seed=seed))
        
        History1 = History()
        hist1 = self.generator.fit_generator( result_generator,
                                              epochs = 100,
                                              steps_per_epoch=2000,
                                              verbose=1,
                                              shuffle=True,
                                              callbacks=[History1, 
                                                         EarlyStopping(patience=5), 
                                                         ReduceLROnPlateau(patience = 3, verbose = 0),
                                                         ModelCheckpoint(outPath + "weights.hdf5", 
                                                                         save_best_only = True, 
                                                                         save_weights_only = False)],
                                              validation_data=(x_test_sim, y_test_sim))
        save_hist(History1, outPath)
        
    
    def train(self, x_train_sim, y_train_sim, x_test_sim, y_test_sim, outPath, epochs, batch_size=1, sample_interval=50):

        start_time = datetime.datetime.now()

        # Adversarial loss ground truths
        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)
        
        total_samples = len(x_train_sim)
        ids = np.arange(total_samples)
        np.random.shuffle(ids)
        n_batches = int(total_samples / batch_size)
        
        for epoch in range(epochs):
            for batch_i, (imgs_A, imgs_B) in enumerate(load_batch(x_train_sim, y_train_sim, batch_size)):

                # ---------------------
                #  Train Discriminator
                # ---------------------

                # Condition on B and generate a translated version
                fake_A = self.generator.predict(imgs_B)
                
                # Train the discriminators (original images = real / generated = Fake)
                #d_loss_real = self.discriminator.train_on_batch([imgs_A, imgs_B], valid)
                #d_loss_fake = self.discriminator.train_on_batch([fake_A, imgs_B], fake)
                d_loss_real = self.discriminator.train_on_batch([imgs_A], valid)
                d_loss_fake = self.discriminator.train_on_batch([fake_A], fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

                # -----------------
                #  Train Generator
                # -----------------

                # Train the generators
                g_loss = self.combined.train_on_batch([imgs_A, imgs_B], [valid, imgs_A])

                elapsed_time = datetime.datetime.now() - start_time

                # Plot the progress
                if batch_i % sample_interval == 0:
                    
                    valid_test = np.ones((len(x_test_sim),) + self.disc_patch)
                    t_loss = self.combined.evaluate([y_test_sim, x_test_sim], [valid_test, y_test_sim], verbose=0)
                    
                    print ("[Epoch %d/%d-%d/%d] [D loss&acc: %.3f, %.3f%%] [G loss&accA&accB: %.3f, %.3f%%, %.3f%%] [Test loss&acc: %.3f, %.3f%%, %.3f%%] time: %s" % (epoch, epochs,
                                                                                batch_i, n_batches,
                                                                                d_loss[0], 100*d_loss[1],
                                                                                g_loss[2], 100*g_loss[3], 100*g_loss[4],
                                                                                t_loss[2], 100*t_loss[3], 100*t_loss[4],
                                                                                elapsed_time))                 

                # If at save interval => save generated image samples
                if batch_i % sample_interval == 0:
                    self.sample_images(outPath, epoch, batch_i)


    def sample_images(self, outPath, epoch, batch_i, examples = [0, 77, 34]):
        
        r, c = 3, 3
        p_size_1 = 128
        
        imgs_A = y_test_sim[examples]
        imgs_B = x_test_sim[examples]
        
        fake_A = gan.generator.predict(imgs_B)

        gen_imgs = np.concatenate([imgs_B, fake_A, imgs_A])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        titles = ['Input', 'Generated', 'Target']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                gen = np.reshape(gen_imgs[cnt], (p_size_1,p_size_1))
                axs[i,j].imshow(gen)
                
                #axs[i,j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[i])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig(outPath + "%d_%d.png" % (epoch, batch_i),
                   format='png', transparent=True, dpi=300)
        plt.close()

In [None]:

def model(input_shape):
#input_shape = (2048, 3, 1)
    visible1 = Input(shape=input_shape, name = 'input1')
    visible2 = Input(shape=input_shape, name = 'input2')
#    print(visible)
    conv2d_1 = Convolution2D(64, (1,3),  activation='relu', name ='c1') 
    c11  = conv2d_1(visible1)
    a1 = conv2d_1(visible2)
    c11 = BatchNormalization()(c11)
        
    conv2d_2 = Convolution2D(64, (1, 1), strides=1, activation='relu', name = 'c2')
    c12 = conv2d_2(c11)
 
    a2 = conv2d_2(a1)
    c12 = BatchNormalization()(c12)  
    
            
    conv2d_3 = Convolution2D(64, (1, 1), strides =1, activation='relu', name = 'c3')
    c13 = conv2d_3(c12)

    a3 = conv2d_3(a2)
    c13 = BatchNormalization()(c13)
    
    conv2d_4 = Convolution2D(128, (1, 1), strides =1, activation='relu', name = 'c4')
    c14 = conv2d_4(c13)

    a4 = conv2d_4(a3)
    c14 = BatchNormalization()(c14)
    
    conv2d_5 = Convolution2D(1024, (1, 1), strides =1, activation='relu', name = 'c5')
    c15 = conv2d_5(c14)
    c15 = BatchNormalization()(c15)
    local_feat = c15
    global_feat = MaxPooling2D(pool_size= (2048, 1))(c15)
    
#    print (global_feat)
#    print (keras.backend.is_keras_tensor(global_feat))
    tile = Lambda(lambda x: keras.backend.tile(x, (1, 2048, 1, 1)))(global_feat)
#    print (keras.backend.is_keras_tensor(tile))
#    print (tile)
    
    all_feat = keras.layers.Concatenate(axis = -1)([local_feat, tile])
#    print (all_feat)
    
    # point_net_cls
    conv2d6 = Convolution2D(512, (1, 1), strides =1, activation='relu', name = 'c6')
    c21 = conv2d6(all_feat)
#    print(c21)
    c21 = BatchNormalization()(c21)      
    c22 = Convolution2D(256, (1, 1), strides =1, activation='relu', name = 'c7')(c21)
    c22 = BatchNormalization()(c22)
    c23 = Convolution2D(128, (1, 1), strides =1, activation='relu', name = 'c8')(c22)
    c23 = BatchNormalization()(c23)
    c24 = Convolution2D(128, (1, 1), strides =1, activation='relu', name = 'c9')(c23)
    c24 = BatchNormalization()(c24)
    #print (c24)
    prediction = Convolution2D(2, (1, 1), strides =1, activation='softmax', name = 'output')(c24)
#    print(prediction)
    # --------------------------------------------------end of pointnet
    
    
    a5 = Flatten()(a4)
    a6 = Dense(64, name = 'output2')(a5)
    
    # print the model summary
    model = Model(inputs= [visible1, visible2], outputs=[prediction, a6])
           
#   
    return model
#    
#model = Model(inputs=visible, outputs=prediction)
bsp_model = model((2048,3,1))
bsp_model.summary()


In [57]:
def build_2head_discriminator(img_shape, df = 64):
        
        def d_layer(img_A, img_B, filters, f_size=3, bn=True, cname = 'c1'): # Chnaged here for the order of bn and activation
            """Discriminator layer"""
            
            conv1 = Conv2D(filters, kernel_size=3, strides=2, padding='same', name = cname)
            batch1 = BatchNormalization(momentum=0.8)
            act1 = Activation(activation='relu')

            c11 = conv1(img_A)
            if bn:
                c11 = batch1(c11)
            c11 = act1(c11)

            s11 = conv1(img_B)
            if bn:
                s11 = batch1(s11)
            s11 = act1(s11)

            return c11, s11
        
        def d_layers(img_A):
            d1 = d_layer(img_A, self.df, bn=False)
            d2 = d_layer(d1, self.df*2)
            d3 = d_layer(d2, self.df*4)
            d4 = d_layer(d3, self.df*8)
            d5 = Flatten()(d4)
            d6 = Dense(128, activation='softmax')(d5)
            
            return Model(img_A, d6)
        
        img_A = Input(shape=img_shape)
        img_B = Input(shape=img_shape)
        
        c11, s11 = d_layer(img_A, img_B, df, bn=False, cname = 'c1')
        c22, s22 = d_layer(c11, s11, df*2, cname = 'c2')
        c33, s33 = d_layer(c22, s22, df*4, cname = 'c3')
        c44, s44 = d_layer(c33, s33, df*8, cname = 'c4')
        
        emb_A = Flatten(name = 'embeddingA')(c44)
        prediction_A = Dense(1, activation='softmax')(emb_A)
        validity_A = Conv2D(1, kernel_size=3, strides=1, padding='same')(c44)
        
        emb_B = Flatten(name = 'embeddingB')(s44)
        prediction_B = Dense(1, activation='softmax')(emb_B)
        validity_B = Conv2D(1, kernel_size=3, strides=1, padding='same')(s44)
        
        model = Model(inputs= [img_A, img_B], outputs=[validity_A, validity_B, prediction_A, prediction_B])
        
        return model

In [58]:
bsp_model = build_2head_discriminator((128, 128, 1), df = 64)
bsp_model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_35 (InputLayer)           (None, 128, 128, 1)  0                                            
__________________________________________________________________________________________________
input_36 (InputLayer)           (None, 128, 128, 1)  0                                            
__________________________________________________________________________________________________
c1 (Conv2D)                     (None, 64, 64, 64)   640         input_35[0][0]                   
                                                                 input_36[0][0]                   
__________________________________________________________________________________________________
activation_57 (Activation)      (None, 64, 64, 64)   0           c1[0][0]                         
          