diff --git a/lab/tensorflow/shaping.py b/lab/tensorflow/shaping.py index da50e5a..39876f5 100644 --- a/lab/tensorflow/shaping.py +++ b/lab/tensorflow/shaping.py @@ -1,9 +1,11 @@ -import tensorflow as tf from plum import Union -from . import dispatch, B, Numeric +import tensorflow as tf + from ..shape import unwrap_dimension -from ..types import Int, TFNumeric, NPNumeric +from ..types import Int, NPNumeric, TFNumeric +from ..util import resolve_axis +from . import B, Numeric, dispatch __all__ = [] @@ -81,7 +83,8 @@ def take(a: TFNumeric, indices_or_mask, axis: Int = 0): # Perform taking operation. if is_mask: - result = tf.boolean_mask(a, indices_or_mask, axis=axis) + # `tf.boolean_mask` isn't happy with negative axes. + result = tf.boolean_mask(a, indices_or_mask, axis=resolve_axis(a, axis)) else: result = tf.gather(a, indices_or_mask, axis=axis) diff --git a/tests/test_shaping.py b/tests/test_shaping.py index c62527a..3f56992 100644 --- a/tests/test_shaping.py +++ b/tests/test_shaping.py @@ -1,20 +1,19 @@ +import lab as B import numpy as np import pytest import tensorflow as tf -from plum import NotFoundLookupError - -import lab as B from lab.shape import Shape +from plum import NotFoundLookupError # noinspection PyUnresolvedReferences from .util import ( - check_function, - Tensor, - Matrix, - Value, List, + Matrix, + Tensor, Tuple, + Value, approx, + check_function, check_lazy_shapes, ) @@ -293,13 +292,13 @@ def test_take_consistency(check_lazy_shapes): check_function( B.take, (Matrix(3, 3), Value([0, 1], [True, True, False])), - {"axis": Value(0, 1)}, + {"axis": Value(0, 1, -1)}, ) def test_take_consistency_order(check_lazy_shapes): # Check order of indices. - check_function(B.take, (Matrix(3, 4), Value([2, 1])), {"axis": Value(0, 1)}) + check_function(B.take, (Matrix(3, 4), Value([2, 1])), {"axis": Value(0, 1, -1)}) def test_take_indices_rank(check_lazy_shapes): @@ -315,7 +314,7 @@ def test_take_indices_rank(check_lazy_shapes): ) def test_take_list_tuple(check_lazy_shapes, indices_or_mask): check_function( - B.take, (Matrix(3, 3, 3), Value(indices_or_mask)), {"axis": Value(0, 1, 2)} + B.take, (Matrix(3, 3, 3), Value(indices_or_mask)), {"axis": Value(0, 1, 2, -1)} )