Skip to content

Commit

Permalink
Migrate tensorflow_graphics/nn to Python 3.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 361949734
  • Loading branch information
tensorflower-gardener authored and Copybara-Service committed Mar 10, 2021
1 parent 5caaccb commit e110a12
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
11 changes: 9 additions & 2 deletions tensorflow_graphics/nn/layer/tests/graph_convolution_test.py
Expand Up @@ -13,8 +13,14 @@
# limitations under the License.
"""Tests for the graph convolution layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl.testing import parameterized
import numpy as np
from six.moves import range
from six.moves import zip
import tensorflow as tf

import tensorflow_graphics.nn.layer.graph_convolution as gc_layer
Expand Down Expand Up @@ -170,7 +176,7 @@ def test_feature_steered_convolution_layer_training(self):
for _ in range(num_training_iterations):
grads = tape.gradient(loss, trainable_variables)
tf.compat.v1.train.GradientDescentOptimizer(1e-4).apply_gradients(
zip(grads, trainable_variables))
list(zip(grads, trainable_variables)))
else:
output = gc_layer.feature_steered_convolution_layer(
data=data,
Expand Down Expand Up @@ -288,7 +294,8 @@ def test_dynamic_graph_convolution_keras_layer_training(self, reduction):
for _ in range(num_training_iterations):
grads = tape.gradient(loss, trainable_variables)
tf.compat.v1.train.GradientDescentOptimizer(1e-4).apply_gradients(
zip(grads, trainable_variables))
list(zip(grads, trainable_variables)))


if __name__ == "__main__":
test_case.main()
7 changes: 4 additions & 3 deletions tensorflow_graphics/nn/metric/intersection_over_union.py
Expand Up @@ -17,6 +17,7 @@
from __future__ import division
from __future__ import print_function

from six.moves import range
import tensorflow as tf

from tensorflow_graphics.util import asserts
Expand Down Expand Up @@ -66,12 +67,12 @@ def evaluate(ground_truth_labels,
predicted_labels = asserts.assert_binary(predicted_labels)

sum_ground_truth = tf.math.reduce_sum(
input_tensor=ground_truth_labels, axis=range(-grid_size, 0))
input_tensor=ground_truth_labels, axis=list(range(-grid_size, 0)))
sum_predictions = tf.math.reduce_sum(
input_tensor=predicted_labels, axis=range(-grid_size, 0))
input_tensor=predicted_labels, axis=list(range(-grid_size, 0)))
intersection = tf.math.reduce_sum(
input_tensor=ground_truth_labels * predicted_labels,
axis=range(-grid_size, 0))
axis=list(range(-grid_size, 0)))
union = sum_ground_truth + sum_predictions - intersection

return tf.where(
Expand Down

0 comments on commit e110a12

Please sign in to comment.