diff --git a/webbpsf/gridded_library.py b/webbpsf/gridded_library.py index 8ffc6189..63209bae 100644 --- a/webbpsf/gridded_library.py +++ b/webbpsf/gridded_library.py @@ -460,8 +460,8 @@ def writeto(self, data, meta, detector): hdu.writeto(file, overwrite=self.overwrite) -def display_psf_grid(grid, zoom_in=True, figsize=(12, 12)): - """ Display a PSF grid in a pair of lpots +def display_psf_grid(grid, zoom_in=True, figsize=(14, 12), scale_range=1e-4): + """ Display a PSF grid in a pair of plots Shows the NxN grid in NxN subplots, repeated to show first the individual PSFs, and then their differences @@ -476,6 +476,8 @@ def display_psf_grid(grid, zoom_in=True, figsize=(12, 12)): ------- grid : photutils.GriddedPSFModel object The grid of PSFs to be displayed. + scale_range : float + Dynamic range for display scale. vmin will be set to this factor timex vmax. """ @@ -506,19 +508,24 @@ def show_grid_helper(grid, data, title="Grid of PSFs", vmax=0, vmin=0, scale='lo for ix in range(n): for iy in range(n): i = ix*n+iy - axes[n-1-iy, ix].imshow(data[i], vmax=vmax, vmin=vmin, norm=norm) + im = axes[n-1-iy, ix].imshow(data[i], vmax=vmax, vmin=vmin, norm=norm) axes[n-1-iy, ix].xaxis.set_visible(False) axes[n-1-iy, ix].yaxis.set_visible(False) axes[n-1-iy, ix].set_title("{}".format(tuple_to_int(grid.grid_xypos[i]))) if zoom_in: axes[n-1-iy,ix].use_sticky_edges = False axes[n-1-iy,ix].margins(x=-0.25, y=-0.25) - plt.suptitle("{} for {} in {} ".format(title, + plt.suptitle("{} for {} in {} \noversampling: {}x".format(title, grid.meta['detector'][0], - grid.meta['filter'][0]), fontsize=16) + grid.meta['filter'][0], grid.oversampling), fontsize=16) + + + cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.95) + cbar.set_label("Intensity, relative to PSF sum = 1.0") + vmax = grid.data.max() - vmin = vmax/1e4 + vmin = vmax*scale_range show_grid_helper(grid, grid.data, vmax=vmax, vmin=vmin) meanpsf = np.mean(grid.data, axis=0) diff --git a/webbpsf/optics.py b/webbpsf/optics.py index e709056a..89bd5798 100644 --- a/webbpsf/optics.py +++ b/webbpsf/optics.py @@ -1610,3 +1610,168 @@ def __init__(self, instrument, include_oversize=True, **kwargs): # No need to subclass any of the methods; it's sufficient to set the custom # amplitude mask attribute value. + + +# Alternative implementation that just reads OPDs from some file +class LookupTableFieldDependentAberration(poppy.OpticalElement): + """ Retrieve OPDs from a lookup table over many field points. + This is pretty much a hack, hard-coded for a specific data delivery from Ball! + Intended for WFR4 data prep, not generalized beyond that. + + Parameters + ----------- + add_niriss_defocus: bool + add 0.8 microns PTV defocus to NIRISS only (for WFR4 test) + rm_ptt: bool + Remove piston, tip, tilt + rm_center_ptt : bool + If rm_ptt, use the center value for each detector rather than per field point + + + Usage: + ------ + inst = webbpsf.NIRCam() # or any other SI + inst._si_wfe_class = LookupTableFieldDependentAberration() + + """ + + def __init__(self, instrument, field_points_file=None, phasemap_file=None, + add_niriss_defocus=True, rm_ptt=True, rm_center_ptt=True, **kwargs): + super().__init__( + name="Aberrations", + **kwargs + ) + import warnings + + self.instrument = instrument + self.instr_name = instrument.name + + self.rm_ptt = rm_ptt + + if self.instr_name =='NIRCam': + self.instr_name += " "+self.instrument.module + elif self.instr_name == 'FGS': + self.instr_name = self.instrument.detector + + self.tel_coords = instrument._tel_coords() + + # load the OPD lookup map table (datacube) here + + fp_path = '/ifs/jwst/tel/wfr4_mirage_sims/phase_maps_from_ball/' + if field_points_file is None: + field_points_file = fp_path + 'The_Field_Coordinates.txt' + if phasemap_file is None: + phasemap_file = fp_path + 'phase_maps.fits' + + self.table = Table.read(field_points_file, format='ascii', names=('V2', 'V3')) + + self.yoffset = -7.8 + + self.table['V3'] += self.yoffset # Correct from -YAN to actual V3 + + self.phasemaps = fits.getdata(phasemap_file) + import webbpsf.constants + self.phasemap_pixelscale = webbpsf.constants.JWST_CIRCUMSCRIBED_DIAMETER/256 * units.meter / units.pixel + + # Determine the pupil sampling of the first aperture in the + # instrument's optical system + if isinstance(instrument.pupil, poppy.OpticalElement): + # This branch needed to handle the OTE Linear Model case + npix = instrument.pupil.shape[0] + self.pixelscale = instrument.pupil.pixelscale + else: + # these branches to handle FITS files, by name or as an object + if isinstance(instrument.pupil, fits.HDUList): + pupilheader = instrument.pupil[0].header + else: + pupilfile = os.path.join(instrument._datapath, "OPD", instrument.pupil) + pupilheader = fits.getheader(pupilfile) + + npix = pupilheader['NAXIS1'] + self.pixelscale = pupilheader['PUPLSCAL'] * units.meter / units.pixel + + #self.ztable = self.ztable_full[self.ztable_full['instrument'] == lookup_name] + + # Figure out the closest field point + + telcoords_am = self.tel_coords.to(units.arcmin).value + print(f"Requested field point has coord {telcoords_am}") + v2 = self.table['V2'] + + v3 = self.table['V3'] + r = np.sqrt((telcoords_am[0] - v2) ** 2 + (telcoords_am[1] - v3) ** 2) + closest = np.argmin(r) # if there are two field points with identical coords or equal distance just one is returned + + print(f"Closest field point is row {closest}: {self.table[closest]} ") + + # Save closest ISIM CV3 WFE measured field point for reference + self.row = self.table[closest] + + + self.name = "{instrument} at V2V3=({v2:.2f},{v3:.2f}) Lookup table WFE from ({v2t:.2f},{v3t:.2f})".format( + instrument=self.instr_name, + v2=telcoords_am[0], v3=telcoords_am[1], + v2t=self.row['V2'], v3t=self.row['V3'] + + ) + self.si_wfe_type = ("Lookup Table", + "SI + OTE WFE from supplied lookup table of phase maps.") + + # Retrieve the phase map + + phasemap = self.phasemaps[closest] + + if phasemap.shape[0] != 256: + raise NotImplementedError("Hard coded for Ball delivery of 256 pixel phase maps") + + # Resample to 1024 across, by replicating each pixel into a 4x4 block + resample_factor = 4 + phasemap_big = np.kron(phasemap, np.ones((resample_factor,resample_factor))) + + self.opd = phasemap_big * 1e-6 # Convert to microns + self.amplitude = np.ones_like(self.opd) + + if rm_ptt: + apmask = self.opd != 0 + if rm_center_ptt: + # Remove the PTT values at the center of each instrument, rather than per field point. This + # leaves in the field dependence but takes out the bulk offset + # These values are just a precomputed lookup table of the coefficients returned by the + # opd_expand_nonorthonormal call just below, for the center field point on each. + + coeffs_per_si = {"NIRCam A": [-3.50046880e-10, -7.29120639e-08, -1.39751567e-08], + "NIRCam B": [-2.45093780e-09, -2.51804001e-07, -2.64821753e-07], + "NIRISS": [-1.49297771e-09, -2.11111038e-06, -3.99881993e-07], + "FGS1": [ 9.86180620e-09, -5.94041500e-07, 1.18953161e-06], + "FGS2": [ 4.84327424e-09, -8.24285481e-07, 5.09791593e-07], + "MIRI": [-8.75766849e-09, -1.27850277e-06, -1.03467567e-06],} + coeffs = coeffs_per_si[self.instr_name] + else: + coeffs = poppy.zernike.opd_expand_nonorthonormal(self.opd, aperture=apmask, nterms=3) + ptt_only = poppy.zernike.opd_from_zernikes(coeffs, aperture=apmask, npix=self.opd.shape[0], outside=0) + self.opd -= ptt_only + print(f"Removing piston, tip, tilt from the input wavefront. Coeffs for {self.instr_name}: {coeffs},") + + if add_niriss_defocus and self.instr_name=='NIRISS': + # The Ball delivery was supposed to have defocused NIRISS for rehearsal purposes, but didn't. + # So fix that here. + self.instrument.options['defocus_waves'] = 0.8 + self.instrument.options['defocus_wavelength'] = 1e-6 # Add 0.8 microns PTV defocus + warnings.warn("Adding defocus=0.8 waves for NIRISS!") + + + def header_keywords(self): + """ Return info we would like to save in FITS header of output PSFs + """ + from collections import OrderedDict + keywords = OrderedDict() + keywords['SIWFETYP'] = self.si_wfe_type + keywords['SIWFEFPT'] = ( f"{self.row['V2']:.3f}, {self.row['V3']:.3f}", "Closest lookup table meas. field point") + return keywords + + # wrapper just to change default vmax + def display(self, *args, **kwargs): + if 'opd_vmax' not in kwargs: + kwargs.update({'opd_vmax': 2.5e-7}) + + return super().display(*args, **kwargs)