Skip to content

Commit

Permalink
- Fix use case when screen_dimensions or lower_left_corner don't …
Browse files Browse the repository at this point in the history
…have batch dimensions, while other tensors do have batch dimensions.

- Unify rasterization backend output to produce channel dimension for `mask` and `triangle_index`

PiperOrigin-RevId: 352640436
  • Loading branch information
podlipensky authored and Copybara-Service committed Jan 26, 2021
1 parent 3866f98 commit d51fe97
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 18 deletions.
13 changes: 13 additions & 0 deletions tensorflow_graphics/rendering/opengl/math.py
Expand Up @@ -426,6 +426,19 @@ def perspective_correct_barycentrics(triangle_vertices_model_space,
tensor_name="triangle_vertices_model_space",
has_dim_equals=((-2, 3), (-1, 3)))

lower_left_corner = tf.convert_to_tensor(value=lower_left_corner)
screen_dimensions = tf.convert_to_tensor(value=screen_dimensions)
lower_left_corner = shape.add_batch_dimensions(
lower_left_corner,
"lower_left_corner",
model_to_eye_matrix.shape[:-2],
last_axis=-2)
screen_dimensions = shape.add_batch_dimensions(
screen_dimensions,
"screen_dimensions",
model_to_eye_matrix.shape[:-2],
last_axis=-2)

