# Importing libraries

In [None]:
import tensorflow as tf
print(tf.__version__)
tf.test.is_gpu_available()
tf.config.list_physical_devices('GPU')

In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals

from tensorflow_examples.models.pix2pix import pix2pix

import tensorflow_datasets as tfds
tfds.disable_progress_bar()

from IPython.display import clear_output
import matplotlib.pyplot as plt
import os
import shutil
import numpy as np
import random
from random import shuffle
from PIL import Image

from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

# Creating Image Data Generators

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
tf.__version__

'2.2.0'

In [None]:
sz = (128, 128)
BATCH_SIZE = 20

In [None]:
# note that for the directories, keras data generators expect a url that points to the folder
# that is located at the second level above the images/masks
# folder --> directory fed to keras generator
#### folder --> sub-directory, usually these stands to the number of clases for the classification model, since it's a segmentation task, it's only one folder
####### images or masks

# directories of the sugarbeet dataset
rgb_images_dir = '/home/path/to/images'
annotations_color_dir = '/home/path/to/masks'

In [None]:
# we create two instances with the same arguments
data_gen_args = dict(                    
                     horizontal_flip=True,
                     vertical_flip=True,
                     rescale=1./255,
                     validation_split=0.1
                     )

image_datagen = ImageDataGenerator(**data_gen_args)
mask_datagen = ImageDataGenerator(**data_gen_args)

#Splitting training and validation without applying same transformations to both datasets:
#https://stackoverflow.com/questions/53037510/can-flow-from-directory-get-train-and-validation-data-from-the-same-directory-in
#https://stackoverflow.com/questions/42443936/keras-split-train-test-set-when-using-imagedatagenerator

# Provide the same seed and keyword arguments to the fit and flow methods
seed = 1
train_image_generator = image_datagen.flow_from_directory(
    rgb_images_dir,
    target_size = sz,
    batch_size=BATCH_SIZE,
    class_mode=None,
    shuffle=True,
    subset='training',
    seed=seed)
train_mask_generator = mask_datagen.flow_from_directory(
    annotations_color_dir,
    target_size = sz,
    batch_size=BATCH_SIZE,
    class_mode=None,
    shuffle=True,
    subset='training',
    seed=seed)

val_image_generator = image_datagen.flow_from_directory(
    rgb_images_dir,
    target_size = sz,
    batch_size=BATCH_SIZE,
    class_mode=None,
    shuffle=True,
    subset='validation',
    seed=seed)
val_mask_generator = mask_datagen.flow_from_directory(
    annotations_color_dir,
    target_size = sz,
    batch_size=BATCH_SIZE,
    class_mode=None,
    shuffle=True,
    subset='validation',
    seed=seed)

 #about resetting test generators
#https://medium.com/@vijayabhaskar96/tutorial-image-classification-with-keras-flow-from-directory-and-generators-95f75ebe5720


Found 4224 images belonging to 1 classes.
Found 4224 images belonging to 1 classes.
Found 469 images belonging to 1 classes.
Found 469 images belonging to 1 classes.


In [None]:
def normalize(images):
  normalized_images = []
  for i in range(len(images)):
    image = images[i]
    max_val = np.ceil(np.max(image))
    image = image/max_val
    normalized_images.append(image)
  return normalized_images

