-
Notifications
You must be signed in to change notification settings - Fork 22
/
zern.py
356 lines (287 loc) · 13.6 KB
/
zern.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
#!/usr/bin/env python
# encoding: utf-8
"""
@file zern.py
@brief Zernike basis function utilities
@package libtim.zern
@brief Zernike basis function utilities
@author Tim van Werkhoven (werkhoven@strw.leidenuniv.nl)
@copyright Creative Commons Attribution-Share Alike license versions 3.0 or higher, see http://creativecommons.org/licenses/by-sa/3.0/
@date 20120403
Construct and analyze Zernike basis functions
"""
#==========================================================================
# Import libraries here
#==========================================================================
import numpy as np
import libtim as tim
import libtim.im
#==========================================================================
# Defines
#==========================================================================
#==========================================================================
# Routines
#==========================================================================
from scipy.misc import factorial as fac
def zernike_rad(m, n, rho):
"""
Make radial Zernike polynomial on coordinate grid **rho**.
@param [in] m Radial Zernike index
@param [in] n Azimuthal Zernike index
@param [in] rho Radial coordinate grid
@return Radial polynomial with identical shape as **rho**
"""
if (np.mod(n-m, 2) == 1):
return rho*0.0
wf = rho*0.0
for k in range((n-m)/2+1):
wf += rho**(n-2.0*k) * (-1.0)**k * fac(n-k) / ( fac(k) * fac( (n+m)/2.0 - k ) * fac( (n-m)/2.0 - k ) )
return wf
def zernike(m, n, rho, phi, norm=True):
"""
Calculate Zernike mode (m,n) on grid **rho** and **phi**.
**rho** and **phi** should be radial and azimuthal coordinate grids of identical shape, respectively.
@param [in] m Radial Zernike index
@param [in] n Azimuthal Zernike index
@param [in] rho Radial coordinate grid
@param [in] phi Azimuthal coordinate grid
@param [in] norm Normalize modes to unit variance
@return Zernike mode (m,n) with identical shape as rho, phi
@see <http://research.opt.indiana.edu/Library/VSIA/VSIA-2000_taskforce/TOPS4_2.html> and <http://research.opt.indiana.edu/Library/HVO/Handbook.html>.
"""
nc = 1.0
if (norm):
nc = (2*(n+1)/(1+(m==0)))**0.5
if (m > 0): return nc*zernike_rad(m, n, rho) * np.cos(m * phi)
if (m < 0): return nc*zernike_rad(-m, n, rho) * np.sin(-m * phi)
return nc*zernike_rad(0, n, rho)
def noll_to_zern(j):
"""
Convert linear Noll index to tuple of Zernike indices.
j is the linear Noll coordinate, n is the radial Zernike index and m is the azimuthal Zernike index.
@param [in] j Zernike mode Noll index
@return (n, m) tuple of Zernike indices
@see <https://oeis.org/A176988>.
"""
if (j == 0):
raise ValueError("Noll indices start at 1, 0 is invalid.")
n = 0
j1 = j-1
while (j1 > n):
n += 1
j1 -= n
m = (-1)**j * ((n % 2) + 2 * int((j1+((n+1)%2)) / 2.0 ))
return (n, m)
def zernikel(j, rho, phi, norm=True):
n, m = noll_to_zern(j)
return zernike(m, n, rho, phi, norm)
# def zernikel(j, size=256, norm=True):
# """
# Calculate Zernike mode with Noll-index j on a square grid of <size>^2
# elements
# """
# n, m = noll_to_zern(j)
#
# grid = (np.indices((size, size), dtype=np.float) - 0.5*size) / (0.5*size)
# grid_rad = (grid[0]**2. + grid[1]**2.)**0.5
# grid_ang = np.arctan2(grid[0], grid[1])
# return zernike(m, n, grid_rad, grid_ang, norm)
def noll_to_zern_broken(j):
"""
Previous and incorrect Noll-to-Zernike conversion.
Stored for reference purposes. Fixed around 1321970330.98158 or Tue Nov 22 13:59:10 2011 UTC. All data generated before this has invalid mapping.
@deprecated Incorrect mapping, use noll_to_zern instead
"""
raise DeprecatedWarning("Incorrect mapping, use noll_to_zern instead")
n = 0
j1 = j
while (j1 > n):
n += 1
j1 -= n
m = -n+2*j1
return (n, m)
def fix_noll_map(max):
"""
Translate old incorrect Noll coordinates to correct values.
This function repairs data generated with noll_to_zern_broken().
"""
return [(jold, jnew)
for jold in xrange(max)
for jnew in xrange(1, max)
if (noll_to_zern_broken(jold) == noll_to_zern(jnew))]
def zern_normalisation(nmodes=30):
"""
Calculate normalisation vector.
This function calculates a **nmodes** element vector with normalisation constants for Zernike modes that have not already been normalised.
@param [in] nmodes Size of normalisation vector.
@see <http://research.opt.indiana.edu/Library/VSIA/VSIA-2000_taskforce/TOPS4_2.html> and <http://research.opt.indiana.edu/Library/HVO/Handbook.html>.
"""
nolls = (noll_to_zern(j+1) for j in xrange(nmodes))
norms = [(2*(n+1)/(1+(m==0)))**0.5 for n, m in nolls]
return np.asanyarray(norms)
### Higher level Zernike generating / fitting functions
def calc_zern_basis(nmodes, rad, modestart=1, calc_covmat=False):
"""
Calculate a basis of **nmodes** Zernike modes with radius **rad**.
((If **mask** is true, set everything outside of radius **rad** to zero (default). If this is not done, the set of Zernikes will be **rad** by **rad** square and are not orthogonal anymore.)) --> Nothing is masked, do this manually using the 'mask' entry in the returned dict.
This output of this function can be used as cache for other functions.
@param [in] nmodes Number of modes to generate
@param [in] rad Radius of Zernike modes
@param [in] modestart First mode to calculate (Noll index, i.e. 1=piston)
@param [in] calc_covmat Return covariance matrix for Zernike modes, and its inverse
@return Dict with entries 'modes' a list of Zernike modes, 'modesmat' a matrix of (nmodes, npixels), 'covmat' a covariance matrix for all these modes with 'covmat_in' its inverse, 'mask' is a binary mask to crop only the orthogonal part of the modes.
"""
if (nmodes <= 0):
return {'modes':[], 'modesmat':[], 'covmat':0, 'covmat_in':0, 'mask':[[0]]}
if (rad <= 0):
raise ValueError("radius should be > 0")
if (modestart <= 0):
raise ValueError("**modestart** Noll index should be > 0")
# Use vectors instead of a grid matrix
rvec = ((np.arange(2.0*rad) - rad)/rad)
r0 = rvec.reshape(-1,1)
r1 = rvec.reshape(1,-1)
grid_rad = tim.im.mk_rad_mask(2*rad)
grid_ang = np.arctan2(r0, r1)
grid_mask = grid_rad <= 1
# Build list of Zernike modes, these are *not* masked/cropped
zern_modes = [zernikel(zmode, grid_rad, grid_ang) for zmode in xrange(modestart, nmodes+modestart)]
# Convert modes to (nmodes, npixels) matrix
zern_modes_mat = np.r_[zern_modes].reshape(nmodes, -1)
covmat = covmat_in = None
if (calc_covmat):
# Calculate covariance matrix
covmat = np.array([[np.sum(zerni * zernj * grid_mask) for zerni in zern_modes] for zernj in zern_modes])
# Invert covariance matrix using SVD
covmat_in = np.linalg.pinv(covmat)
# Create and return dict
return {'modes': zern_modes, 'modesmat': zern_modes_mat, 'covmat':covmat, 'covmat_in':covmat_in, 'mask': grid_mask}
def fit_zernike(wavefront, zern_data={}, nmodes=10, startmode=1, fitweight=None, center=(-0.5, -0.5), rad=-0.5, rec_zern=True, err=None):
"""
Fit **nmodes** Zernike modes to a **wavefront**.
The **wavefront** will be fit to Zernike modes for a circle with radius **rad** with origin at **center**. **weigh** is a weighting mask used when fitting the modes.
If **center** or **rad** are between 0 and -1, the values will be interpreted as fractions of the image shape.
**startmode** indicates the Zernike mode (Noll index) to start fitting with, i.e. ***startmode**=4 will skip piston, tip and tilt modes. Modes below this one will be set to zero, which means that if **startmode** == **nmodes**, the returned vector will be all zeroes. This parameter is intended to ignore low order modes when fitting (piston, tip, tilt) as these can sometimes not be derived from data.
If **err** is an empty list, it will be filled with measures for the fitting error:
1. Mean squared difference
2. Mean absolute difference
3. Mean absolute difference squared
This function uses **zern_data** as cache. If this is not given, it will be generated. See calc_zern_basis() for details.
@param [in] wavefront Input wavefront to fit
@param [in] zern_data Zernike basis cache
@param [in] nmodes Number of modes to fit
@param [in] startmode Start fitting at this mode (Noll index)
@param [in] fitweight Mask to use as weights when fitting
@param [in] center Center of Zernike modes to fit
@param [in] rad Radius of Zernike modes to fit
@param [in] rec_zern Reconstruct Zernike modes and calculate errors.
@param [out] err Fitting errors
@return Tuple of (wf_zern_vec, wf_zern_rec, fitdiff) where the first element is a vector of Zernike mode amplitudes, the second element is a full 2D Zernike reconstruction and the last element is the 2D difference between the input wavefront and the full reconstruction.
@see See calc_zern_basis() for details on **zern_data** cache
"""
if (rad < -1 or min(center) < -1):
raise ValueError("illegal radius or center < -1")
elif (rad > 0.5*max(wavefront.shape)):
raise ValueError("radius exceeds wavefront shape?")
elif (max(center) > max(wavefront.shape)-rad):
raise ValueError("fitmask shape exceeds wavefront shape?")
elif (startmode < 1):
raise ValueError("startmode<1 is not a valid Noll index")
# Convert rad and center if coordinates are fractional
if (rad < 0):
rad = -rad * min(wavefront.shape)
if (min(center) < 0):
center = -np.r_[center] * min(wavefront.shape)
# Make cropping slices to select only central part of the wavefront
xslice = slice(center[0]-rad, center[0]+rad)
yslice = slice(center[1]-rad, center[1]+rad)
# Compute Zernike basis if absent
if (not zern_data.has_key('modes')):
tmp_zern = calc_zern_basis(nmodes, rad)
zern_data['modes'] = tmp_zern['modes']
zern_data['modesmat'] = tmp_zern['modesmat']
zern_data['covmat'] = tmp_zern['covmat']
zern_data['covmat_in'] = tmp_zern['covmat_in']
zern_data['mask'] = tmp_zern['mask']
# Compute Zernike basis if insufficient
elif (nmodes > len(zern_data['modes']) or
zern_data['modes'][0].shape != (2*rad, 2*rad)):
tmp_zern = calc_zern_basis(nmodes, rad)
# This data already exists, overwrite it with new data
zern_data['modes'] = tmp_zern['modes']
zern_data['modesmat'] = tmp_zern['modesmat']
zern_data['covmat'] = tmp_zern['covmat']
zern_data['covmat_in'] = tmp_zern['covmat_in']
zern_data['mask'] = tmp_zern['mask']
zern_basis = zern_data['modes'][:nmodes]
zern_basismat = zern_data['modesmat'][:nmodes]
grid_mask = zern_data['mask']
wf_zern_vec = 0
grid_vec = grid_mask.reshape(-1)
if (fitweight != None):
# Weighed LSQ fit with data. Only fit inside grid_mask
# Multiply weight with binary mask, reshape to vector
weight = ((fitweight[yslice, xslice])[grid_mask]).reshape(1,-1)
# LSQ fit with weighed data
wf_w = ((wavefront[yslice, xslice])[grid_mask]).reshape(1,-1) * weight
#wf_zern_vec = np.dot(wf_w, np.linalg.pinv(zern_basismat[:, grid_vec] * weight)).ravel()
# This is 5x faster:
wf_zern_vec = np.linalg.lstsq((zern_basismat[:, grid_vec] * weight).T, wf_w.ravel())[0]
else:
# LSQ fit with data. Only fit inside grid_mask
# Crop out central region of wavefront, then only select the orthogonal part of the Zernike modes (grid_mask)
wf_w = ((wavefront[yslice, xslice])[grid_mask]).reshape(1,-1)
#wf_zern_vec = np.dot(wf_w, np.linalg.pinv(zern_basismat[:, grid_vec])).ravel()
# This is 5x faster
wf_zern_vec = np.linalg.lstsq(zern_basismat[:, grid_vec].T, wf_w.ravel())[0]
wf_zern_vec[:startmode-1] = 0
# Calculate full Zernike phase & fitting error
if (rec_zern):
wf_zern_rec = calc_zernike(wf_zern_vec, zern_data=zern_data, rad=min(wavefront.shape)/2)
fitdiff = (wf_zern_rec - wavefront[yslice, xslice])
fitdiff[grid_mask == False] = fitdiff[grid_mask].mean()
else:
wf_zern_rec = None
fitdiff = None
if (err != None):
# For calculating scalar fitting qualities, only use the area inside the mask
fitresid = fitdiff[grid_mask == True]
err.append((fitresid**2.0).mean())
err.append(np.abs(fitresid).mean())
err.append(np.abs(fitresid).mean()**2.0)
return (wf_zern_vec, wf_zern_rec, fitdiff)
def calc_zernike(zern_vec, rad, zern_data={}, mask=True):
"""
Construct wavefront with Zernike amplitudes **zern_vec**.
Given vector **zern_vec** with the amplitude of Zernike modes, return the reconstructed wavefront with radius **rad**.
This function uses **zern_data** as cache. If this is not given, it will be generated. See calc_zern_basis() for details.
If **mask** is True, set everything outside radius **rad** to zero, this is the default and will use orthogonal Zernikes. If this is False, the modes will not be cropped.
@param [in] zern_vec 1D vector of Zernike amplitudes
@param [in] rad Radius for Zernike modes to construct
@param [in] zern_data Zernike basis cache
@param [in] mask If True, set everything outside the Zernike aperture to zero, otherwise leave as is.
@see See calc_zern_basis() for details on **zern_data** cache and **mask**
"""
# Compute Zernike basis if absent
if (not zern_data.has_key('modes')):
tmp_zern = calc_zern_basis(len(zern_vec), rad)
zern_data['modes'] = tmp_zern['modes']
zern_data['modesmat'] = tmp_zern['modesmat']
zern_data['covmat'] = tmp_zern['covmat']
zern_data['covmat_in'] = tmp_zern['covmat_in']
zern_data['mask'] = tmp_zern['mask']
# Compute Zernike basis if insufficient
elif (len(zern_vec) > len(zern_data['modes'])):
tmp_zern = calc_zern_basis(len(zern_vec), rad)
# This data already exists, overwrite it with new data
zern_data['modes'] = tmp_zern['modes']
zern_data['modesmat'] = tmp_zern['modesmat']
zern_data['covmat'] = tmp_zern['covmat']
zern_data['covmat_in'] = tmp_zern['covmat_in']
zern_data['mask'] = tmp_zern['mask']
zern_basis = zern_data['modes']
gridmask = 1
if (mask):
gridmask = zern_data['mask']
# Reconstruct the wavefront by summing modes
return reduce(lambda x,y: x+y[1]*zern_basis[y[0]] * gridmask, enumerate(zern_vec), 0)