diff --git a/tensorflow_graphics/image/pyramid.py b/tensorflow_graphics/image/pyramid.py index 86bf08a54..ce088fd0c 100644 --- a/tensorflow_graphics/image/pyramid.py +++ b/tensorflow_graphics/image/pyramid.py @@ -21,15 +21,19 @@ from __future__ import division from __future__ import print_function +from typing import Callable, List, Optional, Tuple + import numpy as np from six.moves import range import tensorflow as tf from tensorflow_graphics.util import export_api from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias -def _downsample(image, kernel): +def _downsample(image: type_alias.TensorLike, + kernel: type_alias.TensorLike) -> tf.Tensor: """Downsamples the image using a convolution with stride 2. Args: @@ -48,7 +52,8 @@ def _downsample(image, kernel): input=image, filters=kernel, strides=[1, 2, 2, 1], padding="SAME") -def _binomial_kernel(num_channels, dtype=tf.float32): +def _binomial_kernel(num_channels: int, + dtype: tf.DType = tf.float32) -> tf.Tensor: """Creates a 5x5 binomial kernel. Args: @@ -65,7 +70,9 @@ def _binomial_kernel(num_channels, dtype=tf.float32): return tf.constant(kernel, dtype=dtype) * tf.eye(num_channels, dtype=dtype) -def _build_pyramid(image, sampler, num_levels): +def _build_pyramid(image: type_alias.TensorLike, sampler: Callable[ + [type_alias.TensorLike, type_alias.TensorLike], type_alias.TensorLike], + num_levels: int) -> List[tf.Tensor]: """Creates the different levels of the pyramid. Args: @@ -87,7 +94,8 @@ def _build_pyramid(image, sampler, num_levels): return levels -def _split(image, kernel): +def _split(image: type_alias.TensorLike, + kernel: type_alias.TensorLike) -> Tuple[tf.Tensor, tf.Tensor]: """Splits the image into high and low frequencies. This is achieved by smoothing the input image and substracting the smoothed @@ -112,7 +120,10 @@ def _split(image, kernel): return high, low -def _upsample(image, kernel, output_shape=None): +def _upsample(image: type_alias.TensorLike, + kernel: type_alias.TensorLike, + output_shape: Optional[type_alias.TensorLike] = None + ) -> tf.Tensor: """Upsamples the image using a transposed convolution with stride 2. Args: @@ -139,7 +150,9 @@ def _upsample(image, kernel, output_shape=None): padding="SAME") -def downsample(image, num_levels, name="pyramid_downsample"): +def downsample(image: type_alias.TensorLike, + num_levels: int, + name: str = "pyramid_downsample") -> List[tf.Tensor]: """Generates the different levels of the pyramid (downsampling). Args: @@ -165,7 +178,8 @@ def downsample(image, num_levels, name="pyramid_downsample"): return _build_pyramid(image, _downsample, num_levels) -def merge(levels, name="pyramid_merge"): +def merge(levels: List[type_alias.TensorLike], + name: str = "pyramid_merge") -> tf.Tensor: """Merges the different levels of the pyramid back to an image. Args: @@ -196,7 +210,9 @@ def merge(levels, name="pyramid_merge"): return image -def split(image, num_levels, name="pyramid_split"): +def split(image: type_alias.TensorLike, + num_levels: int, + name: str = "pyramid_split") -> List[tf.Tensor]: """Generates the different levels of the pyramid. Args: @@ -228,7 +244,9 @@ def split(image, num_levels, name="pyramid_split"): return levels -def upsample(image, num_levels, name="pyramid_upsample"): +def upsample(image: type_alias.TensorLike, + num_levels: int, + name: str = "pyramid_upsample") -> List[tf.Tensor]: """Generates the different levels of the pyramid (upsampling). Args: