In [None]:
import iris_pipeline

iris_pipeline.monkeypatch_jwst_datamodels()  

In [None]:
import astropy.units as u

In [None]:
import numpy as np
from jwst import datamodels

In [None]:
from test_utils import get_data_from_url

In [None]:
raw_science_filename = get_data_from_url("17903858")

In [None]:
input_model = datamodels.open(raw_science_filename)

The input model doesn't have the WCS ASDF extension.

In [None]:
assert not hasattr(input_model.meta, "wcs")

Also it doesn't have FITS WCS keywords,
we add them here.

In [None]:
input_model.meta.wcsinfo.ctype1 = "RA---TAN"
input_model.meta.wcsinfo.ctype2 = "DEC--TAN"
input_model.meta.wcsinfo.cdelt1 = 1e-6
input_model.meta.wcsinfo.cdelt2 = 1e-6
input_model.meta.wcsinfo.crval1 = 265
input_model.meta.wcsinfo.crval2 = -29
input_model.meta.wcsinfo.crpix1 = 2048.12
input_model.meta.wcsinfo.crpix2 = 2048.12

Assign WCS for now just parses the `wcsinfo` metadata and creates a `gwcs.WCS` instance with the proper coordinate transformations using `astropy.modeling`.

In [None]:
output_model = iris_pipeline.assign_wcs.AssignWcsStep.call(input_model)                                                      

In [None]:
from astropy.tests.helper import assert_quantity_allclose

The WCS object can be called with pixel numbers and returns back the coordinates in the sky,
here we double-check that the reference pixel is reprojected back to the right input sky location.

In [None]:
assert_quantity_allclose(
    (input_model.meta.wcsinfo.crval1*u.deg, input_model.meta.wcsinfo.crval2*u.deg),
    output_model.meta.wcs(input_model.meta.wcsinfo.crpix1*u.pix,input_model.meta.wcsinfo.crpix2*u.pix)
)                                                                

## Compare with the standard `astropy` WCS

We can write the file with the WCS keywords in the header to a FITS file and then parse it with `astropy.wcs` and compare the transformation at the 4 corners of the array between the 2 WCS objects.

In [None]:
from astropy import wcs

In [None]:
input_model.to_fits("temp_wcs.fits", overwrite=True)

In [None]:
astropy_fits_wcs = wcs.WCS("temp_wcs.fits")

In [None]:
astropy_fits_wcs.pixel_to_world_values(0,0)

In [None]:
output_model.meta.wcs(0*u.pix, 0*u.pix)

In [None]:
pixels = [0, 4096] * u.pix

In [None]:
for pix_x in pixels:
    for pix_y in pixels:
        print(pix_x, pix_y)
        assert_quantity_allclose(
            astropy_fits_wcs.pixel_to_world_values(pix_x,pix_y) * u.deg,
            output_model.meta.wcs(pix_x, pix_y)
        )