Skip to content

Commit

Permalink
Adds typing information to geometry/graph_pooling.py.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 380156388
  • Loading branch information
G4G authored and Copybara-Service committed Jun 18, 2021
1 parent 4b585f5 commit d4f2f6a
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions tensorflow_graphics/geometry/convolution/graph_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,20 @@
from __future__ import division
from __future__ import print_function

from typing import Callable

import tensorflow as tf

from tensorflow_graphics.geometry.convolution import utils
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import type_alias


def pool(data, pool_map, sizes, algorithm='max', name='graph_pooling_pool'):
def pool(data: type_alias.TensorLike,
pool_map: type_alias.TensorLike,
sizes: type_alias.TensorLike,
algorithm: str = 'max',
name: str = 'graph_pooling_pool') -> tf.Tensor:
# pyformat: disable
"""Implements graph pooling.
Expand Down Expand Up @@ -109,7 +116,10 @@ def pool(data, pool_map, sizes, algorithm='max', name='graph_pooling_pool'):
return pooled


def unpool(data, pool_map, sizes, name='graph_pooling_unpool'):
def unpool(data: type_alias.TensorLike,
pool_map: type_alias.TensorLike,
sizes: type_alias.TensorLike,
name: str = 'graph_pooling_unpool') -> tf.Tensor:
# pyformat: disable
r"""Graph upsampling by inverting the pooling map.
Expand Down Expand Up @@ -176,12 +186,13 @@ def unpool(data, pool_map, sizes, name='graph_pooling_unpool'):


def upsample_transposed_convolution(
data,
pool_map,
sizes,
kernel_size,
transposed_convolution_op,
name='graph_pooling_upsample_transposed_convolution'):
data: type_alias.TensorLike,
pool_map: type_alias.TensorLike,
sizes: type_alias.TensorLike,
kernel_size: int,
transposed_convolution_op: Callable[[type_alias.TensorLike],
type_alias.TensorLike],
name: str = 'graph_pooling_upsample_transposed_convolution') -> tf.Tensor:
# pyformat: disable
r"""Graph upsampling by transposed convolution.
Expand Down

0 comments on commit d4f2f6a

Please sign in to comment.