Skip to content

Commit

Permalink
Merge pull request #353 from jtravs/interp2d_rectangular_fixes
Browse files Browse the repository at this point in the history
ENH: interpolate: use RectBivariateSpline for grid-form data

Closes Trac #286. Using the more robust spline approach works
around issues met in #898 (strange knot selection), #1364
and #1072 (knot selection warnings), and #776 (crash on large data).
  • Loading branch information
pv committed Nov 22, 2012
2 parents fcfc438 + 23b7fd8 commit 152de29
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 10 deletions.
44 changes: 34 additions & 10 deletions scipy/interpolate/interpolate.py
Expand Up @@ -14,6 +14,7 @@

import fitpack
import _fitpack
import dfitpack

def reduce_sometrue(a):
all = a
Expand Down Expand Up @@ -141,15 +142,29 @@ class interp2d(object):

def __init__(self, x, y, z, kind='linear', copy=True, bounds_error=False,
fill_value=np.nan):
self.x, self.y, self.z = map(ravel, map(asarray, [x, y, z]))

if len(self.z) == len(self.x) * len(self.y):
self.x, self.y = meshgrid(x,y)
self.x, self.y = map(ravel, [self.x, self.y])
if len(self.x) != len(self.y):
raise ValueError("x and y must have equal lengths")
if len(self.z) != len(self.x):
raise ValueError("Invalid length for input z")
self.x, self.y, self.z = map(asarray, [x, y, z])
self.x, self.y = map(ravel, [self.x, self.y])

if self.z.size == len(self.x) * len(self.y):
rectangular_grid = True
if not all(self.x[1:] > self.x[:-1]):
j = np.argsort(self.x)
self.x = self.x[j]
self.z = self.z[:,j]
if not all(self.y[1:] > self.y[:-1]):
j = np.argsort(self.y)
self.y = self.y[j]
self.z = self.z[j,:]
self.z = ravel(self.z.T)
else:
rectangular_grid = False
self.z = ravel(self.z)
if len(self.x) != len(self.y):
raise ValueError(
"x and y must have equal lengths for non rectangular grid")
if len(self.z) != len(self.x):
raise ValueError(
"Invalid length for input z for non rectangular grid")

try:
kx = ky = {'linear' : 1,
Expand All @@ -158,7 +173,16 @@ def __init__(self, x, y, z, kind='linear', copy=True, bounds_error=False,
except KeyError:
raise ValueError("Unsupported interpolation type.")

self.tck = fitpack.bisplrep(self.x, self.y, self.z, kx=kx, ky=ky, s=0.)
if not rectangular_grid:
# TODO: surfit is really not meant for interpolation!
self.tck = fitpack.bisplrep(self.x, self.y, self.z,
kx=kx, ky=ky, s=0.0)
else:
nx, tx, ny, ty, c, fp, ier = dfitpack.regrid_smth(
self.x, self.y, self.z, None, None, None, None,
kx=kx, ky=ky, s=0.0)
self.tck = (tx[:nx], ty[:ny], c[:(nx - kx - 1) * (ny - ky - 1)],
kx, ky)

def __call__(self,x,y,dx=0,dy=0):
"""Interpolate the function.
Expand Down
33 changes: 33 additions & 0 deletions scipy/interpolate/tests/test_interpolate.py
Expand Up @@ -25,6 +25,39 @@ def test_interp2d_meshgrid_input(self):
I = interp2d(x, y, z)
assert_almost_equal(I(1.0, 2.0), sin(2.0), decimal=2)

def test_interp2d_meshgrid_input_unsorted(self):
np.random.seed(1234)
x = linspace(0, 2, 16)
y = linspace(0, pi, 21)

z = sin(x[None,:] + y[:,None]/2.)
ip1 = interp2d(x.copy(), y.copy(), z, kind='cubic')

np.random.shuffle(x)
z = sin(x[None,:] + y[:,None]/2.)
ip2 = interp2d(x.copy(), y.copy(), z, kind='cubic')

np.random.shuffle(x)
np.random.shuffle(y)
z = sin(x[None,:] + y[:,None]/2.)
ip3 = interp2d(x, y, z, kind='cubic')

x = linspace(0, 2, 31)
y = linspace(0, pi, 30)

assert_equal(ip1(x, y), ip2(x, y))
assert_equal(ip1(x, y), ip3(x, y))

def test_interp2d_linear(self):
# Ticket #898
a = np.zeros([5, 5])
a[2, 2] = 1.0
x = y = np.arange(5)
b = interp2d(x, y, a, 'linear')
assert_almost_equal(b(2.0, 1.5), np.array([0.5]), decimal=2)
assert_almost_equal(b(2.0, 2.5), np.array([0.5]), decimal=2)


class TestInterp1D(object):

def setUp(self):
Expand Down

0 comments on commit 152de29

Please sign in to comment.