In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

In [None]:
def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  input_mask -= 1
  return input_image, input_mask

def load_image(datapoint):
  input_image = tf.image.resize(datapoint['image'], (572, 572))
  input_mask = tf.image.resize(
    datapoint['segmentation_mask'],
    (388, 388),
    method = tf.image.ResizeMethod.NEAREST_NEIGHBOR,
  )

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

In [None]:
def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()

In [None]:
TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 8
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

train_batches = (
    train_images
    .cache()
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
    .repeat()
    .prefetch(buffer_size=tf.data.AUTOTUNE))

test_batches = test_images.batch(BATCH_SIZE)

In [None]:
for images, masks in train_batches.take(2):
  sample_image, sample_mask = images[0], masks[0]
  display([sample_image, sample_mask])

In [None]:
from numpy import concatenate
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, MaxPooling2D

inputs = tf.keras.Input(shape=[572,572,3])

x = Conv2D(64,3,activation='relu',padding='valid')(inputs)
x1 = Conv2D(64,3,activation='relu',padding='valid')(x)
x = MaxPooling2D()(x1)

x = Conv2D(128,3,activation='relu',padding='valid')(x)
x2 = Conv2D(128,3,activation='relu',padding='valid')(x)
x = MaxPooling2D()(x2)

x = Conv2D(256,3,activation='relu',padding='valid')(x)
x3 = Conv2D(256,3,activation='relu',padding='valid')(x)
x = MaxPooling2D()(x3)

x = Conv2D(512,3,activation='relu',padding='valid')(x)
x4 = Conv2D(512,3,activation='relu',padding='valid')(x)
x = MaxPooling2D()(x4)

x = Conv2D(1024,3,activation='relu',padding='valid')(x)
x = Conv2D(1024,3,activation='relu',padding='valid')(x)
x = Conv2DTranspose(512,3,strides=(2,2),padding='same')(x)

x4 = tf.image.resize_with_pad(x4,56,56)
x4up = tf.concat([x4,x],3)
x = Conv2D(512,3,activation='relu',padding='valid')(x4up)
x = Conv2D(512,3,activation='relu',padding='valid')(x)
x = Conv2DTranspose(256,3,strides=(2,2),padding='same')(x)

x3 = tf.image.resize_with_pad(x3,104,104)
x3up = tf.concat([x3,x],3)
x = Conv2D(256,3,activation='relu',padding='valid')(x3up)
x = Conv2D(256,3,activation='relu',padding='valid')(x)
x = Conv2DTranspose(128,3,strides=(2,2),padding='same')(x)

x2 = tf.image.resize_with_pad(x2,200,200)
x2up = tf.concat([x2,x],3)
x = Conv2D(128,3,activation='relu',padding='valid')(x2up)
x = Conv2D(128,3,activation='relu',padding='valid')(x)
x = Conv2DTranspose(64,3,strides=(2,2),padding='same')(x)

x1 = tf.image.resize_with_pad(x1,392,392)
x1up = tf.concat([x1,x],3)
x = Conv2D(64,3,activation='relu',padding='valid')(x1up)
x = Conv2D(64,3,activation='relu',padding='valid')(x)

outputs = Conv2D(3, (1,1), padding='same',activation='softmax')(x)

model = tf.keras.Model(inputs, outputs)

In [None]:
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [None]:
from IPython.display import clear_output
def create_mask(pred_mask):
  pred_mask = tf.math.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]

def show_predictions(dataset=None, num=1):
  if dataset:
    for image, mask in dataset.take(num):
      pred_mask = model.predict(image)
      display([image[0], mask[0], create_mask(pred_mask)])
  else:
    display([sample_image, sample_mask,
             create_mask(model.predict(sample_image[tf.newaxis, ...]))])
    
class DisplayCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    clear_output(wait=True)
    show_predictions()
    print ('\nSample Prediction after epoch {}\n'.format(epoch+1))

In [None]:
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE
history = model.fit(train_batches,
                    steps_per_epoch = STEPS_PER_EPOCH,
                    epochs=100,
                    validation_data=test_batches,
                    validation_steps=VALIDATION_STEPS,
                    callbacks=[DisplayCallback()])