Skip to content

Commit b6559a3

Browse files
krematascopybara-github
authored andcommitted
Implementation of alpha rendering for samples on a ray using only the density information.
PiperOrigin-RevId: 371591566
1 parent 241dc7d commit b6559a3

File tree

2 files changed

+163
-0
lines changed

2 files changed

+163
-0
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2020 The TensorFlow Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""This module implements the radiance-based ray rendering."""
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import tensorflow as tf
21+
22+
from tensorflow_graphics.util import export_api
23+
from tensorflow_graphics.util import shape
24+
25+
26+
def compute_density(density_values, distances, name=None):
27+
"""Renders the density values (alpha) for points along a ray, as described in ["NeRF Representing Scenes as Neural Radiance Fields for View Synthesis"](https://github.com/bmild/nerf).
28+
29+
Note:
30+
In the following, A1 to An are optional batch dimensions.
31+
32+
Args:
33+
density_values: A tensor of shape `[A1, ..., An, N, 1]`,
34+
where N are the samples on the ray.
35+
distances: A tensor of shape `[A1, ..., An, N]` containing the distances
36+
between the samples, where N are the samples on the ray.
37+
name: A name for this op. Defaults to "ray_radiance".
38+
39+
Returns:
40+
A tensor of shape `[A1, ..., An, 1]` for the estimated density values,
41+
and a tensor of shape `[A1, ..., An, N]` for the sample weights.
42+
"""
43+
44+
with tf.compat.v1.name_scope(name, "ray_density",
45+
[density_values, distances]):
46+
density_values = tf.convert_to_tensor(value=density_values)
47+
distances = tf.convert_to_tensor(value=distances)
48+
distances = tf.expand_dims(distances, -1)
49+
50+
shape.check_static(
51+
tensor=density_values, tensor_name="density_values",
52+
has_dim_equals=(-1, 1))
53+
shape.check_static(
54+
tensor=density_values,
55+
tensor_name="density_values",
56+
has_rank_greater_than=1)
57+
shape.check_static(
58+
tensor=distances,
59+
tensor_name="distances",
60+
has_rank_greater_than=1)
61+
shape.compare_batch_dimensions(
62+
tensors=(density_values, distances),
63+
tensor_names=("density_values", "dists"),
64+
last_axes=-3,
65+
broadcast_compatible=True)
66+
shape.compare_dimensions(
67+
tensors=(density_values, distances),
68+
tensor_names=("density_values", "dists"),
69+
axes=-2)
70+
71+
alpha = 1. - tf.exp(-density_values * distances)
72+
alpha = tf.squeeze(alpha, -1)
73+
ray_sample_weights = alpha * tf.math.cumprod(1. - alpha + 1e-10, -1,
74+
exclusive=True)
75+
ray_alpha = tf.expand_dims(tf.reduce_sum(ray_sample_weights, -1), axis=-1)
76+
return ray_alpha, ray_sample_weights
77+
78+
# API contains all public functions and classes.
79+
__all__ = export_api.get_functions_and_classes()
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2020 The TensorFlow Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for radiance ray rendering."""
15+
16+
from absl.testing import flagsaver
17+
from absl.testing import parameterized
18+
import numpy as np
19+
20+
from tensorflow_graphics.rendering.volumetric import ray_density
21+
from tensorflow_graphics.util import test_case
22+
23+
24+
def generate_random_test_ray_render():
25+
"""Generates random test for the voxels rendering functions."""
26+
batch_shape = np.random.randint(1, 3)
27+
n_rays = np.random.randint(1, 512)
28+
random_ray_values = np.random.uniform(size=[batch_shape] + [n_rays, 1])
29+
random_ray_dists = random_ray_values[..., -1]
30+
return random_ray_values, random_ray_dists
31+
32+
33+
class RayDensityTest(test_case.TestCase):
34+
35+
@parameterized.parameters(
36+
((6, 1), (6,)),
37+
((8, 16, 44, 1), (8, 16, 44)),
38+
((12, 8, 16, 22, 1), (12, 8, 16, 22)),
39+
((32, 32, 256, 1), (32, 32, 256)),
40+
((32, 32, 256, 1), (1, 1, 256)),
41+
((32, 32, 256, 1), (256,)),
42+
)
43+
def test_render_shape_exception_not_raised(self, *shapes):
44+
"""Tests that the shape exceptions are not raised."""
45+
self.assert_exception_is_not_raised(ray_density.compute_density, shapes)
46+
47+
@parameterized.parameters(
48+
("must have a rank greater than 1", ((1,), (3,))),
49+
("must have exactly 1 dimensions in axis -1", ((44, 4), (44, 1))),
50+
("Not all batch dimensions are broadcast-compatible.",
51+
((32, 32, 256, 1,), (32, 16, 256,))),
52+
("must have the same number of dimensions",
53+
((32, 32, 128, 1,), (32, 32, 555,))),
54+
)
55+
def test_render_shape_exception_raised(self, error_msg, shape):
56+
"""Tests that the shape exception is raised."""
57+
self.assert_exception_is_raised(ray_density.compute_density,
58+
error_msg, shape)
59+
60+
@flagsaver.flagsaver(tfg_add_asserts_to_graph=False)
61+
def test_render_jacobian_random(self):
62+
"""Tests the Jacobian of render."""
63+
point_values, point_distance = generate_random_test_ray_render()
64+
self.assert_jacobian_is_correct_fn(
65+
lambda x: ray_density.compute_density(x, point_distance)[0],
66+
[point_values])
67+
self.assert_jacobian_is_correct_fn(
68+
lambda x: ray_density.compute_density(point_values, x)[0],
69+
[point_distance])
70+
71+
def test_render_preset(self):
72+
"""Checks that render returns the expected value."""
73+
74+
image_rays = np.zeros((128, 128, 64, 1))
75+
image_rays[32:96, 32:96, 16:32, :] = 1
76+
distances = np.zeros((128, 128, 64)) + 1.5
77+
target_image = np.zeros((128, 128, 1))
78+
target_image[32:96, 32:96, :] = 1
79+
rendered_image, *_ = ray_density.compute_density(image_rays, distances)
80+
self.assertAllClose(rendered_image, target_image)
81+
82+
83+
if __name__ == "__main__":
84+
test_case.main()

0 commit comments

Comments
 (0)