In [None]:
import datetime
import warnings
from astropy.io import fits
from astropy.io import ascii as asc
import webbpsf
import numpy as np
import pprint
import photutils
from matplotlib.colors import LogNorm
import pysiaf
import poppy
import scipy
import multiprocessing
import yaml

from mirage.seed_image import catalog_seed_image
%pylab inline

### Goal: Translate the tip/tilt from a WebbPSF simulation into a pixel vector 

In [None]:
nc = webbpsf.NIRCam()
nc, ote = webbpsf.enable_adjustable_ote(nc)

plt.imshow(ote.opd)
plt.show()
ote.print_state()

In [None]:
psf = nc.calc_psf(nlambda=30, oversample=1, fov_pixels=1024, add_distortion=False)
webbpsf.display_psf(psf, vmin=1e-8, vmax=1e-5)

In [None]:
# Add some tilt
ote.reset()
ote.move_seg_local('A4', xtilt=-30, ytilt=0)
ote.print_state()

psf_tilted = nc.calc_psf(nlambda=30, oversample=1, fov_pixels=1024, add_distortion=False)
webbpsf.display_psf(psf_tilted, vmin=1e-8, vmax=1e-5)

In [None]:
# Add some tilt and piston?
ote.reset()
ote.move_seg_local('A1', xtilt=20, ytilt=10, piston=100)
ote.print_state()

psf_tilted_pistoned = nc.calc_psf(nlambda=30, oversample=1, fov_pixels=1024, add_distortion=False)
webbpsf.display_psf(psf_tilted_pistoned, vmin=1e-8, vmax=1e-5)

sources = photutils.find_peaks(psf_tilted_pistoned[1].data, 1e-4, box_size=50, subpixel=True)
print(sources)

# So it looks like the piston only has an effect on the order of ~10 pixels?

In [None]:
psf_tilted.info()

In [None]:
psf_data = psf_tilted[1].data

In [None]:
# Find the location of the two PSFs here

sources = photutils.find_peaks(psf_data,1e-4, box_size=50, subpixel=True)
print(sources)

im = plt.imshow(psf_data, norm=LogNorm(), clim=(1e-8, 1e-4))
plt.scatter(sources['x_centroid'], sources['y_centroid'], color='r', marker='+')
plt.colorbar(im)
plt.show()

In [None]:
x_pix_distance = sources['x_centroid'][1] - sources['x_centroid'][0]
y_pix_distance = sources['y_centroid'][1] - sources['y_centroid'][0]
print(x_pix_distance, y_pix_distance)
print(x_pix_distance/10, y_pix_distance/20)

## Quantify the relation between tilt and displacement

In [None]:
# Collect information on how the tilt-to-displacement ratio changes over the detector
random_tilts = np.random.random((20, 2)) * 60 - 30

results = []
for i, (x_tilt, y_tilt) in enumerate(random_tilts):
    ote.reset()
    ote.move_seg_local('C1', xtilt=x_tilt, ytilt=y_tilt)
    psf = nc.calc_psf(nlambda=30, oversample=1, fov_pixels=1024, add_distortion=False)
    psf_data = psf[1].data
    sources = photutils.find_peaks(psf_data, 1e-4, box_size=50, subpixel=True)
    
    im = plt.imshow(psf_data, norm=LogNorm(), clim=(1e-8, 1e-4))
    plt.scatter(sources['x_centroid'], sources['y_centroid'], color='r', marker='+')
    plt.colorbar(im)
    plt.show()
    
    results.append([x_tilt, y_tilt, sources['x_centroid'], sources['y_centroid']])
    print(x_tilt, y_tilt)
    print('Completed {}/{} calculations'.format(i + 1, len(random_tilts)))
    
print(results)

In [None]:
def _convert_control_to_global(segment, xtilt, ytilt):
    """Convert vectors coordinates in the local segment control
    coordinate system to NIRCam detector X and Y coordinates.
    At least proportionally."""
    control_xaxis_rotations = {
        'A1': 180, 'A2': 120, 'A3': 60, 'A4': 0,'A5': -60, 
        'A6': -120, 'B1': 0, 'C1': 60, 'B2': -60, 'C2': 0, 
        'B3': -120, 'C3': -60, 'B4': -180, 'C4': -120, 
        'B5': -240, 'C5': -180, 'B6': -300, 'C6': -240
    }

    x_rot = control_xaxis_rotations[segment[:2]]  # degrees
    x_rot_rad = x_rot * np.pi / 180 # radians
    print('Rotating by {} deg ({} rad) to account for segment {}'.format(x_rot, x_rot_rad, segment))
    
    y_det = (xtilt * np.cos(x_rot_rad)) - (ytilt * np.sin(x_rot_rad))
    x_det = (xtilt * np.sin(x_rot_rad)) + (ytilt * np.cos(x_rot_rad))
    
    return x_det, y_det

