In [None]:
!pip install -q git+https://github.com/tensorflow/examples.git

In [None]:
import os
import numpy as np
import cv2
import skimage
import skimage.io

from IPython.display import clear_output
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt


import tensorflow as tf
from tensorflow_examples.models.pix2pix import pix2pix
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [None]:
def normalize(input_image):
    input_image = tf.cast(input_image, tf.float32) / 255.0
    return input_image

def normalize_mask(input_image):
    input_image[input_image > 0] = 255
    input_image = tf.cast(input_image, tf.float32) / 255.0
    return input_image

In [None]:
IMG_HEIGHT = 128
IMG_WIDTH = 128
BATCH_SIZE = 48
obj_path = '/kaggle/input/synthetic-dataset-for-object-detection/sofa_5/sofa_5/'
# dataset_path = '/kaggle/input/synthetic-dataset-for-object-detection/synthetic_dataset/sofa/sofa_v4/'

In [None]:
image_dataset = np.array([
    cv2.resize(cv2.imread(f'{obj_path}features/images/' + file), dsize=(IMG_HEIGHT, IMG_WIDTH), interpolation=cv2.INTER_CUBIC)
    for file in tqdm(os.listdir(f'{obj_path}features/images/'))])

mask_dataset = np.array([
    np.expand_dims(
        cv2.resize(
            cv2.imread(f'{obj_path}labels/images/' + file, cv2.IMREAD_GRAYSCALE),
            dsize=(IMG_HEIGHT, IMG_WIDTH),
            interpolation=cv2.INTER_CUBIC
        ),
        axis=2
    ) for file in tqdm(os.listdir(f'{obj_path}labels/images/'))])

image_for_test = np.array([
    cv2.resize(cv2.imread(f'{obj_path}/image_for_test/' + file), dsize=(IMG_HEIGHT, IMG_WIDTH), interpolation=cv2.INTER_CUBIC)
    for file in tqdm(os.listdir(f'{obj_path}/image_for_test/'))])

In [None]:
ImgDir = f'{obj_path}/image_for_test/'
image_for_test = os.listdir(f"{ImgDir}")
imgs = []
for i in image_for_test:
    o = cv2.resize(cv2.imread(f'{ImgDir}' + i),
               dsize=(IMG_HEIGHT, IMG_WIDTH),
               interpolation=cv2.INTER_CUBIC)
    o = cv2.cvtColor(o, cv2.COLOR_BGR2RGB)
    imgs.append(o)
    

In [None]:
data_gen_args = dict(
    horizontal_flip=True,
#     vertical_flip=True,
# #     channel_shift_range=100,
#     rotation_range=90,
#     width_shift_range=0.1,
#     height_shift_range=0.1,
#     zoom_range=0.2,
)

image_datagen = ImageDataGenerator(**data_gen_args, preprocessing_function=normalize, validation_split=0.3)
mask_datagen = ImageDataGenerator(**data_gen_args, preprocessing_function=normalize_mask, validation_split=0.3)
image_for_test_datagen = ImageDataGenerator(preprocessing_function=normalize)

seed = 1046527

image_datagen.fit(image_dataset[:2000], augment=True, seed=seed)
mask_datagen.fit(mask_dataset[:2000], augment=True, seed=seed)

In [None]:
train_image_generator = image_datagen.flow_from_directory(
    f'{obj_path}features',
    class_mode=None,
    seed=seed,
    batch_size=BATCH_SIZE,
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    subset='training'
)

train_mask_generator = mask_datagen.flow_from_directory(
    f'{obj_path}labels',
    class_mode=None,
    seed=seed,
    batch_size=BATCH_SIZE,
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    color_mode='grayscale',
    subset='training'
)

train_generator = zip(train_image_generator, train_mask_generator)

valid_image_generator = image_datagen.flow_from_directory(
    f'{obj_path}features',
    class_mode=None,
    seed=seed,
    batch_size=BATCH_SIZE,
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    subset='validation'
)

valid_mask_generator = mask_datagen.flow_from_directory(
    f'{obj_path}labels',
    class_mode=None,
    seed=seed,
    batch_size=BATCH_SIZE,
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    color_mode='grayscale',
    subset='validation'
)

valid_generator = zip(valid_image_generator, valid_mask_generator)

image_for_test_generator = image_for_test_datagen.flow_from_directory(
    f'{obj_path}image_for_test_dg/',
    class_mode=None,
    seed=seed,
    batch_size=1,
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    subset='training'
)

In [None]:
def display(display_list):
    plt.figure(figsize=(15, 15))

    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()

In [None]:
for image, mask in train_generator:
    sample_image, sample_mask = image[0], mask[0]
    break
display([sample_image, sample_mask])

In [None]:
base_model = tf.keras.applications.MobileNetV2(input_shape=[IMG_HEIGHT, IMG_WIDTH, 3], include_top=False, classes=2)

# Use the activations of these layers
layer_names = [
    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project',      # 4x4
]
base_model_outputs = [base_model.get_layer(name).output for name in layer_names]

# Create the feature extraction model
down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)

down_stack.trainable = False

In [None]:
up_stack = [
    pix2pix.upsample(1536, 3),  # 4x4 -> 8x8
    pix2pix.upsample(1024, 3),  # 8x8 -> 16x16
    pix2pix.upsample(512, 3),  # 16x16 -> 32x32
    pix2pix.upsample(256, 3),   # 32x32 -> 64x64
]

