Skip to content

Commit

Permalink
Adds typing information to the package math/interpolation.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 409519160
  • Loading branch information
G4G authored and Copybara-Service committed Nov 12, 2021
1 parent ffe56e0 commit f643ebe
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 45 deletions.
41 changes: 25 additions & 16 deletions tensorflow_graphics/math/interpolation/bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
from __future__ import print_function

import enum
import tensorflow as tf
from typing import Tuple, Union

import tensorflow as tf
from tensorflow_graphics.util import asserts
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


class Degree(enum.IntEnum):
Expand All @@ -39,19 +41,19 @@ class Degree(enum.IntEnum):
QUARTIC = 4


def _constant(position):
def _constant(position: tf.Tensor) -> tf.Tensor:
"""B-Spline basis function of degree 0 for positions in the range [0, 1]."""
# A piecewise constant spline is discontinuous at the knots.
return tf.expand_dims(tf.clip_by_value(1.0 + position, 1.0, 1.0), axis=-1)


def _linear(position):
def _linear(position: tf.Tensor) -> tf.Tensor:
"""B-Spline basis functions of degree 1 for positions in the range [0, 1]."""
# Piecewise linear splines are C0 smooth.
return tf.stack((1.0 - position, position), axis=-1)


def _quadratic(position):
def _quadratic(position: tf.Tensor) -> tf.Tensor:
"""B-Spline basis functions of degree 2 for positions in the range [0, 1]."""
# We pre-calculate the terms that are used multiple times.
pos_sq = tf.pow(position, 2.0)
Expand All @@ -62,7 +64,7 @@ def _quadratic(position):
axis=-1)


def _cubic(position):
def _cubic(position: tf.Tensor) -> tf.Tensor:
"""B-Spline basis functions of degree 3 for positions in the range [0, 1]."""
# We pre-calculate the terms that are used multiple times.
neg_pos = 1.0 - position
Expand All @@ -77,7 +79,7 @@ def _cubic(position):
axis=-1)


def _quartic(position):
def _quartic(position: tf.Tensor) -> tf.Tensor:
"""B-Spline basis functions of degree 4 for positions in the range [0, 1]."""
# We pre-calculate the terms that are used multiple times.
neg_pos = 1.0 - position
Expand All @@ -96,12 +98,14 @@ def _quartic(position):
axis=-1)


