Skip to content

Commit

Permalink
Prevent negative indices for tf.boolean_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Oct 3, 2021
1 parent e622689 commit 2b854d8
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
11 changes: 7 additions & 4 deletions lab/tensorflow/shaping.py
@@ -1,9 +1,11 @@
import tensorflow as tf
from plum import Union

from . import dispatch, B, Numeric
import tensorflow as tf

from ..shape import unwrap_dimension
from ..types import Int, TFNumeric, NPNumeric
from ..types import Int, NPNumeric, TFNumeric
from ..util import resolve_axis
from . import B, Numeric, dispatch

__all__ = []

Expand Down Expand Up @@ -81,7 +83,8 @@ def take(a: TFNumeric, indices_or_mask, axis: Int = 0):

# Perform taking operation.
if is_mask:
result = tf.boolean_mask(a, indices_or_mask, axis=axis)
# `tf.boolean_mask` isn't happy with negative axes.
result = tf.boolean_mask(a, indices_or_mask, axis=resolve_axis(a, axis))
else:
result = tf.gather(a, indices_or_mask, axis=axis)

Expand Down
19 changes: 9 additions & 10 deletions tests/test_shaping.py
@@ -1,20 +1,19 @@
import lab as B
import numpy as np
import pytest
import tensorflow as tf
from plum import NotFoundLookupError

import lab as B
from lab.shape import Shape
from plum import NotFoundLookupError

# noinspection PyUnresolvedReferences
from .util import (
check_function,
Tensor,
Matrix,
Value,
List,
Matrix,
Tensor,
Tuple,
Value,
approx,
check_function,
check_lazy_shapes,
)

Expand Down Expand Up @@ -293,13 +292,13 @@ def test_take_consistency(check_lazy_shapes):
check_function(
B.take,
(Matrix(3, 3), Value([0, 1], [True, True, False])),
{"axis": Value(0, 1)},
{"axis": Value(0, 1, -1)},
)


def test_take_consistency_order(check_lazy_shapes):
# Check order of indices.
check_function(B.take, (Matrix(3, 4), Value([2, 1])), {"axis": Value(0, 1)})
check_function(B.take, (Matrix(3, 4), Value([2, 1])), {"axis": Value(0, 1, -1)})


def test_take_indices_rank(check_lazy_shapes):
Expand All @@ -315,7 +314,7 @@ def test_take_indices_rank(check_lazy_shapes):
)
def test_take_list_tuple(check_lazy_shapes, indices_or_mask):
check_function(
B.take, (Matrix(3, 3, 3), Value(indices_or_mask)), {"axis": Value(0, 1, 2)}
B.take, (Matrix(3, 3, 3), Value(indices_or_mask)), {"axis": Value(0, 1, 2, -1)}
)


Expand Down

0 comments on commit 2b854d8

Please sign in to comment.