Skip to content

Commit

Permalink
Sample random patches from a perspective camera.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 362005941
  • Loading branch information
krematas authored and Copybara-Service committed Mar 29, 2021
1 parent 27f0e67 commit 37094bb
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 6 deletions.
93 changes: 87 additions & 6 deletions tensorflow_graphics/rendering/camera/perspective.py
Expand Up @@ -48,6 +48,7 @@
from typing import Tuple
import tensorflow as tf

from tensorflow_graphics.geometry.representation import grid
from tensorflow_graphics.util import asserts
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import safe_ops
Expand Down Expand Up @@ -453,12 +454,12 @@ def random_rays(focal: tf.Tensor, principal_point: tf.Tensor,
width: The width of the image plane in pixels.
n_rays: The number M of rays to sample.
margin: The margin around the borders of the image.
name: A name for this op that defaults to "perspective_ray".
name: A name for this op that defaults to "random_rays".
Returns:
A tensor of shape `[A1, ..., An, M, 3]` with the ray directions and
a tensor of shape `[A1, ..., An, M, 2]` with the pixel locations
a tensor of shape `[A1, ..., An, M, 2]` with the pixel x, y locations.
"""
with tf.compat.v1.name_scope(name, "sample_rays_from_random_pixels",
with tf.compat.v1.name_scope(name, "random_rays",
[focal, principal_point]):
focal = tf.convert_to_tensor(value=focal)
principal_point = tf.convert_to_tensor(value=principal_point)
Expand All @@ -475,15 +476,95 @@ def random_rays(focal: tf.Tensor, principal_point: tf.Tensor,
broadcast_compatible=True)

batch_dims = tf.shape(focal)[:-1]
random_x = tf.random.uniform(tf.concat([batch_dims, [n_rays]], axis=0),
target_shape = tf.concat([batch_dims, [n_rays]], axis=0)
random_x = tf.random.uniform(target_shape,
minval=margin,
maxval=width - margin,
dtype=tf.int32)
random_y = tf.random.uniform(tf.concat([batch_dims, [n_rays]], axis=0),
random_y = tf.random.uniform(target_shape,
minval=margin,
maxval=height - margin,
dtype=tf.int32)
pixels = tf.cast(tf.stack((random_y, random_x), axis=-1), tf.float32)
pixels = tf.cast(tf.stack((random_x, random_y), axis=-1), tf.float32)
rays = ray(pixels,
tf.expand_dims(focal, -2),
tf.expand_dims(principal_point, -2))
return rays, tf.cast(pixels, tf.int32)


def random_patches(focal: tf.Tensor, principal_point: tf.Tensor,
height: int, width: int, patch_height: int, patch_width: int,
scale: float = 1.0,
indexing: str = "ij",
name: str = None) -> Tuple[tf.Tensor, tf.Tensor]:
"""Sample patches at different scales and from an image.
Args:
focal: A tensor of shape `[A1, ..., An, 2]`
principal_point: A tensor of shape `[A1, ..., An, 2]`
height: The height of the image plane in pixels.
width: The width of the image plane in pixels.
patch_height: The height M of the patch in pixels.
patch_width: The width N of the patch in pixels.
scale: The scale of the patch.
indexing: Indexing of the patch ('ij' or 'xy')
name: A name for this op that defaults to "random_patches".
Returns:
A tensor of shape `[A1, ..., An, M*N, 3]` where the last dimension is the
ray directions in 3D passing from the M*N pixels of the patch and
a tensor of shape `[A1, ..., An, M*N, 2]` with the pixel x, y locations.
"""
with tf.compat.v1.name_scope(name, "random_patches",
[focal, principal_point]):
focal = tf.convert_to_tensor(value=focal)
principal_point = tf.convert_to_tensor(value=principal_point)

shape.check_static(
tensor=focal, tensor_name="focal", has_dim_equals=(-1, 2))
shape.check_static(
tensor=principal_point,
tensor_name="principal_point",
has_dim_equals=(-1, 2))
shape.compare_batch_dimensions(
tensors=(focal, principal_point),
tensor_names=("focal", "principal_point"),
last_axes=-2,
broadcast_compatible=True)

if indexing not in ["xy", "ij"]:
raise ValueError("'axis' needs to be 'xy' or 'ij'")

batch_shape = tf.shape(focal)[:-1]
patch = grid.generate([0, 0],
[patch_width - 1, patch_height - 1],
[patch_width, patch_height])
if indexing == "xy":
patch = tf.reverse(patch, axis=[-1])
patch = tf.cast(patch, tf.float32)
patch = patch * scale

interm_shape = tf.concat([tf.ones_like(batch_shape), tf.shape(patch)],
axis=0)
patch = tf.reshape(patch, interm_shape)

random_y = tf.random.uniform(batch_shape,
minval=0,
maxval=height - int(patch_height * scale) + 1,
dtype=tf.int32)
random_x = tf.random.uniform(batch_shape,
minval=0,
maxval=width - int(patch_width * scale) + 1,
dtype=tf.int32)

patch_origins = tf.cast(tf.stack([random_x, random_y], axis=-1), tf.float32)
patch_origins = tf.expand_dims(tf.expand_dims(patch_origins, -2), -2)

pixels = tf.cast(patch + patch_origins, tf.float32)

final_shape = tf.concat([batch_shape, [patch_height * patch_width, 2]],
axis=0)
pixels = tf.reshape(pixels, final_shape)

rays = ray(pixels,
tf.expand_dims(focal, -2),
tf.expand_dims(principal_point, -2))
Expand Down
54 changes: 54 additions & 0 deletions tensorflow_graphics/rendering/camera/tests/perspective_test.py
Expand Up @@ -518,5 +518,59 @@ def test_random_rays_exception_exception_raised(self, error_msg,
self.assert_exception_is_raised(perspective.random_rays, error_msg, shapes,
height=height, width=width, n_rays=n_rays)

@parameterized.parameters(
(128, 128, 64, 64, (2,), (2,)),
(128, 256, 64, 64, (2,), (2,)),
(128, 128, 64, 72, (2,), (2,)),
(128, 256, 64, 72, (2,), (2,)),
(128, 128, 64, 64, (2, 2), (2, 2)),
(128, 128, 64, 64, (5, 3, 2), (5, 3, 2)),
(128, 128, 64, 64, (3, 2), (1, 2)),
(128, 128, 128, 128, (3, 2), (1, 2)),
)
def test_random_patches_exception_exception_not_raised(self,
height,
width,
patch_height,
patch_width,
*shapes):
"""Tests that the shape exceptions are not raised."""
self.assert_exception_is_not_raised(perspective.random_patches, shapes,
height=height, width=width,
patch_height=patch_height,
patch_width=patch_width)

@parameterized.parameters(
("must have exactly 2 dimensions in axis -1",
128, 128, 64, 64, (None,), (2,)),
("must have exactly 2 dimensions in axis -1",
128, 128, 64, 64, (2,), (None,)),
("Not all batch dimensions are broadcast-compatible.",
128, 128, 64, 64, (3, 2), (2, 2)),
)
def test_random_patches_exception_exception_raised(self, error_msg,
height, width,
patch_height, patch_width,
*shapes):
"""Tests that the shape exceptions are properly raised."""
self.assert_exception_is_raised(perspective.random_patches,
error_msg,
shapes,
height=height,
width=width,
patch_height=patch_height,
patch_width=patch_width)

@parameterized.parameters(
(((1., 1.), (1., 1.), 1, 1, 1, 1), (((-1., -1., 1.),), ((0., 0.),))),
)
def test_random_patches_preset(self, test_inputs, test_outputs):
"""Tests that the ray function gives the correct result."""
self.assert_output_is_correct(perspective.random_patches,
test_inputs,
test_outputs,
tile=False)


if __name__ == "__main__":
test_case.main()

0 comments on commit 37094bb

Please sign in to comment.