/
fit.py
182 lines (151 loc) · 6.85 KB
/
fit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""tools to fit to an algebraic decay."""
# Copyright 2018-2020 TeNPy Developers, GNU GPLv3
import numpy as np
import scipy.optimize as optimize
__all__ = [
'alg_decay', 'linear_fit', 'lin_fit_res', 'alg_decay_fit_res', 'alg_decay_fit',
'alg_decay_fits', 'plot_alg_decay_fit', 'fit_with_sum_of_exp', 'sum_of_exp'
]
def alg_decay(x, a, b, c):
"""define the algebraic decay."""
return a * x**(-b) + c
def linear_fit(x, y):
"""Perform a linear fit of y to ax + b.
Returns a, b, res.
"""
assert x.ndim == 1 and y.ndim == 1
fit = np.linalg.lstsq(np.vstack([x, np.ones(len(x))]).T, y)
return fit[0][0], fit[0][1], fit[1][0]
def lin_fit_res(x, y):
"""Returns the least-square residue of a linear fit y vs x."""
assert x.ndim == 1 and y.ndim == 1
fit = np.linalg.lstsq(np.vstack([x, np.ones(len(x))]).T, y)
if len(fit[1]) < 1:
return np.max(y) - np.min(y)
return fit[1][0]
def alg_decay_fit_res(log_b, x, y):
"""Returns the residue of an algebraic decay fit of the form ``x**(-np.exp(log_b))``."""
return lin_fit_res(x**(-np.exp(log_b)), y)
def alg_decay_fit(x, y, npts=5, power_range=(0.01, 4.), power_mesh=[60, 10]):
"""Fit y to the form ``a*x**(-b) + c``.
Returns a triplet [a, b, c].
npts specifies the maximum number of points to fit. If npts < len(x), then alg_decay_fit() will only fit to the last npts points.
power_range is a tuple that gives that restricts the possible ranges for b.
power_mesh is a list of numbers, which specifies how fine to search for the optimal b.
E.g., if power_mesh = [60,10], then it'll first divide the power_range into 60 intervals, and then divide those intervals by 10.
"""
x = np.array(x)
y = np.array(y)
assert x.ndim == 1 and y.ndim == 1
assert len(x) == len(y)
if npts < 3:
raise ValueError
if len(x) > npts:
x = x[-npts:]
y = y[-npts:]
global_log_power_range = (np.log(power_range[0]), np.log(power_range[1]))
log_power_range = global_log_power_range
for i in range(len(power_mesh)):
# number of points inclusive
brute_Ns = (power_mesh[i] if i == 0 else 2 * power_mesh[i]) + 1
log_power_step = (log_power_range[1] - log_power_range[0]) / float(brute_Ns - 1)
brute_fit = optimize.brute(alg_decay_fit_res, [log_power_range], (x, y),
Ns=brute_Ns,
finish=None)
if brute_fit <= global_log_power_range[0] + 1e-6:
return [0., 0., y[-1]] # shit happened
log_power_range = (brute_fit - log_power_step, brute_fit + log_power_step)
l_fit = linear_fit(x**(-np.exp(brute_fit)), y)
return [l_fit[0], np.exp(brute_fit), l_fit[1]]
def alg_decay_fits(x, ys, npts=5, power_range=(0.01, 4.), power_mesh=[60, 10]):
"""Fit arrays of y's to the form a * x**(-b) + c.
Returns arrays of [a, b, c]."""
x = np.array(x)
if x.ndim != 1:
raise ValueError
ys = np.array(ys)
y_shape = ys.shape
assert y_shape[-1] == len(x)
abc_flat = np.array([
alg_decay_fit(x, yyy, npts=npts, power_range=power_range, power_mesh=power_mesh)
for yyy in ys.reshape(-1, len(x))
])
return abc_flat.reshape(y_shape[:-1] + (3, ))
def plot_alg_decay_fit(plot_module, x, y, fit_par, xfunc=None, kwargs={}, plot_fit_args={}):
"""Given x, y, and fit_par (output from alg_decay_fit), produces a plot of the algebraic decay
fit.
plot_module is matplotlib.pyplot, or a subplot. x, y are the data (real, 1-dimensional
np.ndarray) fit_par is a triplet of numbers [a, b, c] that describes and algebraic decay (see
alg_decay()). xfunc is an optional parameter that scales the x-axis in the resulting plot.
kwargs is a dictionary, whoses key/items are passed to the plot function. plot_fit_args is a
dictionary that controls how the fit is shown.
"""
if xfunc is None:
xfunc = lambda x: x
if plot_fit_args.get('show_data_points', True):
plot_module.plot(xfunc(x), y, 'o', **kwargs)
n_interp = plot_fit_args.get('n_interp', 30)
if len(x) > 1:
interp_x = np.arange(-0.03, 1.1, 1. / n_interp) * \
(np.max(x) - np.min(x)) + np.min(x)
if plot_fit_args.get('show_fit', True):
plot_module.plot(xfunc(interp_x), alg_decay(interp_x, *fit_par), '-', **kwargs)
extrap_xrange = np.array([x[-2], np.max(interp_x)])
if 'extrap_line_start' in plot_fit_args:
try:
extrap_xrange[0] = x[plot_fit_args['extrap_line_start']]
except IndexError:
if plot_fit_args['extrap_line_start'] >= len(x):
extrap_xrange[0] = np.max(interp_x)
if plot_fit_args['extrap_line_start'] < -len(x):
extrap_xrange[0] = np.min(interp_x)
if 'extrap_line_end' in plot_fit_args and plot_fit_args['extrap_line_end'] < len(x):
try:
extrap_xrange[1] = x[plot_fit_args['extrap_line_end']]
except IndexError:
extrap_xrange[1] = extrap_xrange[0]
if extrap_xrange[0] < extrap_xrange[1]:
plot_module.plot(xfunc(extrap_xrange), [fit_par[2], fit_par[2]], '--', **kwargs)
def fit_with_sum_of_exp(f, n, N=50):
r"""Approximate a decaying function `f` with a sum of exponentials.
MPOs can naturally represent long-range interactions with an exponential decay.
A common technique for other (e.g. powerlaw) long-range interactions is to approximate them
by sums of exponentials and to include them into the MPOs.
This funciton allows to do that.
The algorithm/implementation follows the appendix of [Murg2010]_.
Parameters
----------
f : function
Decaying function to be approximated. Needs to accept a 1D numpy array `x`
n : int
Number of exponentials to be used.
N : int
Number of points at which to evaluate/fit `f`;
we evaluate and fit `f` at the points ``x = np.arange(1, N+1)``.
Returns
-------
lambdas, prefactors: 1D arrays
Such that :math:`f(k) \approx \sum_i x_i \lambda_i^k` for (integer) 1 <= `k` <= `N`.
The function :func:`sum_of_exp` evaluates this for given `x`.
"""
assert n < N
x = np.arange(1, N + 1)
f_x = f(x)
F = np.zeros([N - n + 1, n])
for i in range(n):
F[:, i] = f_x[i:i + N - n + 1]
U, V = np.linalg.qr(F)
U1 = U[:-1, :]
U2 = U[1:, :]
M = np.dot(np.linalg.pinv(U1), U2)
lam = np.linalg.eigvals(M)
lam = np.sort(lam)[::-1]
# least-square fit
W = np.power.outer(lam, x).T
pref, res, rank, s = np.linalg.lstsq(W, f_x, None)
return lam, pref
def sum_of_exp(lambdas, prefactors, x):
"""Evaluate ``sum_i prefactor[i] * lambda[i]**x`` for different `x`.
See :func:`fit_sum_of_exp` for more details.
"""
return np.real_if_close(np.dot(np.power.outer(lambdas, x).T, prefactors))