Skip to content

Commit

Permalink
Support conversion of structured objects to NP
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Nov 1, 2019
1 parent 06312f8 commit fec660f
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 7 deletions.
32 changes: 26 additions & 6 deletions lab/generic.py
Expand Up @@ -831,13 +831,13 @@ def argsort(a, axis=-1, descending=False):
"""


add_conversion_method(Number, NPNumeric, np.array)
add_conversion_method(AGNumeric, NPNumeric,
lambda x: convert(x._value, NPNumeric))
add_conversion_method(TFNumeric, NPNumeric, lambda x: x.numpy())
add_conversion_method(TorchNumeric, NPNumeric, lambda x: x.detach().numpy())
NPOrNum = {NPNumeric, Number} #: Type NumPy numeric or number.
add_conversion_method(AGNumeric, NPOrNum, lambda x: x._value)
add_conversion_method(TFNumeric, NPOrNum, lambda x: x.numpy())
add_conversion_method(TorchNumeric, NPOrNum, lambda x: x.detach().numpy())


@dispatch(object)
def to_numpy(a):
"""Convert an object to NumPy.
Expand All @@ -847,4 +847,24 @@ def to_numpy(a):
Returns:
`np.ndarray`: `a` as NumPy.
"""
return convert(a, NPNumeric)
return convert(a, NPOrNum)


@dispatch([object])
def to_numpy(*elements):
return to_numpy(elements)


@dispatch(list)
def to_numpy(a):
return [to_numpy(x) for x in a]


@dispatch(tuple)
def to_numpy(a):
return tuple(to_numpy(x) for x in a)


@dispatch(dict)
def to_numpy(a):
return {k: to_numpy(v) for k, v in a.items()}
2 changes: 1 addition & 1 deletion setup.cfg
@@ -1,6 +1,6 @@
[metadata]
name = backends
version = 0.3.1
version = 0.3.2
author = Wessel Bruinsma
author_email = wessel.p.bruinsma@gmail.com
description = A generic interface for linear algebra backends
Expand Down
19 changes: 19 additions & 0 deletions tests/test_generic.py
Expand Up @@ -277,3 +277,22 @@ def test_argsort():
def test_to_numpy():
check_function(B.to_numpy, (Tensor(),))
check_function(B.to_numpy, (Tensor(4),))


def test_to_numpy_multiple_objects():
assert B.to_numpy(tf.constant(1), tf.constant(1)) == (1, 1)


def test_to_numpy_list():
x = B.to_numpy([tf.constant(1)])
assert isinstance(x[0], (B.Number, B.NPNumeric))


def test_to_numpy_tuple():
x = B.to_numpy((tf.constant(1),))
assert isinstance(x[0], (B.Number, B.NPNumeric))


def test_to_numpy_dict():
x = B.to_numpy({'a': tf.constant(1)})
assert isinstance(x['a'], (B.Number, B.NPNumeric))

0 comments on commit fec660f

Please sign in to comment.