##### Copyright 2021 Google LLC.




In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/graphics/blob/master/tensorflow_graphics/projects/radiance_fields/TFG_tiny_nerf.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/projects/radiance_fields/TFG_tiny_nerf.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

# Setup and imports

In [None]:
%pip install tensorflow_graphics

In [None]:
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow.keras.layers as layers

import tensorflow_graphics.projects.radiance_fields.data_loaders as data_loaders
import tensorflow_graphics.projects.radiance_fields.utils as utils
import tensorflow_graphics.rendering.camera.perspective as perspective
import tensorflow_graphics.geometry.representation.ray as ray
import tensorflow_graphics.math.feature_representation as feature_rep
import tensorflow_graphics.rendering.volumetric.ray_radiance as ray_radiance

Please download the data from the original [repository](https://github.com/bmild/nerf). In this tutorial we experimented with the synthetic data (lego, ship, boat, etc) that can be found [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1). Then, you can either point to them locally (if you run a custom kernel) or upload them to the google colab.

Or you can simply create a shortcut of the shared folder and load the path directly.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
DATASET_DIR = '/content/drive/MyDrive/nerf_synthetic'

In [None]:
#@title Parameters

batch_size = 10 #@param {type:"integer"}
n_posenc_freq = 6 #@param {type:"integer"}
learning_rate = 0.0005 #@param {type:"number"}
n_filters = 256 #@param {type:"integer"}


num_epochs = 100 #@param {type:"integer"}
n_rays = 512 #@param {type:"integer"}
near = 2.0 #@param {type:"number"}
far = 6.0 #@param {type:"number"}
ray_steps = 64 #@param {type:"integer"}

# Training a NeRF network

In [None]:
#@title Load the lego dataset { form-width: "350px" }

dataset, height, width = data_loaders.load_synthetic_nerf_dataset(
    dataset_dir=DATASET_DIR,
    dataset_name='lego',
    split='train',
    scale=0.125,
    batch_size=batch_size)

In [None]:
#@title Prepare the NeRF model and optimizer { form-width: "350px" }

input_dim = n_posenc_freq * 2 * 3 + 3


def get_model():
    """Tiny NeRF network."""
    with tf.name_scope("Network/"):
      input_features = layers.Input(shape=[input_dim])
      fc0 = layers.Dense(n_filters, activation=layers.ReLU())(input_features)
      fc1 = layers.Dense(n_filters, activation=layers.ReLU())(fc0)
      fc2 = layers.Dense(n_filters, activation=layers.ReLU())(fc1)
      fc3 = layers.Dense(n_filters, activation=layers.ReLU())(fc2)
      fc4 = layers.Dense(n_filters, activation=layers.ReLU())(fc3)
      fc4 = layers.concatenate([fc4, input_features], -1)
      fc5 = layers.Dense(n_filters, activation=layers.ReLU())(fc4)
      fc6 = layers.Dense(n_filters, activation=layers.ReLU())(fc5)
      fc7 = layers.Dense(n_filters, activation=layers.ReLU())(fc6)
      rgba = layers.Dense(4)(fc7)
      return tf.keras.Model(inputs=[input_features], outputs=[rgba])


model = get_model()
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

In [None]:
# @title Set up the training procedure { form-width: "350px" }

@tf.function
def network_inference_and_rendering(ray_points, model):
  """Render the 3D ray points into rgb pixels.

  Args:
    ray_points: A tensor of shape `[A, B, C, 3]` where A is the batch size,
      B is the number of rays, C is the number of samples per ray.
    model: the NeRF model to run

  Returns:
    Two tensors of size `[A, B, 3]`.
  """
  features_xyz = feature_rep.positional_encoding(ray_points, n_posenc_freq)
  features_xyz = tf.reshape(features_xyz, [-1, tf.shape(features_xyz)[-1]])
  rgba = model([features_xyz])
  target_shape = tf.concat([tf.shape(ray_points)[:-1], [4]], axis=-1)
  rgba = tf.reshape(rgba, target_shape)
  rgb, alpha = tf.split(rgba, [3, 1], axis=-1)
  rgb = tf.sigmoid(rgb)
  alpha = tf.nn.relu(alpha)
  rgba = tf.concat([rgb, alpha], axis=-1)
  dists = utils.get_distances_between_points(ray_points)
  rgb_render, _, _ = ray_radiance.compute_radiance(rgba, dists)
  return rgb_render


@tf.function
def train_step(ray_origin, ray_direction, gt_rgb):
  """Training function for coarse and fine networks.

  Args:
    ray_origin: A tensor of shape `[A, B, 3]` where A is the batch size,
      B is the number of rays.
    ray_direction: A tensor of shape `[A, B, 3]` where A is the batch size,
      B is the number of rays.
    gt_rgb: A tensor of shape `[A, B, 3]` where A is the batch size,
      B is the number of rays.

  Returns:
    A scalar.
  """
  with tf.GradientTape() as tape:
    ray_points, _ = ray.sample_1d(
        ray_origin,
        ray_direction,
        near=near,
        far=far,
        n_samples=ray_steps,
        strategy='stratified')

    rgb = network_inference_and_rendering(ray_points, model)
    total_loss = utils.l2_loss(rgb, gt_rgb)
  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  return total_loss

In [None]:
for epoch in range(0, num_epochs):
  epoch_loss = 0.0
  for image, focal, principal_point, transform_matrix in dataset:
    # Prepare the rays
    random_rays, random_pixels_xy = perspective.random_rays(focal,
                                                            principal_point,
                                                            height,
                                                            width,
                                                            n_rays)
    # TF-Graphics camera rays to NeRF world rays
    random_rays = utils.change_coordinate_system(random_rays,
                                                 (0., 0., 0.),
                                                 (1., -1., -1.))
    rays_org, rays_dir = utils.camera_rays_from_transformation_matrix(
        random_rays,
        transform_matrix)
    random_pixels_yx = tf.reverse(random_pixels_xy, axis=[-1])
    pixels = tf.gather_nd(image, random_pixels_yx, batch_dims=1)
    pixels_rgb, _ = tf.split(pixels, [3, 1], axis=-1)
    dist_loss = train_step(rays_org, rays_dir, pixels_rgb)
    epoch_loss += dist_loss
  print('Epoch {0} loss: {1:.3f}'.format(epoch, epoch_loss))

# Testing

In [None]:
# @title Load the test data

test_dataset, height, width = data_loaders.load_synthetic_nerf_dataset(
    dataset_dir=DATASET_DIR,
    dataset_name='lego',
    split='val',
    scale=0.125,
    batch_size=1,
    shuffle=False)

In [None]:
for testimg, focal, principal_point, transform_matrix in test_dataset.take(1):
  testimg = testimg[0, :, :, :3]

  img_rays, _ = perspective.random_patches(
      focal,
      principal_point,
      height,
      width,
      patch_height=height,
      patch_width=width,
      scale=1.0)

  # Break the test image into lines, so we don't run out of memory
  batch_rays = tf.split(img_rays, height, axis=1)
  output = []
  for random_rays in batch_rays:
    random_rays = utils.change_coordinate_system(random_rays,
                                                  (0., 0., 0.),
                                                  (1., -1., -1.))
    rays_org, rays_dir = utils.camera_rays_from_transformation_matrix(
        random_rays,
        transform_matrix)
    ray_points, _ = ray.sample_1d(
        rays_org,
        rays_dir,
        near=near,
        far=far,
        n_samples=ray_steps,
        strategy='stratified')
    rgb = network_inference_and_rendering(ray_points, model)
    output.append(rgb)
  final_image = tf.concat(output, axis=0)

  fig, ax = plt.subplots(1, 2)
  ax[0].imshow(final_image)
  ax[1].imshow(testimg)
  plt.show()
  loss = tf.reduce_mean(tf.square(final_image - testimg))
  psnr = -10. * tf.math.log(loss) / tf.math.log(10.)
  print(psnr.numpy())