In [None]:
tilt_onto_x = [abs(_convert_control_to_global('C1', x_tilt, y_tilt)[0]) for x_tilt, y_tilt, _, _ in results]
x_displacement = [abs(x_centroids[1] - x_centroids[0]) for _, _, x_centroids, _ in results]

tilt_onto_y = [abs(_convert_control_to_global('C1', x_tilt, y_tilt)[1]) for x_tilt, y_tilt, _, _ in results]
y_displacement = [abs(y_centroids[1] - y_centroids[0]) for _, _, _, y_centroids in results]

fig, [ax1, ax2] = plt.subplots (1, 2, figsize=(12, 5))
x = np.linspace(0, 35, 10)
ax1.scatter(tilt_onto_x, x_displacement)
ax1.plot(x, x * x_slope + x_intercept)
ax1.set_title('X Conversion')
ax1.set_xlabel('|Tilt onto X| (microns)')
ax1.set_ylabel('X Displacement (pixels)')

ax2.scatter(tilt_onto_y, y_displacement)
ax2.plot(x, x * y_slope + y_intercept)
ax2.set_title('Y Conversion')
ax2.set_xlabel('|Tilt onto Y| (microns)')
ax2.set_ylabel('Y Displacement (pixels)')

plt.show()


In [None]:
def generate_random_tilts(i, plot=False):
    x_tilt, y_tilt = random_tilts[i]
    
    nc = webbpsf.NIRCam()
    nc, ote = webbpsf.enable_adjustable_ote(nc)
    ote.move_seg_local(segment , xtilt=x_tilt, ytilt=y_tilt)
    psf = nc.calc_psf(nlambda=30, oversample=1, fov_pixels=1024, add_distortion=False)
    psf_data = psf[1].data
    sources = photutils.find_peaks(psf_data, 1e-4, box_size=50, subpixel=True)

    if plot:
        im = plt.imshow(psf_data, norm=LogNorm(), clim=(1e-8, 1e-4))
        plt.scatter(sources['x_centroid'], sources['y_centroid'], color='r', marker='+')
        plt.colorbar(im)
        plt.show()

    parameters = [x_tilt, y_tilt, sources['x_centroid'], sources['y_centroid']]
    print('    Completed {}/{} simulations'.format(i + 1, len(random_tilts)))
    
    return parameters
    

In [None]:
# Start time: 8:57
# Calculate the slope and intercept for each segment
tilt_to_displacement_params = {}

for segment in webbpsf.constants.SEGNAMES_WSS:
    segment = segment[:2]
    
    random_tilts = np.random.random((20, 2)) * 60 - 30
    
    print('Calculating segment {}'.format(segment))

    # Use multiprocessing to generate all the PSFs
    p = multiprocessing.Pool(6)
    results = p.map(generate_random_tilts, range(20))
    
    if len(results) != 20:
        print('Results is only {} long'.format(len(results)))
        print(results)
        raise ValueError('Something funky happened with the multiprocessing.')
    
    tilt_onto_y = [abs(_convert_control_to_global(segment, x_tilt, y_tilt)[1]) for x_tilt, y_tilt, _, _ in results]
    y_displacement = [abs(y_centroids[1] - y_centroids[0]) for _, _, _, y_centroids in results]
    tilt_onto_x = [abs(_convert_control_to_global(segment, x_tilt, y_tilt)[0]) for x_tilt, y_tilt, _, _ in results]
    x_displacement = [abs(x_centroids[1] - x_centroids[0]) for _, _, x_centroids, _ in results]
    
    x_slope, x_intercept, _, _, _ = scipy.stats.linregress(tilt_onto_x, x_displacement)
    y_slope, y_intercept, _, _, _ = scipy.stats.linregress(tilt_onto_y, y_displacement)
    print('    X slope and intercept:', x_slope, x_intercept)
    print('    Y slope and intercept:', y_slope, y_intercept)
    
    fig, [ax1, ax2] = plt.subplots (1, 2, figsize=(12, 5))
    x = np.linspace(0, 35, 10)
    ax1.scatter(tilt_onto_x, x_displacement)
    ax1.plot(x, x * x_slope + x_intercept)
    ax1.set_title('X Conversion')
    ax1.set_xlabel('|Tilt onto X| (microns)')
    ax1.set_ylabel('X Displacement (pixels)')

    ax2.scatter(tilt_onto_y, y_displacement)
    ax2.plot(x, x * y_slope + y_intercept)
    ax2.set_title('Y Conversion')
    ax2.set_xlabel('|Tilt onto Y| (microns)')
    ax2.set_ylabel('Y Displacement (pixels)')

    plt.show()
    
    params = {}
    params['x'] = [x_slope, x_intercept]
    params['y'] = [y_slope, y_intercept]
    tilt_to_displacement_params[segment] = params


