Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions tensorflow_graphics/rendering/kernels/rasterization_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def rasterize(vertices: tf.Tensor,
triangles: tf.Tensor,
view_projection_matrices: tf.Tensor,
image_size: Tuple[int, int],
num_layers=1,
face_culling_mode=FaceCullingMode.NONE,
enable_cull_face: bool,
num_layers: int,
name=None):
"""Rasterizes the scene.

Expand All @@ -83,9 +83,10 @@ def rasterize(vertices: tf.Tensor,
batches of view projection matrices.
image_size: An tuple of integers (width, height) containing the dimensions
in pixels of the rasterized image.
num_layers: Number of depth layers to render. Analytic barycentric gradients
are only available if num_layers is 1.
face_culling_mode: one of FaceCullingMode. Defaults to NONE.
enable_cull_face: A boolean, which will enable BACK face culling when True
and no face culling when False.
num_layers: Number of depth layers to render. Output tensors shape depends
on whether num_layers=1 or not.
name: A name for this op. Defaults to 'rasterization_backend_cpu_rasterize'.

Returns:
Expand Down Expand Up @@ -139,6 +140,7 @@ def rasterize(vertices: tf.Tensor,
per_image_triangle_ids = []
per_image_masks = []
image_width, image_height = image_size
face_culling_mode = FaceCullingMode.BACK if enable_cull_face else FaceCullingMode.NONE
for batch_index in range(batch_size):
clip_vertices_slice = vertices[batch_index, ...]

Expand Down
12 changes: 11 additions & 1 deletion tensorflow_graphics/rendering/opengl/rasterization_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def rasterize(vertices,
triangles,
view_projection_matrices,
image_size,
enable_cull_face,
num_layers,
name=None):
"""Rasterizes the scene.

