diff --git a/tensorflow_graphics/notebooks/inverse_rendering.ipynb b/tensorflow_graphics/notebooks/inverse_rendering.ipynb new file mode 100644 index 000000000..c9043e16f --- /dev/null +++ b/tensorflow_graphics/notebooks/inverse_rendering.ipynb @@ -0,0 +1,604 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "4WSXmzOVBn-X" + }, + "source": [ + "##### Copyright 2021 Google LLC." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "gkJ16-cKBuAK" + }, + "outputs": [], + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hGbNmaA5_si_" + }, + "source": [ + "# Inverse Rendering\n", + "\n", + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/graphics/blob/master/tensorflow_graphics/notebooks/inverse_rendering.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/notebooks/inverse_rendering.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", + " \u003c/td\u003e\n", + "\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VvLMHI88DBVs" + }, + "source": [ + "This notebook demonstrates an optimization that approximates an image of a 3D\n", + "shape under unknown camera and lighting using differentiable rendering\n", + "functions. The variables of optimization include: **camera rotation**,\n", + "**position**, and **field-of-view**, **lighting direction**, and **background\n", + "color**.\n", + "\n", + "Because the TFG rendering does not include global illumination effects such as\n", + "shadows, the output rendering will not perfectly match the input shape. To\n", + "overcome this issue, we use a robust loss based on the\n", + "[structured similarity metric](https://www.tensorflow.org/api_docs/python/tf/image/ssim).\n", + "\n", + "As demonstrated here, accurate derivatives at occlusion boundaries are critical\n", + "for the optimization to succeed. TensorFlow Graphics implements the\n", + "**rasterize-then-splat** algorithm [Cole, et al., 2021] to produce derivatives\n", + "at occlusions. Rasterization with no special treatment of occlusions is provided\n", + "for comparison; without handling occlusion boundaries, the optimization\n", + "diverges." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ppRKISWUQIeB" + }, + "source": [ + "## Setup Notebook" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pZs6dzmQsdY6" + }, + "outputs": [], + "source": [ + "#@title Install TensorFlow Graphics\n", + "%pip install tensorflow_graphics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "__t3mrMftAA2" + }, + "outputs": [], + "source": [ + "#@title Fetch the model and target image\n", + "!wget -N https://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/spot.zip\n", + "!unzip -o spot.zip\n", + "!wget -N https://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/spot.png" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "3H__1-brS0ms" + }, + "outputs": [], + "source": [ + "#@title Import modules\n", + "import math\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "import matplotlib.pyplot as plt\n", + "from PIL import Image as PilImage\n", + "import tempfile\n", + "\n", + "import tensorflow_graphics.geometry.transformation.quaternion as quat\n", + "import tensorflow_graphics.geometry.transformation.euler as euler\n", + "import tensorflow_graphics.geometry.transformation.look_at as look_at\n", + "import tensorflow_graphics.geometry.transformation.rotation_matrix_3d as rotation_matrix_3d\n", + "from tensorflow_graphics.rendering.camera import perspective\n", + "from tensorflow_graphics.rendering import triangle_rasterizer\n", + "from tensorflow_graphics.rendering import splat\n", + "\n", + "from tensorflow_graphics.rendering.texture import texture_map\n", + "from tensorflow_graphics.geometry.representation.mesh import normals as normals_module" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2TebacgwQKeG" + }, + "source": [ + "## Load the Spot model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "U98pmFE_OWtn" + }, + "outputs": [], + "source": [ + "#@title Load the mesh and texture\n", + "def load_and_flatten_obj(obj_path):\n", + " \"\"\"Loads an .obj and flattens the vertex lists into a single array.\n", + "\n", + " .obj files may contain separate lists of positions, texture coordinates, and\n", + " normals. In this case, a triangle vertex will have three values: indices into\n", + " each of the position, texture, and normal lists. This function flattens those\n", + " lists into a single vertex array by looking for unique combinations of\n", + " position, texture, and normal, adding those to list, and then reindexing the\n", + " triangles.\n", + "\n", + " This function processes only 'v', 'vt', 'vn', and 'f' .obj lines.\n", + "\n", + " Args:\n", + " obj_path: the path to the Wavefront .obj file.\n", + "\n", + " Returns:\n", + " a numpy array of vertices and a Mx3 numpy array of triangle indices.\n", + "\n", + " The vertex array will have shape Nx3, Nx5, Nx6, or Nx8, depending on whether\n", + " position, position + texture, position + normals, or\n", + " position + texture + normals are present.\n", + "\n", + " Unlike .obj, the triangle vertex indices are 0-based.\n", + " \"\"\"\n", + " VERTEX_TYPES = ['v', 'vt', 'vn']\n", + "\n", + " vertex_lists = {n: [] for n in VERTEX_TYPES}\n", + " flat_vertices_list = []\n", + " flat_vertices_indices = {}\n", + " flat_triangles = []\n", + " # Keep track of encountered vertex types.\n", + " has_type = {t: False for t in VERTEX_TYPES}\n", + "\n", + " with open(obj_path) as obj_file:\n", + " for line in iter(obj_file):\n", + " tokens = line.split()\n", + " if not tokens:\n", + " continue\n", + " line_type = tokens[0]\n", + " # We skip lines not starting with v, vt, vn, or f.\n", + " if line_type in VERTEX_TYPES:\n", + " vertex_lists[line_type].append([float(x) for x in tokens[1:]])\n", + " elif line_type == 'f':\n", + " triangle = []\n", + " for i in range(3):\n", + " # The vertex name is one of the form: 'v', 'v/vt', 'v//vn', or\n", + " # 'v/vt/vn'.\n", + " vertex_name = tokens[i + 1]\n", + " if vertex_name in flat_vertices_indices:\n", + " triangle.append(flat_vertices_indices[vertex_name])\n", + " continue\n", + " # Extract all vertex type indices ('' for unspecified).\n", + " vertex_indices = vertex_name.split('/')\n", + " while len(vertex_indices) \u003c 3:\n", + " vertex_indices.append('')\n", + " flat_vertex = []\n", + " for vertex_type, index in zip(VERTEX_TYPES, vertex_indices):\n", + " if index:\n", + " # obj triangle indices are 1 indexed, so subtract 1 here.\n", + " flat_vertex += vertex_lists[vertex_type][int(index) - 1]\n", + " has_type[vertex_type] = True\n", + " else:\n", + " # Append zeros for missing attributes.\n", + " flat_vertex += [0, 0] if vertex_type == 'vt' else [0, 0, 0]\n", + " flat_vertex_index = len(flat_vertices_list)\n", + "\n", + " flat_vertices_list.append(flat_vertex)\n", + " flat_vertices_indices[vertex_name] = flat_vertex_index\n", + " triangle.append(flat_vertex_index)\n", + " flat_triangles.append(triangle)\n", + "\n", + " # Keep only vertex types that are used in at least one vertex.\n", + " flat_vertices_array = np.float32(flat_vertices_list)\n", + " flat_vertices = flat_vertices_array[:, :3]\n", + " if has_type['vt']:\n", + " flat_vertices = np.concatenate((flat_vertices, flat_vertices_array[:, 3:5]),\n", + " axis=-1)\n", + " if has_type['vn']:\n", + " flat_vertices = np.concatenate((flat_vertices, flat_vertices_array[:, -3:]),\n", + " axis=-1)\n", + "\n", + " return flat_vertices, np.int32(flat_triangles)\n", + "\n", + "\n", + "def load_texture(texture_filename):\n", + " \"\"\"Returns a texture image loaded from a file (float32 in [0,1] range).\"\"\"\n", + " with open(texture_filename, 'rb') as f:\n", + " return np.asarray(PilImage.open(f)).astype(np.float32) / 255.0\n", + "\n", + "\n", + "spot_texture_map = load_texture('spot/spot_texture.png')\n", + "\n", + "vertices, triangles = load_and_flatten_obj('spot/spot_triangulated.obj')\n", + "vertices, uv_coords = tf.split(vertices, (3, 2), axis=-1)\n", + "normals = normals_module.vertex_normals(vertices, triangles)\n", + "print(vertices.shape, uv_coords.shape, normals.shape, triangles.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "WHDaIoh7RPSA" + }, + "outputs": [], + "source": [ + "#@title Load and display target image\n", + "from PIL import Image as PilImage\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "def show_image(image, show=True):\n", + " plt.imshow(image, origin='lower')\n", + " plt.axis('off')\n", + " if show:\n", + " plt.show()\n", + "\n", + "\n", + "with open('spot.png', 'rb') as target_file:\n", + " target_image = PilImage.open(target_file)\n", + " target_image.thumbnail([200, 200])\n", + " target_image = np.array(target_image).astype(np.float32) / 255.0\n", + " target_image = np.flipud(target_image)\n", + "\n", + "image_width = target_image.shape[1]\n", + "image_height = target_image.shape[0]\n", + "\n", + "show_image(target_image)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FrVYSAlpQoM1" + }, + "source": [ + "## Set up rendering functions and variables" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "TDDSf5YQSXVV" + }, + "outputs": [], + "source": [ + "#@title Initial variables\n", + "import math\n", + "\n", + "\n", + "def make_initial_variables():\n", + " camera_translation = tf.Variable([[0.0, 0.0, -4]])\n", + " fov = tf.Variable([40.0 * math.pi / 180.0])\n", + " quaternion = tf.Variable(\n", + " tf.expand_dims(quat.from_euler((0.0, 0.0, 0.0)), axis=0))\n", + " background_color = tf.Variable([1.0, 1.0, 1.0, 1.0])\n", + " light_direction = tf.Variable([0.5, 0.5, 1.0])\n", + " return {\n", + " 'quaternion': quaternion,\n", + " 'translation': camera_translation,\n", + " 'fov': fov,\n", + " 'background_color': background_color,\n", + " 'light_direction': light_direction\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "JEAeBBhvTAez" + }, + "outputs": [], + "source": [ + "#@title Rendering functions\n", + "\n", + "\n", + "def shade(rasterized, light_direction, ka=0.5, kd=0.5):\n", + " \"\"\"Shades the input rasterized buffer using a basic illumination model.\n", + "\n", + " Args:\n", + " rasterized: a dictionary of interpolated attribute buffers.\n", + " light_direction: a vector defining the direction of a single light.\n", + " ka: ambient lighting coefficient.\n", + " kd: diffuse lighting coefficient.\n", + "\n", + " Returns:\n", + " an RGBA buffer of shaded pixels.\n", + " \"\"\"\n", + " textured = texture_map.map_texture(rasterized['uv'][tf.newaxis, ...],\n", + " spot_texture_map)[0, ...]\n", + "\n", + " light_direction = tf.reshape(light_direction, [1, 1, 3])\n", + " light_direction = tf.math.l2_normalize(light_direction, axis=-1)\n", + " n_dot_l = tf.clip_by_value(\n", + " tf.reduce_sum(\n", + " rasterized['normals'] * light_direction, axis=2, keepdims=True), 0.0,\n", + " 1.0)\n", + " ambient = textured * ka\n", + " diffuse = textured * kd * n_dot_l\n", + " lit = ambient + diffuse\n", + "\n", + " lit_rgba = tf.concat((lit, rasterized['mask']), -1)\n", + " return lit_rgba\n", + "\n", + "\n", + "def rasterize_without_splatting(projection, image_width, image_height,\n", + " light_direction):\n", + " rasterized = triangle_rasterizer.rasterize(vertices, triangles, {\n", + " 'uv': uv_coords,\n", + " 'normals': normals\n", + " }, projection, (image_height, image_width))\n", + "\n", + " lit = shade(rasterized, light_direction)\n", + " return lit\n", + "\n", + "\n", + "def rasterize_then_splat(projection, image_width, image_height,\n", + " light_direction):\n", + " return splat.rasterize_then_splat(vertices, triangles, {\n", + " 'uv': uv_coords,\n", + " 'normals': normals\n", + " }, projection, (image_height, image_width),\n", + " lambda d: shade(d, light_direction))\n", + "\n", + "\n", + "def render_forward(variables, rasterization_func):\n", + " camera_translation = variables['translation']\n", + " eye = camera_translation\n", + " # Place the \"center\" of the scene along the Z axis from the camera.\n", + " center = tf.constant([[0.0, 0.0, 1.0]]) + camera_translation\n", + " world_up = tf.constant([[0.0, 1.0, 0.0]])\n", + "\n", + " normalized_quaternion = variables['quaternion'] / tf.norm(\n", + " variables['quaternion'], axis=1, keepdims=True)\n", + " model_rotation_3x3 = rotation_matrix_3d.from_quaternion(normalized_quaternion)\n", + " model_rotation_4x4 = tf.pad(model_rotation_3x3 - tf.eye(3),\n", + " ((0, 0), (0, 1), (0, 1))) + tf.eye(4)\n", + "\n", + " look_at_4x4 = look_at.right_handed(eye, center, world_up)\n", + " perspective_4x4 = perspective.right_handed(variables['fov'],\n", + " (image_width / image_height,),\n", + " (0.01,), (10.0,))\n", + "\n", + " projection = tf.matmul(perspective_4x4,\n", + " tf.matmul(look_at_4x4, model_rotation_4x4))\n", + "\n", + " rendered = rasterization_func(projection, image_width, image_height,\n", + " variables['light_direction'])\n", + "\n", + " background_rgba = variables['background_color']\n", + " background_rgba = tf.tile(\n", + " tf.reshape(background_rgba, [1, 1, 4]), [image_height, image_width, 1])\n", + " composited = rendered + background_rgba * (1.0 - rendered[..., 3:4])\n", + " return composited" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "9UU1JD7HPKLv" + }, + "outputs": [], + "source": [ + "#@title SSIM Loss function\n", + "def ssim_loss(target, rendered):\n", + " target_yuv = tf.compat.v2.image.rgb_to_yuv(target[..., :3])\n", + " rendered_yuv = tf.compat.v2.image.rgb_to_yuv(rendered[..., :3])\n", + " ssim = tf.compat.v2.image.ssim(target_yuv, rendered_yuv, max_val=1.0)\n", + " return 1.0 - ssim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "0yYdpz-hPOLL" + }, + "outputs": [], + "source": [ + "#@title Backwards pass\n", + "@tf.function\n", + "def render_grad(target, variables, rasterization_func):\n", + " with tf.GradientTape() as g:\n", + " rendered = render_forward(variables, rasterization_func)\n", + " loss_value = ssim_loss(target, rendered)\n", + " grads = g.gradient(loss_value, variables)\n", + " return rendered, grads, loss_value" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lB9Xq96vO3us" + }, + "source": [ + "## Run optimization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "vdcE4k8ZhaOT" + }, + "outputs": [], + "source": [ + "#@title Run gradient descent\n", + "variables = make_initial_variables()\n", + "\n", + "# Change this to rasterize to test without RtS\n", + "rasterization_mode = 'rasterize then splat' #@param [ \"rasterize then splat\", \"rasterize without splatting\"]\n", + "rasterization_func = (\n", + " rasterize_then_splat if rasterization_mode == 'rasterize then splat' else\n", + " rasterize_without_splatting)\n", + "\n", + "learning_rate = 0.02 #@param {type: \"slider\", min: 0.002, max: 0.05, step: 0.002}\n", + "start = render_forward(variables, rasterization_func)\n", + "optimizer = tf.keras.optimizers.Adam(learning_rate)\n", + "animation_images = [start.numpy()]\n", + "num_steps = 300 #@param { type: \"slider\", min: 100, max: 2000, step: 100}\n", + "\n", + "for i in range(num_steps):\n", + " current, grads, loss = render_grad(target_image, variables,\n", + " rasterization_func)\n", + " to_apply = [(grads[k], variables[k]) for k in variables.keys()]\n", + " optimizer.apply_gradients(to_apply)\n", + " if i \u003e 0 and i % 10 == 0:\n", + " animation_images.append(current.numpy())\n", + " if i % 100 == 0:\n", + " print('Loss at step {:03d}: {:.3f}'.format(i, loss.numpy()))\n", + " pass\n", + "print('Final loss {:03d}: {:.3f}'.format(i, loss.numpy()))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "dkThT6Oshf-9" + }, + "outputs": [], + "source": [ + "#@title Display results\n", + "plt.figure(figsize=[18, 6])\n", + "plt.subplot(1, 4, 1)\n", + "plt.title('Initialization')\n", + "show_image(np.clip(start, 0.0, 1.0), show=False)\n", + "plt.subplot(1, 4, 2)\n", + "plt.title('After Optimization')\n", + "show_image(np.clip(current, 0.0, 1.0), show=False)\n", + "plt.subplot(1, 4, 3)\n", + "plt.title('Target')\n", + "show_image(target_image, show=False)\n", + "plt.subplot(1, 4, 4)\n", + "plt.title('Difference')\n", + "show_image(current[..., 0] - target_image[..., 0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "aQwoCByjhh44" + }, + "outputs": [], + "source": [ + "%%capture\n", + "#@title Display animation\n", + "import matplotlib.animation as animation\n", + "\n", + "def save_animation(images):\n", + " fig = plt.figure(figsize=(8, 8))\n", + " plt.axis('off')\n", + " ims = [[plt.imshow(np.flipud(np.clip(i, 0.0, 1.0)))] for i in images]\n", + " return animation.ArtistAnimation(fig, ims, interval=50, blit=True)\n", + "\n", + "anim = save_animation(animation_images)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LTG19FACaTTK" + }, + "outputs": [], + "source": [ + "from IPython.display import HTML\n", + "\n", + "HTML(anim.to_jshtml())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "r2ibcdm1hkiC" + }, + "outputs": [], + "source": [ + "#@title Display initial and optimized camera parameters\n", + "def print_camera_params(v):\n", + " print(f\"FoV (degrees): {v['fov'].numpy() * 180.0 / math.pi}\")\n", + " print(f\"Position: {v['translation'].numpy()}\")\n", + " print(\n", + " f\"Orientation (xyz angles): {euler.from_quaternion(v['quaternion']).numpy()}\"\n", + " )\n", + "\n", + "\n", + "print(\"INITIAL CAMERA:\")\n", + "print_camera_params(make_initial_variables())\n", + "print(\"\\nOPTIMIZED CAMERA:\")\n", + "print_camera_params(variables)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "name": "TFG Inverse Rendering.ipynb", + "provenance": [], + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}