In [None]:
def calc_location_after_tilt(segment, xtilt, ytilt):
    with open('/user/lchambers/OTECommSims/tilt_conversion_parameters.yaml', 'r') as stream:
        d = yaml.load(stream)

    all_slopes = []
    all_intercepts = []
    for _, seg_dict in d.items():
        for axis, axis_dict in seg_dict.items():
            all_intercepts.append(axis_dict['intercept'])
            all_slopes.append(axis_dict['slope'])

    slope = np.average(all_slopes)
    intercept = np.average(all_intercepts)

    print('Slope = {} +/- {}'.format(slope, np.std(all_slopes)))
    print('Intercept = {} +/- {}'.format(intercept, np.std(all_intercepts)))
    
    slope = round(slope, 1)
    intercept = round(intercept, 0)
    
    tilt_onto_x, tilt_onto_y = _convert_control_to_global(segment, xtilt, ytilt)
    print('Tilt projected onto V2/V3:', tilt_onto_x, tilt_onto_y)
    
    x_displacement = tilt_onto_x * slope + intercept
    y_displacement = tilt_onto_y * slope + intercept
    
    return -x_displacement, -y_displacement

In [None]:
with open('/user/lchambers/OTECommSims/tilt_conversion_parameters.yaml', 'r') as stream:
    d = yaml.load(stream)

all_slopes = []
all_intercepts = []
for _, seg_dict in d.items():
    for axis, axis_dict in seg_dict.items():
        all_intercepts.append(axis_dict['intercept'])
        all_slopes.append(axis_dict['slope'])

slope = np.average(all_slopes)
intercept = np.average(all_intercepts)

print(slope, intercept)

In [None]:
plt.scatter(range(36), all_slopes)
plt.axhline(y=slope)

In [None]:
# TEST!!!!

# Add some tilt
segment = 'A4'
xtilt = 3
ytilt = 30
ote.reset()
ote.move_seg_local(segment, xtilt=xtilt, ytilt=ytilt)
# ote.print_state()

psf_tilted = nc.calc_psf(nlambda=30, oversample=1, fov_pixels=1024, add_distortion=False)
# webbpsf.display_psf(psf_tilted, vmin=1e-8, vmax=1e-5)
psf_data = psf_tilted[1].data

x_displacement, y_displacement = calc_location_after_tilt(segment, xtilt, ytilt)
print(x_displacement, y_displacement)

sources = photutils.find_peaks(psf_data, 1e-4, box_size=50, subpixel=True)
print(sources['x_centroid', 'y_centroid'])

In [None]:
im = plt.imshow(psf_data, norm=LogNorm(), clim=(1e-8, 1e-4))
# plt.scatter(512, 512, color='blue', marker='*')
plt.scatter(512 + x_displacement, 512 + y_displacement, color='r', marker='+')
# plt.scatter(512 + y_displacement, 512 - x_displacement, color='r', marker='X')
plt.scatter(sources['x_centroid'], sources['y_centroid'], color='grey', marker='+')

plt.colorbar(im)
plt.show()

# Make diagram of OTE with segment labels and coordinate systems

In [None]:
import matplotlib
import numpy as np
import poppy
import webbpsf

In [None]:
def get_transmission_no_struts(self, wave):
    segpaths = {}
#     strutpaths = []
    for segname, vertices in self.segdata:
        segpaths[segname] = matplotlib.path.Path(vertices)
#     for strutname, vertices in self.strutdata:
#         strutpaths.append(matplotlib.path.Path(vertices))

    y, x = wave.coordinates()
    pts = np.asarray([a for a in zip(x.flat, y.flat)])
    npix = wave.shape[0]
    out = np.zeros((npix, npix))

    # paint the segments 1 but leave out the SMSS struts
    for segname, p in segpaths.items():
        res = p.contains_points(pts)
        res.shape = (npix, npix)
        out[res] = 1 if not self.label_segments else int(segname.split('-')[1])
