## LOAD LIBRARIES

In [1]:
# ENVIRONMENT
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
import pathlib
from glob import glob

import warnings
warnings.filterwarnings("ignore")

# TENSORFLOW 2.0
import tensorflow as tf
print(tf.__version__)
print('GPU available:', tf.test.is_gpu_available())

from tensorflow.keras.models import Model
from tensorflow.python.keras.layers import Add, BatchNormalization, Conv2D, Dense, Flatten, Input, LeakyReLU, PReLU, Lambda, MaxPool2D
from tensorflow.keras.optimizers.schedules import PiecewiseConstantDecay
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy, MeanAbsoluteError, MeanSquaredError
binary_cross_entropy = BinaryCrossentropy()

from tensorflow.keras.metrics import Mean

# import tensorflow.keras.backend as K
from tensorflow.keras import backend as K
from tensorflow.python.data.experimental import AUTOTUNE

# ESSENTIAL 
import numpy as np
from sklearn.model_selection import train_test_split
import cv2

# VISUALIZER
import matplotlib.pyplot as plt
import mahotas 
import imutils
plt.ioff()
%matplotlib inline
from IPython import display

# UTILS
import time
from datetime import date
from tqdm import tqdm

2.1.0
Instructions for updating:
Use `tf.config.list_physical_devices('GPU')` instead.
GPU available: False


In [2]:
# Setting seed
tf.random.set_seed(27)
np.random.seed(27)

In [3]:
# Pre-trained VGG 
from tensorflow.python.keras.applications.vgg19 import VGG19
vgg_19 = VGG19(input_shape=(None, None, 3), weights='imagenet', include_top=False)

In [2]:
tf.compat.v1.disable_eager_execution()

NameError: name 'tf' is not defined

## LOAD DIRECTORIES

```
GAN
│   
└───data
│   │
│   └───real_images
│   │    │  
│   │    └─── tranmission
│   │    │  
│   │    └─── blended
│   │
│   └───synthetic_images
│        │  
│        └─── tranmission
│        │  
│        └─── reflection
│   
└───logs
    │
    └─── ckpts 
    │    │  
    │    └─── pretrain 
    │    │  
    │    └─── train 
    │
    └─── output
         │  
         └─── pretrain 
         │  
         └─── train 

```

In [None]:
# Dataset Directories
DATA_DIR = "./data"

BLENDED_TRAIN_DIR = os.path.join(DATA_DIR, 'real_images/blended')
TRANSMISSION_TRAIN_DIR = os.path.join(DATA_DIR, 'real_images/tranmission')

SYN_REFLECTION_TRAIN_DIR = os.path.join(DATA_DIR, 'synthetic_images/reflection')
SYN_TRANSMISSION_TRAIN_DIR = os.path.join(DATA_DIR, 'synthetic_images/tranmission')

BLENDED_TEST_DIR = os.path.join()
TRANSMISSION_TEST_DIR = os.path.join()

# IF THERE ARE SAMPLES
SAMPLE_DIR = os.path.join()

# Training Support Directories 
LOG_DIR = './logs'

CKPT_DIR  = os.path.join(LOG_DIR, 'ckpts')
OUTPUT_DIR  = os.path.join(LOG_DIR, 'output')

## PREPARE DATASET

### Configuration

In [4]:
## Image Configuration
# SCALE = 4  
# HR_SIZE = 128
# LR_SIZE = HR_SIZE//SCALE 
# CHANNEL = 3
BUFFER_SIZE = 400
IMG_WIDTH = 256
IMG_HEIGHT = 256


# Training Configuration
BATCH_SIZE = 8
BUFFER_SIZE = 400
PRETRAIN_LR = 2e-4
PRETRAIN_LR_DECAY_STEP = 20000
GAN_LR = 1e-4

# Loss Configuration 
GAN_LOSS_COEFF = 0.005
CONTENT_LOSS_COEFF = 0.01

EPS = 1e-12

In [3]:
IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

### Process Data

**Please remember to use tf.Dataset & tf.image.ImageGenerator to load images to tf dataset**

In [None]:
class DataLoader(object):
    """ Class to call a pipeline that load images from paths to Tensorflow dataset
    """
    def load_paths_from_directory(self, directory):
        return sorted(glob(os.path.join(directory, "*.{}".format('jpg'))))