Expand All @@ -121,6 +123,10 @@ def rasterize(vertices,
batches of view projection matrices
image_size: An tuple of integers (width, height) containing the dimensions
in pixels of the rasterized image.
enable_cull_face: A boolean, which will enable BACK face culling when True
and no face culling when False. Default is True.
num_layers: Number of depth layers to render. Not supported by current
backend yet, but exists for interface compatibility.
name: A name for this op. Defaults to 'rasterization_backend_rasterize'.

Returns:
Expand All @@ -138,6 +144,10 @@ def rasterize(vertices,
"""
with tf.compat.v1.name_scope(name, "rasterization_backend_rasterize",
(vertices, triangles, view_projection_matrices)):

if num_layers != 1:
raise ValueError("OpenGL rasterizer only supports single layer.")

vertices = tf.convert_to_tensor(value=vertices)
triangles = tf.convert_to_tensor(value=triangles)
view_projection_matrices = tf.convert_to_tensor(
Expand Down Expand Up @@ -173,7 +183,7 @@ def rasterize(vertices,
rasterized = render_ops.rasterize(
num_points=geometry.shape[-3],
alpha_clear=0.0,
enable_cull_face=True,
enable_cull_face=enable_cull_face,
variable_names=("view_projection_matrix", "triangular_mesh"),
variable_kinds=("mat", "buffer"),
variable_values=(view_projection_matrices,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
_IMAGE_HEIGHT = 5
_IMAGE_WIDTH = 7
_TRIANGLE_SIZE = 2.0
_ENABLE_CULL_FACE = True
_NUM_LAYERS = 1


def _generate_vertices_and_view_matrices():
Expand Down Expand Up @@ -56,7 +58,8 @@ def _generate_vertices_and_view_matrices():
def _proxy_rasterize(vertices, triangles, view_projection_matrices):
return rasterization_backend.rasterize(vertices, triangles,
view_projection_matrices,
(_IMAGE_WIDTH, _IMAGE_HEIGHT))
(_IMAGE_WIDTH, _IMAGE_HEIGHT),
_ENABLE_CULL_FACE, _NUM_LAYERS)


class RasterizationBackendTest(test_case.TestCase):
Expand Down Expand Up @@ -97,9 +100,7 @@ def test_rasterize_batch_vertices_only(self):
view_projection_matrix = [
view_projection_matrix[0], view_projection_matrix[0]
]
predicted_fb = rasterization_backend.rasterize(
vertices, triangles, view_projection_matrix,
(_IMAGE_WIDTH, _IMAGE_HEIGHT))
predicted_fb = _proxy_rasterize(vertices, triangles, view_projection_matrix)
mask = predicted_fb.foreground_mask
self.assertAllEqual(mask[0, ...], tf.ones_like(mask[0, ...]))

Expand All @@ -111,9 +112,7 @@ def test_rasterize_batch_view_only(self):
triangles = np.array(((0, 1, 2),), np.int32)
vertices, view_projection_matrix = _generate_vertices_and_view_matrices()
vertices = np.array([vertices[0], vertices[0]], dtype=np.float32)
predicted_fb = rasterization_backend.rasterize(
vertices, triangles, view_projection_matrix,
(_IMAGE_WIDTH, _IMAGE_HEIGHT))
predicted_fb = _proxy_rasterize(vertices, triangles, view_projection_matrix)
self.assertAllEqual(predicted_fb.foreground_mask[0, ...],
tf.ones_like(predicted_fb.foreground_mask[0, ...]))
self.assertAllEqual(predicted_fb.foreground_mask[1, ...],
Expand Down Expand Up @@ -144,9 +143,7 @@ def test_rasterize_preset(self):
dtype=np.float32)
triangles = np.array(((1, 2, 0), (0, 2, 3)), np.int32)

predicted_fb = rasterization_backend.rasterize(
vertices, triangles, view_projection_matrix,
(_IMAGE_WIDTH, _IMAGE_HEIGHT))
predicted_fb = _proxy_rasterize(vertices, triangles, view_projection_matrix)

with self.subTest(name="triangle_index"):
groundtruth_triangle_index = np.zeros((1, _IMAGE_HEIGHT, _IMAGE_WIDTH, 1),
Expand Down
9 changes: 8 additions & 1 deletion tensorflow_graphics/rendering/rasterization_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def rasterize(vertices,
triangles,
view_projection_matrices,
image_size,
enable_cull_face=True,
num_layers=1,
backend=RasterizationBackends.OPENGL):
"""Rasterizes the scene.

Expand All @@ -44,6 +46,11 @@ def rasterize(vertices,
batches of view projection matrices.
image_size: A tuple of integers (width, height) containing the dimensions
in pixels of the rasterized image.
enable_cull_face: A boolean, which will enable BACK face culling when True
and no face culling when False. Default is True.
num_layers: Number of depth layers to render. Output tensors shape depends
on whether num_layers=1 or not. Supported by CPU rasterizer only and does
nothing for OpenGL backend.
backend: An enum containing the backend method to use for rasterization.
Supported options are defined in the RasterizationBackends enum.

Expand Down Expand Up @@ -72,7 +79,7 @@ def rasterize(vertices,
raise KeyError("Backend is not supported: %s." % backend)

return backend_module.rasterize(vertices, triangles, view_projection_matrices,
image_size)
image_size, enable_cull_face, num_layers)


# API contains all public functions and classes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class RasterizationBackendTestBase(test_case.TestCase):
def setUp(self):
super().setUp()
self._backend = rasterization_backend.RasterizationBackends.OPENGL
self._num_layers = 1
self._enable_cull_face = True

def _create_placeholders(self, shapes, dtypes):
if tf.executing_eagerly():
Expand All @@ -41,28 +43,40 @@ def _create_placeholders(self, shapes, dtypes):
@parameterized.parameters((
((2, 7, 3), (5, 3), (2, 4, 4)),
(tf.float32, tf.int32, tf.float32),
(True,),
), (
((2, 7, 3), (5, 3), (2, 4, 4)),
(tf.float32, tf.int32, tf.float32),
(False,),
))
def test_rasterizer_rasterize_exception_not_raised(self, shapes, dtypes):
def test_rasterizer_rasterize_exception_not_raised(self, shapes, dtypes,
enable_cull_face):
"""Tests that supported backends do not raise exceptions."""
placeholders = self._create_placeholders(shapes, dtypes)
try:
rasterization_backend.rasterize(placeholders[0], placeholders[1],
placeholders[2], (600, 800),
enable_cull_face, self._num_layers,
self._backend)
except Exception as e: # pylint: disable=broad-except
self.fail('Exception raised: %s' % str(e))

@parameterized.parameters((
((1, 7, 3), (5, 3), (1, 4, 4)),
(tf.float32, tf.int32, tf.float32),
(True,),
), (
((1, 7, 3), (5, 3), (1, 4, 4)),
(tf.float32, tf.int32, tf.float32),
(False,),
))
def test_rasterizer_return_correct_batch_shapes(self, shapes, dtypes):
def test_rasterizer_return_correct_batch_shapes(self, shapes, dtypes,
enable_cull_face):
"""Tests that supported backends return correct shape."""
placeholders = self._create_placeholders(shapes, dtypes)
frame_buffer = rasterization_backend.rasterize(placeholders[0],
placeholders[1],
placeholders[2], (600, 800),
self._backend)
frame_buffer = rasterization_backend.rasterize(
placeholders[0], placeholders[1], placeholders[2], (600, 800),
enable_cull_face, self._num_layers, self._backend)
batch_size = shapes[0][0]
self.assertEqual([batch_size],
frame_buffer.triangle_id.get_shape().as_list()[:-3])
Expand All @@ -80,17 +94,19 @@ def test_rasterizer_rasterize_exception_raised(self, shapes, dtypes, backend):
placeholders = self._create_placeholders(shapes, dtypes)
with self.assertRaisesRegexp(KeyError, 'Backend is not supported'):
rasterization_backend.rasterize(placeholders[0], placeholders[1],
placeholders[2], (600, 800), backend)
placeholders[2], (600, 800),
self._enable_cull_face, self._num_layers,
backend)

def test_rasterizer_all_vertices_visible(self):
"""Renders simple triangle and asserts that it is fully visible."""
vertices = tf.convert_to_tensor([[[0, 0, 0], [10, 10, 0], [0, 10, 0]]],
dtype=tf.float32)
triangles = tf.convert_to_tensor([[0, 1, 2]], dtype=tf.int32)
view_projection_matrix = tf.expand_dims(tf.eye(4), axis=0)
frame_buffer = rasterization_backend.rasterize(vertices, triangles,
view_projection_matrix,
(100, 100), self._backend)
frame_buffer = rasterization_backend.rasterize(
vertices, triangles, view_projection_matrix, (100, 100),
self._enable_cull_face, self._num_layers, self._backend)
self.assertAllEqual(frame_buffer.triangle_id.shape[:-1],
frame_buffer.vertex_ids.shape[:-1])
# Assert that triangle is visible.
Expand Down
9 changes: 6 additions & 3 deletions tensorflow_graphics/rendering/triangle_rasterizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,12 @@ def rasterize(vertices,
model_to_eye_matrix)

vertices = _merge_batch_dims(vertices, last_axis=-2)
rasterized = rasterization_backend.rasterize(vertices, triangles,
view_projection_matrix,
image_size_backend, backend)
rasterized = rasterization_backend.rasterize(
vertices,
triangles,
view_projection_matrix,
image_size_backend,
backend=backend)
outputs = {
"mask":
_restore_batch_dims(rasterized.foreground_mask, input_batch_shape),
Expand Down