In [None]:
#If this doesn't work, applying mask changes
#https://github.com/keras-team/keras-preprocessing/issues/125
def prepare_mask(masks):
  new_masks = [] 
  for i in range (np.shape(masks)[0]):
    mask_datapoint = masks[i]
    r = mask_datapoint[:,:,0]
    g = mask_datapoint[:,:,1]
    b = mask_datapoint[:,:,2]
    r[r>(150/255)]=3
    g[g>(150/255)]=1
    b[b>(150/255)]=2
    x_centroid=[]
    y_centroid=[]
    for x in range(b.shape[0]):
      for y in range(b.shape[1]):
        if b[x,y]==2:
          x_centroid.append(x)
          y_centroid.append(y)
    
    for i in range(len(x_centroid)):
      x=x_centroid[i]
      y=y_centroid[i]
      augm=1
      x_augm_plus=x+augm
      x_augm_min=x-augm
      y_augm_plus=y+augm
      y_augm_min=y-augm
      
      if(x_augm_plus<b.shape[0]):
        b[x_augm_plus,y]=2
      if(x_augm_min>0):
        b[x_augm_min,y]=2
      if(y_augm_plus<b.shape[1]):
        b[x,y_augm_plus]=2
      if(y_augm_min>0):
        b[x,y_augm_min]=2
      
      if(x_augm_plus>0 and x_augm_plus<b.shape[0] and y_augm_plus>0 and y_augm_plus<b.shape[1]):
        b[x_augm_plus,y_augm_plus]=2
      if(x_augm_plus>0 and x_augm_plus<b.shape[0] and y_augm_min>0 and y_augm_min<b.shape[1]):
        b[x_augm_plus,y_augm_min]=2
      if(x_augm_min>0 and x_augm_min<b.shape[0] and y_augm_plus>0 and y_augm_plus<b.shape[1]):
        b[x_augm_min,y_augm_plus]=2
      if(x_augm_min>0 and x_augm_min<b.shape[0] and y_augm_min>0 and y_augm_min<b.shape[1]):
        b[x_augm_min,y_augm_min]=2

    merged = np.maximum.reduce([r,g,b])
    new_masks.append(merged)
  new_masks = np.expand_dims(new_masks,3)
  return new_masks

In [None]:
def my_image_mask_generator(image_generator, mask_generator):
    train_generator = zip(image_generator, mask_generator)
    for (imgs, msks) in train_generator:
        msks_total = prepare_mask(msks)
        yield (imgs, msks_total)

In [None]:
def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ['Input Image', 'True Mask', 'Predicted Mask', 'Fourth image']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()

In [None]:
 #for joining generators
 #https://stackoverflow.com/questions/3211041/how-to-join-two-generators-in-python
 #https://stackoverflow.com/questions/49404993/keras-how-to-use-fit-generator-with-multiple-inputs/49405175#comment89680557_49405175
my_train_generator = my_image_mask_generator(train_image_generator, train_mask_generator)
my_val_generator = my_image_mask_generator(val_image_generator, val_mask_generator)

In [None]:
x,y = next(my_train_generator)

In [None]:
sample_image = x[0]
sample_mask = y[0]
display([sample_image,sample_mask])

In [None]:
z,w = next(my_val_generator)

# Define the model
The model being used here is a modified U-Net. A pretrained model (MobileNetV2) is used as the encoder. The decoder will be the upsample block already implemented in TensorFlow Examples in the [Pix2pix tutorial](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py). 

In [None]:
OUTPUT_CHANNELS = 3

In [None]:
base_model = tf.keras.applications.ResNet101V2(input_shape=[128, 128, 3], include_top=False,weights="imagenet")
layer_names = [
    'conv1_conv', #64x64
    'conv2_block3_1_relu',   # 32x32
    'conv3_block4_1_relu',   # 16x16
    'conv4_block23_1_relu',  # 8x8
    'conv5_block3_2_relu',      # 4x4
]
layers = [base_model.get_layer(name).output for name in layer_names]
# Create the feature extraction model
down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)

down_stack.trainable = False

In [None]:
tf.keras.utils.plot_model(down_stack, show_shapes=True)

The decoder/upsampler is simply a series of upsample blocks implemented in TensorFlow examples.

In [None]:
up_stack_crop = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
   pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]

In [None]:
def resnet_model(output_channels, inputs):
  #x = inputs

  # Downsampling through the model
  skips = down_stack(inputs)
  print('skips', skips)
  last_downsampled_layer = skips[-1]
  x = last_downsampled_layer
  skips = list(reversed(skips[:-1]))

  ##################################################
  ############# Crop/Weed branch ###################
  ##################################################

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack_crop, skips):
    x = up(x)
    concat = tf.keras.layers.Concatenate()
    x = concat([x, skip])

  # This is the last layer of the model
  last = tf.keras.layers.Conv2DTranspose(
      output_channels, 3, strides=2,
      padding='same')  #64x64 -> 128x128
  x = last(x)
  return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
