# Implementing Image Segmentation

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/tutorials/images/segmentation">
    <img src="https://www.tensorflow.org/images/tf_logo_32px.png" />
    View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/images/segmentation.ipynb">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" />
    Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/tutorials/images/segmentation.ipynb">
    <img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />
    View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/images/segmentation.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

This tutorial focuses on the task of image segmentation, using a modified <a href="https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/" class="external">U-Net</a>.

## What Is image segmentation?

In an image classification task, the network assigns a label (or class) to each input image. However, suppose you want to know the shape of that object, which pixel belongs to which object, and so on. In such case you need to assign a class to each pixel of the image—this task is known as segmentation. A segmentation model returns much more detailed information about the image. Image segmentation has many applications in medical imaging, self-driving cars, and satellite imaging, just to name a few.

This tutorial uses the [Oxford-IIIT Pet Dataset](https://www.robots.ox.ac.uk/~vgg/data/pets/) ([Parkhi et al, 2012](https://www.robots.ox.ac.uk/~vgg/publications/2012/parkhi12a/parkhi12a.pdf)). The dataset consists of images of 37 pet breeds, with 200 images per breed (~100 each in the training and test splits). Each image includes the corresponding labels and pixel-wise masks. The masks are class labels for each pixel. Each pixel is given one of three categories:

- Class 1: Pixel belonging to the pet
- Class 2: Pixel bordering the pet
- Class 3: None of the above or a surrounding pixel

## Step 1: Install TensorFlow Examples
- Install the TensorFlow examples package from the specified GitHub repository using pip

**Note**: Install these packages only when using a local machine, not the Simplilearn lab

In [None]:
!pip install git+https://github.com/tensorflow/examples.git

## Step 2: Import TensorFlow and TensorFlow Datasets
- Import the TensorFlow library, which is a popular open-source framework for building and training various types of machine learning models
- Import the TensorFlow Datasets library

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

## Step 3: Enhance Image Translation and Visualization with Pix2Pix and Matplotlib
- Import the **pix2pix** module; it contains implementations of the Pix2Pix model, which is commonly used for image-to-image translation tasks like colorization and style transfer
- Import the **clear_output** function to clear the output of the IPython notebook cell, providing cleaner and updated visualizations
- Import the **matplotlib.pyplot** module, a popular plotting library in Python




In [None]:
from tensorflow_examples.models.pix2pix import pix2pix
from IPython.display import clear_output
import matplotlib.pyplot as plt

## Step 4: Download the Oxford-IIIT Pets Dataset

The dataset is [available from TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/oxford_iiit_pet). The segmentation masks are included in version 3+.

In [None]:
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

In addition, the image color values are normalized to the **[0, 1]** range. Finally, as mentioned above, the pixels in the segmentation mask are labeled  **{1, 2, 3}**. For the sake of convenience, subtract 1 from the segmentation mask, resulting in labels **{0, 1, 2}**.

## Step 5: Data Normalization for Improved Model Training
- The **normalize** function takes **input_image** and **input_mask** as inputs, performs normalization operations on them, and returns the results.
- The **input_image** is cast to **tf.float32** and divided by 255.0 to normalize the pixel values between 0 and 1.
- The **input_mask** is subtracted by 1.
- The function returns the normalized **input_image** and the updated **input_mask**.



In [None]:
def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  input_mask -= 1
  return input_image, input_mask

## Step 6: Resize Image with Mask Loading and Perform Normalization in Machine Learning Tasks
- Resize the image to (128, 128)
- Call the **normalize** function to normalize the resized **input_image** and **input_mask**
- Return the resized **input_image** and **input_mask**


In [None]:
def load_image(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(
    datapoint['segmentation_mask'],
    (128, 128),
    method = tf.image.ResizeMethod.NEAREST_NEIGHBOR,
  )

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

## Step 7: Initialize Dataset and Training Parameters for Model Training

**Note: The dataset already contains the required training and test splits, so continue to use the same splits.**

- Define variables for training a model
- Batch size is 64, which refers to the number of examples processed in one iteration during training
- Set the buffer size to 1000
- Calculate the number of steps needed to complete one epoch (a full pass through the dataset)



In [None]:
TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 64
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

## Step 8: Load Image Data and Preprocess from the Dataset

- Load the **train** split of the dataset and perform preprocessing or transformations on the images
- Enable parallel processing of the images for improved performance
- Load the **test** split of the dataset
- Perform similar preprocessing or transformations on the test images


In [None]:
train_images = dataset['train'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
test_images = dataset['test'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)


## Step 9: Perform Augmentation for Input and Label Data

- The following class performs a simple augmentation by randomly flipping an image.
Go to the [Image augmentation](data_augmentation.ipynb) tutorial to learn more.


In [None]:
class Augment(tf.keras.layers.Layer):
  def __init__(self, seed=42):
    super().__init__()
    # both use the same seed, so they'll make the same random changes.
    self.augment_inputs = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
    self.augment_labels = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)

  def call(self, inputs, labels):
    inputs = self.augment_inputs(inputs)
    labels = self.augment_labels(labels)
    return inputs, labels

## Step 10: Perform Data Preparation and Batching
Build the input pipeline, applying the augmentation after batching the inputs

In [None]:
train_batches = (
    train_images
    .cache()
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
    .repeat()
    .map(Augment())
    .prefetch(buffer_size=tf.data.AUTOTUNE))

test_batches = test_images.batch(BATCH_SIZE)

## Step 11: Perform Image Display Function for Visualization
- Visualize an image example and its corresponding mask from the dataset

In [None]:
def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()

## Step 12: Visualize the Sample Images and Masks from Training Batches
- Iterate over the first two batches of images and masks from the train_batches dataset
- **sample_image** is assigned as the first image in the batch (images[0]), and **sample_mask** is assigned as the corresponding mask (masks[0]).
- Visualize the sample_image and sample_mask using the **display** function

In [None]:
for images, masks in train_batches.take(2):
  sample_image, sample_mask = images[0], masks[0]
  display([sample_image, sample_mask])

## Define the Model
The model being used here is a modified [U-Net](https://arxiv.org/abs/1505.04597). A U-Net consists of an encoder (downsampler) and decoder (upsampler). To learn robust features and reduce the number of trainable parameters, use a pretrained model—[MobileNetV2](https://arxiv.org/abs/1801.04381)—as the encoder. For the decoder, you will use the upsample block, which is already implemented in the [pix2pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py) example in the TensorFlow Examples repository. (Check out the [pix2pix: Image-to-image translation with a conditional GAN](../generative/pix2pix.ipynb) tutorial in the notebook.)


## Step 13: Implement Feature Extraction with MobileNetV2 for Transfer Learning

- Create an instance of the MobileNetV2 model as the base_model for feature extraction
- The **input_shape** parameter defines the shape of the input images as **[128, 128, 3]**, representing 128x128 pixel RGB images.
- A list of layer_names is defined, representing the intermediate layers of the MobileNetV2 model at specific resolutions.
- By setting down_stack.trainable to False, the weights of the down_stack model are frozen, and they will not be updated during training.
- Only the additional layers added on top of this feature extraction model will be trainable.


In [None]:
base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)

layer_names = [
    'block_1_expand_relu',
    'block_3_expand_relu',
    'block_6_expand_relu',
    'block_13_expand_relu',
    'block_16_project',
]
base_model_outputs = [base_model.get_layer(name).output for name in layer_names]

down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)

down_stack.trainable = False

## Step 14: Upsample Layers for Feature Map Enhancement in Pix2Pix Architecture
- The decoder or upsampler is simply a series of upsample blocks implemented in TensorFlow examples.
- Each **pix2pix.upsample()** call increases the image size by performing an upsampling operation with a specified size and kernel size.

In [None]:
up_stack = [
    pix2pix.upsample(512, 3),
    pix2pix.upsample(256, 3),
    pix2pix.upsample(128, 3),
    pix2pix.upsample(64, 3),
]

## Step 15: Perform Comprehensive Architecture with Downsampling, Upsampling, and Skip Connections
- Define input layer
- Downsample the input image
- Reverse the skip connections
- Upsample the image and perform skip connections
- Concatenate the upsampled image and the corresponding skip connection
- Perform the final convolutional transpose layer
- Apply the final transpose convolution operation
- Create and return the model



In [None]:
def unet_model(output_channels:int):
  inputs = tf.keras.layers.Input(shape=[128, 128, 3])

  # Downsampling through the model
  skips = down_stack(inputs)
  x = skips[-1]
  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    concat = tf.keras.layers.Concatenate()
    x = concat([x, skip])

  # This is the last layer of the model
  last = tf.keras.layers.Conv2DTranspose(
      filters=output_channels, kernel_size=3, strides=2,
      padding='same')  #64x64 -> 128x128

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

**Note**:
The number of filters on the last layer is set to the number of **output_channels**. This will be one output channel per class.

## Step 16: Train the Model
- Declare the number of output classes for segmentation
- Create the U-Net model with the specified number of output channels or classes
- Compile the model with the optimizer, loss function, and metrics
- Sparse categorical cross entropy loss
- Track accuracy as a metric during training

**Note:**

- For a multiclass classification task, utilize the **tf.keras.losses.CategoricalCrossentropy loss function**.
- Set the **from_logits** argument to **True** because the labels are given as single integers, not as score vectors for each class's pixel
- When running inference, the label assigned to the pixel is the channel with the highest value. This is what the **create_mask** function is doing.


In [None]:
OUTPUT_CLASSES = 3

model = unet_model(output_channels=OUTPUT_CLASSES)
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

## Step 17: Plot the Resulting Model Architecture


In [None]:
tf.keras.utils.plot_model(model, show_shapes=True)

## Step 18: Try Out the Model to Check What It Predicts Before Training
- Convert predicted mask to a single-channel mask
- Get the channel with the highest value
- Add a new axis to make it a single-channel mask
- Return the first (and presumably only) mask in the batch


In [None]:
def create_mask(pred_mask):
  pred_mask = tf.math.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]