vertices_screen, vertices_w = model_to_screen(triangle_vertices_model_space,
model_to_eye_matrix,
perspective_matrix,
Expand Down
Expand Up @@ -128,7 +128,7 @@ def rasterize(vertices,
is associated to a pixel, the index is set to -1.
The second element in the tuple is of shape `[A1, ..., An, H, W, 3]` and
correspond to barycentric coordinates per pixel. The last element in the
tuple is of shape `[A1, ..., An, H, W]` and stores a value of `0` of the
tuple is of shape `[A1, ..., An, H, W, 1]` and stores a value of `0` of the
pixel is assciated with the background, and `1` with the foreground.
"""
with tf.compat.v1.name_scope(name, "rasterization_backend_rasterize",
Expand Down Expand Up @@ -183,12 +183,19 @@ def rasterize(vertices,
fragment_shader=fragment_shader)

triangle_index = tf.cast(rasterized[..., 0], tf.int32) - 1
# Slicing of the tensor will result in all batch dimensions being
# `None` for tensorflow graph mode, therefore we have to fix it in order to
# have explicit shape.
width, height = image_size
triangle_index = tf.reshape(triangle_index,
common_batch_shape + [height, width, 1])
barycentric_coordinates = rasterized[..., 1:3]
barycentric_coordinates = tf.concat(
(barycentric_coordinates, 1.0 - barycentric_coordinates[..., 0:1] -
barycentric_coordinates[..., 1:2]),
axis=-1)
mask = tf.cast(rasterized[..., 3], tf.int32)
mask = tf.reshape(mask, common_batch_shape + [height, width, 1])

return triangle_index, barycentric_coordinates, mask

Expand Down
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import parameterized
import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -96,8 +95,8 @@ def test_rasterize_batch_vertices_only(self):
(_IMAGE_WIDTH, _IMAGE_HEIGHT))
self.assertAllEqual(mask[0, ...], tf.ones_like(mask[0, ...]))

gt_layer_1 = np.zeros((_IMAGE_HEIGHT, _IMAGE_WIDTH), np.float32)
gt_layer_1[_IMAGE_HEIGHT // 2:, _IMAGE_WIDTH // 2:] = 1.0
gt_layer_1 = np.zeros((_IMAGE_HEIGHT, _IMAGE_WIDTH, 1), np.float32)
gt_layer_1[_IMAGE_HEIGHT // 2:, _IMAGE_WIDTH // 2:, 0] = 1.0
self.assertAllEqual(mask[1, ...], gt_layer_1)

def test_rasterize_batch_view_only(self):
Expand Down Expand Up @@ -138,15 +137,16 @@ def test_rasterize_preset(self):
(_IMAGE_WIDTH, _IMAGE_HEIGHT))

with self.subTest(name="triangle_index"):
groundtruth_triangle_index = np.zeros((_IMAGE_HEIGHT, _IMAGE_WIDTH),
groundtruth_triangle_index = np.zeros((_IMAGE_HEIGHT, _IMAGE_WIDTH, 1),
dtype=np.int32)
groundtruth_triangle_index[..., :_IMAGE_WIDTH // 2] = -1
groundtruth_triangle_index[:_IMAGE_HEIGHT // 2, _IMAGE_WIDTH // 2:] = 1
groundtruth_triangle_index[..., :_IMAGE_WIDTH // 2, 0] = -1
groundtruth_triangle_index[:_IMAGE_HEIGHT // 2, _IMAGE_WIDTH // 2:, 0] = 1
self.assertAllEqual(groundtruth_triangle_index, predicted_triangle_index)

with self.subTest(name="mask"):
groundtruth_mask = np.ones((_IMAGE_HEIGHT, _IMAGE_WIDTH), dtype=np.int32)
groundtruth_mask[..., :_IMAGE_WIDTH // 2] = 0
groundtruth_mask = np.ones((_IMAGE_HEIGHT, _IMAGE_WIDTH, 1),
dtype=np.int32)
groundtruth_mask[..., :_IMAGE_WIDTH // 2, 0] = 0
self.assertAllEqual(groundtruth_mask, predicted_mask)

attributes = np.array(
Expand Down
23 changes: 14 additions & 9 deletions tensorflow_graphics/rendering/triangle_rasterizer.py
Expand Up @@ -91,9 +91,9 @@ def rasterize(vertices,
name: A name for this op. Defaults to 'triangle_rasterizer_rasterize'.
Returns:
A dictionary. The key "mask" is of shape `[A1, ..., An, height, width]` and
stores a value of `0` of the pixel is assciated with the background, and `1`
with the foreground. The key "barycentrics" is of shape
A dictionary. The key "mask" is of shape `[A1, ..., An, height, width, 1]`
and stores a value of `0` of the pixel is assciated with the background,
and `1` with the foreground. The key "barycentrics" is of shape
`[A1, ..., An, height, width, 3]` and stores barycentric weights. Finally,
the dictionary contains perspective correct interpolated attributes of shape
`[A1, ..., An, height, width, K]` per entry in the `attributes` dictionary.
Expand Down Expand Up @@ -135,27 +135,32 @@ def rasterize(vertices,
backend)
outputs = {"mask": mask, "triangle_indices": triangle_index}

batch_shape = triangle_index.shape[:-3]
batch_shape = [_dim_value(dim) for dim in batch_shape]

vertices = tf.gather(vertices, triangles, axis=-2)

# Gather does not work on negative indices, which is the case for the pixel
# associated to the background.
triangle_index = triangle_index * mask
# Extract batch shape in order to make sure it is preserved after `gather`
# operation.
batch_shape = triangle_index.shape[:-3]
batch_shape = [_dim_value(dim) for dim in batch_shape]
# Remove last dimension of `triangle_index` in order to make it compatible
# with gather operations.
triangle_index_lean = tf.squeeze(triangle_index, axis=-1)
vertices_per_pixel = tf.gather(
vertices, triangle_index, axis=-3, batch_dims=len(batch_shape))
vertices, triangle_index_lean, axis=-3, batch_dims=len(batch_shape))
barycentrics = _perspective_correct_barycentrics(vertices_per_pixel,
model_to_eye_matrix,
perspective_matrix,
image_size_float)
mask_float = tf.cast(tf.expand_dims(mask, axis=-1), vertices.dtype)
mask_float = tf.cast(mask, vertices.dtype)
outputs["barycentrics"] = mask_float * barycentrics

for key, attribute in attributes.items():
attribute = tf.convert_to_tensor(value=attribute)
outputs[key] = mask_float * _perspective_correct_attributes(
attribute, barycentrics, triangles, triangle_index, len(batch_shape))
attribute, barycentrics, triangles, triangle_index_lean,
len(batch_shape))

return outputs

Expand Down
43 changes: 43 additions & 0 deletions tensorflow_graphics/util/shape.py
Expand Up @@ -19,6 +19,7 @@

import itertools

import numpy as np
import six
import tensorflow as tf

Expand Down Expand Up @@ -380,5 +381,47 @@ def is_static(tensor_shape):
return None not in tensor_shape.as_list()


def add_batch_dimensions(tensor, tensor_name, batch_shape, last_axis=None):
"""Broadcasts tensor to match batch dimensions.
It will either broadcast to all provided batch dimensions, therefore
increasing tensor shape by len(batch_shape) dimensions or will do nothing if
batch dimensions already present and equal to expected batch dimensions.
Args:
tensor: A tensor to broadcast of a shape [A1, ..., An, B1, ..., Bn]. Where
[A1, ..., An] is batch dimensions (it is allowed to have no batch
dimensions), and [B1, ..., Bn] are other tensor dimensions. If [A1, ...,
An] are present but different from values in `batch_shape` the error will
be thrown.
tensor_name: Name of `tensor` to be used in the error message if one is
batch_shape: list of `int` representing desired batch dimensions.
last_axis: An `int` corresponding to the last axis of the batch (with zero
based indices). For instance, if there is only a single batch dimension,
last axis should be `0`. If there is no batch dimensions it must be set to
`None`.
thrown.
Returns:
Tensor of a shape `batch_shape` + [B1, ..., Bn] or unmodified tensor if
`batch_shape` = [A1, ..., An].
Raises:
ValueError if tensor already has batch dimensions different from desired
one.
"""
if last_axis is not None:
last_axis = _fix_axes([tensor], [last_axis], allow_negative=True)[0]
tensor_batch_shape = tensor.shape.as_list()[:last_axis + 1]
if np.array_equal(tensor_batch_shape, batch_shape):
return tensor
elif tensor_batch_shape:
raise ValueError(
'Tensor {} has batch dimensions different from target '
'one. Found {}, but expected no batch dimensions or {}'.format(
tensor_name, tensor.shape[:last_axis + 1], batch_shape))

return tf.broadcast_to(tensor, batch_shape + list(tensor.shape))


# The util functions or classes are not exported.
__all__ = []

0 comments on commit d51fe97

Please sign in to comment.