Skip to content

Commit

Permalink
Merge 60def1d into ee853d4
Browse files Browse the repository at this point in the history
  • Loading branch information
copybara-service[bot] committed Apr 23, 2021
2 parents ee853d4 + 60def1d commit 39a22a2
Showing 1 changed file with 20 additions and 7 deletions.
27 changes: 20 additions & 7 deletions tensorflow_graphics/geometry/representation/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import division
from __future__ import print_function

from typing import Tuple
from typing import Tuple, Union
from six.moves import range
import tensorflow as tf

Expand Down Expand Up @@ -68,8 +68,8 @@ def _points_from_z_values(ray_org: TensorLike,
def sample_stratified_1d(
ray_org: TensorLike,
ray_dir: TensorLike,
near: float,
far: float,
near: Union[float, TensorLike],
far: Union[float, TensorLike],
n_samples: int,
name: str = "sample_stratified_1d") -> Tuple[tf.Tensor, tf.Tensor]:
"""Sample points on a ray using stratified sampling.
Expand All @@ -89,6 +89,10 @@ def sample_stratified_1d(
and a tensor of shape `[A1, ..., An, M]` for the Z values on the points.
"""
with tf.name_scope(name):
ray_org = tf.convert_to_tensor(ray_org)
ray_dir = tf.convert_to_tensor(ray_dir)
near = tf.convert_to_tensor(near) * tf.ones(tf.shape(ray_org)[:-1])
far = tf.convert_to_tensor(far) * tf.ones(tf.shape(ray_org)[:-1])
shape.check_static(
tensor=ray_org,
tensor_name="ray_org",
Expand All @@ -100,12 +104,21 @@ def sample_stratified_1d(
shape.compare_batch_dimensions(
tensors=(ray_org, ray_dir),
tensor_names=("ray_org", "ray_dir"),
last_axes=(-2, -2),
last_axes=-2,
broadcast_compatible=False)
shape.compare_batch_dimensions(
tensors=(tf.expand_dims(near, axis=-1), tf.expand_dims(far, axis=-1)),
tensor_names=("near", "far"),
last_axes=-1,
broadcast_compatible=False)
shape.compare_batch_dimensions(
tensors=(ray_org, tf.expand_dims(near, axis=-1)),
tensor_names=("ray_org", "near"),
last_axes=-2,
broadcast_compatible=False)

batch_dims = tf.shape(ray_org)[:-1]
random_z_values = sampling.stratified_1d(near * tf.ones(batch_dims),
far * tf.ones(batch_dims),
random_z_values = sampling.stratified_1d(near,
far,
n_samples)
points3d = _points_from_z_values(ray_org, ray_dir, random_z_values)
return points3d, random_z_values
Expand Down

0 comments on commit 39a22a2

Please sign in to comment.