#     for p in strutpaths:
#         res = p.contains_points(pts)
#         res.shape = (npix, npix)
#         out[res] = 0
    return out

In [None]:
wave = poppy.Wavefront()
primary = webbpsf.optics.WebbPrimaryAperture(label_segments=False)
pupil = get_transmission_no_struts(primary, wave)

In [None]:
plt.imshow(pupil)
fits.writeto('JWST_pupil_no_struts.fits', pupil)

In [None]:
def convert_meters_to_pixels(x_meters, y_meters):
    """Or at least I think it's meters."""
    x_pixels = x_meters * (1024 / (3.99609375*2)) + 512
    y_pixels = y_meters * (1024 / (3.99609375*2)) + 512
    
    return x_pixels, y_pixels

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(15,15))

cmap = plt.cm.YlOrRd
cmap.set_under(color='white')
ax.imshow(pupil, cmap=cmap, vmin=0.5, vmax=2.2)
ax.invert_yaxis()
for segment, center in primary.seg_centers.items():
    x_center, y_center = convert_meters_to_pixels(*center)
    ax.text(x_center, y_center, segment[:2], color='black', 
            horizontalalignment='center', verticalalignment='center',
            fontweight='bold', fontsize=25)
    
    x_rot = ote._control_xaxis_rotations[segment[:2]] * np.pi / 180
    x_x = 100 * np.cos(x_rot)
    x_y = 100 * np.sin(x_rot)
    y_rot = (ote._control_xaxis_rotations[segment[:2]] + 90) * np.pi / 180
    y_x = 100 * np.cos(y_rot)
    y_y = 100 * np.sin(y_rot)
    
    seg_x_axis = ax.quiver([x_center], [y_center], [x_x], [x_y], color=['white'], angles='xy', scale_units='xy', scale=1.5, width=5e-3)
    seg_y_axis = ax.quiver([x_center], [y_center], [y_x], [y_y], color=['white'], angles='xy', scale_units='xy', scale=1.5, width=5e-3)
    
    x_label = ax.text(x_center + 0.7 * x_x, y_center + 0.7 * x_y, 'X', color='black', 
                        horizontalalignment='center', verticalalignment='center', fontsize=20)
    y_label = ax.text(x_center + 0.7 * y_x, y_center + 0.7 * y_y, 'Y', color='black', 
                        horizontalalignment='center', verticalalignment='center', fontsize=20)
    

# Add V2/V3 vectors
v2 = ax.quiver([512], [512], [100], [0], angles='xy', scale_units='xy', scale=1.7, width=5e-3)
v3 = ax.quiver([512], [512], [0], [100], angles='xy', scale_units='xy', scale=1.7, width=5e-3)
v2_label = ax.text(512 + 0.7 * 100, 512 + 0.7 * 0, 'V2', color='black', 
                    horizontalalignment='center', verticalalignment='center', fontsize=20)
v3_label = ax.text(512 + 0.7 * 0, 512 + 0.7 * 100, 'V3', color='black', 
                    horizontalalignment='center', verticalalignment='center', fontsize=20)
    
    
ax.set_xlabel('X [V2]') 
ax.set_ylabel('Y [V3]')
ax.set_xlim(90, 910)
ax.set_ylim(50, 950)
ax.axis('off')
plt.savefig('JWST Segments and Coordinates.png', transparent=True)
plt.show()

## Figure out how to modify the point source list made in `catalog_seed_image`

In [None]:
ps_file = '/Users/lchambers/TEL/mirage/mirage/catalogs/2MASS_RA146.88deg_Dec63.25deg.list'
gtab = asc.read(ps_file)
# Look at the header lines to see if inputs
# are in units of pixels or RA, Dec
pixel_flag = False
try:
    if 'position_pixels' in gtab.meta['comments'][0:4]:
        pixel_flag = True
except:
    pass
# Check to see if magnitude system is specified
# in the comments. If not default to AB mag
msys = 'abmag'
condition = ('stmag' in gtab.meta['comments'][0:4]) | ('vegamag' in gtab.meta['comments'][0:4])
if condition:
    msys = [l for l in gtab.meta['comments'][0:4] if 'mag' in l][0]
    msys = msys.lower()
    
lines = gtab

In [None]:
ra = 146.9190376246782
dec = 63.24114957538154

