Skip to content

Commit

Permalink
Allows for using map_fn instead of vectorized_map.
Browse files Browse the repository at this point in the history
map_fn uses significantly less memory.

PiperOrigin-RevId: 401519292
  • Loading branch information
chaene authored and Copybara-Service committed Oct 15, 2021
1 parent 6c53748 commit 8935d36
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
21 changes: 16 additions & 5 deletions tensorflow_graphics/rendering/barycentrics.py
Expand Up @@ -23,8 +23,10 @@


def differentiable_barycentrics(
framebuffer: fb.Framebuffer, clip_space_vertices: type_alias.TensorLike,
triangles: type_alias.TensorLike) -> fb.Framebuffer:
framebuffer: fb.Framebuffer,
clip_space_vertices: type_alias.TensorLike,
triangles: type_alias.TensorLike,
use_vectorized_map: bool = True) -> fb.Framebuffer:
"""Computes differentiable barycentric coordinates from a Framebuffer.
The barycentric coordinates will be differentiable w.r.t. the input vertices.
Expand All @@ -39,6 +41,7 @@ def differentiable_barycentrics(
triangles: a 2-D int32 tensor with shape [triangle_count, 3] or a 3-D tensor
with shape [batch, triangle_count, 3] containing per-triangle vertex
indices in counter-clockwise order.
use_vectorized_map: If true uses vectorized_map otherwise uses map_fn.
Returns:
a copy of `framebuffer`, but the differentiable barycentric coordinates will
Expand Down Expand Up @@ -95,9 +98,17 @@ def compute_barycentrics_fn(
barycentric_coords = tf.transpose(barycentric_coords, perm=[1, 2, 3, 0])
return barycentric_coords

per_image_barycentrics = tf.vectorized_map(
compute_barycentrics_fn,
(clip_space_vertices, triangles, framebuffer.triangle_id))
if use_vectorized_map:
per_image_barycentrics = tf.vectorized_map(
compute_barycentrics_fn,
(clip_space_vertices, triangles, framebuffer.triangle_id))
else:
num_meshes = tf.shape(clip_space_vertices)[0]
triangles_repeated = tf.repeat(triangles, repeats=num_meshes, axis=0)
per_image_barycentrics = tf.map_fn(
compute_barycentrics_fn,
(clip_space_vertices, triangles_repeated, framebuffer.triangle_id),
fn_output_signature=tf.TensorSpec(shape=(1, None, None, 3)))

barycentric_coords = tf.stack(per_image_barycentrics, axis=0)
# After stacking barycentrics will have layers dimension no matter what.
Expand Down
5 changes: 4 additions & 1 deletion tensorflow_graphics/rendering/triangle_rasterizer.py
Expand Up @@ -41,6 +41,7 @@ def rasterize(
view_projection_matrix: type_alias.TensorLike,
image_size: Tuple[int, int],
enable_cull_face: bool = True,
use_vectorized_map: bool = True,
backend: enum.Enum = rasterization_backend.RasterizationBackends.OPENGL,
name: str = "triangle_rasterizer_rasterize"
) -> Dict[str, type_alias.TensorLike]:
Expand All @@ -64,6 +65,8 @@ def rasterize(
the rasterized image.
enable_cull_face: Enables BACK face culling when True, and no culling when
False.
use_vectorized_map: If true uses vectorized_map for barycentrics
computations otherwise uses map_fn.
backend: A rasterization_backend.RasterizationBackends enum containing the
backend method to use for rasterization.
name: A name for this op. Defaults to "triangle_rasterizer_rasterize".
Expand Down Expand Up @@ -124,7 +127,7 @@ def rasterize(
clip_space_vertices = utils.transform_homogeneous(view_projection_matrix,
vertices)
rasterized = barycentrics_module.differentiable_barycentrics(
rasterized, clip_space_vertices, triangles)
rasterized, clip_space_vertices, triangles, use_vectorized_map)
barycentrics = rasterized.barycentrics.value
outputs["barycentrics"] = utils.restore_batch_dims(
rasterized.foreground_mask * barycentrics, input_batch_shape)
Expand Down

0 comments on commit 8935d36

Please sign in to comment.