TD4 : Image impainting

INSA HdF
S. Delprat

In [None]:
import tensorflow as tf
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from tqdm import tqdm
import io
import os

from IPython.display import clear_output

In [None]:
# if 1==1:
#     # if you use google drive, then update the folder according to its structure
#     baseFolder='/content/gdrive/MyDrive/Colab Notebooks/TD/TD4 Impainting'
#     from google.colab import drive
#     drive.mount("/content/gdrive")
# else:
#     # kaggle default folder for this dataset
baseFolder = '../input/inpaintingdataset/'

In [None]:
img_height = 512
img_width = 512

In [None]:
def tensor_info(Img,name='img'):
    print(name,' min : ',tf.math.reduce_min(Img).numpy(),' Max : ',tf.math.reduce_max(Img).numpy(),' type : ',Img.dtype, ' shape : ',Img.shape.as_list())    

Read an image.
1.   Resize the image if resize_needed==True
2.   Rescale the pixels value into -1,1 if rescale_symetric==True

rescale_symetric should be False for the mask such that the output is 0-1 and not (-1,1)


In [None]:
def read_image(filename,resize_needed,rescale_symetric=True):
    image = tf.io.read_file(filename)
    image = tf.image.decode_png(image, channels=3)
    if resize_needed:
        image = tf.image.resize(image, [img_height, img_width])
    else:
        image = tf.cast(image,tf.float32)
    if rescale_symetric:
        image = (image-127.5)/127.5
    else:
        image=image/255.0
    image = tf.expand_dims(image,axis=0)
    return image

# Q1) Ecrire la fonction buildDecoder

En entrée: 
+ inputShape : size (nx x nx x nz) of the input image (fixed noise prior)
+ nc : list whose every element is the number of conv filter in each block
+ display : boolean indicating wether of not the model summary should be displayed

output:
+ model


