In [None]:
cd HST

In [None]:
import pathlib
import tempfile
from urllib.parse import urlparse
from urllib.request import urlretrieve

from gwcs import coordinate_frames as cf
from astropy import units as u
from astropy.modeling import models
from astropy.io import fits
from gwcs import WCS
import numpy as np
import asdf


from hst_grism_reffiles import create_tsgrism_wavelengthrange, create_grism_specwcs
from generate_wfc3_distortion import create_wfc3_distortion
from transform_models import WFC3IRForwardGrismDispersion, WFC3IRBackwardGrismDispersion

In [None]:
reference_files = dict()

specwcs = asdf.open('../config/HST/WFC3_G141_specwcs.asdf').tree
displ = specwcs['displ']
dispx = specwcs['dispx']
dispy = specwcs['dispy']
invdispl = specwcs['invdispl']
invdispx = specwcs['invdispx']
invdispy = specwcs['invdispy']
orders = specwcs['order']

gdetector = cf.Frame2D(name='grism_detector', 
                       axes_order=(0, 1),
                       unit=(u.pix, u.pix))

det2det = WFC3IRForwardGrismDispersion(orders,
                                        lmodels=displ,
                                        xmodels=invdispx,
                                        ymodels=dispy)

det2det.inverse = WFC3IRBackwardGrismDispersion(orders,
                                              lmodels=invdispl,
                                              xmodels=dispx,
                                              ymodels=dispy)

grism_pipeline = [(gdetector, det2det)]

In [None]:
from astropy.utils.data import download_file
fn = download_file('https://github.com/npirzkal/aXe_WFC3_Cookbook/raw/main/cookbook_data/G141/ib6o23rsq_flt.fits', cache=True)
grism_image_hdulist = fits.open(fn)

In [None]:
acoef = dict(grism_image_hdulist[1].header['A*'])
a_order = acoef.pop('A_ORDER')
bcoef = dict(grism_image_hdulist[1].header['B_*'])
b_order = bcoef.pop('B_ORDER')
crpix = [grism_image_hdulist[1].header['CRPIX1'], grism_image_hdulist[1].header['CRPIX2']]

crval = [grism_image_hdulist[1].header['CRVAL1'], grism_image_hdulist[1].header['CRVAL2']]
cdmat = np.array([[grism_image_hdulist[1].header['CD1_1'], grism_image_hdulist[1].header['CD1_2']],
                  [grism_image_hdulist[1].header['CD2_1'], grism_image_hdulist[1].header['CD2_2']]])

apcoef = {}
for key in acoef:
    apcoef['c' + key.split('A_')[1]] = acoef[key]
    
bpcoef = {}
for key in bcoef:
    bpcoef['c' + key.split('B_')[1]] = bcoef[key]

a_poly = models.Polynomial2D(a_order, **apcoef)
b_poly = models.Polynomial2D(b_order, **bpcoef)

mr = (models.Shift(-(crpix[0]-1)) & models.Shift(-(crpix[1]-1)) | # Calculate u and v coords
     models.Mapping((0, 1, 0, 1, 0, 1)) | a_poly & b_poly & models.Identity(2) | # calculate f(u,v) and g(u,v)
     models.Mapping((0, 2, 1, 3)) | models.math.AddUfunc() & models.math.AddUfunc() | # Calculate u+f(u,v) and v+g(u,v)
     models.AffineTransformation2D(matrix=cdmat) | models.Pix2Sky_TAN() | 
     models.RotateNative2Celestial(crval[0], crval[1], 180))

imagepipe = []
det_frame = cf.Frame2D(name="detector")
spec = cf.SpectralFrame(name='spectral', axes_order=(2,), unit=(u.micron,),
                            axes_names=('wavelength',))

imagepipe.append((cf.CompositeFrame([det_frame, spec], name="detector"), mr & models.Identity(2)))


world_frame = cf.CelestialFrame(name="world", unit = (u.Unit("deg"), u.Unit("deg")), 
                             axes_names=('lon', 'lat'), axes_order=(0, 1),
                             reference_frame="ICRS")

imagepipe.append((cf.CompositeFrame([world_frame, spec], name="sky"), None))
grism_pipeline.extend(imagepipe)

wcsobj = WCS(grism_pipeline)

In [None]:
wcsobj

### Experimenting to figure out how to get inverse transform working

In [None]:
imagepipe3 = []
det_frame = cf.Frame2D(name="detector")
imagepipe3.append((det_frame, mr))

world_frame = cf.CelestialFrame(name="world", unit = (u.Unit("deg"), u.Unit("deg")), 
                             axes_names=('lon', 'lat'), axes_order=(0, 1),
                             reference_frame="ICRS")
imagepipe3.append((world_frame, None))

# Reminder, this is gwcs.WCS
wcsobj3 = WCS(imagepipe3)

In [None]:
wcsobj3(507, 507)

In [None]:
# This works, since I haven't added any wavelength or order input/output anywhere in the mr model yet

wcsobj3.invert(53.07354713110038, -27.70724006671666)

In [None]:
wcsobj3.get_transform('detector', 'world')

In [None]:
mr.has_inverse()