##### Copyright 2019 Google LLC.

In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Closed Form Matting Energy
<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/graphics/blob/master/tensorflow_graphics/notebooks/matting.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/notebooks/matting.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

## Setup & Imports
If Tensorflow Graphics is not installed on your system, the following cell can install the Tensorflow Graphics package for you.

In [0]:
#!pip install tensorflow_graphics
!git clone https://github.com/tensorflow/graphics.git
!mv graphics/tensorflow_graphics tensorflow_graphics
!rm -rf graphics

Now that Tensorflow Graphics is installed, let's import everything needed to run the demos contained in this notebook.

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow_graphics.image import matting

tf.enable_eager_execution()

## Imports the Image and Scribbles.
Download the image and scribbles.

In [0]:
# Downloads the image and scribbles
# Courtesy of Keenan Crane www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/.
!wget https://storage.googleapis.com/tensorflow-graphics/notebooks/matting/image.png
!wget https://storage.googleapis.com/tensorflow-graphics/notebooks/matting/scribbles.png

# Reads and decode images.
source = tf.read_file('image.png')
source = tf.cast(tf.io.decode_png(source), tf.float64) / 255.0
source = tf.expand_dims(source, axis=0)
scribbles = tf.read_file('scribbles.png')
scribbles = tf.cast(tf.io.decode_png(scribbles), tf.float64) / 255.0
scribbles = tf.expand_dims(scribbles, axis=0)

# Shows images
fig = plt.figure(figsize=(22, 18))
axes = fig.add_subplot(1, 2, 1)
axes.grid(False)
axes.set_title('Input image', fontsize=14)
plt.imshow(source[0, ...].numpy())
axes = fig.add_subplot(1, 2, 2)
axes.grid(False)
axes.set_title('Input scribbles', fontsize=14)
plt.imshow(scribbles[0, ...].numpy())

Extract the foreground and background constraints from the scribble image.

In [0]:
# Extracts the forgreound and background constraints from the scribble image.
foreground = tf.clip_by_value(
    tf.sign(tf.reduce_sum(scribbles - source, axis=-1, keepdims=True)), 0.0,
    1.0)
background = tf.clip_by_value(
    tf.sign(tf.reduce_sum(source - scribbles, axis=-1, keepdims=True)), 0.0,
    1.0)

# Shows foreground and background constraints.
fig = plt.figure(figsize=(22, 18))
axes = fig.add_subplot(1, 2, 1)
axes.grid('off')
axes.set_title('Foreground constraints', fontsize=14)
plt.imshow(foreground[0, ..., 0].numpy(), cmap='gray', vmin=0, vmax=1)
axes = fig.add_subplot(1, 2, 2)
axes.grid('off')
axes.set_title('Background constraints', fontsize=14)
plt.imshow(background[0, ..., 0].numpy(), cmap='gray', vmin=0, vmax=1)

## Setup & Run the Optimization

Setup the matting loss function using TensorFlow Graphics.

In [0]:
# Initializes the matte with random values.
matte_shape = tf.concat((tf.shape(source)[:-1], (1,)), axis=-1)
matte = tf.Variable(
    tf.random.uniform(
        shape=matte_shape, minval=0.0, maxval=1.0, dtype=tf.float64))
# Creates the closed form matting Laplacian
L, _ = matting.build_matrices(source)


# Function computing the loss and applying the gradient. 
@tf.function
def optimize(optimizer):
  with tf.GradientTape() as tape:
    tape.watch(matte)
    # Computes a loss enforcing the scribble constraints.
    constraints = tf.reduce_mean(
        (foreground + background) * tf.squared_difference(matte, foreground))
    # Computes the matting loss.
    smoothness = matting.loss(matte, L)
    # Sums up the constraint and matting losses.
    total_loss = 100 * constraints + smoothness
  # Computes and apply gradient to the matte.
  g = tape.gradient(total_loss, [matte])
  optimizer.apply_gradients(zip(g, (matte,)))

Run the Adam optimizer for 500 iterations. 

In [0]:
# Runs the Adam optimizer for 500 iterations.
optimizer = tf.train.AdamOptimizer(learning_rate=1.0)
nb_iterations = 500
for it in range(nb_iterations):
  optimize(optimizer)

# Displays the result.
fig = plt.figure(figsize=(22, 18))
axes = fig.add_subplot(1, 3, 1)
axes.grid('off')
axes.set_title('Input image', fontsize=14)
plt.imshow(source[0, ...].numpy())
axes = fig.add_subplot(1, 3, 2)
axes.grid('off')
axes.set_title('Input scribbles', fontsize=14)
plt.imshow(scribbles[0, ...].numpy())
axes = fig.add_subplot(1, 3, 3)
axes.grid('off')
axes.set_title('Matte', fontsize=14)
plt.imshow(matte.numpy()[0, ..., 0], cmap='gray', vmin=0, vmax=1)