Import all necessary functions

In [None]:

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow import keras
from keras import backend as K
from random import random
import numpy as np
from numpy import load
from numpy import zeros
from numpy import ones
from numpy import asarray
from numpy.random import randint
from numpy import cov
from numpy import trace
from numpy import iscomplexobj
from numpy import asarray
from numpy import vstack

from keras.optimizers import Adam
from keras.initializers import RandomNormal
from keras.models import Model
from keras.models import Sequential
from keras.models import Input
from keras.layers import Conv2D, Conv2DTranspose, LeakyReLU, Activation, Concatenate
from keras.layers import Input, Dense, Add, Dot, Reshape, Flatten, BatchNormalization, Lambda, Softmax, Embedding, Multiply, Add
from matplotlib import pyplot
from keras.applications.inception_v3 import InceptionV3
from keras.applications.inception_v3 import preprocess_input
from time import time
import os
import functools

from scipy.linalg import sqrtm
from skimage.transform import resize
from tensorflow.keras.models import load_model
from tensorflow.python.ops import array_ops

from keras.engine.base_layer import Layer, InputSpec
from keras.engine import *
from keras.legacy import interfaces
from keras import activations
from keras import initializers
from keras import regularizers
from keras import constraints
from keras.utils.generic_utils import func_dump
from keras.utils.generic_utils import func_load
from keras.utils.generic_utils import deserialize_keras_object
from keras.utils.generic_utils import has_arg
from keras.utils import conv_utils
from keras.models import load_model
from random import randint, shuffle, uniform
import glob
import time
import warnings
from PIL import Image
from random import randint, shuffle, uniform

from keras.models import Sequential, Model
from keras.layers import Conv2D, ZeroPadding2D, BatchNormalization, Input, Dropout
from keras.layers import Conv2DTranspose, Reshape, Activation, Cropping2D, Flatten
from keras.layers import Concatenate
from keras.layers.advanced_activations import LeakyReLU
from keras.activations import relu
from keras.initializers import RandomNormal
from keras.preprocessing import image

In [None]:
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
sess = tf.Session(config=config)
K.set_session(sess)

Define Spectral Normalization, Self Attention Gamma Layer functions

In [None]:

class DenseSN(Dense):
    def build(self, input_shape):
        assert len(input_shape) >= 2
        input_dim = input_shape[-1]
        self.kernel = self.add_weight(shape=(input_dim, self.units),
                                      initializer=self.kernel_initializer,
                                      name='kernel',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)
        if self.use_bias:
            self.bias = self.add_weight(shape=(self.units,),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None
        self.u = self.add_weight(shape=tuple([1, self.kernel.shape.as_list()[-1]]),
                                 initializer=initializers.RandomNormal(0, 1),
                                 name='sn',
                                 trainable=False)
        self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
        self.built = True
        
    def call(self, inputs, training=None):
        def _l2normalize(v, eps=1e-12):
            return v / (K.sum(v ** 2) ** 0.5 + eps)
        def power_iteration(W, u):
            _u = u
            _v = _l2normalize(K.dot(_u, K.transpose(W)))
            _u = _l2normalize(K.dot(_v, W))
            return _u, _v
        W_shape = self.kernel.shape.as_list()
        #Flatten the Tensor
        W_reshaped = K.reshape(self.kernel, [-1, W_shape[-1]])
        _u, _v = power_iteration(W_reshaped, self.u)
        #Calculate Sigma
        sigma=K.dot(_v, W_reshaped)
        sigma=K.dot(sigma, K.transpose(_u))
        #normalize it
        W_bar = W_reshaped / sigma
        #reshape weight tensor
        if training in {0, False}:
            W_bar = K.reshape(W_bar, W_shape)
        else:
            with tf.control_dependencies([self.u.assign(_u)]):
                 W_bar = K.reshape(W_bar, W_shape)  
        output = K.dot(inputs, W_bar)
        if self.use_bias:
            output = K.bias_add(output, self.bias, data_format='channels_last')
        if self.activation is not None:
            output = self.activation(output)
        return output 
        
class _ConvSN(Layer):

    def __init__(self, rank,
                 filters,
                 kernel_size,
                 strides=1,
                 padding='valid',
                 data_format=None,
                 dilation_rate=1,
                 activation=None,
                 use_bias=True,
                 kernel_initializer='glorot_uniform',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None,
                 bias_constraint=None,
                 spectral_normalization=True,
                 **kwargs):
        super(_ConvSN, self).__init__(**kwargs)
        self.rank = rank
        self.filters = filters
        self.kernel_size = conv_utils.normalize_tuple(kernel_size, rank, 'kernel_size')
        self.strides = conv_utils.normalize_tuple(strides, rank, 'strides')
        self.padding = conv_utils.normalize_padding(padding)
        self.data_format = conv_utils.normalize_data_format(data_format)
        self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, rank, 'dilation_rate')
        self.activation = activations.get(activation)
        self.use_bias = use_bias
        self.kernel_initializer = initializers.get(kernel_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)
        self.activity_regularizer = regularizers.get(activity_regularizer)
        self.kernel_constraint = constraints.get(kernel_constraint)
        self.bias_constraint = constraints.get(bias_constraint)
        self.input_spec = InputSpec(ndim=self.rank + 2)
        self.spectral_normalization = spectral_normalization
        self.u = None
        
    def _l2normalize(self, v, eps=1e-12):
        return v / (K.sum(v ** 2) ** 0.5 + eps)
    
    def power_iteration(self, u, W):
        '''
        Accroding the paper, we only need to do power iteration one time.
        '''
        v = self._l2normalize(K.dot(u, K.transpose(W)))
        u = self._l2normalize(K.dot(v, W))
        return u, v
    def build(self, input_shape):
        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1
        if input_shape[channel_axis] is None:
            raise ValueError('The channel dimension of the inputs '
                             'should be defined. Found `None`.')
        input_dim = input_shape[channel_axis]
        kernel_shape = self.kernel_size + (input_dim, self.filters)

        self.kernel = self.add_weight(shape=kernel_shape,
                                      initializer=self.kernel_initializer,
                                      name='kernel',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)

        #Spectral Normalization
        if self.spectral_normalization:
            self.u = self.add_weight(shape=tuple([1, self.kernel.shape.as_list()[-1]]),
                                     initializer=initializers.RandomNormal(0, 1),
                                     name='sn',
                                     trainable=False)
        
        if self.use_bias:
            self.bias = self.add_weight(shape=(self.filters,),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None
        # Set input spec.
        self.input_spec = InputSpec(ndim=self.rank + 2,
                                    axes={channel_axis: input_dim})
        self.built = True

    def call(self, inputs):
        def _l2normalize(v, eps=1e-12):
            return v / (K.sum(v ** 2) ** 0.5 + eps)
        def power_iteration(W, u):
            _u = u
            _v = _l2normalize(K.dot(_u, K.transpose(W)))
            _u = _l2normalize(K.dot(_v, W))
            return _u, _v
        
        if self.spectral_normalization:
            W_shape = self.kernel.shape.as_list()
            #Flatten the Tensor
            W_reshaped = K.reshape(self.kernel, [-1, W_shape[-1]])
            _u, _v = power_iteration(W_reshaped, self.u)
            #Calculate Sigma
            sigma=K.dot(_v, W_reshaped)
            sigma=K.dot(sigma, K.transpose(_u))
            #normalize it
            W_bar = W_reshaped / sigma
            #reshape weight tensor
            if training in {0, False}:
                W_bar = K.reshape(W_bar, W_shape)
            else:
                with tf.control_dependencies([self.u.assign(_u)]):
                    W_bar = K.reshape(W_bar, W_shape)

            #update weitht
            self.kernel = W_bar
        
        if self.rank == 1:
            outputs = K.conv1d(
                inputs,
                self.kernel,
                strides=self.strides[0],
                padding=self.padding,
                data_format=self.data_format,
                dilation_rate=self.dilation_rate[0])
        if self.rank == 2:
            outputs = K.conv2d(
                inputs,
                self.kernel,
                strides=self.strides,
                padding=self.padding,
                data_format=self.data_format,
                dilation_rate=self.dilation_rate)
        if self.rank == 3:
            outputs = K.conv3d(
                inputs,
                self.kernel,
                strides=self.strides,
                padding=self.padding,
                data_format=self.data_format,
                dilation_rate=self.dilation_rate)

        if self.use_bias:
            outputs = K.bias_add(
                outputs,
                self.bias,
                data_format=self.data_format)

        if self.activation is not None:
            return self.activation(outputs)
        return outputs

    def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_last':
            space = input_shape[1:-1]
            new_space = []
            for i in range(len(space)):
                new_dim = conv_utils.conv_output_length(
                    space[i],
                    self.kernel_size[i],
                    padding=self.padding,
                    stride=self.strides[i],
                    dilation=self.dilation_rate[i])
                new_space.append(new_dim)
            return (input_shape[0],) + tuple(new_space) + (self.filters,)
        if self.data_format == 'channels_first':
            space = input_shape[2:]
            new_space = []
            for i in range(len(space)):
                new_dim = conv_utils.conv_output_length(
                    space[i],
                    self.kernel_size[i],
                    padding=self.padding,
                    stride=self.strides[i],
                    dilation=self.dilation_rate[i])
                new_space.append(new_dim)
            return (input_shape[0], self.filters) + tuple(new_space)

    def get_config(self):
        config = {
            'rank': self.rank,
            'filters': self.filters,
            'kernel_size': self.kernel_size,
            'strides': self.strides,
            'padding': self.padding,
            'data_format': self.data_format,
            'dilation_rate': self.dilation_rate,
            'activation': activations.serialize(self.activation),
            'use_bias': self.use_bias,
            'kernel_initializer': initializers.serialize(self.kernel_initializer),
            'bias_initializer': initializers.serialize(self.bias_initializer),
            'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
            'bias_regularizer': regularizers.serialize(self.bias_regularizer),
            'activity_regularizer': regularizers.serialize(self.activity_regularizer),
            'kernel_constraint': constraints.serialize(self.kernel_constraint),
            'bias_constraint': constraints.serialize(self.bias_constraint)
        }
        base_config = super(_Conv, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
    
class ConvSN2D(Conv2D):

    def build(self, input_shape):
        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1
        if input_shape[channel_axis] is None:
            raise ValueError('The channel dimension of the inputs '
                             'should be defined. Found `None`.')
        input_dim = input_shape[channel_axis]
        kernel_shape = self.kernel_size + (input_dim, self.filters)

        self.kernel = self.add_weight(shape=kernel_shape,
                                      initializer=self.kernel_initializer,
                                      name='kernel',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)

        if self.use_bias:
            self.bias = self.add_weight(shape=(self.filters,),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None
            
        self.u = self.add_weight(shape=tuple([1, self.kernel.shape.as_list()[-1]]),
                         initializer=initializers.RandomNormal(0, 1),
                         name='sn',
                         trainable=False)
        
      
        
        # Set input spec.
        self.input_spec = InputSpec(ndim=self.rank + 2,
                                    axes={channel_axis: input_dim})
        self.built = True
    def call(self, inputs, training=None):
        def _l2normalize(v, eps=1e-12):
            return v / (K.sum(v ** 2) ** 0.5 + eps)
        def power_iteration(W, u):
            #Accroding the paper, we only need to do power iteration one time.
            _u = u
            _v = _l2normalize(K.dot(_u, K.transpose(W)))
            _u = _l2normalize(K.dot(_v, W))
            return _u, _v
        #Spectral Normalization
        W_shape = self.kernel.shape.as_list()
        #Flatten the Tensor
        W_reshaped = K.reshape(self.kernel, [-1, W_shape[-1]])
        _u, _v = power_iteration(W_reshaped, self.u)
        #Calculate Sigma
        sigma=K.dot(_v, W_reshaped)
        sigma=K.dot(sigma, K.transpose(_u))
        #normalize it
        W_bar = W_reshaped / sigma
        #reshape weight tensor
        if training in {0, False}:
            W_bar = K.reshape(W_bar, W_shape)
        else:
            with tf.control_dependencies([self.u.assign(_u)]):
                W_bar = K.reshape(W_bar, W_shape)
                
        outputs = K.conv2d(
                inputs,
                W_bar,
                strides=self.strides,
                padding=self.padding,
                data_format=self.data_format,
                dilation_rate=self.dilation_rate)
        if self.use_bias:
            outputs = K.bias_add(
                outputs,
                self.bias,
                data_format=self.data_format)
        if self.activation is not None:
            return self.activation(outputs)
        return outputs

    

class ConvSN2DTranspose(Conv2DTranspose):

    def build(self, input_shape):
        if len(input_shape) != 4:
            raise ValueError('Inputs should have rank ' +
                             str(4) +
                             '; Received input shape:', str(input_shape))
        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1
        if input_shape[channel_axis] is None:
            raise ValueError('The channel dimension of the inputs '
                             'should be defined. Found `None`.')
        input_dim = input_shape[channel_axis]
        kernel_shape = self.kernel_size + (self.filters, input_dim)

        self.kernel = self.add_weight(shape=kernel_shape,
                                      initializer=self.kernel_initializer,
                                      name='kernel',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)
        if self.use_bias:
            self.bias = self.add_weight(shape=(self.filters,),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None
            
        self.u = self.add_weight(shape=tuple([1, self.filters]),
                        initializer=initializers.RandomNormal(0, 1),
                        name='sn',
                        trainable=False)
        
        # Set input spec.
        self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim})
        self.built = True  
    
    def call(self, inputs):
        input_shape = K.shape(inputs)
        batch_size = input_shape[0]
        if self.data_format == 'channels_first':
            h_axis, w_axis = 2, 3
        else:
            h_axis, w_axis = 1, 2

        height, width = input_shape[h_axis], input_shape[w_axis]
        kernel_h, kernel_w = self.kernel_size
        stride_h, stride_w = self.strides
        if self.output_padding is None:
            out_pad_h = out_pad_w = None
        else:
            out_pad_h, out_pad_w = self.output_padding

        # Infer the dynamic output shape:
        out_height = conv_utils.deconv_length(dim_size = height,
                                                    kernel_size = kernel_h,
                                                    padding=self.padding,
                                                    output_padding=out_pad_h,
                                                    stride_size=stride_h,
                                                    dilation=self.dilation_rate[0])
        out_width = conv_utils.deconv_length(dim_size = width,
                                                    kernel_size = kernel_w,
                                                    padding=self.padding,
                                                    output_padding=out_pad_w,
                                                    stride_size=stride_w,
                                                    dilation=self.dilation_rate[1])
        if self.data_format == 'channels_first':
            output_shape = (batch_size, self.filters, out_height, out_width)
        else:
            output_shape = (batch_size, out_height, out_width, self.filters)

        output_shape_tensor = array_ops.stack(output_shape)   
        #Spectral Normalization    
        def _l2normalize(v, eps=1e-12):
            return v / (K.sum(v ** 2) ** 0.5 + eps)
        def power_iteration(W, u):
            #Accroding the paper, we only need to do power iteration one time.
            _u = u
            _v = _l2normalize(K.dot(_u, K.transpose(W)))
            _u = _l2normalize(K.dot(_v, W))
            return _u, _v
        W_shape = self.kernel.shape.as_list()
        #Flatten the Tensor
        W_reshaped = K.reshape(self.kernel, [-1, W_shape[-2]])
        _u, _v = power_iteration(W_reshaped, self.u)
        #Calculate Sigma
        sigma=K.dot(_v, W_reshaped)
        sigma=K.dot(sigma, K.transpose(_u))
        #normalize it
        W_bar = W_reshaped / sigma
        #reshape weight tensor
        if training in {0, False}:
            W_bar = K.reshape(W_bar, W_shape)
        else:
            with tf.control_dependencies([self.u.assign(_u)]):
                W_bar = K.reshape(W_bar, W_shape)
        self.kernel = W_bar
        
        outputs = K.conv2d_transpose(
            inputs,
            self.kernel,
            output_shape_tensor,
            self.strides,
            padding=self.padding,
            data_format=self.data_format)

        if self.use_bias:
            outputs = K.bias_add(
                outputs,
                self.bias,
                data_format=self.data_format)

        if self.activation is not None:
            return self.activation(outputs)
        return outputs



class SelfAttentionGamma(Layer):

    def __init__(self,
                 **kwargs):
        super(SelfAttentionGamma, self).__init__(**kwargs)

    def build(self, input_shape):
        assert isinstance(input_shape, list)
        self.gamma = self.add_weight(shape=[1],
                                         name='gamma',
                                         initializer='zeros')
        self.built = True

    def call(self, inputs, training=None):
        # print(inputs[1].shape.as_list())
        # print(inputs[0].shape.as_list())
        return self.gamma*K.reshape(inputs[1], shape = K.shape(inputs[0])) + inputs[0]

    def get_config(self):
        config = {
        }
        base_config = super(SelfAttentionGamma, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        assert isinstance(input_shape, list)
        return input_shape[0]

Load inception network

In [None]:
inception = InceptionV3(include_top=False, pooling='avg', input_shape=(299, 299, 3))

Hyperparameters

In [None]:
currdir = 'C:\\Users\\parit\\Documents\\CycleGAN\\test\\'
datadir = 'C:\\Users\\parit\\Documents\\CycleGAN\\datasets\\cezanne2photo\\'
Disc_learningrate = 2e-4
Gen_learningrate = 2e-4
batch_size = 1
LAMBDACYCLE = 10
LAMBDAID = 5

Calculation of activations from inception network for FID calculation

In [None]:
FID_BATCH_SIZE = 1
def scale_images(images, new_shape):
    images_list = list()
    for image in images:
        # resize with nearest neighbor interpolation
        new_image = resize(image, new_shape, 0)
        # store
        images_list.append(new_image)
    return asarray(images_list)

def inception_activations(images):
    size = 299
    images = images.astype('float32')
    images = scale_images(images, (299,299,3))
    images = preprocess_input(images)
    activations = inception.predict(images)
    return activations
## Batched FID because of my memory constraints
def get_inception_activations(inps):
    n_batches = int(np.ceil(float(inps.shape[0]) / FID_BATCH_SIZE))
    act = np.zeros([inps.shape[0], 2048], dtype = np.float32)
    for i in range(n_batches):
        inp = inps[i * FID_BATCH_SIZE : (i + 1) * FID_BATCH_SIZE]
        act[i * FID_BATCH_SIZE : i * FID_BATCH_SIZE + min(FID_BATCH_SIZE, inp.shape[0])] = inception_activations(inp)
    return act

Calculation of FID

In [None]:
def fidCalculate(g_func, testB, test_mean, test_sigma):
    X_out = g_func.predict(testB)
    X_out = (X_out + 1.) * 127.5
    # print(X_out.shape)
    f2 = get_inception_activations(X_out)
    mean2, sigma2 = f2.mean(axis=0), np.cov(f2, rowvar=False)
    sum_sq_diff = np.sum((test_mean - mean2)**2)
    cov_mean = sqrtm(test_sigma.dot(sigma2))
    if np.iscomplexobj(cov_mean):
        cov_mean = cov_mean.real
    fid = sum_sq_diff + np.trace(test_sigma + sigma2 - 2.0*cov_mean)
    print(fid)
    with open(currdir + "logs/"+"FID_LOGS" + ".txt", "a") as f:
        f.write(str(fid) + '\n')

Save models

In [None]:
def save_models(step, g_model_AtoB, g_model_BtoA, d_model_A, d_model_B):

    filename1 = currdir + "models\\" + 'g_model_AtoB_weights_%06d.h5' % (step+1)
    g_model_AtoB.save_weights(filename1)
    
    filename2 = currdir + "models\\" + 'g_model_BtoA_weights_%06d.h5' % (step+1)
    g_model_BtoA.save_weights(filename2)
    
    filename3 = currdir + "models\\" + 'd_model_A_weights_%06d.h5' % (step+1)
    d_model_A.save_weights(filename3)
    
    filename4 =  currdir + "models\\" + 'd_model_B_weights_%06d.h5' % (step+1)
    d_model_B.save_weights(filename4)
    print('>Saved: models')

Loading images and preprocessing(Random cropping and flipping)

In [None]:
def imageprocess(img_path):   
    img = image.load_img(img_path)
    img = img.resize((143, 143), Image.BILINEAR) # To perform random cropping of size 15
    img_arr = image.img_to_array(img)
    img_norm = np.array(img_arr)/255*2-1
    h1 = (143 - 128)//2
    h2 = (143 + 128)//2
    shift = randint(0,h1)
    h1 = h1 - shift
    h2 = h2 - shift
    w1 = h1
    w2 = h2
    img_cropped = img_norm[h1:h2,w1:w2,:]
    flip = randint(0,1)
    if flip:
        img_cropped = img_cropped[:,::-1]
    return img_cropped

def loadimage(path):
    train_A_paths = glob.glob(path +  "trainA/*.jpg")
    train_B_paths = glob.glob(path +  "trainB/*.jpg")
    test_A_paths = glob.glob(path +  "testA/*.jpg")
    test_B_paths = glob.glob(path +  "testB/*.jpg")
    
    data = []
    for img_path in train_A_paths:
        data.append(imageprocess(img_path))
    train_A = np.float32(data)
    print(train_A.shape)
    del data
#     print(train_A[0])
    data = []
    for img_path in train_B_paths:
        data.append(imageprocess(img_path))
    train_B = np.float32(data)
    del data
    
    data = []
    for img_path in test_A_paths:
        data.append(imageprocess(img_path))
    test_A = np.float32(data)
    del data
    
    data = []
    for img_path in test_B_paths:
        data.append(imageprocess(img_path))
    test_B = np.float32(data)
    del data
    
    return train_A, train_B, test_A, test_B
#     return train_A

Load data

In [None]:
train_A, train_B, test_A, test_B = loadimage(datadir)

Calculate mean and sigma for A domain(artistic) images

In [None]:
fid_pre = vstack(((train_A + 1.)*127.5, (test_A + 1.)*127.5))
f1 = get_inception_activations(fid_pre)
test_mean, test_sigma = f1.mean(axis=0), np.cov(f1, rowvar=False)
del fid_pre, f1

used for selfattention

In [None]:
def mult1(x):
    x1 = x[0]
    x2 = x[1]
    x3 = K.permute_dimensions(x[1],pattern=(0,2,1))
    x4 = K.batch_dot(x1, x3, axes=(2,1))
    return x4

Discriminator Model

In [None]:

# define the discriminator model
def define_discriminator(image_shape = (128, 128, 3)):
    
    init = 'glorot_uniform'
   
    in_image = Input(shape=image_shape)
   
    d = ConvSN2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(in_image)
    d = LeakyReLU(alpha=0.2)(d)
    
    d = ConvSN2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
#     d = BatchNormalization(momentum=0.9, epsilon=1.01e-5)(d, training=1)
    d = LeakyReLU(alpha=0.2)(d)
    
    SAop_f = ConvSN2D(128//8, kernel_size=1, strides=1,kernel_initializer='glorot_uniform', padding='same', activation=None)(d)
    # print(SAop_f.shape.as_list())
    SAop_f_re = Lambda(lambda x:K.reshape(x, shape = (K.shape(x)[0], -1,K.shape(x)[3])))(SAop_f)
    # print(SAop_f_re.shape.as_list())
    SAop_g = ConvSN2D(128//8, kernel_size=1, strides=1,kernel_initializer='glorot_uniform', padding='same', activation=None)(d)
    SAop_g_re = Lambda(lambda x:K.reshape(x, shape = (K.shape(x)[0], K.shape(x)[1]*K.shape(x)[2],K.shape(x)[-1])))(SAop_g)
    SAop_h = ConvSN2D(128, kernel_size=1, strides=1,kernel_initializer='glorot_uniform', padding='same', activation=None)(d)
    SAop_h_re = Lambda(lambda x:K.reshape(x, shape = (K.shape(x)[0], K.shape(x)[1]*K.shape(x)[2],K.shape(x)[-1])))(SAop_h)
    mult = Lambda(lambda x:mult1(x))([SAop_g_re, SAop_f_re])
    # print(mult.shape.as_list())
    attnmap = Softmax(axis=-1)(mult)
    # print(attnmap.shape.as_list())
    SAop = Lambda(lambda x:K.batch_dot(x[0], x[1], axes=(2, 1)))([attnmap, SAop_h_re])
    # print(SAop.shape.as_list())
    d = SelfAttentionGamma()([d, SAop])    
    
    
    
    d = ConvSN2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
#     d = BatchNormalization(momentum=0.9, epsilon=1.01e-5)(d, training=1)
    d = LeakyReLU(alpha=0.2)(d)
    
    d = ConvSN2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
#     d = BatchNormalization(momentum=0.9, epsilon=1.01e-5)(d, training=1)
    d = LeakyReLU(alpha=0.2)(d)
    
    # second last output layer
    d = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
    d = BatchNormalization(momentum=0.9, epsilon=1.01e-5)(d, training=1)
    d = LeakyReLU(alpha=0.2)(d)
    # patch output
    
    patch_out = ConvSN2D(1, (4,4), padding='same', kernel_initializer=init)(d)
    # define model
    model = Model(in_image, patch_out)
    return model

Resnet Block

In [None]:
def resnet_block(n_filters, input_layer):
    
    init = 'glorot_uniform'
    
    g = ConvSN2D(n_filters, (3,3), padding='same', kernel_initializer=init)(input_layer)
    # g = BatchNormalization(momentum=0.9, epsilon=1.01e-5)(g, training=1)
    g = Activation('relu')(g)
    
    g = ConvSN2D(n_filters, (3,3), padding='same', kernel_initializer=init)(g)
    # g = BatchNormalization(momentum=0.9, epsilon=1.01e-5)(g, training=1)
    
    g = Add()([g, input_layer])
    return g

Generator Model

In [None]:
def define_generator(image_shape=(128,128,3), n_resnet=6):
    # weight initialization
    init = 'glorot_uniform'
    # image input
    in_image = Input(shape=image_shape)
    
    g = ConvSN2D(64, (7,7), padding='same', kernel_initializer=init)(in_image)
    # g = BatchNormalization(momentum=0.9, epsilon=1.01e-5)(g, training=1)
    g = Activation('relu')(g)
    
    g = ConvSN2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
    # g = BatchNormalization(momentum=0.9, epsilon=1.01e-5)(g, training=1)
    g = Activation('relu')(g)
    
    g = ConvSN2D(256, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
    # g = BatchNormalization(momentum=0.9, epsilon=1.01e-5)(g, training=1)
    g = Activation('relu')(g)
    
    for _ in range(n_resnet):
        g = resnet_block(256, g)
    
    g = ConvSN2DTranspose(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
    # g = BatchNormalization(momentum=0.9, epsilon=1.01e-5)(g, training=1)
    g = Activation('relu')(g)
    
    g = ConvSN2DTranspose(64, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
    # g = BatchNormalization(momentum=0.9, epsilon=1.01e-5)(g, training=1)
    g = Activation('relu')(g)
    
    g = ConvSN2D(3, (7,7), padding='same', kernel_initializer=init)(g)
    # g = BatchNormalization(momentum=0.9, epsilon=1.01e-5)(g, training=1)
    out_image = Activation('tanh')(g)
    # define model
    model = Model(in_image, out_image)
    return model

In [None]:
Disc_A = define_discriminator(image_shape=(128,128,3))
Disc_B = define_discriminator(image_shape=(128,128,3))
Disc_A.summary()

In [None]:
Gen_AtoB = define_generator(image_shape=(128,128,3))
Gen_BtoA = define_generator(image_shape=(128,128,3))
Gen_AtoB.summary()

Define Loss Functions : Generator Loss, Discriminator Loss, Cycle Loss, Identity Loss

Define functions for training

In [None]:
loss_fn = lambda output, target : K.mean(K.abs(K.square(output-target)))

fake_pool_a = K.placeholder(shape=(None, 128, 128, 3))
fake_pool_b = K.placeholder(shape=(None, 128, 128, 3))

real_A = Gen_AtoB.inputs[0]
real_B = Gen_BtoA.inputs[0]

fake_B = Gen_AtoB.outputs[0]
fake_A = Gen_BtoA.outputs[0]

rec_A = Gen_BtoA([fake_B])
rec_B = Gen_AtoB([fake_A])

cycle_ABA = K.function([real_A], [fake_B, rec_A])
cycle_BAB = K.function([real_B], [fake_A, rec_B])


Disc_op_real_A = Disc_A([real_A])
Disc_op_fake_A = Disc_A([fake_A])
Disc_op_fakepool_A = Disc_A([fake_pool_a])

Disc_A_real_loss = loss_fn(Disc_op_real_A, K.ones_like(Disc_op_real_A))
Disc_A_fake_loss = loss_fn(Disc_op_fakepool_A, K.zeros_like(Disc_op_fakepool_A))

Disc_A_loss = Disc_A_real_loss + Disc_A_fake_loss

Disc_op_real_B = Disc_B([real_B])
Disc_op_fake_B = Disc_B([fake_B])
Disc_op_fakepool_B = Disc_B([fake_pool_b])

Disc_B_real_loss = loss_fn(Disc_op_real_B, K.ones_like(Disc_op_real_B))
Disc_B_fake_loss = loss_fn(Disc_op_fakepool_B, K.zeros_like(Disc_op_fakepool_B))

Disc_B_loss = Disc_B_real_loss + Disc_B_fake_loss

Gen_BtoA_loss = loss_fn(Disc_op_fake_A, K.ones_like(Disc_op_fake_A))
Gen_AtoB_loss = loss_fn(Disc_op_fake_B, K.ones_like(Disc_op_fake_B))

loss_cycle_A = K.mean(K.abs(rec_A-real_A))
loss_cycle_B = K.mean(K.abs(rec_B-real_B))

tot_loss_cycle = loss_cycle_A + loss_cycle_B

id_A = Gen_BtoA([real_A])
loss_id_A = K.mean(K.abs(id_A - real_A))

id_B = Gen_AtoB([real_B])
loss_id_B = K.mean(K.abs(id_B - real_B))

tot_loss_id = loss_id_A + loss_id_B

tot_loss_D = Disc_A_loss + Disc_B_loss

tot_loss_G = Gen_BtoA_loss + Gen_AtoB_loss + LAMBDACYCLE*tot_loss_cycle + LAMBDAID*tot_loss_id

weightsD = Disc_A.trainable_weights + Disc_B.trainable_weights
weightsG = Gen_AtoB.trainable_weights + Gen_BtoA.trainable_weights

training_updates_disc = Adam(lr=Disc_learningrate, beta_1=0.5).get_updates(weightsD,[],tot_loss_D)
Disc_train = K.function([real_A, real_B, fake_pool_a, fake_pool_b],[Disc_A_loss, Disc_B_loss], training_updates_disc)

training_updates_gen = Adam(lr=Gen_learningrate, beta_1=0.5).get_updates(weightsG,[], tot_loss_G)
Gen_train = K.function([real_A, real_B, fake_pool_a, fake_pool_b], [Gen_BtoA_loss, Gen_AtoB_loss, tot_loss_cycle, tot_loss_id], training_updates_gen)

Display Images

In [None]:
from IPython.display import display
from io import BytesIO
byte_io = BytesIO()
def display_image(X, rows=1, iteration=1, sv = False):
    assert X.shape[0]%rows == 0
    int_X = ((X+1.)*127.5).clip(0,255).astype('uint8')
    int_X = int_X.reshape(rows, -1, 128, 128,3).swapaxes(1,2).reshape(rows*128,-1, 3)
    img = Image.fromarray(int_X)
    display(img)
    if sv:
        img.save(currdir + "images/" + "{}.png".format(iteration),"PNG")

Generate Images

In [None]:
def gen_image(A,B, iteration = 0, sv = False):
    assert A.shape==B.shape
    def G(fn_generate, X, Y):
        r = np.array([fn_generate([X[i:i+1]]) for i in range(X.shape[0])])
#         print(r.shape)
        return r.swapaxes(0,1)[:,:,0]        
    rA = G(cycle_ABA, A, B)
    rB = G(cycle_BAB, B, A)
    arr = np.concatenate([A,B,rA[0],rB[0],rA[1],rB[1]])
    display_image(arr, 3, iteration, sv)

Train

In [None]:
import time
from IPython.display import clear_output
t0 = time.time()
MAXEPOCHS = 100
steps = 0
epoch = 0
localepoch1 = 0
localepoch2 = 0
err_Gen_BtoA = 0
err_Gen_AtoB = 0
err_Disc_A = 0
err_Disc_B = 0
err_cycle = 0
err_id = 0
display_iters = 50
fid_freq = 1000
save_freq = 7000
save_image_iters = 3000
j = 0
i = 0


fake_A_pool = []
fake_B_pool = []

while epoch < MAXEPOCHS:   
    
    if j+batch_size > len(train_A):
        j = 0
        localepoch1 += 1
        np.random.shuffle(train_A)
    A = train_A[j:j+batch_size]
    j += batch_size

    if i+batch_size > len(train_B):
        i = 0
        localepoch2 += 1
        np.random.shuffle(train_B)
    B = train_B[i:i+batch_size]
    i += batch_size
    
    epoch = min(localepoch1, localepoch2)

    tmp_fake_A = Gen_BtoA.predict(B)
    tmp_fake_B = Gen_AtoB.predict(A) 

    tmp_A = []
    tmp_B = []
    
    for img in tmp_fake_A:
        if len(fake_A_pool) < 50:  ## pool size of 50
            fake_A_pool.append(img)
            tmp_A.append(img)
        else:
            p = np.random.uniform(0, 1)
            if p > 0.5:
                random_id = randint(0, 49)
                tmp = np.copy(fake_A_pool[random_id])
                fake_A_pool[random_id] = img
                tmp_A.append(tmp)
            else:
                tmp_A.append(img)
                
    for img in tmp_fake_B:
        if len(fake_B_pool) < 50:
            fake_B_pool.append(img)
            tmp_B.append(img)
        else:
            p = np.random.uniform(0, 1)
            if p > 0.5:
                random_id = randint(0, 49)
                tmp = np.copy(fake_B_pool[random_id])
                fake_B_pool[random_id] = img
                tmp_B.append(tmp)
            else:
                tmp_B.append(img) 
    
    pool_a = np.array(tmp_A)
    pool_b = np.array(tmp_B)
           
    err_Disc_A, err_Disc_B = Disc_train([A, B, pool_a, pool_b])
  
    err_Gen_BtoA, err_Gen_AtoB, err_cycle, err_id = Gen_train([A, B, pool_a, pool_b])
    steps+=1
    
    with open(currdir + "logs/" + "losses_logs" + ".txt", "a") as f:
        f.write(str(err_Disc_A) + '\t' + str(err_Disc_B) + '\t' + str(err_Gen_BtoA) + '\t' + str(err_Gen_AtoB) + '\t' +  str(err_cycle) + '\t' + str(err_id) + '\n')
    print("err_Disc_A: {}".format(err_Disc_A) + " err_Disc_B: {}".format(err_Disc_B) + " err_Gen_BtoA: {}".format(err_Gen_BtoA) + " err_Gen_AtoB: {}".format(err_Gen_AtoB))    
        
    
    if steps%fid_freq == 0:
        print(time.time()-t0)
        print("Epoch-{}".format(epoch))
        print("err_Disc_A: {}".format(err_Disc_A) + " err_Disc_B: {}".format(err_Disc_B) + " err_Gen_BtoA: {}".format(err_Gen_BtoA) + " err_Gen_AtoB: {}".format(err_Gen_AtoB))
        idx = randint(0,len(test_B)-100)
        fid_B_data = test_B[idx:idx+100]
        fidCalculate(Gen_BtoA, fid_B_data, test_mean, test_sigma)      

    if steps%display_iters==0:
        clear_output()
        
        if j+4 > len(train_A):
            j = 0
            localepoch1 += 1
            np.random.shuffle(train_A)
        A = train_A[j:j+4]
        j += 4

        if i+4 > len(train_B):
            i = 0
            localepoch2 += 1
            np.random.shuffle(train_B)
        B = train_B[i:i+4]
        i += 4        
        
        if(steps%save_image_iters == 0):
            gen_image(A, B, steps, sv=True)
        else:
            gen_image(A,B, steps)
    
   
    if steps%save_freq == 0:
        save_models(steps, Gen_AtoB, Gen_BtoA, Disc_A, Disc_B)
            

Test

In [None]:
idx = randint(0, len(train_B))
B = train_B[idx:idx+1]
def G(fn_generate, X):
    r = np.array([fn_generate([X[i:i+1]]) for i in range(X.shape[0])])
#         print(r.shape)
    return r.swapaxes(0,1)[:,:,0]
Gen_BtoA.load_weights('C:\\Users\\parit\\Documents\\CycleGAN\\CycleGANRestNetGenSN1SA\\cezanne\\finalweights\\' + 'g_model_BtoA_weights_010001.h5')
Gen_AtoB.load_weights('C:\\Users\\parit\\Documents\\CycleGAN\\CycleGANRestNetGenSN1SA\\cezanne\\finalweights\\' + 'g_model_AtoB_weights_010001.h5')
Disc_A.load_weights('C:\\Users\\parit\\Documents\\CycleGAN\\CycleGANRestNetGenSN1SA\\cezanne\\finalweights\\' + 'd_model_A_weights_010001.h5')
Disc_B.load_weights('C:\\Users\\parit\\Documents\\CycleGAN\\CycleGANRestNetGenSN1SA\\cezanne\\finalweights\\' + 'd_model_B_weights_010001.h5')    
r_cezanne = G(cycle_BAB, B)
arr = np.concatenate([B, r_cezanne[0]])
display_image(arr, 1)

idx_fid = randint(0,len(test_B)-100)
fid_B_data = test_B[idx_fid:idx_fid+100]
print("FID:")
fidCalculate(Gen_BtoA, fid_B_data, test_mean, test_sigma)

Run this cell for generating all artistic images

In [None]:
idx = randint(0, len(train_B))
B = train_B[idx:idx+1]

        
def G(fn_generate, X):
    r = np.array([fn_generate([X[i:i+1]]) for i in range(X.shape[0])])
    return r.swapaxes(0,1)[:,:,0]

Gen_BtoA.load_weights('C:\\Users\\parit\\Documents\\CycleGAN\\CycleGANRestNetGenSN1SA\\ukiyoe\\finalweights\\' + 'g_model_BtoA_weights_069001.h5')
Gen_AtoB.load_weights('C:\\Users\\parit\\Documents\\CycleGAN\\CycleGANRestNetGenSN1SA\\ukiyoe\\finalweights\\' + 'g_model_AtoB_weights_069001.h5')
Disc_A.load_weights('C:\\Users\\parit\\Documents\\CycleGAN\\CycleGANRestNetGenSN1SA\\ukiyoe\\finalweights\\' + 'd_model_A_weights_069001.h5')
Disc_B.load_weights('C:\\Users\\parit\\Documents\\CycleGAN\\CycleGANRestNetGenSN1SA\\ukiyoe\\finalweights\\' + 'd_model_B_weights_069001.h5')
r_ukiyoe = G(cycle_BAB, B)

Gen_BtoA.load_weights('C:\\Users\\parit\\Documents\\CycleGAN\\CycleGANRestNetGenSN1SA\\vangogh\\finalweights\\' + 'g_model_BtoA_weights_024001.h5')
Gen_AtoB.load_weights('C:\\Users\\parit\\Documents\\CycleGAN\\CycleGANRestNetGenSN1SA\\vangogh\\finalweights\\' + 'g_model_AtoB_weights_024001.h5')
Disc_A.load_weights('C:\\Users\\parit\\Documents\\CycleGAN\\CycleGANRestNetGenSN1SA\\vangogh\\finalweights\\' + 'd_model_A_weights_024001.h5')
Disc_B.load_weights('C:\\Users\\parit\\Documents\\CycleGAN\\CycleGANRestNetGenSN1SA\\vangogh\\finalweights\\' + 'd_model_B_weights_024001.h5')    
r_vangogh = G(cycle_BAB, B)

Gen_BtoA.load_weights('C:\\Users\\parit\\Documents\\CycleGAN\\CycleGANRestNetGenSN1SA\\monet\\finalweights\\' + 'g_model_BtoA_weights_041202.h5')
Gen_AtoB.load_weights('C:\\Users\\parit\\Documents\\CycleGAN\\CycleGANRestNetGenSN1SA\\monet\\finalweights\\' + 'g_model_AtoB_weights_041202.h5')
Disc_A.load_weights('C:\\Users\\parit\\Documents\\CycleGAN\\CycleGANRestNetGenSN1SA\\monet\\finalweights\\' + 'd_model_A_weights_041202.h5')
Disc_B.load_weights('C:\\Users\\parit\\Documents\\CycleGAN\\CycleGANRestNetGenSN1SA\\monet\\finalweights\\' + 'd_model_B_weights_041202.h5')    
r_monet = G(cycle_BAB, B)

Gen_BtoA.load_weights('C:\\Users\\parit\\Documents\\CycleGAN\\CycleGANRestNetGenSN1SA\\cezanne\\finalweights\\' + 'g_model_BtoA_weights_010001.h5')
Gen_AtoB.load_weights('C:\\Users\\parit\\Documents\\CycleGAN\\CycleGANRestNetGenSN1SA\\cezanne\\finalweights\\' + 'g_model_AtoB_weights_010001.h5')
Disc_A.load_weights('C:\\Users\\parit\\Documents\\CycleGAN\\CycleGANRestNetGenSN1SA\\cezanne\\finalweights\\' + 'd_model_A_weights_010001.h5')
Disc_B.load_weights('C:\\Users\\parit\\Documents\\CycleGAN\\CycleGANRestNetGenSN1SA\\cezanne\\finalweights\\' + 'd_model_B_weights_010001.h5')    
r_cezanne = G(cycle_BAB, B)

arr = np.concatenate([B,r_ukiyoe[0], r_vangogh[0], r_monet[0], r_cezanne[0]])
print("   real image     Ukiyoe image     Vangogh Image     Monet image      Cezanne Image")                        
display_image(arr, 1)



References: https://github.com/tjwei/GANotebooks/blob/master/CycleGAN-keras.ipynb