##### Copyright 2021 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Inverse Rendering
<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/graphics/blob/master/tensorflow_graphics/notebooks/inverse_rendering.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/notebooks/inverse_rendering.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

This notebook demonstrates an optimization that approximates an image of a 3D shape under unknown camera and lighting using differentiable rendering functions. The variables of optimization include: **camera rotation**, **position**, and **field-of-view**, **lighting direction**, and **background color**. 

Because the TFG rendering does not include global illumination effects such as shadows, the output rendering will not perfectly match the input shape. To overcome this issue, we use a robust loss based on the [structured similarity metric](https://www.tensorflow.org/api_docs/python/tf/image/ssim).

As demonstrated here, accurate derivatives at occlusion boundaries are critical for the optimization to succeed. TensorFlow Graphics implements the **rasterize-then-splat** algorithm [Cole, et al., 2021] to produce derivatives at occlusions. Rasterization with no special treatment of occlusions is provided for comparison; without handling occlusion boundaries, the optimization diverges. 

## Setup Notebook

In [None]:
%%capture
#@title Install TensorFlow Graphics
%pip install tensorflow_graphics

In [None]:
#@title Fetch the model and target image
!wget -N https://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/spot.zip
!unzip -o spot.zip
!wget -N https://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/spot.png

In [None]:
#@title Import modules
import math
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image as PilImage
import tempfile

import tensorflow_graphics.geometry.transformation.quaternion as quat
import tensorflow_graphics.geometry.transformation.euler as euler
import tensorflow_graphics.geometry.transformation.look_at as look_at
import tensorflow_graphics.geometry.transformation.rotation_matrix_3d as rotation_matrix_3d
from tensorflow_graphics.rendering.camera import perspective
from tensorflow_graphics.rendering import triangle_rasterizer
from tensorflow_graphics.rendering import splat

from tensorflow_graphics.rendering.texture import texture_map
from tensorflow_graphics.geometry.representation.mesh import normals as normals_module

## Load the Spot model

In [None]:
#@title Load the mesh and texture
def load_and_flatten_obj(obj_path):
  """Loads an .obj and flattens the vertex lists into a single array.

  .obj files may contain separate lists of positions, texture coordinates, and
  normals. In this case, a triangle vertex will have three values: indices into
  each of the position, texture, and normal lists. This function flattens those
  lists into a single vertex array by looking for unique combinations of
  position, texture, and normal, adding those to list, and then reindexing the
  triangles.

  This function processes only 'v', 'vt', 'vn', and 'f' .obj lines.

  Args:
    obj_path: the path to the Wavefront .obj file.

  Returns:
    a numpy array of vertices and a Mx3 numpy array of triangle indices.

    The vertex array will have shape Nx3, Nx5, Nx6, or Nx8, depending on whether
    position, position + texture, position + normals, or
    position + texture + normals are present.

    Unlike .obj, the triangle vertex indices are 0-based.
  """
  VERTEX_TYPES = ['v', 'vt', 'vn']

  vertex_lists = {n: [] for n in VERTEX_TYPES}
  flat_vertices_list = []
  flat_vertices_indices = {}
  flat_triangles = []
  # Keep track of encountered vertex types.
  has_type = {t: False for t in VERTEX_TYPES}

  with open(obj_path) as obj_file:
    for line in iter(obj_file):
      tokens = line.split()
      if not tokens:
        continue
      line_type = tokens[0]
      # We skip lines not starting with v, vt, vn, or f.
      if line_type in VERTEX_TYPES:
        vertex_lists[line_type].append([float(x) for x in tokens[1:]])
      elif line_type == 'f':
        triangle = []
        for i in range(3):
          # The vertex name is one of the form: 'v', 'v/vt', 'v//vn', or
          # 'v/vt/vn'.
          vertex_name = tokens[i + 1]
          if vertex_name in flat_vertices_indices:
            triangle.append(flat_vertices_indices[vertex_name])
            continue
          # Extract all vertex type indices ('' for unspecified).
          vertex_indices = vertex_name.split('/')
          while len(vertex_indices) < 3:
            vertex_indices.append('')
          flat_vertex = []
          for vertex_type, index in zip(VERTEX_TYPES, vertex_indices):
            if index:
              # obj triangle indices are 1 indexed, so subtract 1 here.
              flat_vertex += vertex_lists[vertex_type][int(index) - 1]
              has_type[vertex_type] = True
            else:
              # Append zeros for missing attributes.
              flat_vertex += [0, 0] if vertex_type == 'vt' else [0, 0, 0]
          flat_vertex_index = len(flat_vertices_list)

          flat_vertices_list.append(flat_vertex)
          flat_vertices_indices[vertex_name] = flat_vertex_index
          triangle.append(flat_vertex_index)
        flat_triangles.append(triangle)

  # Keep only vertex types that are used in at least one vertex.
  flat_vertices_array = np.float32(flat_vertices_list)
  flat_vertices = flat_vertices_array[:, :3]
  if has_type['vt']:
    flat_vertices = np.concatenate((flat_vertices, flat_vertices_array[:, 3:5]),
                                   axis=-1)
  if has_type['vn']:
    flat_vertices = np.concatenate((flat_vertices, flat_vertices_array[:, -3:]),
                                   axis=-1)

  return flat_vertices, np.int32(flat_triangles)

def load_texture(texture_filename):
  """Returns a texture image loaded from a file (float32 in [0,1] range)."""
  with open(texture_filename, 'rb') as f:
    return np.asarray(PilImage.open(f)).astype(np.float32) / 255.0

spot_texture_map = load_texture('spot/spot_texture.png')

vertices, triangles = load_and_flatten_obj('spot/spot_triangulated.obj')
vertices, uv_coords = tf.split(vertices, (3,2), axis=-1)
normals = normals_module.vertex_normals(vertices, triangles)
print(vertices.shape, uv_coords.shape, normals.shape, triangles.shape)

In [None]:
#@title Load and display target image
from PIL import Image as PilImage
import matplotlib.pyplot as plt

def show_image(image, show=True):
  plt.imshow(image, origin='lower')
  plt.axis('off')
  if show:
    plt.show()

with open('spot.png', 'rb') as target_file:
  target_image = PilImage.open(target_file)
  target_image.thumbnail([200,200])
  target_image = np.array(target_image).astype(np.float32) / 255.0
  target_image = np.flipud(target_image)

image_width = target_image.shape[1]
image_height = target_image.shape[0]

show_image(target_image)

## Set up rendering functions and variables

In [None]:
#@title Initial variables
import math

def make_initial_variables():
  camera_translation = tf.Variable([[0.0, 0.0, -4]])
  fov = tf.Variable([40.0 * math.pi / 180.0])
  quaternion = tf.Variable(tf.expand_dims(
      quat.from_euler((0.0, 0.0, 0.0)), axis=0))
  background_color = tf.Variable([1.0, 1.0, 1.0, 1.0])
  light_direction = tf.Variable([0.5, 0.5, 1.0])
  return {
      'quaternion': quaternion,
      'translation': camera_translation,
      'fov': fov,
      'background_color': background_color,
      'light_direction': light_direction
  }


In [None]:
#@title Rendering functions


def shade(rasterized, light_direction, ka=0.5, kd=0.5):
  """Shades the input rasterized buffer using a basic illumination model.

  Args:
    rasterized: a dictionary of interpolated attribute buffers.
    light_direction: a vector defining the direction of a single light.
    ka: ambient lighting coefficient.
    kd: diffuse lighting coefficient.

  Returns:
    an RGBA buffer of shaded pixels.
  """
  textured = texture_map.map_texture(rasterized['uv'][tf.newaxis, ...],
                                     spot_texture_map)[0, ...]

  light_direction = tf.reshape(light_direction, [1, 1, 3])
  light_direction = tf.math.l2_normalize(light_direction, axis=-1)
  n_dot_l = tf.clip_by_value(
      tf.reduce_sum(
          rasterized['normals'] * light_direction, axis=2, keepdims=True), 0.0,
      1.0)
  ambient = textured * ka
  diffuse = textured * kd * n_dot_l
  lit = ambient + diffuse

  lit_rgba = tf.concat((lit, rasterized['mask']), -1)
  return lit_rgba


def rasterize_without_splatting(projection, image_width, image_height,
                                light_direction):
  rasterized = triangle_rasterizer.rasterize(vertices, triangles, {
      'uv': uv_coords,
      'normals': normals
  }, projection, (image_height, image_width))

  lit = shade(rasterized, light_direction)
  return lit


def rasterize_then_splat(projection, image_width, image_height,
                         light_direction):
  return splat.rasterize_then_splat(vertices, triangles, {
      'uv': uv_coords,
      'normals': normals
  }, projection, (image_height, image_width),
                                    lambda d: shade(d, light_direction))


def render_forward(variables, rasterization_func):
  camera_translation = variables['translation']
  eye = camera_translation
  # Place the "center" of the scene along the Z axis from the camera.
  center = tf.constant([[0.0, 0.0, 1.0]]) + camera_translation
  world_up = tf.constant([[0.0, 1.0, 0.0]])

  normalized_quaternion = variables['quaternion'] / tf.norm(
      variables['quaternion'], axis=1, keepdims=True)
  model_rotation_3x3 = rotation_matrix_3d.from_quaternion(normalized_quaternion)
  model_rotation_4x4 = tf.pad(model_rotation_3x3 - tf.eye(3),
                              ((0, 0), (0, 1), (0, 1))) + tf.eye(4)

  look_at_4x4 = look_at.right_handed(eye, center, world_up)
  perspective_4x4 = perspective.right_handed(variables['fov'],
                                             (image_width / image_height,),
                                             (0.01,), (10.0,))

  projection = tf.matmul(perspective_4x4,
                         tf.matmul(look_at_4x4, model_rotation_4x4))

  rendered = rasterization_func(projection, image_width, image_height,
                                variables['light_direction'])

  background_rgba = variables['background_color']
  background_rgba = tf.tile(
      tf.reshape(background_rgba, [1, 1, 4]), [image_height, image_width, 1])
  composited = rendered + background_rgba * (1.0 - rendered[..., 3:4])
  return composited

In [None]:
#@title Loss function
def ssim_loss(target, rendered):
  target_yuv = tf.compat.v2.image.rgb_to_yuv(target[..., :3])
  rendered_yuv = tf.compat.v2.image.rgb_to_yuv(rendered[..., :3])
  ssim = tf.compat.v2.image.ssim(target_yuv, rendered_yuv, max_val=1.0)
  return 1.0 - ssim

In [None]:
#@title Backwards pass
@tf.function
def render_grad(target, variables, rasterization_func):
  with tf.GradientTape() as g:
    rendered = render_forward(variables, rasterization_func)
    loss_value = ssim_loss(target, rendered)
  grads = g.gradient(loss_value, variables)
  return rendered, grads, loss_value

## Run optimization

In [None]:
#@title Run gradient descent
variables = make_initial_variables()

# Change this to rasterize to test without RtS
rasterization_mode = 'rasterize then splat'  #@param [ "rasterize then splat", "rasterize without splatting"]
rasterization_func = (
    rasterize_then_splat
    if rasterization_mode == 'rasterize then splat' else rasterize_without_splatting)

learning_rate = 0.02 #@param {type: "slider", min: 0.002, max: 0.05, step: 0.002}
start = render_forward(variables, rasterization_func)
optimizer = tf.keras.optimizers.Adam(learning_rate)
animation_images = [start.numpy()]
num_steps = 300 #@param { type: "slider", min: 100, max: 2000, step: 100}

for i in range(num_steps):
  current, grads, loss = render_grad(target_image, variables, rasterization_func)
  to_apply = [(grads[k], variables[k]) for k in variables.keys()]
  optimizer.apply_gradients(to_apply)
  if i > 0 and i % 10 == 0:
    animation_images.append(current.numpy())
  if i % 100 == 0:
    print('Loss at step {:03d}: {:.3f}'.format(i, loss.numpy()))
    pass
print('Final loss {:03d}: {:.3f}'.format(i, loss.numpy()))

In [None]:
#@title Display results
plt.figure(figsize=[18,6])
plt.subplot(1,4,1)
plt.title('Initialization')
show_image(np.clip(start, 0.0, 1.0), show=False)
plt.subplot(1,4,2)
plt.title('After Optimization')
show_image(np.clip(current, 0.0, 1.0), show=False)
plt.subplot(1,4,3)
plt.title('Target')
show_image(target_image, show=False)
plt.subplot(1,4,4)
plt.title('Difference')
show_image(current[...,0] - target_image[...,0])

In [None]:
%%capture
#@title Display animation
import matplotlib.animation as animation

def save_animation(images):
  fig = plt.figure(figsize=(8, 8))
  plt.axis('off')
  ims = [[plt.imshow(np.flipud(np.clip(i, 0.0, 1.0)))] for i in images]
  return animation.ArtistAnimation(fig, ims, interval=50, blit=True)

anim = save_animation(animation_images)

In [None]:
from IPython.display import HTML
HTML(anim.to_jshtml())

In [None]:
#@title Display initial and optimized camera parameters
def print_camera_params(v):
  print(f"FoV (degrees): {v['fov'].numpy() * 180.0 / math.pi}")
  print(f"Position: {v['translation'].numpy()}")
  print(f"Orientation (xyz angles): {euler.from_quaternion(v['quaternion']).numpy()}")

print("INITIAL CAMERA:")
print_camera_params(make_initial_variables())
print("\nOPTIMIZED CAMERA:")
print_camera_params(variables)