/
bsplines.py
352 lines (272 loc) · 9.99 KB
/
bsplines.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
from sympy.core import S, sympify
from sympy.functions import Piecewise, piecewise_fold
from sympy.sets.sets import Interval
from sympy.core.cache import lru_cache
def _ivl(cond, x):
"""return the interval corresponding to the condition
Conditions in spline's Piecewise give the range over
which an expression is valid like (lo <= x) & (x <= hi).
This function returns (lo, hi).
"""
from sympy.logic.boolalg import And
if isinstance(cond, And) and len(cond.args) == 2:
a, b = cond.args
if a.lts == x:
a, b = b, a
return a.lts, b.gts
raise TypeError('unexpected cond type: %s' % cond)
def _add_splines(c, b1, d, b2, x):
"""Construct c*b1 + d*b2."""
if b1 == S.Zero or c == S.Zero:
rv = piecewise_fold(d * b2)
elif b2 == S.Zero or d == S.Zero:
rv = piecewise_fold(c * b1)
else:
new_args = []
# Just combining the Piecewise without any fancy optimization
p1 = piecewise_fold(c * b1)
p2 = piecewise_fold(d * b2)
# Search all Piecewise arguments except (0, True)
p2args = list(p2.args[:-1])
# This merging algorithm assumes the conditions in
# p1 and p2 are sorted
for arg in p1.args[:-1]:
expr = arg.expr
cond = arg.cond
lower = _ivl(cond, x)[0]
# Check p2 for matching conditions that can be merged
for i, arg2 in enumerate(p2args):
expr2 = arg2.expr
cond2 = arg2.cond
lower_2, upper_2 = _ivl(cond2, x)
if cond2 == cond:
# Conditions match, join expressions
expr += expr2
# Remove matching element
del p2args[i]
# No need to check the rest
break
elif lower_2 < lower and upper_2 <= lower:
# Check if arg2 condition smaller than arg1,
# add to new_args by itself (no match expected
# in p1)
new_args.append(arg2)
del p2args[i]
break
# Checked all, add expr and cond
new_args.append((expr, cond))
# Add remaining items from p2args
new_args.extend(p2args)
# Add final (0, True)
new_args.append((0, True))
rv = Piecewise(*new_args, evaluate=False)
return rv.expand()
@lru_cache(maxsize=128)
def bspline_basis(d, knots, n, x):
"""
The $n$-th B-spline at $x$ of degree $d$ with knots.
Explanation
===========
B-Splines are piecewise polynomials of degree $d$. They are defined on a
set of knots, which is a sequence of integers or floats.
Examples
========
The 0th degree splines have a value of 1 on a single interval:
>>> from sympy import bspline_basis
>>> from sympy.abc import x
>>> d = 0
>>> knots = tuple(range(5))
>>> bspline_basis(d, knots, 0, x)
Piecewise((1, (x >= 0) & (x <= 1)), (0, True))
For a given ``(d, knots)`` there are ``len(knots)-d-1`` B-splines
defined, that are indexed by ``n`` (starting at 0).
Here is an example of a cubic B-spline:
>>> bspline_basis(3, tuple(range(5)), 0, x)
Piecewise((x**3/6, (x >= 0) & (x <= 1)),
(-x**3/2 + 2*x**2 - 2*x + 2/3,
(x >= 1) & (x <= 2)),
(x**3/2 - 4*x**2 + 10*x - 22/3,
(x >= 2) & (x <= 3)),
(-x**3/6 + 2*x**2 - 8*x + 32/3,
(x >= 3) & (x <= 4)),
(0, True))
By repeating knot points, you can introduce discontinuities in the
B-splines and their derivatives:
>>> d = 1
>>> knots = (0, 0, 2, 3, 4)
>>> bspline_basis(d, knots, 0, x)
Piecewise((1 - x/2, (x >= 0) & (x <= 2)), (0, True))
It is quite time consuming to construct and evaluate B-splines. If
you need to evaluate a B-spline many times, it is best to lambdify them
first:
>>> from sympy import lambdify
>>> d = 3
>>> knots = tuple(range(10))
>>> b0 = bspline_basis(d, knots, 0, x)
>>> f = lambdify(x, b0)
>>> y = f(0.5)
Parameters
==========
d : integer
degree of bspline
knots : list of integer values
list of knots points of bspline
n : integer
$n$-th B-spline
x : symbol
See Also
========
bspline_basis_set
References
==========
.. [1] https://en.wikipedia.org/wiki/B-spline
"""
from sympy.core.symbol import Dummy
# make sure x has no assumptions so conditions don't evaluate
xvar = x
x = Dummy()
knots = tuple(sympify(k) for k in knots)
d = int(d)
n = int(n)
n_knots = len(knots)
n_intervals = n_knots - 1
if n + d + 1 > n_intervals:
raise ValueError("n + d + 1 must not exceed len(knots) - 1")
if d == 0:
result = Piecewise(
(S.One, Interval(knots[n], knots[n + 1]).contains(x)), (0, True)
)
elif d > 0:
denom = knots[n + d + 1] - knots[n + 1]
if denom != S.Zero:
B = (knots[n + d + 1] - x) / denom
b2 = bspline_basis(d - 1, knots, n + 1, x)
else:
b2 = B = S.Zero
denom = knots[n + d] - knots[n]
if denom != S.Zero:
A = (x - knots[n]) / denom
b1 = bspline_basis(d - 1, knots, n, x)
else:
b1 = A = S.Zero
result = _add_splines(A, b1, B, b2, x)
else:
raise ValueError("degree must be non-negative: %r" % n)
# return result with user-given x
return result.xreplace({x: xvar})
def bspline_basis_set(d, knots, x):
"""
Return the ``len(knots)-d-1`` B-splines at *x* of degree *d*
with *knots*.
Explanation
===========
This function returns a list of piecewise polynomials that are the
``len(knots)-d-1`` B-splines of degree *d* for the given knots.
This function calls ``bspline_basis(d, knots, n, x)`` for different
values of *n*.
Examples
========
>>> from sympy import bspline_basis_set
>>> from sympy.abc import x
>>> d = 2
>>> knots = range(5)
>>> splines = bspline_basis_set(d, knots, x)
>>> splines
[Piecewise((x**2/2, (x >= 0) & (x <= 1)),
(-x**2 + 3*x - 3/2, (x >= 1) & (x <= 2)),
(x**2/2 - 3*x + 9/2, (x >= 2) & (x <= 3)),
(0, True)),
Piecewise((x**2/2 - x + 1/2, (x >= 1) & (x <= 2)),
(-x**2 + 5*x - 11/2, (x >= 2) & (x <= 3)),
(x**2/2 - 4*x + 8, (x >= 3) & (x <= 4)),
(0, True))]
Parameters
==========
d : integer
degree of bspline
knots : list of integers
list of knots points of bspline
x : symbol
See Also
========
bspline_basis
"""
n_splines = len(knots) - d - 1
return [bspline_basis(d, tuple(knots), i, x) for i in range(n_splines)]
def interpolating_spline(d, x, X, Y):
"""
Return spline of degree *d*, passing through the given *X*
and *Y* values.
Explanation
===========
This function returns a piecewise function such that each part is
a polynomial of degree not greater than *d*. The value of *d*
must be 1 or greater and the values of *X* must be strictly
increasing.
Examples
========
>>> from sympy import interpolating_spline
>>> from sympy.abc import x
>>> interpolating_spline(1, x, [1, 2, 4, 7], [3, 6, 5, 7])
Piecewise((3*x, (x >= 1) & (x <= 2)),
(7 - x/2, (x >= 2) & (x <= 4)),
(2*x/3 + 7/3, (x >= 4) & (x <= 7)))
>>> interpolating_spline(3, x, [-2, 0, 1, 3, 4], [4, 2, 1, 1, 3])
Piecewise((7*x**3/117 + 7*x**2/117 - 131*x/117 + 2, (x >= -2) & (x <= 1)),
(10*x**3/117 - 2*x**2/117 - 122*x/117 + 77/39, (x >= 1) & (x <= 4)))
Parameters
==========
d : integer
Degree of Bspline strictly greater than equal to one
x : symbol
X : list of strictly increasing integer values
list of X coordinates through which the spline passes
Y : list of strictly increasing integer values
list of Y coordinates through which the spline passes
See Also
========
bspline_basis_set, interpolating_poly
"""
from sympy import symbols, Dummy
from sympy.solvers.solveset import linsolve
from sympy.matrices.dense import Matrix
# Input sanitization
d = sympify(d)
if not (d.is_Integer and d.is_positive):
raise ValueError("Spline degree must be a positive integer, not %s." % d)
if len(X) != len(Y):
raise ValueError("Number of X and Y coordinates must be the same.")
if len(X) < d + 1:
raise ValueError("Degree must be less than the number of control points.")
if not all(a < b for a, b in zip(X, X[1:])):
raise ValueError("The x-coordinates must be strictly increasing.")
X = [sympify(i) for i in X]
# Evaluating knots value
if d.is_odd:
j = (d + 1) // 2
interior_knots = X[j:-j]
else:
j = d // 2
interior_knots = [
(a + b)/2 for a, b in zip(X[j : -j - 1], X[j + 1 : -j])
]
knots = [X[0]] * (d + 1) + list(interior_knots) + [X[-1]] * (d + 1)
basis = bspline_basis_set(d, knots, x)
A = [[b.subs(x, v) for b in basis] for v in X]
coeff = linsolve((Matrix(A), Matrix(Y)), symbols("c0:{}".format(len(X)), cls=Dummy))
coeff = list(coeff)[0]
intervals = {c for b in basis for (e, c) in b.args if c != True}
# Sorting the intervals
# ival contains the end-points of each interval
ival = [_ivl(c, x) for c in intervals]
com = zip(ival, intervals)
com = sorted(com, key=lambda x: x[0])
intervals = [y for x, y in com]
basis_dicts = [{c: e for (e, c) in b.args} for b in basis]
spline = []
for i in intervals:
piece = sum(
[c * d.get(i, S.Zero) for (c, d) in zip(coeff, basis_dicts)], S.Zero
)
spline.append((piece, i))
return Piecewise(*spline)