# Cutout data from fits files

## Imports

In [None]:
import sys
import os
import pprint
import json
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets

from astropy.io import fits
from astropy.wcs import WCS
from astropy.nddata import Cutout2D

root = "/Users/phdenzel/gleam"
sys.path.append(root)
import gleam
from gleam.lensobject import LensObject
from gleam.multilens import MultiLens
from gleam.utils.lensing import downsample_model, upsample_model
from gleam.utils.plotting import IPColorbar, IPPointCache
from gleam.utils.plotting import plot_scalebar, plot_labelbox
from gleam.utils.rgb_map import lupton_like, asin_stack, hsv_stack
import gleam.utils.colors as gcl
gcl.GLEAMcmaps.register_all()

## Main

### Search for a fits file

In [None]:
directories = !ls -d delay_composites/*/
pprint.pprint(directories)
print("")
filenames = !ls delay_composites/WFIJ2033-4723/*.fits
pprint.pprint(filenames)
rgb_data = {}
rgb_data['filenames'] = [filenames[0], filenames[1], filenames[2]]

### Read fits file and extract data

In [None]:
%%script false
rgb_data['hdu'] = []
for filename in rgb_data['filenames']:
    hdu = fits.open(filename)
    print(len(hdu))
    for h in hdu:
        if isinstance(h, fits.hdu.image.ImageHDU):
            hdu = h
            break
    if isinstance(hdu, list):
        hdu = hdu[0]
    rgb_data['hdu'].append(hdu)
    print('#'*80)
    print(hdu)
    print(hdu.data.shape)
    # print(hdu.header['CD2_2'])
    #print(repr(hdu.header))
    try:
        print(hdu.header['ORIENTAT'])
    except:
        print('ORIENTAT not found')
    # print(repr(hdu.header))

positions = [p//2 for p in hdu.data.shape]

In [None]:
rgb_data['lens_object'] = []
for f in rgb_data['filenames']:
    lo = LensObject(f, auto=False, glscfactory_options={})
    rgb_data['lens_object'].append(lo)
    print("#"*80)
    print(lo.__v__)

In [None]:
pxscales = [l.px2arcsec for l in rgb_data['lens_object']]
scale = 100
for pxs in pxscales:
    print("Rectangular: {}".format(abs(pxs[0] - pxs[1]) < 1e-8))
    if scale > pxs[0]:
        scale = pxs[0]
print("Setting scale to: {}\nout of {}".format(scale, pxscales))

In [None]:
for i, lo in enumerate(rgb_data['lens_object']):
    print(lo)
    # cut out rectangle from fits data
    dims = lo.naxis2, lo.naxis1
    mindim = min(dims)
    mind2 = [mindim//2, mindim//2] if mindim%2==0 else [mindim//2, mindim//2+1]
    data = lo.data[dims[0]//2-mind2[0]:dims[0]//2+mind2[1],
                   dims[1]//2-mind2[0]:dims[1]//2+mind2[1]]
    # get extent in order to rescale
    mapr = (lo.px2arcsec[1] * data.shape[0])/2
    shape = [int((2*r/a) + 0.5) for r, a in zip([mapr, mapr], [scale, scale])]
    extent = [-mapr, mapr, -mapr, mapr]
    if not (sum([abs(pa-scale) for pa in lo.px2arcsec]) < 2e-8):
        # print("Downsampling map...")
        data = np.nan_to_num(data)
        # data = downsample_model(kappa=data, pixel_scale=lo.px2arcsec[0],
        #                         extent=extent, shape=shape,
        #                         sanitize=False, verbose=True)
        print("Upsampling map...")
        data = upsample_model(kappa=data, pixel_scale=lo.px2arcsec[0],
                              extent=extent, shape=shape,
                              sanitize=False, verbose=True)
    # re-generate data
    jdict = lo.__json__
    jdict['data'] = data[:]
    jdict['hdr']
    jdict['hdr']['NAXIS1'] = data.shape[1]
    jdict['hdr']['NAXIS2'] = data.shape[0]
    jdict['px2arcsec'] = [scale, scale]
    jdict['mapr'] = mapr
    jdict['lens'] = None
    jdict['srcimgs'] = []
    del jdict['__type__']
    rect_lo = LensObject.from_jdict(jdict)
    rgb_data['lens_object'][i] = rect_lo
    
# print(rgb_data['lens_object'][0].naxis1, rgb_data['lens_object'][0].naxis1)

### Plot the data and determine anchor points

In [None]:
rgb_data['position'] = []
rgb_data['size'] = []

#### R

In [None]:
%matplotlib notebook
fig, ax = plt.subplots()
lo = rgb_data['lens_object'][0]
fig, ax, plt_out = lo.plot_f(fig, ax=ax, cmap=gcl.GLEAMcmaps.vilux, colorbar=True,
                             scalebar=False,
                             vmin=0, vmax=0.5*np.max(lo.data), source_images=False,
                             deconv=False, psf=np.ones((3, 3))/25.)
clrbar = IPColorbar(plt_out[1], plt_out[0])
clrbar.connect()
cache = IPPointCache(plt_out[0], use_modes=[])
cid = cache.connect()

In [None]:
if cache.xy:
    position = [int(p+0.5) for p in cache.xy[-1]]
else:
    position = None
# set size manually
size = 199
# size = 599 # SDSSJ1004+4112
# position = [449, 523]  # B0218+357
# position = [684, 694] # B1600+434
# position = [2487, 2778]  # B1608+656
# position = [1849, 1715]  # DESJ0408-5354
# position = [596, 617]  # FBQ0951+2635
# position = [1006, 991] # HE0435-1223
# position = [1718, 2032]  # HE1104-1805
# position = []  # HE2149-2745
# position = [1222, 1207]  # PG1115+080
# position = [1155, 1160] # RXJ0911+0551
# position = [419, 310] # RXJ1311-01231
# position = [4271, 2372] # SDSSJ1004+4112
# position = [1601, 1639] # WFIJ2033-4723

print("Position: {}".format(position), "size: {}".format(size))

In [None]:
rgb_data['position'].append(position)
rgb_data['size'].append(size)

#### G

In [None]:
%matplotlib notebook
fig, ax = plt.subplots()
lo = rgb_data['lens_object'][1]
fig, ax, plt_out = lo.plot_f(fig, ax=ax, cmap=gcl.GLEAMcmaps.vilux, colorbar=True,
                             scalebar=False, filter_nan=False,
                             vmin=0, vmax=0.5*np.max(lo.data), source_images=False,
                             deconv=False, psf=np.ones((3, 3))/25.)
clrbar = IPColorbar(plt_out[1], plt_out[0])
clrbar.connect()
cache = IPPointCache(plt_out[0], use_modes=[])
cid = cache.connect()

In [None]:
if cache.xy:
    position = [int(p+0.5) for p in cache.xy[-1]]
else:
    position = None
# set size manually
size = 199
# size = 599 # SDSSJ1004+4112
# position = [449, 523]  # B0218+357
# position = [684, 694] # B1600+434
# position = [2487, 2778]  # B1608+656
# position = [2089, 1010]  # DESJ0408-5354
# position = [596, 617]  # FBQ0951+2635
# position = [873, 849]  # HE0435-1223
# position = [1718, 2032]  # HE1104-1805
# position = []  # HE2149-2745
# position = [1222, 1207]  # PG1115+080
# position = [464, 476] # RXJ0911+0551
# position = [1819, 2379] # RXJ1311-01231
# position = [4271, 2372] # SDSSJ1004+4112
# position = [709, 698] # WFIJ2033-4723

print("Position: {}".format(position), "size: {}".format(size))

In [None]:
rgb_data['position'].append(position)
rgb_data['size'].append(size)

#### B

In [None]:
%matplotlib notebook
fig, ax = plt.subplots()
lo = rgb_data['lens_object'][2]
fig, ax, plt_out = lo.plot_f(fig, ax=ax, cmap=gcl.GLEAMcmaps.vilux, colorbar=True,
                             scalebar=False,
                             vmin=0, vmax=0.5*np.max(lo.data), source_images=False,
                             deconv=False, psf=np.ones((3, 3))/25.)
clrbar = IPColorbar(plt_out[1], plt_out[0])
clrbar.connect()
cache = IPPointCache(plt_out[0], use_modes=[])
cid = cache.connect()

In [None]:
if cache.xy:
    position = [int(p+0.5) for p in cache.xy[-1]]
else:
    position = None
# set size manually
size = 199
# size = 599 # SDSSJ1004+4112
# position = [449, 523]  # B0218+357
# position = [684, 694] # B1600+434
# position = [2487, 2778]  # B1608+656
# position = [2089, 1010]  # DESJ0408-5354
# position = [596, 617]  # FBQ0951+2635
# position = [1032, 816]  # HE0435-1223
# position = [2071, 1048]  # HE1104-1805
# position = []  # HE2149-2745
# position = [532, 524]  # PG1115+080
# position = [464, 476] # RXJ0911+0551
# position = [1836, 2379] # RXJ1311-01231
# position = [2420, 1961] # SDSSJ1004+4112
# position = [709, 698] # WFIJ2033-4723

print("Position: {}".format(position), "size: {}".format(size))

In [None]:
rgb_data['position'].append(position)
rgb_data['size'].append(size)

## Stack'em

In [None]:
print(rgb_data['position'])
print(rgb_data['size'])

In [None]:
orient = []
for i, lo in enumerate(rgb_data['lens_object']):
    print(lo)
    if 'ORIENTAT' in lo.hdr:
        orient.append(lo.hdr['ORIENTAT'])
print(orient)

In [None]:
from scipy import ndimage
rgb = []
r, g, b = [l.data for l in rgb_data['lens_object']]
# orient = [0, 0, 0]
for i, (l, phi) in enumerate(zip(rgb_data['lens_object'], orient)):
    d = l.data
    p = rgb_data['position'][i]
    s = rgb_data['size'][i]*2
    d = d[p[1]-s//2:p[1]+s//2, p[0]-s//2:p[0]+s//2]
    if phi != 0:
        d = ndimage.rotate(d, -phi, reshape=False)
    rgb.append(d[:])
r, g, b = rgb
print(r.shape, g.shape, b.shape)

# r = (r - r.min()) / (r.max() - r.min())
# g = (g - g.min()) / (g.max() - g.min())
# b = (b - b.min()) / (b.max() - b.min())


In [None]:
# %%script false
plt.close()

# rgb_img = asin_stack(b, g, r, 0.7, 0.8, 15.1, 2.3, 0.5)  # B0218+357
# rgb_img = asin_stack(b, g, r, 2.9, 3.6, 23.1, 2.8, .5)  # B1600+434
# rgb_img = asin_stack(b, g, r, 1.0, 1.2, 0.2, 1.5, .5) # B1608+656
# rgb_img = asin_stack(r, b, g, 0.04, 1.0, 1.9, 3.6, .5) # DESJ0408-5354
# rgb_img = asin_stack(b, g, r, 0.4, 0.2, 3., 1.6, 1.)  # FBQ0951+2635 
# rgb_img = asin_stack(r, g, b, 1.9, .9, 45.4, 0.19, .5)  # HE0435-1223
# rgb_img = asin_stack(r, g, b, 0.06, 0.03, 30.0, 1.3, .4)  # HE1104-1805
# rgb_img = asin_stack(r, g, b, 20, 10, 100, .2, .1)  # PG1115+080  # normalized
# rgb_img = asin_stack(r, g, b, 0.1, 0.4, 0.6, 15., .1) # RXJ0911+0551  # normalized
# rgb_img = asin_stack(r, g, b, 0.009, 0.00175, 0.005, 400., .5) # RXJ1131-1231
# rgb_img = asin_stack(r, g, b, 0.085, 0.175, 3.5, 20., .1) # SDSSJ1004+4112
# rgb_img = asin_stack(r, g, b, 0.04, 0.6, 1.8, 1., .1) # WFIJ2033-4723

plt_out = plt.imshow(rgb_img, origin='lower')
cache = IPPointCache(plt_out, use_modes=[])
cid = cache.connect()
plt.show()


In [None]:
# %%script false
if cache.xy:
    position = [int(p+0.5) for p in cache.xy[-1]]
else:
    position = None
# set size manually
size = 199
# position = [99, 99]  # B0218+357
# position = [199, 200] # B1600+434
# position = [198, 199]  # B1608+656
# position = [200, 199]  # DESJ0408-5354
# position = [200, 214]  # FBQ0951+2635
# position = [111, 92]  # HE0435-1223
# position = [199, 195]  # HE1104-1805
# position = []  # HE2149-2745
# position = [199, 199]  # PG1115+080
# position = [139, 193] # RXJ0911+0551
# position = [198, 200] # RXJ1131-1231
# position = [499, 831] # SDSSJ1004+4112
# position = [300, 300] # WFIJ2033-4723

print("Position: {}".format(position), "size: {}".format(size))

In [None]:
# %%script false
from scipy import ndimage
plt.close()
plt.figure(figsize=(6, 6))
composite = rgb_img[position[1]-size//2:position[1]+size//2+1,
                    position[0]-size//2:position[0]+size//2+1]
R = scale*size/2
extent = [-R, R, -R, R]
# composite = ndimage.rotate(composite, 90, reshape=False)
composite[:, :, 3] = 1
plt.imshow(composite, extent=extent, origin='top')
# plt.imshow(composite, extent=extent, origin='top', interpolation='spline16')
plt.axis('off')
plt.gcf().axes[0].get_xaxis().set_visible(False)
plt.gcf().axes[0].get_yaxis().set_visible(False)
plot_scalebar(extent[1], 1)
plot_labelbox(label='WFIJ2033-4723 HST-WFC3 IR/UVIS\n(F160W, F814W, F621M)', position='top left', fontsize=18)
plt.tight_layout()
# plt.savefig('{}_composite.pdf'.format(os.path.basename(os.path.dirname(filenames[0]))),
#            bbox_inches='tight', pad_inches=0)
plt.show()