## Step 19: Visualize Model Predictions: Show Image Segmentation Results with Predicted Masks
- Iterate over the dataset and display predictions
- Generate a predicted mask for the current image
- Display the original image, true mask, and predicted mask
- Display predictions for a sample image
- Generate and display the predicted mask for the sample image


In [None]:
def show_predictions(dataset=None, num=1):
  if dataset:
    for image, mask in dataset.take(num):
      pred_mask = model.predict(image)
      display([image[0], mask[0], create_mask(pred_mask)])
  else:
    display([sample_image, sample_mask,
             create_mask(model.predict(sample_image[tf.newaxis, ...]))])

In [None]:
show_predictions()

**Observation:**

- Overall, the show_predictions function allows for the visual examination of image segmentation predictions, providing insights into the performance of the model on the dataset or a single sample image.

## Step 20: DisplayCallback - Callback for Visualizing Sample Predictions at the End of Each Epoch
- Clear the output to update the display
- Display sample predictions
- Print the epoch number for reference

**Note:**
- The callback defined below is used to observe how the model improves while it is trained.



In [None]:
class DisplayCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    clear_output(wait=True)
    show_predictions()
    print ('\nSample Prediction after epoch {}\n'.format(epoch+1))

## Step 21: Training and Validation with Visualized Sample Predictions - Monitor Model Performance and Display Results
- The number of epochs for training is 7.
- The number of validation sub-splits is 5.
- Train the model and store the training history



In [None]:
EPOCHS = 2
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS

model_history = model.fit(train_batches, epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=test_batches,
                          callbacks=[DisplayCallback()])

**Observation:**
- The code trains the model for a specific number of epochs, performs validation during training, and includes a callback to display sample predictions.


## Step 22: Visualize the Training and Validation Loss over Epochs
- Extract loss values from model_history
- Plot the loss values
- Plot training loss in red
- Plot validation loss as blue dots
- Set the title of the plot and the label for the x-axis and y-axis
- Set the y-axis limits
- Add a legend to the plot
- Display the plot


In [None]:
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']

plt.figure()
plt.plot(model_history.epoch, loss, 'r', label='Training loss')
plt.plot(model_history.epoch, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()

## Step 23: Make Predictions
- Display predictions for three samples from the **test_batches** dataset

**Note:**
- In the interest of saving time, the number of epochs was kept small, but it can be set to higher values to achieve more accurate results.

In [None]:
show_predictions(test_batches, 3)