Skip to content

Commit

Permalink
- Update unit tests for rasterizer to test different batch shapes.
Browse files Browse the repository at this point in the history
- Fix batch handling in rasterizer backend.

PiperOrigin-RevId: 352640436
  • Loading branch information
podlipensky authored and Copybara-Service committed Jan 25, 2021
1 parent 3866f98 commit 0bf3f5f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
3 changes: 3 additions & 0 deletions tensorflow_graphics/rendering/opengl/rasterization_backend.py
Expand Up @@ -182,6 +182,9 @@ def rasterize(vertices,
geometry_shader=geometry_shader,
fragment_shader=fragment_shader)

# Note: such slicing of the tensor will result in all batch dimensions being
# None for tensorflow graph mode. In eager mode, batch shapes will be
# preserved.
triangle_index = tf.cast(rasterized[..., 0], tf.int32) - 1
barycentric_coordinates = rasterized[..., 1:3]
barycentric_coordinates = tf.concat(
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_graphics/rendering/triangle_rasterizer.py
Expand Up @@ -135,7 +135,7 @@ def rasterize(vertices,
backend)
outputs = {"mask": mask, "triangle_indices": triangle_index}

batch_shape = triangle_index.shape[:-3]
batch_shape = triangle_index.shape[:-2]
batch_shape = [_dim_value(dim) for dim in batch_shape]

vertices = tf.gather(vertices, triangles, axis=-2)
Expand Down

0 comments on commit 0bf3f5f

Please sign in to comment.