In [None]:
def unet_model(output_channels):
    inputs = tf.keras.layers.Input(shape=[IMG_HEIGHT, IMG_WIDTH, 3])

    # Downsampling through the model
    skips = down_stack(inputs)
    x = skips[-1]
    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])

    # This is the last layer of the model
    last = tf.keras.layers.Conv2DTranspose(
        output_channels,
        3,
        strides=2,
        activation='sigmoid',
        padding='same')  # 64x64 -> 128x128

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
from tensorflow.keras import backend as K


def jaccard_distance(y_true, y_pred, smooth=100):

    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1)
    jac = (intersection + smooth) / (sum_ - intersection + smooth)
    return (1 - jac) * smooth


def dice_coef(y_true, y_pred, smooth=1):
    intersection = K.sum(y_true * y_pred, axis=[1,2,3])
    union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
    dice = K.mean((2. * intersection + smooth)/(union + smooth), axis=0)
    return dice

In [None]:
model = unet_model(1)
model.compile(optimizer='adam',
              
#               loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              
              loss=jaccard_distance,
              metrics=['accuracy', tf.keras.metrics.MeanIoU(num_classes=2), dice_coef])

In [None]:
model.summary()

In [None]:
def create_mask(pred_mask):
    pred_mask = pred_mask[:,:,:,0]
    pred_mask = tf.round(pred_mask)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]

In [None]:
def show_predictions(dataset=None, num=1):
    if dataset:
        i = 0
        for image, mask in dataset:
            pred_mask = model.predict(image)
            display([image[0], mask[0], create_mask(pred_mask)])
            i += 1
            if i == num:
                break
    else:
        display([sample_image, sample_mask,
                 create_mask(model.predict(sample_image[tf.newaxis, ...]))])

In [None]:
def pred(img,mode, cont=False):
    color = (256, 0, 256)
    pred = create_mask(mode.predict(img[tf.newaxis, ...]))
    pred = pred[:, :,0].numpy()
    pred[pred>0] = 255
    pred = np.stack((pred,) * 3, axis=-1)
    pred = np.uint8(pred*255)
    pred = cv2.cvtColor(pred, cv2.COLOR_BGR2GRAY)
    
    contours, hierarchy = cv2.findContours(pred, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if cont:
        try:
            cv2.drawContours(pred, [max(contours, key = cv2.contourArea)], -1, cv2.FILLED, 8)
            cv2.fillPoly(pred, pts=[max(contours, key = cv2.contourArea)], color=color)
        except:
            print('Contours not found!')
            cv2.drawContours(pred, contours, -1, cv2.FILLED, 8)
            cv2.fillPoly(pred, pts=contours, color=color)
    else:
        cv2.drawContours(pred, contours, -1, cv2.FILLED, 8)
        cv2.fillPoly(pred, pts=contours, color=color)
    
    return pred

def show_fm():
    img = next(image_for_test_generator)
    img = img[0]
    plt.figure(figsize=(15, 15))

    title = ['Input Image', 'Model Mask', 'Contour Mask']
    s_pred = pred(img, model, False)
    c_pred = pred(img, model, True)
    sh = (img, s_pred,c_pred)
    
    for i in range(len(sh)):
        plt.subplot(1, len(sh), i+1)
        plt.title(title[i])
#         plt.imshow(sh[i],cmap='gray')
        plt.imshow(sh[i])
        plt.axis('off')
    plt.show()
 

In [None]:
show_predictions()

In [None]:
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        clear_output(wait=True)
        show_fm()
        show_fm() 
        show_fm() 
        show_predictions()
        print('\nSample Prediction after epoch {}\n'.format(epoch+1))

In [None]:
EPOCHS = 100
VALIDATION_STEPS = 30
STEPS_PER_EPOCH = 300

model_history = model.fit(
    train_generator,
    epochs=EPOCHS,
    steps_per_epoch=STEPS_PER_EPOCH,
    validation_steps=VALIDATION_STEPS,
    validation_data=valid_generator,
    callbacks=[DisplayCallback()],
)

In [None]:
IMG_HEIGHT, IMG_WIDTH = model.get_config()['layers'][0]['config']['batch_input_shape'][1:3]

ImgDir = f'{obj_path}/image_for_test/'
image_for_test = os.listdir(f"{ImgDir}")
imgs = []
for i in image_for_test:
    o = cv2.resize(cv2.imread(f'{ImgDir}' + i),
               dsize=(IMG_HEIGHT, IMG_WIDTH),
               interpolation=cv2.INTER_CUBIC)
    o = cv2.cvtColor(o, cv2.COLOR_BGR2RGB)
    imgs.append(o)

In [None]:
def show_masked(img):
    color = (256, 0, 256)
    a = img.copy()
    pred = a/255.
    pred = create_mask(model.predict(pred[tf.newaxis, ...]))
    pred = pred[:, :,0].numpy()
    pred = np.stack((pred,) * 3, axis=-1)
    pred = np.uint8(pred*255)
    pred = cv2.cvtColor(pred, cv2.COLOR_BGR2GRAY)

    contours, hierarchy = cv2.findContours(pred, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
#     cv2.drawContours(a, [max(contours, key = cv2.contourArea)], -1, cv2.FILLED, 8)
#     cv2.fillPoly(a, pts=[max(contours, key = cv2.contourArea)], color=color)
    cv2.fillPoly(a, pts=contours, color=color)
    return a


for i in range(len(imgs)):
    plt.figure(figsize=(8, 8))
    a = (show_masked(imgs[i]))
    plt.imshow(a)
    plt.axis('off')
    plt.show()