/
test_rbf.py
103 lines (92 loc) · 3.2 KB
/
test_rbf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#!/usr/bin/env python
# Created by John Travers, Robert Hetland, 2007
""" Test functions for rbf module """
import numpy as np
from numpy.testing import assert_, assert_array_almost_equal, assert_almost_equal
from numpy import linspace, sin, random, exp, allclose
from scipy.interpolate.rbf import Rbf
FUNCTIONS = ('multiquadric', 'inverse multiquadric', 'gaussian',
'cubic', 'quintic', 'thin-plate', 'linear')
def check_rbf1d_interpolation(function):
"""Check that the Rbf function interpolates throught the nodes (1D)"""
x = linspace(0,10,9)
y = sin(x)
rbf = Rbf(x, y, function=function)
yi = rbf(x)
assert_array_almost_equal(y, yi)
assert_almost_equal(rbf(float(x[0])), y[0])
def check_rbf2d_interpolation(function):
"""Check that the Rbf function interpolates throught the nodes (2D)"""
x = random.rand(50,1)*4-2
y = random.rand(50,1)*4-2
z = x*exp(-x**2-1j*y**2)
rbf = Rbf(x, y, z, epsilon=2, function=function)
zi = rbf(x, y)
zi.shape = x.shape
assert_array_almost_equal(z, zi)
def check_rbf3d_interpolation(function):
"""Check that the Rbf function interpolates throught the nodes (3D)"""
x = random.rand(50,1)*4-2
y = random.rand(50,1)*4-2
z = random.rand(50,1)*4-2
d = x*exp(-x**2-y**2)
rbf = Rbf(x, y, z, d, epsilon=2, function=function)
di = rbf(x, y, z)
di.shape = x.shape
assert_array_almost_equal(di, d)
def test_rbf_interpolation():
olderr = np.seterr(all="ignore")
try:
for function in FUNCTIONS:
yield check_rbf1d_interpolation, function
yield check_rbf2d_interpolation, function
yield check_rbf3d_interpolation, function
finally:
np.seterr(**olderr)
def check_rbf1d_regularity(function, atol):
"""Check that the Rbf function approximates a smooth function well away
from the nodes."""
x = linspace(0, 10, 9)
y = sin(x)
rbf = Rbf(x, y, function=function)
xi = linspace(0, 10, 100)
yi = rbf(xi)
#import matplotlib.pyplot as plt
#plt.figure()
#plt.plot(x, y, 'o', xi, sin(xi), ':', xi, yi, '-')
#plt.title(function)
#plt.show()
msg = "abs-diff: %f" % abs(yi - sin(xi)).max()
assert_(allclose(yi, sin(xi), atol=atol), msg)
def test_rbf_regularity():
tolerances = {
'multiquadric': 0.05,
'inverse multiquadric': 0.02,
'gaussian': 0.01,
'cubic': 0.15,
'quintic': 0.1,
'thin-plate': 0.1,
'linear': 0.2
}
olderr = np.seterr(all="ignore")
try:
for function in FUNCTIONS:
yield check_rbf1d_regularity, function, tolerances.get(function, 1e-2)
finally:
np.seterr(**olderr)
def test_default_construction():
"""Check that the Rbf class can be constructed with the default
multiquadric basis function. Regression test for ticket #1228."""
x = linspace(0,10,9)
y = sin(x)
rbf = Rbf(x, y)
yi = rbf(x)
assert_array_almost_equal(y, yi)
def test_function_is_callable():
"""Check that the Rbf class can be constructed with function=callable."""
x = linspace(0,10,9)
y = sin(x)
linfunc = lambda x:x
rbf = Rbf(x, y, function=linfunc)
yi = rbf(x)
assert_array_almost_equal(y, yi)