/
nelder_mead.cljc
308 lines (271 loc) · 12.4 KB
/
nelder_mead.cljc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
;;
;; Copyright © 2017 Colin Smith.
;; This work is based on the Scmutils system of MIT/GNU Scheme:
;; Copyright © 2002 Massachusetts Institute of Technology
;;
;; This is free software; you can redistribute it and/or modify
;; it under the terms of the GNU General Public License as published by
;; the Free Software Foundation; either version 3 of the License, or (at
;; your option) any later version.
;;
;; This software is distributed in the hope that it will be useful, but
;; WITHOUT ANY WARRANTY; without even the implied warranty of
;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
;; General Public License for more details.
;;
;; You should have received a copy of the GNU General Public License
;; along with this code; if not, see <http://www.gnu.org/licenses/>.
;;
(ns sicmutils.numerical.multimin.nelder-mead
(:require [sicmutils.util :as u]))
(defn- v+
"add two vectors elementwise."
[l r]
(mapv + l r))
(defn- v-
"subtract two vectors elementwise."
[l r]
(mapv - l r))
(defn- v*
"multiply vector v by scalar s."
[s v]
(mapv #(* s %) v))
(defn ^:private initial-simplex
"Takes an n-vector x0 and returns a list of n+1 n-vectors, of which x0 is the
first, and the remainder are formed by perturbing each coordinate in turn."
[x0 {:keys [nonzero-delta zero-delta]
:or {nonzero-delta 0.05
zero-delta 0.00025}}]
(let [x0 (vec x0)
scale (inc nonzero-delta)
f (fn [i xi]
(let [perturbed (if (zero? xi)
zero-delta
(* scale xi))]
(assoc x0 i perturbed)))]
(into [x0] (map-indexed f x0))))
(defn ^:private sup-norm
"Returns the absolute value of the distance of the individual coordinate in any
simplex farthest from its corresponding point in x0."
[[x0 :as simplex]]
(let [coords (if (sequential? x0)
(mapcat #(v- % x0) simplex)
(map #(- % x0) simplex))]
(reduce max (map u/compute-abs coords))))
(defn ^:private counted
"Takes a function and returns a pair of:
- an atom that keeps track of fn invocation counts,
- the instrumented fn"
[f]
(let [count (atom 0)]
[count (fn [x]
(swap! count inc)
(f x))]))
(defn ^:private sort-by-f
"Returns the two inputs `simplex` and `f(simplex)` sorted in ascending order by
function value.
Dimension must == the length of each element in the simplex."
([simplex f-simplex]
(sort-by-f simplex f-simplex (count (peek simplex))))
([simplex f-simplex dimension]
(let [indices-by-f (sort-by (partial nth f-simplex)
(range 0 (inc dimension)))
sorted-simplex (mapv simplex indices-by-f)
sorted-fsimplex (mapv f-simplex indices-by-f)]
[sorted-simplex sorted-fsimplex])))
(defn ^:private step-defaults
"Generates the options required for a step of Nelder-Mead.
`:adaptive?` controls the set of defaults. If true, they're generated using
the supplied dimension; else, they're static."
[dimension {:keys [adaptive?]
:or {adaptive? true}
:as m}]
(let [base (if adaptive?
{:alpha 1.0
:beta (+ 1.0 (/ 2.0 dimension))
:gamma (- 0.75 (/ (* 2.0 dimension)))
:sigma (- 1.0 (/ dimension))}
{:alpha 1.0
:beta 2.0
:gamma 0.5
:sigma 0.5})]
(merge base (select-keys m [:alpha :beta :gamma :sigma]))))
(defn ^:private step-fn
"Returns a function that performs a single step of nelder-mead. The function
expects a sorted simplex and f-simplex, and returns sorted results - a pair of
- [simplex, f(simplex)]
[This Scholarpedia
page](http://www.scholarpedia.org/article/Nelder-Mead_algorithm) provides a
nice overview of the algorithm.
The parameters in opts follow the convention from [Gao and Han's
paper](https://www.researchgate.net/publication/225691623_Implementing_the_Nelder-Mead_simplex_algorithm_with_adaptive_parameters)
introducing the adaptive parameter version of Nelder-Mead:
:alpha - reflection cefficient
:beta - expansion coefficient
:gamma - contraction coefficient
:sigma - shrink coefficient
"
([f dimension opts]
(let [{:keys [alpha beta sigma gamma]} (step-defaults dimension opts)]
(letfn [(centroid-pt [simplex]
(v* (/ dimension) (reduce v+ (pop simplex))))
;; Returns the point generated by reflecting the worst point across
;; the centroid of the simplex.
(reflect [simplex centroid]
(v- (v* (inc alpha) centroid)
(v* alpha (peek simplex))))
;; Returns the point generated by reflecting the worst point across
;; the centroid, and then stretching it in that direction by a factor
;; of beta.
(reflect-expand [simplex centroid]
(v- (v* (inc (* alpha beta)) centroid)
(v* (* alpha beta) (peek simplex))))
;; Returns the point generated by reflecting the worst point, then
;; shrinking it toward the centroid by a factor of gamma.
(reflect-contract [simplex centroid]
(v- (v* (inc (* gamma alpha)) centroid)
(v* (* gamma alpha) (peek simplex))))
;; Returns the point generated by shrinking the current worst point
;; toward the centroid by a factor of gamma.
(contract [simplex centroid]
(v+ (v* (- 1 gamma) centroid)
(v* gamma (peek simplex))))
;; Returns a simplex generated by scaling each point toward the best
;; point by the shrink factor $\sigma$; ie, by replacing all
;; points (except the best point $s_1$) with $s_i = s_1 + \sigma (\s_i
;; - s_1)$.
(shrink [[s0 & rest]]
(let [scale-toward-s0 #(v+ s0 (v* sigma (v- % s0)))
s (into [s0] (map scale-toward-s0 rest))]
(sort-by-f s (mapv f s) dimension)))]
(fn [simplex [f-best :as f-simplex]]
;; Verify that inputs and outputs remain sorted by f value.
{:pre [(apply <= f-simplex)]
:post [#(apply <= (second %))]}
(let [swap-worst (fn [elem f-elem]
(let [s (conj (pop simplex) elem)
fs (conj (pop f-simplex) f-elem)]
(sort-by-f s fs dimension)))
f-worst (peek f-simplex)
f-butworst (peek (pop f-simplex))
centroid (centroid-pt simplex)
reflected (reflect simplex centroid)
fr (f reflected)]
(cond
;; If the reflected point is the best (minimal) point so far, replace
;; the worst point with either an expansion of the simplex around that
;; point, or the reflected point itself.
;;
;; f(reflected worst) < f(best)
(< fr f-best)
(let [expanded (reflect-expand simplex centroid)
fe (f expanded)]
(if (< fe fr)
(swap-worst expanded fe)
(swap-worst reflected fr)))
;; f(best) <= f(reflected worst) < f(second worst)
;;
;; Else, if the reflected worst point is better than the second worst
;; point, swap it for the worst point.
(< fr f-butworst)
(swap-worst reflected fr)
;; f(butworst) <= f(reflected worst) < f(worst)
;;
;; If the reflected point is still better than the worst point,
;; generated a point by shrinking the reflected point toward the
;; centroid. If this is better than (or equivalent to) the reflected
;; point, replace it. Else, shrink the whole simplex.
(< fr f-worst)
(let [r-contracted (reflect-contract simplex centroid)
frc (f r-contracted)]
(if (<= frc fr)
(swap-worst r-contracted frc)
(shrink simplex)))
;; f(worst) <= f(reflected worst)
;;
;; Else, attempt to contrast the existing worst point toward the
;; centroid. If that improves performance, swap the new point; else,
;; shrink the whole simplex.
:else
(let [contracted (contract simplex centroid)
fc (f contracted)]
(if (< fc f-worst)
(swap-worst contracted fc)
(shrink simplex))))))))))
(defn ^:private convergence-fn
"Returns a function that returns true if the supplied simplex and simplex
evaluations signal convergence, false otherwise."
[{:keys [simplex-tolerance fn-tolerance]
:or {simplex-tolerance 1e-4
fn-tolerance 1e-4}}]
(fn [simplex f-simplex]
(and (<= (sup-norm simplex) simplex-tolerance)
(<= (sup-norm f-simplex) fn-tolerance))))
(defn ^:private stop-fn
"Takes an atom that, when dereferenced, returns a function call count, and the
dimension of the simplex.
Returns a function of `iterations` that returns true if the iteration and
function call limits signal stopping, false otherwise."
[f-counter dimension {:keys [maxiter maxfun]}]
(let [maxiter (or maxiter (* dimension 200))
maxfun (or maxfun (* dimension 200))]
(fn [iterations]
(or (> iterations maxiter)
(> @f-counter maxfun)))))
(defn nelder-mead
"Find the minimum of the function f: R^n -> R, given an initial point q ∈ R^n.
Supports the following optional keyword arguments:
`:callback` if supplied, the supplied fn will be invoked with iteration count,
the values of X and the value of f(X) at each intermediate point of
evaluation.
`:info?` if true, wraps the result with evaluation information.
`:adaptive?` if true, the Nelder-Mead parameters for contraction, expansion,
reflection and shrinking will be set adaptively, as functions of the number of
dimensions. If false they stay constant.
`:alpha` sets the reflection coefficient used for each step of Nelder-Mead.
`:beta` sets the expansion coefficient used for each step of Nelder-Mead.
`:gamma` sets the contraction coefficient used for each step of Nelder-Mead.
`:sigma` sets the shrink coefficient used for each step of Nelder-Mead.
`:maxiter` Maximum number of iterations allowed for the minimizer. Defaults to
200*dimension.
`:maxfun` Maximum number of times the function can be evaluated before exiting.
Defaults to 200*dimension.
`:simplex-tolerance` When the absolute value of the max difference between the
best point and any point in the simplex falls below this tolerance, the
minimizer stops. Defaults to 1e-4.
`:fn-tolerance` When the absolute value of the max difference between the best
point's function value and the fn value of any point in the simplex falls
below this tolerance, the minimizer stops. Defaults to 1e-4.
`:zero-delta` controls the value to which 0 entries in the initial vector are
set during initial simplex generation. Defaults to 0.00025.
`:nonzero-delta` factor by which entries in the initial vector are perturbed to
generate the initial simplex. Defaults to 0.05.
See Gao, F. and Han, L.
Implementing the Nelder-Mead simplex algorithm with adaptive
parameters. 2012. Computational Optimization and Applications.
51:1, pp. 259-277
I gratefully acknowledge the [Python implementation in
SciPy](https://github.com/scipy/scipy/blob/589c9afe41774ee96ec121f1867361146add8276/scipy/optimize/optimize.py#L556:5)
which I have imitated here.
"
[func x0 {:keys [callback] :as opts}]
(let [callback (or callback (constantly nil))
dimension (count x0)
[f-counter f] (counted func)
step (step-fn f dimension opts)
convergence? (convergence-fn opts)
stop? (stop-fn f-counter dimension opts)
simplex (initial-simplex x0 opts)
f-simplex (mapv f simplex)]
(loop [[[s0 :as simplex] [f0 :as f-simplex]] (sort-by-f simplex f-simplex dimension)
iteration 0]
(callback iteration s0 f0)
(let [converged? (convergence? simplex f-simplex)]
(if (or converged? (stop? iteration))
{:result s0
:value f0
:converged? converged?
:iterations iteration
:fncalls @f-counter}
(recur (step simplex f-simplex)
(inc iteration)))))))