From ffd02044772c83f23c4cb4b3e6489c0dea12d10d Mon Sep 17 00:00:00 2001 From: Jesus Recuerda Hueso Date: Wed, 5 Jun 2019 14:55:31 +0200 Subject: [PATCH] Fix issue with numpy interop (#51) Reading from an Array: - When an Array of real or complex numbers with more than two dimensions was benig converted to numpy this conversion didn't return an np.array with a proper shape. Creating an Array: - When an Array of real numbers with more than two dimensions was built. The shape in the device (arrayfire) was not correct. - For complex numbers it occurrs with more than one dimension. --- khiva/array.py | 47 +++++++++++++-------------- tests/unit_tests/array_unit_tests.py | 24 +++++++++++++- tests/unit_tests/matrix_unit_tests.py | 4 +-- 3 files changed, 48 insertions(+), 27 deletions(-) diff --git a/khiva/array.py b/khiva/array.py index 11e048b..e3b283e 100644 --- a/khiva/array.py +++ b/khiva/array.py @@ -14,7 +14,6 @@ import ctypes import logging import sys -from collections import deque from enum import Enum import numpy as np @@ -177,21 +176,24 @@ def _create_array(self, data): if isinstance(data, pd.DataFrame): data = data.values shape = np.array(data.shape) - shape = shape[shape > 1] - shape = deque(shape) - shape.rotate(1) + + if data.size > 1: + trimmed_dims = shape + for _ in range(0, 3): + if trimmed_dims[-1] == 1: + trimmed_dims = trimmed_dims[:-1] + shape = trimmed_dims[::-1] + else: + shape = np.array([1]) + c_array_n = (ctypes.c_longlong * len(shape))(*(np.array(shape)).astype(np.longlong)) c_ndims = ctypes.c_uint(len(shape)) c_complex = np.iscomplexobj(data) if c_complex: - data = np.array([data.real, data.imag]) - c = deque(range(1, len(data.shape))) - c.rotate(1) - c.append(0) - array_joint = np.transpose(data, c).flatten() - else: - array_joint = data.flatten() + data = np.dstack((data.real.flatten(), data.imag.flatten())) + + array_joint = data.flatten() c_array_joint = (_get_array_type(self.khiva_type.value) * len(array_joint))( *array_joint) @@ -212,25 +214,22 @@ def _get_data(self): c_result_array = (_get_array_type(self.khiva_type.value) * self.result_l)(*initialized_result_array) KhivaLibrary().c_khiva_library.get_data(ctypes.pointer(self.arr_reference), ctypes.pointer(c_result_array)) - dims = self.get_dims() - if dims[dims > 1].size > 0: - dims = dims[dims > 1] - else: - dims = np.array([1]) - a = np.array(c_result_array) if self._is_complex(): a = np.array(np.split(a, self.result_l / 2)) a = np.apply_along_axis(lambda args: [complex(*args)], 1, a) - a = a.reshape(dims) - c = deque(range(len(a.shape))) - c.rotate(-1) - a = np.transpose(a, c) + + # Clean up the last n dimensions if these are equal to 1 + if a.size > 1: + trimmed_dims = self.get_dims() + for _ in range(0, 3): + if trimmed_dims[-1] == 1: + trimmed_dims = trimmed_dims[:-1] else: - dims = deque(dims) - dims.rotate(1) - a = a.reshape(dims) + trimmed_dims = np.array([1]) + + a = a.reshape(trimmed_dims[::-1]) a = a.astype(_get_numpy_type(self.khiva_type.value)) return a diff --git a/tests/unit_tests/array_unit_tests.py b/tests/unit_tests/array_unit_tests.py index 5c3ef66..e86671d 100644 --- a/tests/unit_tests/array_unit_tests.py +++ b/tests/unit_tests/array_unit_tests.py @@ -30,21 +30,43 @@ class ArrayTest(unittest.TestCase): def setUp(self): set_backend(KHIVABackend.KHIVA_BACKEND_CPU) + def test_real_1d_creation(self): + a = Array([1, 5, 3, 1]) + np.testing.assert_array_equal(a.dims, np.array([4, 1, 1, 1])) + + def test_single_value_creation(self): + a = Array([1]) + np.testing.assert_array_equal(a.dims, np.array([1, 1, 1, 1])) + def test_real_1d(self): a = Array([1, 2, 3, 4, 5, 6, 7, 8]) expected = np.array([1, 2, 3, 4, 5, 6, 7, 8]) np.testing.assert_array_equal(a.to_numpy(), expected) + def test_real_2d_creation(self): + a = Array([[1, 5, 3, 1], [2, 6, 9, 8], [3, 4, 1, 3]]) + np.testing.assert_array_equal(a.dims, np.array([4, 3, 1, 1])) + def test_real_2d(self): a = Array([[1, 2, 3, 4], [5, 6, 7, 8]]) expected = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]) np.testing.assert_array_equal(a.to_numpy(), expected) + def test_real_3d_creation(self): + a = Array([[[1, 5, 3, 1], [2, 6, 9, 8], [3, 4, 1, 3]], + [[3, 7, 4, 2], [4, 8, 1, 9], [1, 5, 9, 2]]]) + np.testing.assert_array_equal(a.dims, np.array([4, 3, 2, 1])) + def test_real_3d(self): a = Array([[[1, 5], [2, 6]], [[3, 7], [4, 8]]]) expected = np.array([[[1, 5], [2, 6]], [[3, 7], [4, 8]]]) np.testing.assert_array_equal(a.to_numpy(), expected) + def test_real_3d_large_column(self): + a = Array([[[1, 5, 3], [2, 6, 9]], [[3, 7, 4], [4, 8, 1]]]) + expected = np.array([[[1, 5, 3], [2, 6, 9]], [[3, 7, 4], [4, 8, 1]]]) + np.testing.assert_array_equal(a.to_numpy(), expected) + def test_real_4d(self): a = Array([[[[1, 9], [2, 10]], [[3, 11], [4, 12]]], [[[5, 13], [6, 14]], [[7, 15], [8, 16]]]]) expected = np.array([[[[1, 9], [2, 10]], [[3, 11], [4, 12]]], [[[5, 13], [6, 14]], [[7, 15], [8, 16]]]]) @@ -254,7 +276,7 @@ def testCols(self): def testRow(self): a = Array(np.transpose([[1, 2], [3, 4]]), dtype.s32) c = a.get_row(0) - np.testing.assert_array_equal(c.to_numpy(), [1, 2]) + np.testing.assert_array_equal(c.to_numpy(), np.transpose(np.array([[1, 2]]))) def testRows(self): a = Array(np.transpose([[1, 2], [3, 4], [5, 6]]), dtype.s32) diff --git a/tests/unit_tests/matrix_unit_tests.py b/tests/unit_tests/matrix_unit_tests.py index fba349f..5532138 100644 --- a/tests/unit_tests/matrix_unit_tests.py +++ b/tests/unit_tests/matrix_unit_tests.py @@ -68,8 +68,8 @@ def test_find_best_n_motifs_multiple_profiles(self): find_best_n_motifs_result = find_best_n_motifs(stomp_result[0], stomp_result[1], 3, 1) a = find_best_n_motifs_result[1].to_numpy() b = find_best_n_motifs_result[2].to_numpy() - np.testing.assert_array_almost_equal(a, np.array([[12, 12], [12, 12]]), decimal=self.DECIMAL) - np.testing.assert_array_almost_equal(b, np.array([[1, 1], [1, 1]]), decimal=self.DECIMAL) + np.testing.assert_array_almost_equal(a, np.array([[[12], [12]], [[12], [12]]]), decimal=self.DECIMAL) + np.testing.assert_array_almost_equal(b, np.array([[[1], [1]], [[1], [1]]]), decimal=self.DECIMAL) def test_find_best_n_motifs_mirror(self): stomp_result = stomp_self_join(Array([10.1, 11, 10.2, 10.15, 10.775, 10.1, 11, 10.2], dtype.f32), 3)