Skip to content

Commit

Permalink
Removed workaround for not being able to use tf.math.mod as it works …
Browse files Browse the repository at this point in the history
…now.

PiperOrigin-RevId: 258793478
  • Loading branch information
cem-keskin authored and Copybara-Service committed Jul 18, 2019
1 parent 5da4eb2 commit 98cda05
Showing 1 changed file with 5 additions and 10 deletions.
15 changes: 5 additions & 10 deletions tensorflow_graphics/math/interpolation/bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@
from tensorflow_graphics.util import shape


# TODO(b/131510643): remove when TF API is
def _mod(x, y):
return x - tf.cast((x // y) * y, dtype=x.dtype)


class Degree(enum.IntEnum):
"""Defines valid degrees for B-spline interpolation."""
CONSTANT = 0
Expand Down Expand Up @@ -125,7 +120,7 @@ def knot_weights(positions,
knots with nonzero weights. If set to True, the function returns the
weights of only the `degree` + 1 knots that are non-zero, as well as the
indices of the knots.
name: A name for this op. Defaults to "bsplines_knot_weights".
name: A name for this op. Defaults to "bspline_knot_weights".
Returns:
A tensor with dense weights for each control point, with the shape
Expand All @@ -138,7 +133,7 @@ def knot_weights(positions,
ValueError: If degree is greater than 4 or num_knots - 1, or less than 0.
InvalidArgumentError: If positions are not in the right range.
"""
with tf.compat.v1.name_scope(name, "bsplines_knot_weights", [positions]):
with tf.compat.v1.name_scope(name, "bspline_knot_weights", [positions]):
positions = tf.convert_to_tensor(value=positions)

if degree > 4 or degree < 0:
Expand Down Expand Up @@ -198,7 +193,7 @@ def knot_weights(positions,
shape=(-1,))
ind_col = tf.reshape(ind_col, shape=(-1,)) + tiled_shifts
if cyclical:
ind_col = _mod(ind_col, num_knots)
ind_col = tf.math.mod(ind_col, num_knots)
indices = tf.stack((tf.reshape(ind_row, shape=(-1,)), ind_col), axis=-1)
shape_indices = tf.concat((tf.reshape(
num_positions, shape=(1,)), tf.constant(
Expand All @@ -225,7 +220,7 @@ def interpolate_with_weights(knots, weights, name=None):
`C` is the number of knots.
weights: A tensor with shape `[A1, ..., An, C]` containing dense weights for
the knots, where `C` is the number of knots.
name: A name for this op. Defaults to "bsplines_interpolate_with_weights".
name: A name for this op. Defaults to "bspline_interpolate_with_weights".
Returns:
A tensor with shape `[A1, ..., An, B1, ..., Bk]`, which is the result of
Expand All @@ -234,7 +229,7 @@ def interpolate_with_weights(knots, weights, name=None):
Raises:
ValueError: If the last dimension of knots and weights is not equal.
"""
with tf.compat.v1.name_scope(name, "bsplines_interpolate_with_weights",
with tf.compat.v1.name_scope(name, "bspline_interpolate_with_weights",
[knots, weights]):
knots = tf.convert_to_tensor(value=knots)
weights = tf.convert_to_tensor(value=weights)
Expand Down

0 comments on commit 98cda05

Please sign in to comment.