From 463ed7744ea205d4fc8c7dc48d25afa6f97df5cf Mon Sep 17 00:00:00 2001 From: Konstantinos Rematas Date: Mon, 19 Jul 2021 08:48:20 -0700 Subject: [PATCH] Add the NeRF model, together with training and evaluating scripts. PiperOrigin-RevId: 385563950 --- .../projects/radiance_fields/nerf/eval.py | 183 +++++++++++++ .../projects/radiance_fields/nerf/model.py | 258 ++++++++++++++++++ .../radiance_fields/nerf/tests/model_test.py | 55 ++++ .../projects/radiance_fields/nerf/train.py | 114 ++++++++ 4 files changed, 610 insertions(+) create mode 100644 tensorflow_graphics/projects/radiance_fields/nerf/eval.py create mode 100644 tensorflow_graphics/projects/radiance_fields/nerf/model.py create mode 100644 tensorflow_graphics/projects/radiance_fields/nerf/tests/model_test.py create mode 100644 tensorflow_graphics/projects/radiance_fields/nerf/train.py diff --git a/tensorflow_graphics/projects/radiance_fields/nerf/eval.py b/tensorflow_graphics/projects/radiance_fields/nerf/eval.py new file mode 100644 index 000000000..fff10a7dd --- /dev/null +++ b/tensorflow_graphics/projects/radiance_fields/nerf/eval.py @@ -0,0 +1,183 @@ +# Copyright 2020 The TensorFlow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Train.""" +import os +from absl import app +from absl import flags +from absl import logging +import numpy as np +from PIL import Image +from skimage import metrics +import tensorflow as tf + +import tensorflow_graphics.projects.radiance_fields.data_loaders as data_loaders +import tensorflow_graphics.projects.radiance_fields.nerf.model as model_lib +import tensorflow_graphics.projects.radiance_fields.utils as utils +import tensorflow_graphics.rendering.camera.perspective as perspective + + +flags.DEFINE_string('checkpoint_dir', '/tmp/lego/', + 'Path to the directory of the checkpoint.') +flags.DEFINE_string('split', 'val', 'Train/val/test split.') +flags.DEFINE_boolean('single_eval', False, 'How many times to perform eval.') +flags.DEFINE_string('output_dir', '/tmp/lego/eval', + 'Path to the directory of the output results.') +flags.DEFINE_string('dataset_dir', '/path/to/dataset/', + 'Path to the directory of the dataset images.') +flags.DEFINE_string('dataset_name', 'lego', 'Dataset name.') +flags.DEFINE_float('dataset_scale', 0.5, + 'Resolution of the dataset (1.0=800 pixels).') +flags.DEFINE_integer('num_epochs', 10000, 'How many epochs to train') +flags.DEFINE_integer('batch_size', 5, 'Number of images for each batch.') +flags.DEFINE_float('learning_rate', 0.0004, 'The optimizer learning rate.') +flags.DEFINE_integer('n_filters', 256, 'Number of filters of the MLP.') +flags.DEFINE_integer('n_freq_posenc_xyz', 10, + 'Frequencies for the 3D location positional encoding.') +flags.DEFINE_string('scene_bbox', '-1.0,-1.0,-1.0,1.0,1.0,1.0', + 'Bounding box of the scene.') + +flags.DEFINE_integer('n_freq_posenc_dir', 4, + 'Frequencies for the direction positional encoding.') +flags.DEFINE_float('near', 2.0, 'Closest ray location to get samples.') +flags.DEFINE_float('far', 6.0, 'Furthest ray location to get samples.') +flags.DEFINE_integer('ray_samples_coarse', 64, + 'Samples on a ray for the coarse network.') +flags.DEFINE_integer('ray_samples_fine', 128, + 'Samples on a ray for the fine network.') +flags.DEFINE_integer('n_rays', 512, 'Number of rays per image for training.') +flags.DEFINE_boolean('white_background', True, 'Use white background.') +flags.DEFINE_string('master', 'local', 'Location of the session.') +FLAGS = flags.FLAGS + + +def main(_): + + dataset, height, width = data_loaders.load_synthetic_nerf_dataset( + dataset_dir=FLAGS.dataset_dir, + dataset_name=FLAGS.dataset_name, + split=FLAGS.split, + scale=FLAGS.dataset_scale, + batch_size=1, + shuffle=False) + + model = model_lib.NeRF( + ray_samples_coarse=FLAGS.ray_samples_coarse, + ray_samples_fine=FLAGS.ray_samples_fine, + near=FLAGS.near, + far=FLAGS.far, + n_freq_posenc_xyz=FLAGS.n_freq_posenc_xyz, + scene_bbox=tuple([float(x) for x in FLAGS.scene_bbox.split(',')]), + n_freq_posenc_dir=FLAGS.n_freq_posenc_dir, + n_filters=FLAGS.n_filters, + white_background=True) + model.init_coarse_and_fine_models() + model.init_optimizer(learning_rate=FLAGS.learning_rate) + model.init_checkpoint(checkpoint_dir=FLAGS.checkpoint_dir) + + if not tf.io.gfile.exists(FLAGS.output_dir): + tf.io.gfile.makedirs(FLAGS.output_dir) + summary_writer = tf.summary.create_file_writer(FLAGS.output_dir) + + # ---------------------------------------------------------------------------- + current_evaluation = 0 + current_checkpoint = '' + while True: + latest_checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) + + if latest_checkpoint is None: + continue + + if current_checkpoint == latest_checkpoint: + continue + + current_checkpoint = latest_checkpoint + model.load_checkpoint(current_checkpoint) + + total_psnr = [] + total_ssim = [] + image_counter = 0 + for image, focal, principal_point, transform_matrix in dataset: + + img_rays, _ = perspective.random_patches( + focal, + principal_point, + height, + width, + patch_height=height, + patch_width=width, + scale=1.0) + + # Batchify the image to fit into memory + batch_rays = tf.split(img_rays, height, axis=1) + output = [] + for random_rays in batch_rays: + random_rays = utils.change_coordinate_system(random_rays, + (0., 0., 0.), + (1., -1., -1.)) + rays_org, rays_dir = utils.camera_rays_from_transformation_matrix( + random_rays, + transform_matrix) + + rgb_fine, *_ = model.inference(rays_org, rays_dir) + output.append(rgb_fine) + final_image = tf.concat(output, axis=0) + final_image_np = final_image.numpy() + + image_rgb_no_alpha, image_a = tf.split(image, [3, 1], axis=-1) + if FLAGS.white_background: + image = image_rgb_no_alpha * image_a + 1 - image_a + + image_np = image.numpy()[0] + ssim = metrics.structural_similarity(image_np, + final_image_np, + multichannel=True, + data_range=1) + psnr = metrics.peak_signal_noise_ratio(image_np, + final_image_np, + data_range=1) + total_psnr.append(psnr) + total_ssim.append(ssim) + + filename = os.path.join(FLAGS.output_dir, + '{0:05d}.png'.format(image_counter)) + img_to_save = Image.fromarray((final_image_np*255).astype(np.uint8)) + with tf.io.gfile.GFile(filename, 'wb') as f: + img_to_save.save(f) + + logging.info('Image %d: ssim %.3f / psnr: %.3f', + image_counter, ssim, psnr) + image_counter += 1 + + # Show some images + if image_counter < 5: + with summary_writer.as_default(): + tf.summary.image('rgb_fine/{0}'.format(image_counter), + tf.expand_dims(final_image, 0), + step=current_evaluation, + max_outputs=4) + with summary_writer.as_default(): + tf.summary.scalar('eval_ssim', np.mean(total_ssim), + step=current_evaluation) + tf.summary.scalar('eval_psnr', np.mean(total_psnr), + step=current_evaluation) + logging.info('ssim %.3f', np.mean(total_ssim)) + logging.info('psnr %.3f', np.mean(total_psnr)) + current_evaluation += 1 + + if FLAGS.single_eval: + break + + +if __name__ == '__main__': + app.run(main) diff --git a/tensorflow_graphics/projects/radiance_fields/nerf/model.py b/tensorflow_graphics/projects/radiance_fields/nerf/model.py new file mode 100644 index 000000000..6953d8851 --- /dev/null +++ b/tensorflow_graphics/projects/radiance_fields/nerf/model.py @@ -0,0 +1,258 @@ +# Copyright 2020 The TensorFlow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""NeRF models.""" +from absl import logging +import numpy as np +import tensorflow as tf +import tensorflow_graphics.geometry.representation.ray as ray +import tensorflow_graphics.math.feature_representation as feature_rep +import tensorflow_graphics.projects.radiance_fields.nerf.layers as nerf_layers +import tensorflow_graphics.projects.radiance_fields.utils as utils +import tensorflow_graphics.rendering.volumetric.ray_radiance as ray_radiance + + +class NeRF: + """Original NeRF network.""" + + def __init__(self, + ray_samples_coarse=128, + ray_samples_fine=128, + near=1.0, + far=3.0, + n_freq_posenc_xyz=8, + n_freq_posenc_dir=4, + scene_bbox=(-1.0, -1.0, -1.0, 1.0, 1.0, 1.0), + n_filters=256, + white_background=True, + coarse_sampling_strategy="stratified"): + + # Ray parameters + self.ray_samples_coarse = ray_samples_coarse + self.ray_samples_fine = ray_samples_fine + self.near = near + self.far = far + self.white_background = white_background + # Network parameters + self.n_freq_posenc_xyz = n_freq_posenc_xyz + self.n_freq_posenc_dir = n_freq_posenc_dir + + scene_bbox = np.array(scene_bbox).reshape([2, 3]) + area_dims = scene_bbox[1, :] - scene_bbox[0, :] + scene_scale = 1./(max(area_dims)/2.) + scene_translation = -np.mean(scene_bbox, axis=0) + self.scene_scale = scene_scale + self.scene_transl = scene_translation + + self.n_filters = n_filters + + self.coarse_sampling_strategy = coarse_sampling_strategy + + self.xyz_dim = n_freq_posenc_xyz * 2 * 3 + 3 + self.dir_dim = n_freq_posenc_dir * 2 * 3 + 3 + self.coarse_model = None + self.fine_model = None + self.optimizer_network = None + self.network_vars = None + + self.coarse_model_backup = None + self.fine_model_backup = None + + self.latest_epoch = None + self.global_step = None + self.summary_writer = None + self.checkpoint = None + self.manager = None + + def init_coarse_and_fine_models(self): + """Initialize models and variables.""" + self.coarse_model = self.get_model() + self.fine_model = self.get_model() + self.latest_epoch = tf.Variable(0, trainable=False, dtype=tf.int64) + self.global_step = tf.Variable(0, trainable=False, dtype=tf.int64) + self.network_vars = (self.coarse_model.trainable_variables + + self.fine_model.trainable_variables) + + def init_optimizer(self, learning_rate=0.0001, decay_steps=1000, + decay_rate=0.98, staircase=True): + """Initialize the optimizers with a scheduler.""" + learning_rate = tf.keras.optimizers.schedules.ExponentialDecay( + learning_rate, + decay_steps=decay_steps, + decay_rate=decay_rate, + staircase=staircase) + self.optimizer_network = tf.keras.optimizers.Adam( + learning_rate=learning_rate) + + def init_checkpoint(self, checkpoint_dir, checkpoint=None): + """Initialize the checkpoints.""" + self.summary_writer = tf.summary.create_file_writer(checkpoint_dir) + self.checkpoint = tf.train.Checkpoint( + coarse_nerf=self.coarse_model, + fine_nerf=self.fine_model, + optimizer_network=self.optimizer_network, + epoch=self.latest_epoch, + global_step=self.global_step) + self.manager = tf.train.CheckpointManager( + checkpoint=self.checkpoint, directory=checkpoint_dir, max_to_keep=5) + self.load_checkpoint(checkpoint=checkpoint) + + def load_checkpoint(self, checkpoint=None): + """Load checkpoints.""" + latest_checkpoint = self.manager.latest_checkpoint if checkpoint is None else checkpoint + + if latest_checkpoint is not None: + logging.info("Checkpoint %s restored", latest_checkpoint) + _ = self.checkpoint.restore(latest_checkpoint).expect_partial() + else: + logging.warning("No checkpoint was restored.") + + def get_model(self): + """Constructs the original NeRF network as a keras model.""" + with tf.name_scope("Network/"): + xyz_features = tf.keras.layers.Input(shape=[None, None, self.xyz_dim]) + dir_features = tf.keras.layers.Input(shape=[None, None, self.dir_dim]) + + feat0 = nerf_layers.concat_block(xyz_features, + n_filters=self.n_filters, + n_layers=4) + feat1 = nerf_layers.dense_block(feat0, + n_filters=self.n_filters, + n_layers=4) + feat2 = tf.keras.layers.Dense(self.n_filters)(feat1) + density = tf.keras.layers.Dense(1)(feat2) + feat2_dir = tf.keras.layers.concatenate([feat2, dir_features], -1) + feat3 = tf.keras.layers.Dense(self.n_filters//2)(feat2_dir) + rgb = tf.keras.layers.Dense(3)(feat3) + rgb_density = tf.keras.layers.concatenate([rgb, density], -1) + return tf.keras.Model(inputs=[xyz_features, dir_features], + outputs=[rgb_density]) + + @tf.function + def prepare_positional_encoding(self, ray_points, ray_dirs): + """Estimate the positional encoding of the 3D position and direction of the samples along a ray. + + Args: + ray_points: A tensor of shape `[A, B, C, 3]` where A is the batch size, B + is the number of rays, C is the number of samples per ray. + ray_dirs: A tensor of shape `[A, B, 3]` where A is the batch size, B + is the number of rays. + + Returns: + A list containing a tensor of shape `[A, B, C, M]` and a tensor of shape + `[A, B, C, N]`, where M is the dimensionality of the location positional + encoding and N is dimensionality of the direction positional encoding. + """ + n_ray_samples = tf.shape(ray_points)[-2] + scaled_ray_points = self.scene_scale * (ray_points + self.scene_transl) + features_xyz = feature_rep.positional_encoding(scaled_ray_points, + self.n_freq_posenc_xyz) + ray_dirs = tf.tile(tf.expand_dims(ray_dirs, -2), [1, 1, n_ray_samples, 1]) + features_dir = feature_rep.positional_encoding(ray_dirs, + self.n_freq_posenc_dir) + return features_xyz, features_dir + + @tf.function + def render_network_output(self, rgb_density, ray_points): + """Renders the network output into rgb and density values. + + Args: + rgb_density: A tensor of shape `[A, B, C, 4]` where A is the batch size, B + is the number of rays, C is the number of samples per ray. + ray_points: A tensor of shape `[A, B, C, 3]`. + + Returns: + A tensor of shape `[A, B, 3]` and a tensor of shape `[A, B, C]`. + + """ + rgb, density = tf.split(rgb_density, [3, 1], axis=-1) + rgb = tf.sigmoid(rgb) + density = tf.nn.relu(density) + rgb_density = tf.concat([rgb, density], axis=-1) + + dists = utils.get_distances_between_points(ray_points) + rgb_render, a_render, weights = ray_radiance.compute_radiance(rgb_density, + dists) + if self.white_background: + rgb_render = rgb_render + 1 - a_render + return rgb_render, weights + + @tf.function + def inference(self, r_org, r_dir): + """Run both coarse and fine networks for given rays. + + Args: + r_org: A tensor of shape `[A, B, 3]` where A is the batch size, B is the + number of rays. + r_dir: A tensor of shape `[A, B, 3]` where A is the batch size, B is the + number of rays. + + Returns: + Two tensors of size `[A, B, 3]`. + """ + ray_points_coarse, z_vals_coarse = ray.sample_1d( + r_org, + r_dir, + near=self.near, + far=self.far, + n_samples=self.ray_samples_coarse, + strategy=self.coarse_sampling_strategy) + posenc_features = self.prepare_positional_encoding(ray_points_coarse, r_dir) + rgb_density = self.coarse_model(posenc_features) + rgb_coarse, weights_coarse = self.render_network_output(rgb_density, + ray_points_coarse) + depth_map_coarse = tf.reduce_sum(weights_coarse * z_vals_coarse, axis=-1) + + ray_points_fine, z_vals_fine = ray.sample_inverse_transform_stratified_1d( + r_org, + r_dir, + z_vals_coarse, + weights_coarse, + n_samples=self.ray_samples_fine, + combine_z_values=True) + posenc_features = self.prepare_positional_encoding(ray_points_fine, r_dir) + rgb_density = self.fine_model(posenc_features) + rgb_fine, weights_fine = self.render_network_output(rgb_density, + ray_points_fine) + depth_map_fine = tf.reduce_sum(weights_fine * z_vals_fine, axis=-1) + return rgb_fine, rgb_coarse, depth_map_fine, depth_map_coarse + + @tf.function + def train_step(self, r_org, r_dir, gt_rgb): + """Training function for coarse and fine networks. + + Args: + r_org: A tensor of shape `[B, N, 3]` where B is the batch size, N is the + number of rays. + r_dir: A tensor of shape `[B, N, 3]` where B is the batch size, N is the + number of rays. + gt_rgb: A tensor of shape `[B, N, 3]` where B is the batch size, N is the + number of rays. + + Returns: + A scalar. + """ + with tf.GradientTape() as tape: + rgb_fine, rgb_coarse, _, _ = self.inference(r_org, r_dir) + rgb_coarse_loss = utils.l2_loss(rgb_coarse, gt_rgb) + rgb_fine_loss = utils.l2_loss(rgb_fine, gt_rgb) + total_loss = rgb_coarse_loss + rgb_fine_loss + gradients = tape.gradient(total_loss, self.network_vars) + self.optimizer_network.apply_gradients(zip(gradients, self.network_vars)) + + with self.summary_writer.as_default(): + step = self.global_step + tf.summary.scalar("total_loss", total_loss, step=step) + tf.summary.scalar("rgb_loss_f", rgb_fine_loss, step=step) + tf.summary.scalar("rgb_loss_c", rgb_coarse_loss, step=step) + return total_loss diff --git a/tensorflow_graphics/projects/radiance_fields/nerf/tests/model_test.py b/tensorflow_graphics/projects/radiance_fields/nerf/tests/model_test.py new file mode 100644 index 000000000..49864de85 --- /dev/null +++ b/tensorflow_graphics/projects/radiance_fields/nerf/tests/model_test.py @@ -0,0 +1,55 @@ +# Copyright 2020 The TensorFlow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r"""Tests for the NeRF model.""" + +from absl import flags +import tensorflow as tf +import tensorflow_graphics.projects.radiance_fields.nerf.model as model_lib + +from tensorflow_graphics.util import test_case + +FLAGS = flags.FLAGS + + +class ModelTest(test_case.TestCase): + + def test_model_training(self): + """Tests whether the NeRF model is initialized properly and can be trained.""" + + model = model_lib.NeRF( + ray_samples_coarse=128, + ray_samples_fine=128, + near=1.0, + far=6.0, + scene_bbox=(-1.0, -1.0, -1.0, 1.0, 1.0, 1.0), + n_freq_posenc_xyz=10, + n_freq_posenc_dir=4, + n_filters=128, + white_background=True) + model.init_coarse_and_fine_models() + model.init_optimizer(learning_rate=0.0001) + model.init_checkpoint(checkpoint_dir="/tmp/") + + batch_size = 10 + n_rays = 256 + + rays_org = tf.zeros((batch_size, n_rays, 3), dtype=tf.float32) + rays_dir = tf.zeros((batch_size, n_rays, 3), dtype=tf.float32) + pixels_rgb = tf.zeros((batch_size, n_rays, 3), dtype=tf.float32) + + rgb_loss = model.train_step(rays_org, rays_dir, pixels_rgb) + self.assertAllInRange(rgb_loss, 0.0, 1000.0) + +if __name__ == "__main__": + test_case.main() diff --git a/tensorflow_graphics/projects/radiance_fields/nerf/train.py b/tensorflow_graphics/projects/radiance_fields/nerf/train.py new file mode 100644 index 000000000..0cd41ae52 --- /dev/null +++ b/tensorflow_graphics/projects/radiance_fields/nerf/train.py @@ -0,0 +1,114 @@ +# Copyright 2020 The TensorFlow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Train script for NeRF.""" +from absl import app +from absl import flags +from absl import logging +import tensorflow as tf +import tensorflow_graphics.projects.radiance_fields.data_loaders as data_loaders +import tensorflow_graphics.projects.radiance_fields.nerf.model as model_lib +import tensorflow_graphics.projects.radiance_fields.utils as utils +import tensorflow_graphics.rendering.camera.perspective as perspective + +flags.DEFINE_string('checkpoint_dir', '/tmp/lego/', + 'Path to the directory of the checkpoint.') +flags.DEFINE_string('split', 'train', 'Train/val/test split.') +flags.DEFINE_string('dataset_dir', '/path/to/dataset/', + 'Path to the directory of the dataset images.') +flags.DEFINE_string('dataset_name', 'lego', 'Dataset name.') +flags.DEFINE_float('dataset_scale', 0.5, + 'Resolution of the dataset (1.0=800 pixels).') +flags.DEFINE_integer('num_epochs', 10000, 'How many epochs to train') +flags.DEFINE_integer('batch_size', 5, 'Number of images for each batch.') +flags.DEFINE_float('learning_rate', 0.0004, 'The optimizer learning rate.') +flags.DEFINE_integer('decay_steps', 10000, 'Number of images for each batch.') +flags.DEFINE_integer('n_filters', 256, 'Number of filters of the MLP.') +flags.DEFINE_integer('n_freq_posenc_xyz', 10, + 'Frequencies for the 3D location positional encoding.') +flags.DEFINE_string('scene_bbox', '-1.0,-1.0,-1.0,1.0,1.0,1.0', + 'Bounding box of the scene.') + +flags.DEFINE_integer('n_freq_posenc_dir', 4, + 'Frequencies for the direction positional encoding.') +flags.DEFINE_float('near', 2.0, 'Closest ray location to get samples.') +flags.DEFINE_float('far', 6.0, 'Furthest ray location to get samples.') +flags.DEFINE_integer('ray_samples_coarse', 64, + 'Samples on a ray for the coarse network.') +flags.DEFINE_integer('ray_samples_fine', 128, + 'Samples on a ray for the fine network.') +flags.DEFINE_integer('n_rays', 512, 'Number of rays per image for training.') +flags.DEFINE_boolean('white_background', True, 'Use white background.') +flags.DEFINE_string('master', 'local', 'Location of the session.') + +FLAGS = flags.FLAGS + + +def main(_): + + dataset, height, width = data_loaders.load_synthetic_nerf_dataset( + dataset_dir=FLAGS.dataset_dir, + dataset_name=FLAGS.dataset_name, + split=FLAGS.split, + scale=FLAGS.dataset_scale, + batch_size=FLAGS.batch_size) + + model = model_lib.NeRF( + ray_samples_coarse=FLAGS.ray_samples_coarse, + ray_samples_fine=FLAGS.ray_samples_fine, + near=FLAGS.near, + far=FLAGS.far, + n_freq_posenc_xyz=FLAGS.n_freq_posenc_xyz, + scene_bbox=tuple([float(x) for x in FLAGS.scene_bbox.split(',')]), + n_freq_posenc_dir=FLAGS.n_freq_posenc_dir, + n_filters=FLAGS.n_filters, + white_background=FLAGS.white_background) + model.init_coarse_and_fine_models() + model.init_optimizer(learning_rate=FLAGS.learning_rate, + decay_steps=FLAGS.decay_steps) + model.init_checkpoint(checkpoint_dir=FLAGS.checkpoint_dir) + + for epoch in range(int(model.latest_epoch.numpy()), FLAGS.num_epochs): + epoch_loss = 0.0 + for image, focal, principal_point, transform_matrix in dataset: + random_rays, random_pixels_xy = perspective.random_rays( + focal, + principal_point, + height, + width, + FLAGS.n_rays) + random_rays = utils.change_coordinate_system(random_rays, + (0., 0., 0.), + (1., -1., -1.)) + rays_org, rays_dir = utils.camera_rays_from_transformation_matrix( + random_rays, + transform_matrix) + random_pixels_yx = tf.reverse(random_pixels_xy, axis=[-1]) + pixels = tf.gather_nd(image, random_pixels_yx, batch_dims=1) + pixels_rgb, pixels_a = tf.split(pixels, [3, 1], axis=-1) + pixels_rgb = pixels_rgb * pixels_a + 1 - pixels_a + + dist_loss = model.train_step(rays_org, rays_dir, pixels_rgb) + epoch_loss += dist_loss.numpy() + model.global_step.assign_add(1) + + with model.summary_writer.as_default(): + tf.summary.scalar('epoch_loss', epoch_loss, step=epoch) + if epoch % 20 == 0: + model.manager.save() + logging.info('Epoch %d: %.3f.', epoch, epoch_loss) + model.latest_epoch.assign(epoch + 1) + model.manager.save() + +if __name__ == '__main__': + app.run(main)