In [None]:
class ImagePreprocess(object):
    """ Preprocess images before flowing into the model.
        Perform random crop into training size, random augmentation

    """   

In [None]:
def load(image_file):
    image = tf.io.read_file(image_file)
    
    
    return input_image, real_image

In [None]:
def resize(input_image, real_image, height, width):
    input_image = tf.image.resize(input_image, [height, width],
                                  method= tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    real_image = tf.image.resize(real_image, [height, width],
                                 method= tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  
    return input_image, real_image

In [None]:
def random_crop(input_image, real_image):
    stacked_image = tf.stack([input_image, real_image], axis=0)
    cropped_image = tf.image.random_crop(stacked_image, 
                                         size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
  
    return cropped_image[0], cropped_image[1]
  

In [None]:
# normalizing the images to [-1, 1]

def normalize(input_image, real_image):
    input_image = (input_image / 127.5) - 1
    real_image = (real_image / 127.5) - 1
  
    return input_image, real_image

In [None]:
def load_image_train(image_file):
    input_image, real_image = load(image_file)
    # resizing to 512 x 512 x 3
    input_image, real_image = resize(input_image, real_image, 512, 512)
    # randomly cropping to 256 x 256 x 3
    input_image, real_image = random_crop(input_image, real_image)

    input_image, real_image = normalize(input_image, real_image)
  
    return input_image, real_image

In [None]:
def load_image_test(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = resize(input_image, real_image,
                                     IMG_HEIGHT, IMG_WIDTH)
    input_image, real_image = normalize(input_image, real_image)
  
    return input_image, real_image

### Input Pipeline

Or using train_test_split

In [None]:
train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')
train_dataset = train_dataset.map(load_image_train,
                                  num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)

In [None]:
test_dataset = tf.data.Dataset.list_files(PATH+'test/*.jpg')
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(BATCH_SIZE)


### Load Examples

SOME TEST IMAGES

In [5]:
# Load test images
trans_test = cv2.imread('./Test-Images/back_2.jpg', -1)
blend_test = cv2.imread('./Test-Images/blended_2.jpg', -1)

trans_image = tf.io.read_file('./Test-Images/back_2.jpg')
trans_image = tf.image.decode_jpeg(trans_image, channels= 3)

##########
neww = 128
newh = 128

channel = 64

########## RESIZE
output_t = cv2.resize(np.float32(trans_test), (neww,newh), cv2.INTER_CUBIC)/255.0
output_b = cv2.resize(np.float32(blend_test), (neww,newh), cv2.INTER_CUBIC)/255.0

output_image = tf.image.resize(trans_image, [neww, newh],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

######### EXPAND DIM
output_t = np.expand_dims(output_t,axis=0)
output_b = np.expand_dims(output_b,axis=0)

output_image = tf.expand_dims(output_image, axis=0)

In [73]:
# input_img = output_t #output_image   #output_t
# in_image = output_t #output_image  #output_t

# for block_id in range(1, 6):
#     weight = vgg_weight_block(input_img, block_id)
#     print(block_id, weight.shape)
#     weight_resize = tf.image.resize(weight, (tf.shape(input= input_img)[1], tf.shape(input= input_img)[2]), 
#                                            method = tf.image.ResizeMethod.BILINEAR) / 255.0
#     print(block_id, weight_resize.shape)
    
    
# #     in_img = tf.concat([tf.image.resize(weight, (tf.shape(input= in_img)[1], tf.shape(input= in_img)[2]), 
# #                                            method=tf.image.ResizeMethod.BILINEAR) / 255.0, 
# #                           in_img], axis=3)
#     in_image = tf.concat([weight_resize, in_image], axis= 3)
#     print(block_id, in_image.shape)


1 (1, 128, 128, 64)
1 (1, 128, 128, 64)
1 (1, 128, 128, 67)
2 (1, 64, 64, 128)
2 (1, 128, 128, 128)
2 (1, 128, 128, 195)
3 (1, 32, 32, 256)
3 (1, 128, 128, 256)
3 (1, 128, 128, 451)
4 (1, 16, 16, 512)
4 (1, 128, 128, 512)
4 (1, 128, 128, 963)
5 (1, 8, 8, 512)
5 (1, 128, 128, 512)
5 (1, 128, 128, 1475)


## VISUALIZATION UTILITIES

In [None]:
def visualize_sample(ds):
    i = 0
    plt.figure(figsize=(8, 8))
    for lr in ds.take(4):
        plt.subplot(2, 2, i+1)
        plt.imshow((lr.numpy()[0] * 255).astype('int'))
        plt.axis('off')
        i+=1
        plt.title('sample_{}'.format(i))
    
    plt.show()
    
def visualize_ds(ds):
    for lr, hr in ds.take(1):
        plt.figure(figsize=(10, 8))
        
        count_plot=1
        for i in range(3):
            plt.subplot(3, 2, count_plot)
            plt.imshow((lr.numpy()[i] * 255).astype('int'))
            plt.axis('off')
            plt.title('LR')

            plt.subplot(3, 2, count_plot+1)
            plt.imshow((hr.numpy()[i] * 255).astype('int'))
            plt.title('HR')
            plt.axis('off')
            
            count_plot+=2
    
    plt.show()

In [None]:
# def generate_and_save_images(model, epoch, sample_ds, train_times=1, mode='train'):
#     i = 0
#     fig = plt.figure(figsize=(10,10))
#     for sample in sample_ds.take(4):
#         i+=1
#         predictions = model(sample, training=False)
#         predictions = ((predictions.numpy()[0] * 255).astype('int'))
#         plt.subplot(2, 2, i)
#         plt.imshow(predictions)
#         plt.axis('off')
    
#     path = os.path.join(OUTPUT_DIR, mode, 'img_at_{}_epoch_{:04d}.png'.format(train_times, epoch)) 
#     plt.savefig(path)
#     plt.close(fig)

In [None]:
def generate_images(model, test_input, tar):
    prediction = model(test_input, training=True)
    plt.figure(figsize=(15,15))
  
    display_list = [test_input[0], tar[0], prediction[0]]
    title = ['Input Image', 'Ground Truth', 'Predicted Image']
  
    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        # getting the pixel values between [0, 1] to plot it.
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()


## MODEL

Some introduction about the model, blah blah blah

In [6]:
def _vgg(output_layer):
    """ Get VGG19 layers as a model, from input layer
        to the chosen ouput layers
    """
    return Model(vgg_19.input, vgg_19.layers[output_layer].output)

6 blocks vggs

In [33]:
# vgg_mean = [103.939, 116.779, 123.68]

# def get_vgg_blocks():
#     # All required VGG19 layers for this model
#     block_vgg_0 = _vgg(0)   # Input, layer 0 
#     block_vgg_1 = _vgg(2)   # conv1_2, layer 2 
#     block_vgg_2 = _vgg(5)   # conv2_2, layer 5 
#     block_vgg_3 = _vgg(8)   # conv3_2, layer 8 
#     block_vgg_4 = _vgg(13)  # conv4_2, layer 13
#     block_vgg_5 = _vgg(18)  # conv5_2, layer 18

#     block_vgg_list = []
#     for i in range(0, 6):
#         block_vgg_list.append( eval('block_vgg_{}'.format(i)) )
#     return block_vgg_list

# def preprocess_vgg(in_img):
#     in_img = tf.cast(in_img * 255., dtype=tf.float32)
#     r, g, b = tf.split(in_img, 3, 3)
#     bgr = tf.concat([b - vgg_mean[0],
#                      g - vgg_mean[1],
#                      r - vgg_mean[2]], axis=3)
#     return bgr

# block_vgg_list = get_vgg_blocks()

# def vgg_weight_block(in_img, block_num):
#     """ Get weights from blocks of VGG19 pre-trained model
#         Following research paper, get the hypercolumn feature of layer 'conv_(1,2,3,4,5)_2'
#         and concatenate to input before entering Generator's layers
#     """
#     img = preprocess_vgg(in_img)
#     return block_vgg_list[block_num](img)

vgg layers from input to conv5_2

In [8]:
# Get VGG Weights
# vgg_mean = [103.939, 116.779, 123.68]

# def vgg_52(input):
#     """ Get weights from VGG19 pre-trained model
#         Following research paper, get the hypercolumn feature until layer 'conv_5_2'
#     """
#     input = tf.cast(input * 255., dtype=tf.float32)
#     r, g, b = tf.split(input, 3, 3)
#     bgr = tf.concat([b - vgg_mean[0],
#                      g - vgg_mean[1],
#                      r - vgg_mean[2]], axis=3)
#     return _vgg(18)(bgr)

In [7]:
# Preprocess images with VGG means
vgg_mean = [103.939, 116.779, 123.68]

def preprocess_vgg_means(input_img):
    """ Subtract vgg means from images
    """
#     input_img = tf.cast(input_img * 255., dtype= tf.float32)
    input_img = tf.cast(input_img * 255, dtype= tf.float32) # If tf.readfile -> img will be int type, for numpy, I've already converted above
    r, g, b = tf.split(input_img, 3, 3)
    bgr = tf.concat([b - vgg_mean[0],
                     g - vgg_mean[1],
                     r - vgg_mean[2]], axis=3)
    return bgr

In [15]:
# Get VGG Weights

def get_vgg_weights(input_img):
    """ Get weights from VGG19 pre-trained model
        Following research paper, get the hypercolumn feature until layer 'conv_5_2'
    """
    in_img = preprocess_vgg_means(input_img)
#     in_img = input_img
    
    vgg_18 = _vgg(18)
    inp = vgg_18.input
    outputs = [layer.output for layer in vgg_18.layers] 
    functors = [K.function([inp, K.learning_phase()], [out]) for out in outputs]
    
    features = {}
    for i in [0, 2, 5, 8, 13, 18]:
        features[i] = functors[i]([in_img, 0])[0]
        
    features_list = []
    for i in [0, 2, 5, 8, 13, 18]:
        features_list.append(features[i])
    return features, features_list

In [11]:
output_image.shape

TensorShape([1, 128, 128, 3])

In [12]:
img_mean = preprocess_vgg_means(output_image)
img_mean

<tf.Tensor: shape=(1, 128, 128, 3), dtype=float32, numpy=
array([[[[149.061    , 131.22101  , 121.32     ],
         [140.061    , 122.221    , 106.32     ],
         [136.061    , 114.221    ,  80.32     ],
         ...,
         [ 17.060997 ,  -0.7789993, -40.68     ],
         [ 25.060997 ,   4.2210007, -39.68     ],
         [ 59.060997 ,  39.221    ,  13.32     ]],

        [[107.061    ,  80.221    ,  24.32     ],
         [105.061    ,  79.221    ,  26.32     ],
         [105.061    ,  80.221    ,  31.32     ],
         ...,
         [ 16.060997 ,  -5.7789993, -46.68     ],
         [ 48.060997 ,  31.221    ,  -4.6800003],
         [ 69.061    ,  52.221    ,  21.32     ]],

        [[ 98.061    ,  78.221    ,  28.32     ],
         [105.061    ,  85.221    ,  39.32     ],
         [103.061    ,  81.221    ,  36.32     ],
         ...,
         [140.061    , 122.221    , 100.32     ],
         [144.061    , 124.221    , 108.32     ],
         [107.061    ,  81.221    ,  48.32    

In [14]:
features, features_list = get_vgg_weights(img_mean)

AttributeError: 'int' object has no attribute 'op'

In [11]:
vgg18 = _vgg(18)

inp = vgg18.input
outputs = [layer.output for layer in vgg18.layers] 
functors = [K.function([inp, K.learning_phase()], [out]) for out in outputs]


AttributeError: 'int' object has no attribute 'op'

--------------------------

### Generator

GENERATOR RIGHT HERE

In [16]:
# GENERATOR (aka model to train for Reflection Removal)
class Generator(object):
    def __init__(self):
        self.n_filters = 64
        self.channels = 3
        self.initializer = tf.initializers.he_normal(seed=None)
#         self.init_kernel = tf.keras.initializers.Identity(gain= 1.0)

#     def get_weight_vgg(self, in_img):
#         self.input_img = preprocess_vgg(in_img)
#         for block_id in range(1, 6):
#             weight = vgg_weight_block(self.input_img, block_id)
#             in_img = tf.concat([tf.image.resize(weight, (tf.shape(input= in_img)[1], tf.shape(input= in_img)[2]), 
#                                                    method=tf.image.ResizeMethod.BILINEAR)/255.0, in_img], axis=3)
#         return in_img

    def concat_vgg_weight(self, in_img):
        self.in_height = tf.shape(input= in_img)[1]
        self.in_width = tf.shape(input= in_img)[2] 
        _, self.weight_list = get_vgg_weights(in_img)
        
        for block_id in range(1, 6):
            weight = self.weight_list[block_id]
            weight_resize = tf.image.resize(weight, (self.in_height, self.in_width), 
                                            method = tf.image.ResizeMethod.BILINEAR) / 255.0
            in_img = tf.concat([weight_resize, in_img], axis=3)
        return in_img

    def activation_normalizer(self, layer):
        layer = tf.keras.layers.LeakyReLU(0.2)(layer)
        layer = tf.keras.layers.BatchNormalization()(layer)
        return layer


    def build_gen(self, input_shape):

        inp = tf.keras.layers.Input(shape = input_shape)

#         vgg19_features = vgg_52(inp) 
#         vgg19_features = self.get_weight_vgg(inp)
        vgg19_features = self.concat_vgg_weight(inp)

        gen = tf.keras.layers.Conv2D(filters= 64, kernel_size= [1,1], dilation_rate= 1,padding= 'same', kernel_initializer= self.initializer)(vgg19_features)
        gen = self.activation_normalizer(gen)
        gen = tf.keras.layers.Conv2D(64, kernel_size= [3,3], dilation_rate= 1, padding= 'same', kernel_initializer= self.initializer)(gen)
        gen = self.activation_normalizer(gen)
        gen = tf.keras.layers.Conv2D(64, kernel_size= [3,3], dilation_rate= 2, padding= 'same', kernel_initializer= self.initializer)(gen)
        gen = self.activation_normalizer(gen)
        gen = tf.keras.layers.Conv2D(64, kernel_size= [3,3], dilation_rate= 4, padding= 'same', kernel_initializer= self.initializer)(gen)
        gen = self.activation_normalizer(gen)
        gen = tf.keras.layers.Conv2D(64, kernel_size= [3,3], dilation_rate= 8, padding= 'same', kernel_initializer= self.initializer)(gen)
        gen = self.activation_normalizer(gen)
        gen = tf.keras.layers.Conv2D(64, kernel_size= [3,3], dilation_rate= 16, padding= 'same', kernel_initializer= self.initializer)(gen)
        gen = self.activation_normalizer(gen)
        gen = tf.keras.layers.Conv2D(64, kernel_size= [3,3], dilation_rate= 32, padding= 'same', kernel_initializer= self.initializer)(gen)
        gen = self.activation_normalizer(gen)
        gen = tf.keras.layers.Conv2D(64, kernel_size= [3,3], dilation_rate= 64, padding= 'same', kernel_initializer= self.initializer)(gen)
        gen = self.activation_normalizer(gen)
        gen = tf.keras.layers.Conv2D(64, kernel_size= [3,3], dilation_rate= 1, padding= 'same', kernel_initializer= self.initializer)(gen)
        gen = self.activation_normalizer(gen)
        # last layer
        gen = tf.keras.layers.Conv2D(6, kernel_size= [1,1], dilation_rate= 1, padding= 'same')(gen)
        gen = Model(inp, gen)
        return gen

In [17]:
generator = Generator().build_gen([256, 256, 3])

AttributeError: 'int' object has no attribute 'op'

In [75]:
generator.summary()

# tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)

Model: "model_13"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_10 (InputLayer)           [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
tf_op_layer_mul_24 (TensorFlowO [(None, 256, 256, 3) 0           input_10[0][0]                   
__________________________________________________________________________________________________
tf_op_layer_split_24 (TensorFlo [(None, 256, 256, 1) 0           tf_op_layer_mul_24[0][0]         
__________________________________________________________________________________________________
tf_op_layer_sub_72 (TensorFlowO [(None, 256, 256, 1) 0           tf_op_layer_split_24[0][2]       
___________________________________________________________________________________________

### Discriminator

In [15]:
# DISCRIMINATOR
class Discriminator(object):
    def __init__(self):
        self.filters = 64
        self.init_kernel = tf.initializers.he_normal(seed= None)
#         self.initializer = tf.random_normal_initializer(0., 0.02)
    
    def lrelu(self, x_in, a):
        x = tf.identity(x_in)
        return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x)
    
    def conv_block(self, x_in, filters, strides, batch_norm= True, relu_act= True):
#         x = tf.keras.layers.Lambda(lambda x: tf.pad(x_in, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT"))
        x_in = tf.pad(x_in, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT")
    
        x = tf.keras.layers.Conv2D(filters= filters, kernel_size= 4,
                                   strides= strides, use_bias= False, 
                                   padding="VALID",
#                                    padding="SAME", 
                                   kernel_initializer= self.init_kernel)(x_in)
#                                    kernel_initializer= self.initializer)(x_in)
        if batch_norm:
            x = BatchNormalization()(x)
        if relu_act:
            x = self.lrelu(x, 0.2)
#             x = LeakyReLU(0.2)(x)
        return x

    # with build(input_shape):
    def build_discriminator(self, input_shape):
        inp = tf.keras.layers.Input(shape=input_shape, name='input_image')
        tar = tf.keras.layers.Input(shape=input_shape, name='target_image')

        x_in = tf.keras.layers.concatenate([inp, tar])
    
#    # with build(discrim_inputs, discrim_targets):
#     def build_discriminator(self, discrim_inputs, discrim_targets):
#         input_shape = tf.keras.layers.concatenate([discrim_inputs, discrim_targets]).shape
        
#         x_in = tf.keras.layers.Input(shape= input_shape)
        
    # Discriminator Layers:     
        x = self.conv_block(x_in, self.filters, 2, False)  #Layer 1, filters = 64        
        
        x = self.conv_block(x, self.filters *2, 2) #Layer 2, filters = 128
        x = self.conv_block(x, self.filters *4, 2) #Layer 3, filters = 256
        x = self.conv_block(x, self.filters *8, 1) #Layer 4, filters = 512  
        x = self.conv_block(x, 1, 1, False, False) #Layer 5
        
        x = tf.keras.activations.sigmoid(x)
#         x = tf.keras.layers.Dense(1, activation='sigmoid')(x)
        
#         x = Model(x_in, x)
        x = Model(inputs= [inp, tar], outputs= x)
        return x


In [None]:
discriminator = Discriminator().build_discriminator(input_shape= (IMG_WIDTH, IMG_HEIGHT, 3))

### Loss functions

In [None]:
# class Loss():
#     def _perceptual_loss(self,):

# Loss functions
# #L1 loss
def compute_l1_loss(in_img, out_img):
    return tf.reduce_mean(tf.abs(in_img - out_img))

def _l1_loss_reflection(gen_refl, refl, synthetic= False):
    loss_l1_r = compute_l1_loss(gen_refl, refl) if synthetic else 0
    return loss_l1_r

# #Perceptual loss
def _perceptual_loss(gen_img, target_img):
    _, vgg_feature_fake = get_vgg_weights(gen_img)
    _, vgg_feature_real = get_vgg_weights(target_img)
    
    # compute_l1_loss(real, fake)
    p0 = compute_l1_loss(vgg_feature_real[0], vgg_feature_fake[0])             # Input, layer 0
    p1 = compute_l1_loss(vgg_feature_real[1], vgg_feature_fake[1]) /2.6        # conv1_2, layer 2   
    p2 = compute_l1_loss(vgg_feature_real[2], vgg_feature_fake[2]) /4.8        # conv2_2, layer 5 
    p3 = compute_l1_loss(vgg_feature_real[3], vgg_feature_fake[3]) /3.7        # conv3_2, layer 8 
    p4 = compute_l1_loss(vgg_feature_real[4], vgg_feature_fake[4]) /5.6      # conv4_2, layer 13 
    p5 = compute_l1_loss(vgg_feature_real[5], vgg_feature_fake[5]) *10/1.5   # conv5_2, layer 18 

    return p0+p1+p2+p3+p4+p5

In [None]:
# Exclusion loss, in gradient domain  #FLAG, not using synthetic data so do we need this?
def compute_gradient(img):
    gradx = img[:,1:,:,:] - img[:,:-1,:,:]
    grady = img[:,:,1:,:] - img[:,:,:-1,:]
    return gradx, grady

# img1: transmission layer, img2: reflection layer
def _exclusion_loss(img1, img2, level=1):
    gradx_loss=[]
    grady_loss=[]
    
#     trans_layer = img1
#     refl_layer = img2
    
    for l in range(level):
        gradx1, grady1 = compute_gradient(img1)
        gradx2, grady2 = compute_gradient(img2)
        alphax = 2.0* tf.reduce_mean(tf.abs(gradx1)) / tf.reduce_mean(tf.abs(gradx2))
        alphay = 2.0* tf.reduce_mean(tf.abs(grady1)) / tf.reduce_mean(tf.abs(grady2))
        
        gradx1_s = (tf.keras.activations.sigmoid(gradx1) *2) -1
        grady1_s = (tf.keras.activations.sigmoid(grady1) *2) -1
        gradx2_s = (tf.keras.activations.sigmoid(gradx2 *alphax) *2) -1
        grady2_s = (tf.keras.activations.sigmoid(grady2 *alphay) *2) -1

        gradx_loss.append( tf.reduce_mean( tf.multiply( tf.square(gradx1_s), tf.square(gradx2_s)) ) **0.25)
        grady_loss.append( tf.reduce_mean( tf.multiply( tf.square(grady1_s), tf.square(grady2_s)) ) **0.25)

#         img1= tf.nn.avg_pool(img1, [1,2,2,1], [1,2,2,1], padding='SAME')
#         img2= tf.nn.avg_pool(img2, [1,2,2,1], [1,2,2,1], padding='SAME')
        img1= tf.keras.layers.AveragePooling2D( pool_size= (2,2), padding='SAME')(img1)
        img2= tf.keras.layers.AveragePooling2D( pool_size= (2,2), padding='SAME')(img2)

    return gradx_loss, grady_loss

def _gradient_loss(gen_img, refl_img, synthetic= False):
    loss_gradx, loss_grady = _exclusion_loss(gen_img, refl_img, level=3)
    loss_gradxy = tf.reduce_sum(sum(loss_gradx) /3.) + tf.reduce_sum(sum(loss_grady) /3.)
    
    # loss_grad = tf.where(issyn, loss_gradxy/2.0, 0) If is synthetic, = loss_gradxy/2, else 0
    # loss_grad = loss_gradxy / 2.0
    loss_grad = loss_gradxy / 2.0 if synthetic else 0
    return loss_grad

In [1]:
# # Adversarial Loss (aka Generator loss & Discriminator loss)
# EPS = 1e-12
def _adversarial_loss(disc_real_out, disc_gen_out):
#     predict_real = disc_real_out
#     predict_fake = disc_gen_out
#     # Compute Gen & Disc loss
#     gen_loss = tf.reduce_mean(-tf.log(predict_fake + EPS))
#     disc_loss = (tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS)))) * 0.5
    
    gen_loss = tf.reduce_mean(-tf.log(disc_gen_out + EPS))
    disc_loss = (tf.reduce_mean(-(tf.log(disc_real_out + EPS) + tf.log(1 - disc_gen_out + EPS)))) * 0.5
    
    return gen_loss, disc_loss

In [None]:
# L1 loss on reflection image
## loss_l1_r = tf.where(issyn, compute_l1_loss(reflection_layer, reflection), 0)
# loss_l1_r = compute_l1_loss(reflection_layer, reflection)
loss_l1_r = _l1_loss_reflection(gen_refl, refl)

# Perceptual Loss  #FLAG, not using synthetic data so reflection loss is pretty useless here
loss_percep_t = _perceptual_loss(transmission_layer, target)
# loss_percep_r = tf.where(issyn, compute_percep_loss(reflection_layer, reflection, reuse=True), 0.)
loss_percep_r = _perceptual_loss(reflection_layer, reflection)
# loss_percep = tf.where(issyn, loss_percep_t+loss_percep_r, loss_percep_t)
loss_percep = (loss_percep_t + loss_percep_r) if synthetic else loss_percep_t
# FOR REAL IMAGES ONLY
loss_percep = loss_percep_t 


# Exclusion loss
loss_grad = _gradient_loss(gen_img, refl_img)

# TOTAL LOSS
loss = loss_l1_r + loss_percep *0.2 + loss_grad

# OR IF YOU ONLY USE REAL IMAGES FOR TRAINING
# generator_loss += content_loss + perceptual_loss
# discriminator_loss = self.loss._discriminator_loss(hr_output, sr_output) 

### Generate Synthetic Images for training

### Optimizers

In [None]:
class Optimizer(object):
  
 #      def pretrain_optimizer(self):
#         learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(PRETRAIN_LR, 
#                                                                        PRETRAIN_LR_DECAY_STEP, 
#                                                                        0.5, staircase=True)
#         pre_gen_optimizer = Adam(learning_rate=learning_rate)
        
#         return pre_gen_optimizer

    def gan_optimizer(self):
        boundaries = [50000, 100000, 200000, 300000]
        values = [GAN_LR, GAN_LR * 0.5, GAN_LR * 0.5 ** 2,
                  GAN_LR * 0.5 ** 3, GAN_LR * 0.5 ** 4]
        learning_rate = PiecewiseConstantDecay(boundaries, values)

        dis_optimizer = Adam(learning_rate=learning_rate)
        gen_optimizer = Adam(learning_rate=learning_rate)

        return dis_optimizer, gen_optimizer

In [None]:
# Setting up optimizer
generator_optimizer, discriminator_optimizer = Optimizer().gan_optimizer()

In [None]:
#FLAG, optimizers from paper

# train_vars = tf.compat.v1.trainable_variables()
# d_vars = [var for var in train_vars if 'discriminator' in var.name]
# g_vars = [var for var in train_vars if 'g_' in var.name]
# g_opt=tf.compat.v1.train.AdamOptimizer(learning_rate=0.0002).minimize(loss*100+g_loss, var_list=g_vars) # optimizer for the generator
# d_opt=tf.compat.v1.train.AdamOptimizer(learning_rate=0.0001).minimize(d_loss,var_list=d_vars) # optimizer for the discriminator

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)

### Checkpoint

In [None]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)


## METRICS

## TRAINING GAN

In [4]:
EPOCHs = 150

In [None]:
# IF is_training
def prepare_data(train_path):
    input_names=[]
    image1=[]
    image2=[]
    for dirname in train_path:
        train_t_gt = dirname + "transmission_layer/"
        train_r_gt = dirname + "reflection_layer/"
        train_b = dirname + "blended/"
        
        for root, _, fnames in sorted(os.walk(train_t_gt)):
            for fname in fnames:
                if is_image_file(fname):
                    path_input = os.path.join(train_b, fname)
                    path_output1 = os.path.join(train_t_gt, fname)
                    path_output2 = os.path.join(train_r_gt, fname)
                    
                    input_names.append(path_input)
                    image1.append(path_output1)
                    image2.append(path_output2)
    return input_names, image1, image2

In [6]:
import datetime
log_dir = LOG_DIR

summary_writer = tf.summary.create_file_writer(
  log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

NameError: name 'log_dir' is not defined

In [16]:
generator = Generator().build_gen(input_shape=(None, None, 3))
discriminator = Discriminator().build_discriminator(input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))

In [None]:
@tf.function
def train_step(input_image, target, epoch):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)  #FLAG, OUTPUT of GENERATOR IS (None, None, None, 6)
        gen_output_transmission, gen_output_reflection = tf.split(gen_output, num_or_size_splits=2, axis=3) 
    
        disc_real_output = discriminator([input_image, target], training=True)
        disc_generated_output = discriminator([input_image, gen_output_transmission], training=True)
    
#         gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
#         disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
        gen_loss, disc_loss = _adversarial_loss(disc_real_output, disc_generated_output)
        perceptual_loss = _perceptual_loss(gen_output, target)
        
        loss = perceptual_loss *0.2 #+ loss_l1_r +loss_grad
        gen_total_loss = gen_loss + loss*100 #FLAG, 100 or 100.0
  
    generator_gradients = gen_tape.gradient(gen_total_loss,
                                            generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss,
                                                 discriminator.trainable_variables)
  
    generator_optimizer.apply_gradients(zip(generator_gradients,
                                            generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                                discriminator.trainable_variables))
  
    with summary_writer.as_default():
        tf.summary.scalar('gen_total_loss', gen_total_loss, step=epoch)
        tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=epoch)
        tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch)
        tf.summary.scalar('disc_loss', disc_loss, step=epoch)


In [None]:
def fit(train_ds, epochs, test_ds):
    
    for epoch in range(epochs):
        start = time.time()

        display.clear_output(wait=True)

        for example_input, example_target in test_ds.take(1):
            generate_images(generator, example_input, example_target)
        print("Epoch: ", epoch)

    # Train
    for n, (input_image, target) in train_ds.enumerate():
        print('.', end='')
        if (n+1) % 100 == 0:
            print()
        train_step(input_image, target, epoch)
    print()

    # saving (checkpoint) the model every 20 epochs
    if (epoch + 1) % 20 == 0:
        checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                        time.time()-start))
    checkpoint.save(file_prefix = checkpoint_prefix)


### Training session

In [None]:
fit(train_dataset, EPOCHS, test_dataset)

## PREDICT