In [None]:
def buildDecoder(inputShape,nc,display=False):
    # Input
    Input = layers.Input(shape=inputShape)
    
    # Block 1
    x = layers.Conv2D(filters=nc[0], kernel_size=1, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(Input)
    x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
    x = layers.ReLU()(x)
    
    # Block 2
    x = layers.UpSampling2D(interpolation="bilinear")(x)
    x = layers.Conv2D(filters=nc[1], kernel_size=1, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(x)
    x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
    x = layers.ReLU()(x)
    
    # Block 3
    x = layers.UpSampling2D(interpolation="bilinear")(x)
    x = layers.Conv2D(filters=nc[2], kernel_size=1, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(x)
    x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
    x = layers.ReLU()(x)
    
    # Block 4
    x = layers.UpSampling2D(interpolation="bilinear")(x)
    x = layers.Conv2D(filters=nc[3], kernel_size=1, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(x)
    x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
    x = layers.ReLU()(x)
    
    # Block 5
    x = layers.UpSampling2D(interpolation="bilinear")(x)
    x = layers.Conv2D(filters=nc[4], kernel_size=1, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(x)
    x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
    x = layers.ReLU()(x)
    
    # Block 6
    x = layers.UpSampling2D(interpolation="bilinear")(x)
    x = layers.Conv2D(filters=nc[5], kernel_size=1, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(x)
    x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
    x = layers.ReLU()(x)
    
    # Final block
    x = layers.Conv2D(filters=3, kernel_size=1, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False", activation="tanh")(x)
    
    # Model
    model = tf.keras.Model(Input, x)
    
    if display:
        print(model.summary())
    
    return model

In [None]:
def display(imgTarget,img,allLoss):
    clear_output(wait=True)
    plt.figure(figsize=[20,10])
    plt.subplot(1,3,1)
    plt.imshow(img[0,:,:,:]*0.5+0.5)
    plt.axis('off')
    plt.subplot(1,3,2)
    plt.imshow(imgTarget[0,:,:,:]*0.5+0.5)
    plt.axis('off')
    plt.subplot(1,3,3)
    plt.plot(allLoss)
    plt.show()

# Q2) Define an adam optimizer with learning rate 0.01




In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)

# Q3) Define an iteration of the training loop
inputs:
+ imgInput  : network input image (noise prior)
+ imgTarget : masked image
+ imgMask   : mask
+ stddev    : standard deviation of the noise (normal distribution) to be added dynamically to the input (helps regularization of the network)

outputs:
+ loss
+ imgResult : produced image

The function compute the loss and update the networks parameters using the optimizer (already defined in Q2)

Very important : Activate the @tf.function() decorator only when the code is 100% working (this decorators compile the computation graph, so iterations are computed faster, but also makes debugging more crumblesome)

Why keeping this print('*** Tracing ***') stuff ?
=> When using the @tf.function() decorator, if the code is designed correctly, the computation graph should be computed only one or two times during program execution. When this graph is computed, the print is active, and you will see it. 

If for some reasons (typically one of the function inputs is not a tensor) the graph is computed many times, you will see all the **** tracing **** in the output. 
In that case you need to stop the program and fix it. Graph computation takes a lot of time and consumes memory. Soon or later, the memory usage will exceed the available memory and THE SESSION WILL CRASH.

In [None]:
@tf.function()

def NoiseInit(height, width):
    Img=tf.random.uniform(shape=[height,width,3], minval=-0.1, maxval=0.1, dtype=tf.float32)
    Imgs=tf.expand_dims(Img,axis=0)
    return Imgs

def trainStep(imgInput,imgTarget,imgMask,stddev):
    noise = tf.random.normal(shape=[1]+inputShape, stddev=stddev)  
    
    print('*************** Tracing *********************')
    with tf.GradientTape() as tape:
        imgResult = model(imgInput+noise, training=True)
#         print(imgInput.shape)
#         print(imgResult.shape)
        
        masked_error = tf.multiply((tf.math.add(imgTarget,-imgResult)), imgMask)
        loss = tf.reduce_mean(tf.square(masked_error))
     
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients,model.trainable_variables))
        
    return loss,imgResult

# Q4) Inpainting optimization loop
This function is the main optimization loop. 
Inputs:
+ nbiter : number of optimization iterations with the Adam solver
+ stddev : standard deviation of the gaussian noise w

Output :
+ imgResult : final optimized image
+ allLoss: list with all the losses all over the iterations

The iteration progress should be displayed using tqdm. Every 100 iteration, the display should be updated using the display function so the user can track the progress. 

In [None]:
def inpatining(nbiter,stddev):
    allLoss = []
    nsamples = 100
    
    for iter in tqdm(range(nbiter)):
        loss, imgResult = trainStep(imgInput,imgTarget,imgMask,stddev)
        allLoss.append(loss)
            
        if iter%100 == 0:
            display(imgTarget,imgResult,allLoss)
            
    return imgResult[0],allLoss

# Main program

## Q5) Read input and mask images

In [None]:
no=4
if no==1:
    # building. Pick a mask buildingMask1, buildingMask2 or buildingMask3
    filenameImg=os.path.join(baseFolder,'building.png')
    filenameMask=os.path.join(baseFolder,'buildingMask1.png')
elif no==2:
    # abstract wall paper
    filenameImg=os.path.join(baseFolder,'0011_img.png')
    filenameMask=os.path.join(baseFolder,'0011_mask.png')
elif no==3:
    # dog
    filenameImg=os.path.join(baseFolder,'0071_img.png')
    filenameMask=os.path.join(baseFolder,'0071_mask.png')
elif no==4:
    # bison
    filenameImg=os.path.join(baseFolder,'0090_img.png')
    filenameMask=os.path.join(baseFolder,'0090_mask.png')
elif no==5:
    filenameImg=os.path.join(baseFolder,'0063_img.png')
    filenameMask=os.path.join(baseFolder,'0063_mask.png')
elif no==6:
    filenameImg=os.path.join(baseFolder,'0089_img.png')
    filenameMask=os.path.join(baseFolder,'0089_mask.png')

# read 2 images with read_image : original image + mask
imgOrigine = read_image(filenameImg,True)
imgMask = read_image(filenameMask,True,False)

# create masked image
imgTarget = tf.multiply(imgOrigine,imgMask)

# Display in a figure with subplot : the original image, the mask and the target
plt.figure(figsize=(10,15))
plt.subplot(1,3,1)
plt.imshow(imgOrigine[0])
plt.axis('off')
plt.subplot(1,3,2)
plt.imshow(imgMask[0])
plt.axis('off')
plt.subplot(1,3,3)
plt.imshow(imgTarget[0])
plt.show()

## Q6) Main program
1) Générate network parameters & network
2) Générate imgNoise, the input image
3) Perform inpainting
4) Display results

