In [1]:
import numpy as np
import astropy.units as u


from scipy.optimize import least_squares
from astropy.io import fits
from astropy.coordinates import SkyCoord  
from astropy.wcs.utils import fit_wcs_from_points, celestial_frame_to_wcs

In [2]:
def _linear_wcs_fit(params, lon, lat, x, y, w_obj):
    """
    Objective function for fitting linear terms.

    Parameters
    ----------
    params : array
        6 element array. First 4 elements are PC matrix, last 2 are CRPIX.
    lon, lat: array
        Sky coordinates.
    x, y: array
        Pixel coordinates
    w_obj: `~astropy.wcs.WCS`
        WCS object
    """
    pc = params[0:4]
    crpix = params[4:6]

    w_obj.wcs.pc = ((pc[0], pc[1]), (pc[2], pc[3]))
    w_obj.wcs.crpix = crpix
    lon2, lat2 = w_obj.wcs_pix2world(x, y, 0)

    lat_resids = lat - lat2
    lon_resids = lon - lon2
    # In case the longitude has wrapped around
    lon_resids = np.mod(lon_resids - 180.0, 360.0) - 180.0

    resids = np.concatenate((lon_resids * np.cos(np.radians(lat)), lat_resids))

    return resids

def mod_fit_wcs_from_points(
    xy, world_coords, proj_point="center", projection="TAN", plate_scales = (1, 1), crpix = (2400, 2400)):

    xp, yp = xy
    try:
        lon, lat = world_coords.data.lon.deg, world_coords.data.lat.deg
    except AttributeError:
        unit_sph = world_coords.unit_spherical
        lon, lat = unit_sph.lon.deg, unit_sph.lat.deg

    # verify input
    if (type(proj_point) != type(world_coords)) and (proj_point != "center"):
        raise ValueError(
            "proj_point must be set to 'center', or an"
            "`~astropy.coordinates.SkyCoord` object with "
            "a pair of points."
        )

    use_center_as_proj_point = str(proj_point) == "center"

    if not use_center_as_proj_point:
        assert proj_point.size == 1

    wcs = celestial_frame_to_wcs(frame=world_coords.frame, projection=projection)
    wcs.wcs.cdelt = plate_scales            
    
    # Change PC to CD, since cdelt will be set to 1
    if wcs.wcs.has_cd():
        wcs.wcs.pc = wcs.wcs.cd
        wcs.wcs.__delattr__("cd")

    # compute bounding box for sources in image coordinates:
    xpmin, xpmax, ypmin, ypmax = xp.min(), xp.max(), yp.min(), yp.max()

    # set pixel_shape to span of input points
    wcs.pixel_shape = (
        1 if xpmax <= 0.0 else int(np.ceil(xpmax)),
        1 if ypmax <= 0.0 else int(np.ceil(ypmax)),
    )

    # determine CRVAL from input
    close = lambda l, p: p[np.argmin(np.abs(l))]
    if use_center_as_proj_point:  # use center of input points
        sc1 = SkyCoord(lon.min() * u.deg, lat.max() * u.deg)
        sc2 = SkyCoord(lon.max() * u.deg, lat.min() * u.deg)
        pa = sc1.position_angle(sc2)
        sep = sc1.separation(sc2)
        midpoint_sc = sc1.directional_offset_by(pa, sep / 2)
        wcs.wcs.crval = (midpoint_sc.data.lon.deg, midpoint_sc.data.lat.deg)
        wcs.wcs.crpix = ((xpmax + xpmin) / 2.0, (ypmax + ypmin) / 2.0)
    else:  # convert units, initial guess for crpix
        proj_point.transform_to(world_coords)
        wcs.wcs.crval = (proj_point.data.lon.deg, proj_point.data.lat.deg)
        wcs.wcs.crpix = (
            close(lon - wcs.wcs.crval[0], xp + 1),
            close(lon - wcs.wcs.crval[1], yp + 1),
        )

    # fit linear terms, assign to wcs
    # use (1, 0, 0, 1) as initial guess, in case input wcs was passed in
    # and cd terms are way off.
    # Use bounds to require that the fit center pixel is on the input image
    if xpmin == xpmax:
        xpmin, xpmax = xpmin - 0.5, xpmax + 0.5
    if ypmin == ypmax:
        ypmin, ypmax = ypmin - 0.5, ypmax + 0.5

    p0 = np.concatenate([wcs.wcs.pc.flatten(), wcs.wcs.crpix.flatten()])
    fit = least_squares(
        _linear_wcs_fit,
        p0,
        args=(lon, lat, xp, yp, wcs),
        bounds=[
            [-np.inf, -np.inf, -np.inf, -np.inf, xpmin + 1, ypmin + 1],
            [np.inf, np.inf, np.inf, np.inf, xpmax + 1, ypmax + 1],
        ],
    )
    wcs.wcs.crpix = np.array(fit.x[4:6])
    wcs.wcs.pc = np.array(fit.x[0:4].reshape((2, 2)))

    # crvals = wcs.wcs_pix2world(crpix[0], crpix[1], 1)
    # crvals = tuple([float(crval) for crval in crvals])
    # wcs.wcs.crpix = crpix
    # wcs.wcs.crval = crvals
    
    return wcs

In [3]:
hdu = fits.open('combined_FUV_PC00F2_CPS_coo.dat.corr')

field_x = hdu[1].data['field_x']
field_y = hdu[1].data['field_y']
index_ra = hdu[1].data['index_ra']
index_dec = hdu[1].data['index_dec']
skycoo = SkyCoord(index_ra, index_dec, unit = 'deg')

In [4]:
default_wcs = fit_wcs_from_points((field_x, field_y), skycoo)
print(default_wcs)

WCS Keywords

Number of WCS axes: 2
CTYPE : 'RA---TAN'  'DEC--TAN'  
CRVAL : 17.37802453439897  -71.30617922027807  
CRPIX : 2314.364788026458  2327.2190940376995  
CD1_1 CD1_2  : 0.00011222117817702223  -2.8550310859426504e-05  
CD2_1 CD2_2  : 2.8570533913341137e-05  0.00011220567069389997  
NAXIS : 4158  4143


In [5]:
plate_scales = (0.4168705 / 3600, 0.416848964 / 3600)
test_wcs = mod_fit_wcs_from_points((field_x, field_y), skycoo, plate_scales = plate_scales)
print(test_wcs)

WCS Keywords

Number of WCS axes: 2
CTYPE : 'RA---TAN'  'DEC--TAN'  
CRVAL : 17.37802453439897  -71.30617922027807  
CRPIX : 2314.3647876048844  2327.219092787287  
PC1_1 PC1_2  : 0.9691168859463998  -0.24655407177549007  
PC2_1 PC2_2  : 0.24674145874149725  0.9690330281286604  
CDELT : 0.00011579736111111111  0.00011579137888888888  
NAXIS : 4158  4143


In [6]:
mod_wcs = test_wcs.deepcopy()
crvals = mod_wcs.wcs_pix2world(2400, 2400, 1)
crvals = tuple([float(crval) for crval in crvals])

mod_wcs.wcs.crpix = (2400, 2400)
mod_wcs.wcs.crval = crvals

print(default_wcs.wcs_pix2world(1000, 1000, 0 ))
print(test_wcs.wcs_pix2world(1000, 1000, 0 ))
print(mod_wcs.wcs_pix2world(1000, 1000, 0 ))

[array(17.03299742), array(-71.49219869)]
[array(17.03299742), array(-71.49219869)]
[array(17.03275768), array(-71.49215262)]
