From 04030916d74f4efa81df415b2d19a4891210d9fa Mon Sep 17 00:00:00 2001 From: Tobias Hangleiter Date: Thu, 5 Mar 2020 11:39:40 +0100 Subject: [PATCH] Add setter for `Basis.four_element_traces` (#16) * Add setter for Basis.four_element_traces() * Add test for setter --- filter_functions/basis.py | 4 ++++ tests/test_basis.py | 10 +++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/filter_functions/basis.py b/filter_functions/basis.py index 30734a8..a010380 100644 --- a/filter_functions/basis.py +++ b/filter_functions/basis.py @@ -368,6 +368,10 @@ def four_element_traces(self) -> COO: return self._four_element_traces + @four_element_traces.setter + def four_element_traces(self, traces): + self._four_element_traces = traces + def normalize(self) -> None: """Normalize the basis in-place""" if self.ndim == 2: diff --git a/tests/test_basis.py b/tests/test_basis.py index e782cf2..30b0703 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -111,10 +111,14 @@ def test_basis_properties(self): if base.d < 8: # Test very resource intense + ref = np.einsum('iab,jbc,kcd,lda', *(base,)*4) self.assertArrayAlmostEqual(base.four_element_traces.todense(), - np.einsum('iab,jbc,kcd,lda', - *(base,)*4), - atol=1e-16) + ref, atol=1e-16) + + # Test setter + base._four_element_traces = None + base.four_element_traces = ref + self.assertArrayEqual(base.four_element_traces, ref) base._print_checks()