From 4151a4bf2e96e69ec0d3526645bd9d9cec239229 Mon Sep 17 00:00:00 2001 From: sigvaldm Date: Fri, 18 Mar 2022 03:01:50 -0700 Subject: [PATCH] Minor bugfixes --- localreg/rbfnet.py | 4 ---- test/test_rbf.py | 6 +++--- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/localreg/rbfnet.py b/localreg/rbfnet.py index b27ede6..af6cd72 100644 --- a/localreg/rbfnet.py +++ b/localreg/rbfnet.py @@ -254,13 +254,9 @@ def compute_centers(self, input, num, random_state=None): inp_ = inp was_complex = False - print(inp) - print(inp_) - clustering = KMeans(n_clusters=num, random_state=random_state).fit(inp_) centers = clustering.cluster_centers_ - print(centers) if was_complex: centers = centers[:,:n_indeps]+1j*centers[:,n_indeps:] diff --git a/test/test_rbf.py b/test/test_rbf.py index 22f1b1d..c3b1c10 100644 --- a/test/test_rbf.py +++ b/test/test_rbf.py @@ -211,13 +211,13 @@ def test_keep_aspect(): def test_complex_input(): net = RBFnet() input = np.array([[1+1j], [1-1j], [-1-1j], [-1+1j]]) - output = np.real(x)+np.imag(x) + output = np.real(input)+np.imag(input) net.train(input, output, num=3) - assert np.allclose(output, net.predict(input)) + assert np.allclose(output, net.predict(input), atol=1e-5) def test_complex_output(): net = RBFnet() input = np.array([0,1,2]) output = input+1j*input net.train(input, output, num=2) - assert np.allclose(output, net.predict(input)) + assert np.allclose(output, net.predict(input), atol=1e-5)