inputs = tf.keras.layers.Input(shape=[128, 128, 3])
model = resnet_model(OUTPUT_CHANNELS, inputs)

In [None]:
tf.keras.utils.plot_model(model, show_shapes=True)

In [None]:
model.load_weights("/home/path/to/weights/epoch"+str(250)+".h5")

# Define the IoU metric

In [None]:
def mean_iou(y_true, y_pred):
  y_pred = tf.keras.backend.cast(create_mask(y_pred), 'float32')
  inter = tf.math.count_nonzero(tf.logical_and(tf.not_equal(y_true, 0), tf.equal(y_true,y_pred)))
  union = tf.math.count_nonzero(tf.add(y_true, y_pred))
  my_iou = tf.cast(inter/union, 'float32')
  return my_iou

# Compile the model
The network is trying to assign each pixel a label. In the true segmentation mask, each pixel has either a {0,1,2}. The network here is outputting three channels. Essentially, each channel is trying to learn to predict a class.

In [None]:
model.compile(optimizer="adam",
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=["accuracy",mean_iou]
              )

# Viewing the model performance before training the trainable parameters

Using the output of the network, the label assigned to the pixel is the channel with the highest value. This is what the create_mask function is doing.

In [None]:
threshold = 0.5

In [None]:
def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]

In [None]:
def show_predictions(dataset=None, num=1):
  if dataset:
    pred_x, pred_y = next(dataset)
    for i in range (num):
      image, mask = pred_x[i], pred_y[i]
      pred_mask = model.predict(image[tf.newaxis, ...])
      my_iou_crop =  mean_iou(mask, pred_mask)
      print('My Mean IoU for sample image: ', tf.keras.backend.eval(my_iou_crop))

      pred_mask = create_mask(pred_mask)
      print(np.unique(pred_mask), pred_mask.shape)
      
      display([image, mask, pred_mask])
  else:
    pred_mask = model.predict(sample_image[tf.newaxis, ...])
    print('after applying mask, uniques of predicted mask:', np.unique(pred_mask), pred_mask.shape)
    print('uniques of sample mask', np.unique(sample_mask))
    my_iou =  mean_iou(sample_mask, pred_mask)
    print('My Mean IoU for sample image: ', tf.keras.backend.eval(my_iou))

    pred_mask = create_mask(pred_mask)
    display([sample_image, sample_mask, pred_mask])

In [None]:
show_predictions()

# Define the Callback

In [None]:
LAST_EPOCH=0

class ShowPredictions(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    show_predictions()
    print ('\nSample Prediction after epoch {}\n'.format(epoch+1+LAST_EPOCH))

class SaveWeights(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    model.save_weights("/home/path/to/weights/epoch"+str(epoch+1+LAST_EPOCH)+".h5")

# Load a saved model

In [None]:
model.load_weights("/home/path/to/weights/epoch"+str(LAST_EPOCH)+".h5")

In [None]:
def build_callbacks():
  callbacks = [SaveWeights(),
               ShowPredictions()]
               
  return callbacks

# Train the model

In [None]:
STEPS_PER_EPOCH = np.floor(4224/BATCH_SIZE)
VALIDATION_STEPS = np.floor(469/BATCH_SIZE)
print(STEPS_PER_EPOCH, VALIDATION_STEPS)

In [None]:
model_history = model.fit_generator(my_train_generator,epochs=200,
                                    steps_per_epoch=STEPS_PER_EPOCH,
                                    validation_data=my_val_generator,
                                    validation_steps=VALIDATION_STEPS,
                                    callbacks=build_callbacks()
                                    )

In [None]:
loss = model_history.history['mean_iou']
val_loss = model_history.history['val_mean_iou']

epochs = range(20)

plt.figure()
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.legend()
plt.show()

# Save a model

In [None]:
tf.keras.models.save_model(model,"/home/path/to/weights/final.h5")

# Make predictions

See how the system behaves on a set of images

In [None]:
next(my_val_generator)

In [None]:
show_predictions(my_val_generator,num=10)