In [None]:
def shift_sources_by_offset(lines, segment_offset, pixelflag):
    print('    Shifting point source locations by pixel offset {}'.format(segment_offset))

    shifted_lines = lines.copy()
    shifted_lines.remove_rows(np.arange(0, len(shifted_lines)))
    
    ra = 146.9190376246782
    dec = 63.24114957538154

    V2ref_arcsec = aperture.V2Ref
    V3ref_arcsec = aperture.V3Ref
    position_angle = 111.
    print('    Position angle = ', position_angle)
    attitude_ref = pysiaf.utils.rotations.attitude(
        V2ref_arcsec, V3ref_arcsec,  ra, dec, position_angle
    )

    # Convert X/Y pixels (detector frame) to RA/Dec (sky frame)
    x_displacement, y_displacement = segment_offset
    for line in lines:
        x_or_RA, y_or_Dec = line['x_or_RA', 'y_or_Dec']
#         print('original RA/Dec:', x_or_RA, y_or_Dec)
        if not pixelflag:
            # Convert RA/Dec (sky frame) to X/Y pixels (raw frame)
            v2, v3 = pysiaf.utils.rotations.getv2v3(attitude_ref, x_or_RA, y_or_Dec)
#             x_or_RA, y_or_Dec = aperture.tel_to_det(v2, v3)
#             print('converted to X/Y:',x_or_RA, y_or_Dec)

        # Get the appropriate pixel scale from pysiaf
        x_pixel_scale = aperture.XSciScale  # arcsec/pixel
        y_pixel_scale = aperture.YSciScale  # arcsec/pixel
        
        # Convert the pixel displacement into angle
        x_arcsec = x_displacement * x_pixel_scale  # arcsec
        y_arcsec = y_displacement * y_pixel_scale  # arcsec
        
        v2 -= x_arcsec
        v3 += y_arcsec
        
        
#         x = x_or_RA + x_displacement # NRC pixels
#         y = y_or_Dec + y_displacement # NRC pixels
# #         print('Offset X/Y:', x, y)

#         v2, v3 = aperture.det_to_tel(x, y)
        ra, dec = pysiaf.utils.rotations.pointing(attitude_ref, v2, v3)
#         print('Offset RA/Dec', ra, dec)
#         print()

        # TODO: need to be smarter about which magnitude to use here
        shifted_lines.add_row([ra, dec, line['magnitude']])

    return shifted_lines

In [None]:
attitude_ref = pysiaf.utils.rotations.attitude(
        aperture.V2Ref, aperture.V3Ref,  ra, dec, 111.
    )
v2, v3 = pysiaf.utils.rotations.getv2v3(attitude_ref, lines['x_or_RA'], lines['y_or_Dec'])
plt.scatter(v2, v3)
plt.show()
x_or_RA, y_or_Dec = aperture.tel_to_det(v2, v3)
plt.scatter(x_or_RA, y_or_Dec)
plt.xlim(-16000, 16000)
plt.show()

# for i in range(len(v2)):
#     print(v2[i], v3[i], x_or_RA[i], y_or_Dec[i])
plt.scatter(v2, x_or_RA)
plt.xlabel('V2')
plt.ylabel('X Pixels')
plt.show()

In [None]:
shifted_lines = shift_sources_by_offset(lines, (-852.1060000000001, -278.58600000000007), pixel_flag)

In [None]:
plt.scatter(shifted_lines['x_or_RA'], shifted_lines['y_or_Dec'], label='shifted locations')
print(lines)
print(shifted_lines)

In [None]:
plt.scatter(shifted_lines['x_or_RA'], shifted_lines['y_or_Dec'], label='shifted locations')
plt.scatter(lines['x_or_RA'], lines['y_or_Dec'], label='original locations')
plt.scatter(ra, dec, label='target', marker='+')
plt.legend()

In [None]:
siaf = pysiaf.Siaf('NIRCam')
aperture = siaf['NRCA3_FULL']

aperture.det_to_tel(-100, 350)

In [None]:
cat = catalog_seed_image.Catalog_seed()
cat.paramfile = '/Users/lchambers/TEL/mirage/OTECommissioning/OTE01_reducedmosaic/yamls/TEST_jw01134001001_01101_00001_nrca3.yaml'
cat.make_seed()

In [None]:
def show(array,title,min=0,max=1000):
    plt.figure(figsize=(12,12))
    plt.imshow(array,clim=(min,max))
    plt.title(title)
    plt.colorbar(fraction=0.046, pad=0.04).set_label('DN$^{-}$/s')
    plt.show()

In [None]:
show(cat.seedimage, 'sdfjlskfjklsdjflsdjf', max=10)