def knot_weights(positions,
num_knots,
degree,
cyclical,
sparse_mode=False,
name="bspline_knot_weights"):
def knot_weights(
positions: type_alias.TensorLike,
num_knots: type_alias.TensorLike,
degree: int,
cyclical: bool,
sparse_mode: bool = False,
name: str = "bspline_knot_weights"
) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]:
"""Function that converts cardinal B-spline positions to knot weights.
Note:
Expand Down Expand Up @@ -209,9 +213,10 @@ def knot_weights(positions,
return tf.reshape(weights, shape=shape_weights)


def interpolate_with_weights(knots,
weights,
name="bspline_interpolate_with_weights"):
def interpolate_with_weights(
knots: type_alias.TensorLike,
weights: type_alias.TensorLike,
name: str = "bspline_interpolate_with_weights") -> tf.Tensor:
"""Interpolates knots using knot weights.
Note:
Expand Down Expand Up @@ -241,7 +246,11 @@ def interpolate_with_weights(knots,
return tf.tensordot(weights, knots, (-1, -1))


def interpolate(knots, positions, degree, cyclical, name="bspline_interpolate"):
def interpolate(knots: type_alias.TensorLike,
positions: type_alias.TensorLike,
degree: int,
cyclical: bool,
name: str = "bspline_interpolate") -> tf.Tensor:
"""Applies B-spline interpolation to input control points (knots).
Note:
Expand Down
48 changes: 29 additions & 19 deletions tensorflow_graphics/math/interpolation/slerp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@
from __future__ import print_function

import enum
import tensorflow as tf
from typing import Optional, Tuple, Union

import tensorflow as tf
from tensorflow_graphics.math import vector
from tensorflow_graphics.util import asserts
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import safe_ops
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


class InterpolationType(enum.Enum):
Expand All @@ -54,7 +56,9 @@ class InterpolationType(enum.Enum):
QUATERNION = 1


def _safe_dot(vector1, vector2, eps):
def _safe_dot(vector1: type_alias.TensorLike,
vector2: type_alias.TensorLike,
eps: type_alias.Float) -> type_alias.Float:
"""Calculates dot product while ensuring it is in the range [-1, 1]."""
dot_product = vector.dot(vector1, vector2)
# Safely shrink to make sure machine precision does not cause the dot
Expand All @@ -63,12 +67,12 @@ def _safe_dot(vector1, vector2, eps):
vector=dot_product, minval=-1.0, maxval=1.0, open_bounds=False, eps=eps)


def interpolate(vector1,
vector2,
percent,
method=InterpolationType.QUATERNION,
eps=None,
name=None):
def interpolate(vector1: type_alias.TensorLike,
vector2: type_alias.TensorLike,
percent: type_alias.Float,
method: InterpolationType = InterpolationType.QUATERNION,
eps: Optional[type_alias.Float] = None,
name: Optional[str] = None) -> tf.Tensor:
"""Applies slerp to vectors or quaternions.
Args:
Expand Down Expand Up @@ -109,11 +113,12 @@ def interpolate(vector1,
return interpolate_with_weights(vector1, vector2, weight1, weight2)


def interpolate_with_weights(vector1,
vector2,
weight1,
weight2,
name="interpolate_with_weights"):
def interpolate_with_weights(
vector1: type_alias.TensorLike,
vector2: type_alias.TensorLike,
weight1: Union[type_alias.Float, type_alias.TensorLike],
weight2: Union[type_alias.Float, type_alias.TensorLike],
name: str = "interpolate_with_weights") -> tf.Tensor:
"""Interpolates vectors by taking their weighted sum.
Interpolation for all variants of slerp is a simple weighted sum over inputs.
Expand Down Expand Up @@ -141,11 +146,12 @@ def interpolate_with_weights(vector1,
return weight1 * vector1 + weight2 * vector2


def quaternion_weights(quaternion1,
quaternion2,
percent,
eps=None,
name="quaternion_weights"):
def quaternion_weights(
quaternion1: type_alias.TensorLike,
quaternion2: type_alias.TensorLike,
percent: Union[type_alias.Float, type_alias.TensorLike],
eps: Optional[type_alias.Float] = None,
name: str = "quaternion_weights") -> Tuple[tf.Tensor, tf.Tensor]:
"""Calculates slerp weights for two normalized quaternions.
Given a percent and two normalized quaternions, this function returns the
Expand Down Expand Up @@ -214,7 +220,11 @@ def quaternion_weights(quaternion1,
return scale1, scale2


def vector_weights(vector1, vector2, percent, eps=None, name="vector_weights"):
def vector_weights(vector1: type_alias.TensorLike,
vector2: type_alias.TensorLike,
percent: Union[type_alias.Float, type_alias.TensorLike],
eps: Optional[type_alias.Float] = None,
name: str = "vector_weights") -> Tuple[tf.Tensor, tf.Tensor]:
"""Spherical linear interpolation (slerp) between two unnormalized vectors.
This function applies geometric slerp to unnormalized vectors by first
Expand Down
5 changes: 4 additions & 1 deletion tensorflow_graphics/math/interpolation/trilinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@

from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def interpolate(grid_3d, sampling_points, name="trilinear_interpolate"):
def interpolate(grid_3d: type_alias.TensorLike,
sampling_points: type_alias.TensorLike,
name: str = "trilinear_interpolate") -> tf.Tensor:
"""Trilinear interpolation on a 3D regular grid.
Args:
Expand Down
21 changes: 12 additions & 9 deletions tensorflow_graphics/math/interpolation/weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import safe_ops
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def interpolate(points,
weights,
indices,
normalize=True,
allow_negative_weights=False,
name="weighted_interpolate"):
def interpolate(points: type_alias.TensorLike,
weights: type_alias.TensorLike,
indices: type_alias.TensorLike,
normalize: bool = True,
allow_negative_weights: bool = False,
name: str = "weighted_interpolate") -> type_alias.TensorLike:
"""Weighted interpolation for M-D point sets.
Given an M-D point set, this function can be used to generate a new point set
Expand Down Expand Up @@ -93,9 +94,11 @@ def interpolate(points,
point_lists, tf.expand_dims(weights, axis=-1), axis=-2, keepdims=False)


def get_barycentric_coordinates(triangle_vertices,
pixels,
name="rasterizer_get_barycentric_coordinates"):
def get_barycentric_coordinates(
triangle_vertices: type_alias.TensorLike,
pixels: type_alias.TensorLike,
name: str = "rasterizer_get_barycentric_coordinates"
) -> type_alias.TensorLike:
"""Computes the barycentric coordinates of pixels for 2D triangles.
Barycentric coordinates of a point `p` are represented as coefficients
Expand Down

0 comments on commit f643ebe

Please sign in to comment.