NB : in Kaggle, you may get the image by opening the console tab (bottom of the screen) and right clicking on the image and chose save as..

In [None]:
stddev = 0/100 # 2/100
nbiter = 8000

nfilter = [128]*6
upsample_factor = 2**(len(nfilter)-1)
inputShape = [img_height//upsample_factor,img_width//upsample_factor,3]

imgInput = NoiseInit(inputShape[0], inputShape[1])

model = buildDecoder(inputShape,nfilter,display=False)

imgResult, allLoss = inpatining(nbiter,stddev)

# display final result
plt.figure(figsize=[15,10])
plt.imshow(imgResult)
plt.show()

In [None]:
stddev = 2/100
nbiter = 8000

nfilter = [128]*6
upsample_factor = 2**(len(nfilter)-1)
inputShape = [img_height//upsample_factor,img_width//upsample_factor,3]

imgInput = NoiseInit(inputShape[0], inputShape[1])

model = buildDecoder(inputShape,nfilter,display=False)

imgResult, allLoss = inpatining(nbiter,stddev)

# display final result
plt.figure(figsize=[15,10])
plt.imshow(imgResult)
plt.show()

## Q7)
This approach won't perform well when the mask region is very large because we are upsampling the image. Since pixels are "guessed", it's better to use for small mask so the difference between produced image and the original image will not be big. 

## Q8)
The role of 1x1 first convolutional layer with n1[0] channels is to increase the number of channels of the image without changing its actual dimention (width and height). Further more, the number of channels of this layer should be equal to the number of the next layer because the next block is a residual block. And in order to perform an identity in parallel to 3 blocks including in the residual block, we should have a equal number of channel.

## Q9) 
The structure can only work for n1[i] = n1[i-1] for the same reason as mentionned in Q8. Imagine we have n1[i] different than n1[i-1]. As in a residual block, we will perform the sum of the output of 3 convolutional blocks and the identity. If the input of the residual block has a different number of channels than the output, then the identity will have a different number of channels. Consequently, the operation cannot be performed.

## Q10) 
A colored image always has 3 or 4 channels. In our case, we use RGB so it number of channels should be 3.

## Q11) 

In [None]:
def residual_block(Input, k, n1, n2):
    identity = tf.identity(Input)
    
    # Block 1
    Input = layers.Conv2D(filters=n1, kernel_size=1, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(Input)
    Input = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(Input)
    Input = layers.ReLU()(Input)
    
    # Block 2
    Input = layers.Conv2D(filters=n2, kernel_size=k, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(Input)
    Input = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(Input)
    Input = layers.ReLU()(Input)
    
    # Block 3
    Input = layers.Conv2D(filters=n1, kernel_size=1, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(Input)
    Input = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(Input)
    Input = layers.ReLU()(Input)
    
    # Residual
    Input += identity
    
    return Input

## Q12)

In [None]:
def genereResidual2(inputShape, n1, n2, display=False):
    # Input
    Input = layers.Input(shape=inputShape)
    
    # Block 0: 1x1 convolution
    x = layers.Conv2D(filters=n1[0], kernel_size=1, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(Input)
    
    # Block 1
    x = residual_block(x, 3, n1[0], n2[0])
    
    # From Block 2 until before Final Block
    for i in range(1, len(n1)):
        x = layers.UpSampling2D(interpolation="bilinear")(x)
        x = residual_block(x, 3, n1[i], n2[i])
        
    # Final block
    x = layers.Conv2D(filters=3, kernel_size=1, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False", activation="tanh")(x)
    
    # Model
    model = tf.keras.Model(Input, x)
    
    if display:
        print(model.summary())
        
    return model

## Q13)

In [None]:
stddev = 0/100
nbiter = 8000

n1 = [128]*6
n2 = [64]*6
upsample_factor = 2**(len(n1)-1)
inputShape = [img_height//upsample_factor,img_width//upsample_factor,3]

imgInput = NoiseInit(inputShape[0], inputShape[1])

model = genereResidual2(inputShape, n1, n2, display=False)

imgResult, allLoss = inpatining(nbiter,stddev)

# display final result
plt.figure(figsize=[15,10])
plt.imshow(imgResult)
plt.show()

## Q14) "Advanced" residual block

In [None]:
def residual_block_advanced(Input, k, n1, n2):
    
    # Additional Branch
    additional_branch = tf.identity(Input)
    additional_branch = layers.Conv2D(filters=n1, kernel_size=1, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(additional_branch)
    additional_branch = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(additional_branch)
    additional_branch = layers.ReLU()(additional_branch)
    
    # Block 1
    Input = layers.Conv2D(filters=n1, kernel_size=1, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(Input)
    Input = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(Input)
    Input = layers.ReLU()(Input)
    
    # Block 2
    Input = layers.Conv2D(filters=n2, kernel_size=k, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(Input)
    Input = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(Input)
    Input = layers.ReLU()(Input)
    
    # Block 3
    Input = layers.Conv2D(filters=n1, kernel_size=1, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(Input)
    Input = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(Input)
    Input = layers.ReLU()(Input)
    
    # Residual
    Input += additional_branch
    
    return Input

In [None]:
def genereAdvancedResidual2(inputShape, n1, n2, display=False):
    # Input
    Input = layers.Input(shape=inputShape)
    
    # Block 0: 1x1 convolution
    x = layers.Conv2D(filters=n1[0], kernel_size=1, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(Input)
    
    # Block 1
    x = residual_block_advanced(x, 3, n1[0], n2[0])
    
    # From Block 2 until before Final Block
    for i in range(1, len(n1)):
        x = layers.UpSampling2D(interpolation="bilinear")(x)
        x = residual_block_advanced(x, 3, n1[i], n2[i])
        
    # Final block
    x = layers.Conv2D(filters=3, kernel_size=1, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False", activation="tanh")(x)
    
    # Model
    model = tf.keras.Model(Input, x)
    
    if display:
        print(model.summary())
        
    return model

In [None]:
stddev = 1/100
nbiter = 8000

n1 = [256//2**i for i in range(5)]
n2 = [128//2**i for i in range(5)]
upsample_factor = 2**(len(n1)-1)
inputShape = [img_height//upsample_factor,img_width//upsample_factor,3]

imgInput = NoiseInit(inputShape[0], inputShape[1])

model = genereAdvancedResidual2(inputShape, n1, n2, display=False)

imgResult, allLoss = inpatining(nbiter,stddev)

# display final result
plt.figure(figsize=[15,10])
plt.imshow(imgResult)
plt.show()

## Q15) UNet Architecture

In [None]:
def down_conv_block(x, nd, kd):
    
    # Block 1
    x = layers.Conv2D(filters=nd, kernel_size=kd, kernel_initializer=tf.keras.initializers.HeNormal(), strides=2, padding="same", use_bias="False")(x)
    x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
    x = layers.ReLU()(x)
    
    # Block 2
    x = layers.Conv2D(filters=nd, kernel_size=kd, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(x)
    x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
    x = layers.ReLU()(x)
    
    return x

def up_conv_block(x, nu, ku):
    
    # Block 1
    x = layers.Conv2D(filters=nu, kernel_size=ku, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(x)
    x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
    x = layers.ReLU()(x)
    
    # Block 2
    x = layers.Conv2D(filters=nu, kernel_size=1, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(x)
    x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
    x = layers.ReLU()(x)
    
    # Block 3
#     if first_block:
    x = layers.UpSampling2D(interpolation="bilinear")(x)
    
    return x
    
def skip_block(x, ns):
    
    # Block 1
    x = layers.Conv2D(filters=ns, kernel_size=1, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(x)
    x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
    x = layers.ReLU()(x)
    
    return x

In [None]:
def genereUNet(inputShape, nd, kd, nu, ku, ns, display=False):
    # Input
    Input = layers.Input(shape=inputShape)
    
    # Output lists:
    output_down = []
    output_skip = []
    
    x = down_conv_block(Input, nd[0], kd[0])
    output_down.append(x)
    
    for i in range(1,len(nd)):
        x = down_conv_block(x, nd[i], kd[i])
        output_down.append(x)
        
    for i in range(len(ns)):
        x = skip_block(output_down[i], ns[i])
        output_skip.append(x)
        
    x = up_conv_block(output_skip[4], nu[4], ku[4])
    
    for i in reversed(range(1,len(nu)-1)):
        x = up_conv_block(output_skip[i], nu[i], ku[i])
        x = tf.keras.layers.Concatenate()([x, output_skip[i-1]])
    
    x = up_conv_block(output_skip[0], nu[0], ku[0])
        
    # Final block
    x = layers.Conv2D(filters=3, kernel_size=1, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False", activation="tanh")(x)
    
    # Model
    model = tf.keras.Model(Input, x)
    
    if display:
        print(model.summary())
        
    return model, output_skip

In [None]:
stddev = 1/100
nbiter = 8000

nu = [128]*5
nd = [128]*5
ns = [4]*5
kd = [3]*5
ku = [5]*5

inputShape = [img_height, img_width, 3]

imgInput = NoiseInit(inputShape[0], inputShape[1])

model, output_skip = genereUNet(inputShape, nd, kd, nu, ku, ns, display=False)

imgResult, allLoss = inpatining(nbiter,stddev)

# display final result
plt.figure(figsize=[15,10])
plt.imshow(imgResult)
plt.show()

## Q16) Inception

In [None]:
def inception_block(x, n1, n3, n5, nd, np):
    
    # Branch 1
    x1 = layers.Conv2D(filters=n1, kernel_size=1, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(x)
    x1 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x1)
    x1 = layers.ReLU()(x1)
    
    # Branch 2 - Block 1
    x2 = layers.Conv2D(filters=nd, kernel_size=1, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(x)
    x2 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x2)
    x2 = layers.ReLU()(x2)
    
    # Branch 2 - Block 2
    x2 = layers.Conv2D(filters=n3, kernel_size=3, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(x2)
    x2 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x2)
    x2 = layers.ReLU()(x2)
    
    # Branch 3 - Block 1
    x3 = layers.Conv2D(filters=nd, kernel_size=1, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(x)
    x3 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x3)
    x3 = layers.ReLU()(x3)
    
    # Branch 3 - Block 2
    x3 = layers.Conv2D(filters=n5, kernel_size=3, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(x3)
    
    # Branch 3 - Block 3
    x3 = layers.Conv2D(filters=n5, kernel_size=3, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(x3)
    x3 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x3)
    x3 = layers.ReLU()(x3)
    
    # Branch 4 - Block 1
    x4 = layers.MaxPool2D(strides=1, padding="same")(x)
    
    # Branch 4 - Block 2
    x4 = layers.Conv2D(filters=np, kernel_size=1, kernel_initializer=tf.keras.initializers.HeNormal(), strides=1, padding="same", use_bias="False")(x4)
    x4 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x4)
    x4 = layers.ReLU()(x4)
    
    # Concatenation block
    x = tf.keras.layers.Concatenate()([x1, x2, x3, x4])
    
    return x