Skip to content

Commit

Permalink
Merge 091c42f into aeabe70
Browse files Browse the repository at this point in the history
  • Loading branch information
copybara-service[bot] committed May 30, 2021
2 parents aeabe70 + 091c42f commit e7bf074
Showing 1 changed file with 26 additions and 9 deletions.
35 changes: 26 additions & 9 deletions tensorflow_graphics/geometry/representation/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import division
from __future__ import print_function

from typing import Tuple
from typing import Tuple, Union
from six.moves import range
import tensorflow as tf

Expand Down Expand Up @@ -68,8 +68,8 @@ def _points_from_z_values(ray_org: TensorLike,
def sample_stratified_1d(
ray_org: TensorLike,
ray_dir: TensorLike,
near: float,
far: float,
near: Union[float, TensorLike],
far: Union[float, TensorLike],
n_samples: int,
name: str = "sample_stratified_1d") -> Tuple[tf.Tensor, tf.Tensor]:
"""Sample points on a ray using stratified sampling.
Expand All @@ -79,8 +79,10 @@ def sample_stratified_1d(
where the last dimension represents the 3D position of the ray origin.
ray_dir: A tensor of shape `[A1, ..., An, 3]`,
where the last dimension represents the 3D direction of the ray.
near: The smallest distance from the ray origin that a sample can have.
far: The largest distance from the ray origin that a sample can have.
near: The smallest distance from the ray origin that a sample can have; it
can be a float or a tensor of shape `[A1, ..., An]`, broadcast compatible.
far: The largest distance from the ray origin that a sample can have; it
can be a float or a tensor of shape `[A1, ..., An]`, broadcast compatible.
n_samples: A number M to sample on the ray.
name: A name for this op that defaults to "stratified_sampling".
Expand All @@ -89,6 +91,10 @@ def sample_stratified_1d(
and a tensor of shape `[A1, ..., An, M]` for the Z values on the points.
"""
with tf.name_scope(name):
ray_org = tf.convert_to_tensor(ray_org)
ray_dir = tf.convert_to_tensor(ray_dir)
near = tf.convert_to_tensor(near) * tf.ones((1,))
far = tf.convert_to_tensor(far) * tf.ones((1,))
shape.check_static(
tensor=ray_org,
tensor_name="ray_org",
Expand All @@ -100,12 +106,23 @@ def sample_stratified_1d(
shape.compare_batch_dimensions(
tensors=(ray_org, ray_dir),
tensor_names=("ray_org", "ray_dir"),
last_axes=(-2, -2),
last_axes=-2,
broadcast_compatible=False)
shape.compare_batch_dimensions(
tensors=(tf.expand_dims(near, axis=-1), tf.expand_dims(far, axis=-1)),
tensor_names=("near", "far"),
last_axes=-1,
broadcast_compatible=True)
shape.compare_batch_dimensions(
tensors=(ray_org, tf.expand_dims(near, axis=-1)),
tensor_names=("ray_org", "near"),
last_axes=-2,
broadcast_compatible=True)
near = near * tf.ones(tf.shape(ray_org)[:-1])
far = far * tf.ones(tf.shape(ray_org)[:-1])

batch_dims = tf.shape(ray_org)[:-1]
random_z_values = sampling.stratified_1d(near * tf.ones(batch_dims),
far * tf.ones(batch_dims),
random_z_values = sampling.stratified_1d(near,
far,
n_samples)
points3d = _points_from_z_values(ray_org, ray_dir, random_z_values)
return points3d, random_z_values
Expand Down

0 comments on commit e7bf074

Please sign in to comment.