Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions tensorflow_graphics/image/pyramid.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,19 @@
from __future__ import division
from __future__ import print_function

from typing import Callable, List, Optional, Tuple

import numpy as np
from six.moves import range
import tensorflow as tf

from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def _downsample(image, kernel):
def _downsample(image: type_alias.TensorLike,
kernel: type_alias.TensorLike) -> tf.Tensor:
"""Downsamples the image using a convolution with stride 2.

Args:
Expand All @@ -48,7 +52,8 @@ def _downsample(image, kernel):
input=image, filters=kernel, strides=[1, 2, 2, 1], padding="SAME")


def _binomial_kernel(num_channels, dtype=tf.float32):
def _binomial_kernel(num_channels: int,
dtype: tf.DType = tf.float32) -> tf.Tensor:
"""Creates a 5x5 binomial kernel.

Args:
Expand All @@ -65,7 +70,9 @@ def _binomial_kernel(num_channels, dtype=tf.float32):
return tf.constant(kernel, dtype=dtype) * tf.eye(num_channels, dtype=dtype)


def _build_pyramid(image, sampler, num_levels):
def _build_pyramid(image: type_alias.TensorLike, sampler: Callable[
[type_alias.TensorLike, type_alias.TensorLike], type_alias.TensorLike],
num_levels: int) -> List[tf.Tensor]:
"""Creates the different levels of the pyramid.

Args:
Expand All @@ -87,7 +94,8 @@ def _build_pyramid(image, sampler, num_levels):
return levels


def _split(image, kernel):
def _split(image: type_alias.TensorLike,
kernel: type_alias.TensorLike) -> Tuple[tf.Tensor, tf.Tensor]:
"""Splits the image into high and low frequencies.

This is achieved by smoothing the input image and substracting the smoothed
Expand All @@ -112,7 +120,10 @@ def _split(image, kernel):
return high, low


def _upsample(image, kernel, output_shape=None):
def _upsample(image: type_alias.TensorLike,
kernel: type_alias.TensorLike,
output_shape: Optional[type_alias.TensorLike] = None
) -> tf.Tensor:
"""Upsamples the image using a transposed convolution with stride 2.

Args:
Expand All @@ -139,7 +150,9 @@ def _upsample(image, kernel, output_shape=None):
padding="SAME")


def downsample(image, num_levels, name="pyramid_downsample"):
def downsample(image: type_alias.TensorLike,
num_levels: int,
name: str = "pyramid_downsample") -> List[tf.Tensor]:
"""Generates the different levels of the pyramid (downsampling).

Args:
Expand All @@ -165,7 +178,8 @@ def downsample(image, num_levels, name="pyramid_downsample"):
return _build_pyramid(image, _downsample, num_levels)


def merge(levels, name="pyramid_merge"):
def merge(levels: List[type_alias.TensorLike],
name: str = "pyramid_merge") -> tf.Tensor:
"""Merges the different levels of the pyramid back to an image.

Args:
Expand Down Expand Up @@ -196,7 +210,9 @@ def merge(levels, name="pyramid_merge"):
return image


def split(image, num_levels, name="pyramid_split"):
def split(image: type_alias.TensorLike,
num_levels: int,
name: str = "pyramid_split") -> List[tf.Tensor]:
"""Generates the different levels of the pyramid.

Args:
Expand Down Expand Up @@ -228,7 +244,9 @@ def split(image, num_levels, name="pyramid_split"):
return levels


def upsample(image, num_levels, name="pyramid_upsample"):
def upsample(image: type_alias.TensorLike,
num_levels: int,
name: str = "pyramid_upsample") -> List[tf.Tensor]:
"""Generates the different levels of the pyramid (upsampling).

Args:
Expand Down