# Using WEAVE data with pyKOALA

## Initialisation

In [None]:
%matplotlib ipympl
from matplotlib import pyplot as plt
from matplotlib import colors
from matplotlib.ticker import AutoMinorLocator

import numpy as np
import os
import importlib
from astropy import units as u
from astropy.visualization import quantity_support
quantity_support()

from pykoala.corrections import sky

import warnings
# You may want to comment the following line
#warnings.filterwarnings("ignore")

In [None]:
from pykoala import __version__
from pykoala.instruments import weave
from pykoala import cubing

# pyKOALA version
print("pyKOALA version: ", __version__)

### Utility functions for plotting

In [None]:
def new_figure(fig_name, figsize=None, nrows=1, ncols=1, sharex='col', sharey='row', gridspec_kw={'hspace': 0, 'wspace': 0}):
    if figsize is None:
        figsize = (9 + ncols, 4 + nrows)
        
    plt.close(fig_name)
    fig = plt.figure(fig_name, figsize=figsize)
    axes = fig.subplots(nrows=nrows, ncols=ncols, squeeze=False,
                        sharex=sharex, sharey=sharey,
                        gridspec_kw=gridspec_kw
                       )
    #fig.set_tight_layout(True)
    for ax in axes.flat:
        ax.xaxis.set_minor_locator(AutoMinorLocator())
        ax.yaxis.set_minor_locator(AutoMinorLocator())
        ax.tick_params(which='both', bottom=True, top=True, left=True, right=True)
        ax.tick_params(which='major', direction='inout', length=8, grid_alpha=.3)
        ax.tick_params(which='minor', direction='in', length=2, grid_alpha=.1)
        ax.grid(True, which='both')

    #fig.suptitle(f'{self.filename} {fig_name}')
    fig.suptitle(fig_name)
    
    return fig, axes


In [None]:
default_cmap = plt.get_cmap("gist_earth").copy()
default_cmap.set_bad('gray')


def colour_map(fig, ax, cblabel, data, cmap=default_cmap, norm=None, xlabel=None, x=None, ylabel=None, y=None):
    
    if norm is None:
        percentiles = np.array([1, 16, 50, 84, 99])
        ticks = np.nanpercentile(data, percentiles)
        linthresh = np.median(data[data > 0])
        norm = colors.SymLogNorm(vmin=ticks[0], vmax=ticks[-1], linthresh=linthresh)
    else:
        ticks = None
    if y is None:
        y = np.arange(data.shape[0])
    if x is None:
        x = np.arange(data.shape[1])

    im = ax.imshow(data,
                   extent=(x[0]-(x[1]-x[0])/2, x[-1]+(x[-1]-x[-2])/2, y[0]-(y[1]-y[0])/2, y[-1]+(y[-1]-y[-2])/2),
                   interpolation='nearest', origin='lower',
                   cmap=cmap,
                   norm=norm,
                  )
    ax.set_aspect('auto')
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)

    cb = fig.colorbar(im, ax=ax, orientation='vertical', shrink=.9)
    cb.ax.set_ylabel(cblabel)
    if ticks is not None:
        cb.ax.set_yticks(ticks=ticks, labels=[f'{value:.3g} ({percent})\%)' for value, percent in zip(ticks, percentiles)])
    cb.ax.tick_params(labelsize='small')
    
    return im, cb


### Load the science data

In [None]:
importlib.reload(weave)

filename = f"../tests/weave/single_3042890.fit"
rss = weave.weave_rss(filename)
print(f"File {filename} corresponds to object {rss.info['name']}")


In [None]:
rss_wavelength = rss.wavelength.to_value(u.Angstrom)

## Wavelet test

In [None]:
cumulative = np.nancumsum(rss.intensity, axis=1)

In [None]:
wavelet_h = 3  # 1.5 Angstrom (~ HWHM)
wavelet_s = 2*wavelet_h + 1
wavelet_wavelength = rss.wavelength[wavelet_s + wavelet_h + 1 : -wavelet_s - wavelet_h ]
wavelet_intensity = (cumulative[:, 2*wavelet_s:-wavelet_s] - cumulative[:, wavelet_s:-2*wavelet_s]) / wavelet_s
wavelet_intensity -= (cumulative[:,3*wavelet_s:] - cumulative[:,:-3*wavelet_s]) / (3*wavelet_s)

In [None]:
fig, axes = new_figure('wavelet')

im, cb = colour_map(fig, axes[0, 0], 'wavelet', wavelet_intensity, x=rss_wavelength, xlabel='$\lambda\ [\AA]$', ylabel='spec_id')

'''
ax = axes[0, 0]
ax.plot(rss.wavelength.to(u.Angstrom), rss.intensity[0])
ax.plot(wavelet_wavelength, wavelet_intensity[0])
#ax.plot(rss.wavelength, cumulative[0])
'''


In [None]:
fig, axes = new_figure('wavelet sky')

ax = axes[0, 0]
#ax.plot(wavelet_wavelength, np.nanmedian(wavelet_intensity, axis=0), 'r-')
ax.plot(wavelet_wavelength.to(u.Angstrom), np.nanmean(wavelet_intensity, axis=0), 'k--')

p16, p50, p84 = np.nanpercentile(wavelet_intensity, [16, 50, 84], axis=0)
ax.plot(wavelet_wavelength, p50, 'r-', alpha=.5)
ax.fill_between(wavelet_wavelength, p16, p84, color='r', alpha=0.1)

In [None]:
normed = wavelet_intensity - np.nanmean(wavelet_intensity, axis=1)[:, np.newaxis]

In [None]:
normed /= np.sqrt(np.nanmean(normed**2, axis=1))[:, np.newaxis]

In [None]:
correlation = np.empty((offset.size, wavelet_intensity.shape[0]))
correlation.shape

In [None]:
normed.shape

In [None]:
np.nanmean(normed*np.interp(wavelet_wavelength, wavelet_wavelength-offset[0], normed[0])[np.newaxis, :], axis=1).shape


In [None]:
def compute_correlation(wavelength, intensity, intensity_ref, offsets):
    '''
    Compute cross-correlation with respect to a reference spectrum.
    '''
    if intensity.ndim == 2:
        normed = intensity - np.nanmean(intensity, axis=1)[:, np.newaxis]
    elif intensity.ndim == 1:
        normed = np.reshape(intensity - np.nanmean(intensity), (1, intensity.size))
    else:
        # TODO: proper error handling
        print(f'ERROR: intensity has {intensity.ndim} dimensions')
        raise -1
    if intensity_ref.ndim != 1:
        # TODO: proper error handling
        print(f'ERROR: reference intensity has {intensity_ref.ndim} dimensions')
        raise -1
    
    normed /= np.sqrt(np.nanmean(normed**2, axis=1))[:, np.newaxis]
    normed_ref = intensity_ref - np.nanmean(intensity_ref)
    normed_ref /= np.sqrt(np.nanmean(normed_ref**2))
    correlation = np.empty((offsets.size, intensity.shape[0]))
    for i, offset in enumerate(offsets):
        correlation[i] = np.nanmean(normed*np.interp(wavelength, wavelength-offset, normed_ref)[np.newaxis, :], axis=1)
    return correlation

fig, axes = new_figure('correlation')

ax = axes[0, 0]
offset = np.linspace(-10, 10, 501) * u.AA
correlation = compute_correlation(wavelet_wavelength, wavelet_intensity[:100], wavelet_intensity[0], offset)
p16, p50, p84 = np.nanpercentile(correlation, [16, 50, 84], axis=1)
ax.plot(offset, p50, 'k-+')
ax.fill_between(offset, p16, p84, color='k', alpha=0.1)

In [None]:
fig, axes = new_figure('wavelet throughput')

reference = np.nanmedian(wavelet_intensity, axis=0)
reference_std = np.nanstd(reference)

ax = axes[0, 0]
p16, p50, p84 = np.nanpercentile(wavelet_intensity / reference[np.newaxis, :], [16, 50, 84], axis=1)
ax.plot(p50, 'r-')
ax.fill_between(np.arange(p50.size), p16, p84, color='r', alpha=0.1)

ax.plot(np.nanstd(wavelet_intensity, axis=1) / reference_std, 'b-')

## Self-calibration based on sky lines

In [None]:
importlib.reload(sky)
rss_autocal = sky.SkySelfCalibration(rss)

### Lines and continuum

In [None]:
fig, axes = new_figure('lines-continuum', nrows=3)

im, cb = colour_map(fig, axes[0, 0], 'lines', rss.intensity - rss_autocal.continuum.intensity, x=rss_wavelength, xlabel='$\lambda\ [\AA]$', ylabel='spec_id')
im, cb = colour_map(fig, axes[1, 0], 'continuum', rss_autocal.continuum.intensity, x=rss_wavelength, xlabel='$\lambda\ [\AA]$', ylabel='spec_id')
im, cb = colour_map(fig, axes[2, 0], 'line emission S/N', (rss.intensity - rss_autocal.continuum.intensity) / rss_autocal.continuum.scale[:, np.newaxis], x=rss_wavelength, xlabel='$\lambda\ [\AA]$', ylabel='spec_id')

axes[0, 0].get_shared_y_axes().join(axes[0, 0], *axes[1:, 0])


### Sky lines

In [None]:
fig, axes = new_figure('sky_lines')

ax = axes[0, 0]
ax.plot(rss_wavelength, rss_autocal.sky_intensity, 'k-')
for line in rss_autocal.continuum.strong_sky_lines:
    ax.axvspan(rss_wavelength[line['left']], rss_wavelength[line['right']], color='c', alpha=.1)
    ax.axvline(line['sky_wavelength'], c='b', alpha=.5)

In [None]:
importlib.reload(sky)
lsf = sky.LSF_estimator(wavelength_range=20*u.Angstrom, resolution=0.01*u.Angstrom)
sky_lsf = lsf.find_LSF(rss_autocal.continuum.strong_sky_lines['sky_wavelength'], rss_autocal.dc.wavelength, rss_autocal.sky_intensity)
weight = sky_lsf**2
sky_offset = np.sum(weight*lsf.delta_lambda) / np.sum(weight)

print(f'Sky lines FWHM = {lsf.find_FWHM(sky_lsf):.4g}, offset = {sky_offset:.4g}')

In [None]:
fig, axes = new_figure('sky_LSF')

ax = axes[0, 0]
ax.plot(lsf.delta_lambda, sky_lsf, 'k-', label=f'FWHM = {lsf.find_FWHM(sky_lsf):.4g}')
ax.axvline(sky_offset, c='k', ls='--', label=f'offset = {sky_offset:.4g}')
ax.plot(-lsf.delta_lambda+2*sky_offset, sky_lsf, 'k:', label=f'symmetric')
ax.legend()

### Single spectrum

In [None]:
from astropy import stats

spec_id = 42
line_wavelength, line_intensity = rss_autocal.measure_lines(spec_id)

fig, axes = new_figure('single_calibration', nrows=2, ncols=2)

ax = axes[0, 0]
ax.set_ylabel('relative $\Delta\lambda\ [\AA]$')
y = line_wavelength - rss_autocal.continuum.strong_sky_lines['sky_wavelength']
location = stats.biweight.biweight_location(y)
scale = stats.biweight.biweight_scale(y)
ax.plot(line_wavelength, y, 'k.', label=f'sky lines in spec_id={spec_id}')
ax.axhline(location, c='r', ls='-')
ax.fill_between(line_wavelength, location-scale, location+scale, color='r', alpha=.1)#, label=f'${location:.2f}\pm{scale:.2f}$')
ax.legend()

ax = axes[0, 1]
ax.plot(rss_autocal.continuum.strong_sky_lines['sky_intensity'], y, 'k.')
xx = np.nanpercentile(rss_autocal.continuum.strong_sky_lines['sky_intensity'], np.linspace(0, 100, 101))
ax.axhline(location, c='r', ls='-')
ax.fill_between(xx, location-scale, location+scale, color='r', alpha=.1, label=f'${location:.2f}\pm{scale:.2f}$')
ax.legend()


ax = axes[1, 0]
ax.set_ylabel('relative throughput')
y = line_intensity/rss_autocal.continuum.strong_sky_lines['sky_intensity']
location = stats.biweight.biweight_location(y)
scale = stats.biweight.biweight_scale(y)
ax.plot(line_wavelength, y, 'k.')
ax.axhline(location, c='r', ls='-')
ax.fill_between(line_wavelength, location-scale, location+scale, color='r', alpha=.1, label=f'${location:.2f}\pm{scale:.2f}$')

ax.set_xlabel('$\lambda_{line}\ [\AA]$')

ax = axes[1, 1]
ax.plot(rss_autocal.continuum.strong_sky_lines['sky_intensity'], y, 'k.')
xx = np.nanpercentile(rss_autocal.continuum.strong_sky_lines['sky_intensity'], np.linspace(0, 100, 101))
ax.axhline(location, c='r', ls='-')
ax.fill_between(xx, location-scale, location+scale, color='r', alpha=.1, label=f'${location:.2f}\pm{scale:.2f}$')
ax.legend()

ax.set_xlabel('Line intensity')
ax.set_xscale('log')


In [None]:
fig, axes = new_figure('sky_self_calibration', nrows=2)
spec_id = np.arange(rss_autocal.wavelength_offset.size)


ax = axes[0, 0]
ax.set_ylabel('relative $\Delta\lambda\ [\AA]$')

ax.plot(spec_id, rss_autocal.wavelength_offset, 'k-')
ax.fill_between(spec_id, rss_autocal.wavelength_offset-rss_autocal.wavelength_offset_err, rss_autocal.wavelength_offset+rss_autocal.wavelength_offset_err, color='k', alpha=.1)


ax = axes[1, 0]
ax.set_ylabel('relative throughput')

ax.plot(spec_id, rss_autocal.relative_throughput, 'k-')
ax.fill_between(spec_id, rss_autocal.relative_throughput-rss_autocal.relative_throughput_err, rss_autocal.relative_throughput+rss_autocal.relative_throughput_err, color='k', alpha=.1)
ax.plot(.97*np.nanmedian(wavelet_intensity / reference[np.newaxis, :], axis=1), 'r-')
ax.plot(np.nanstd(wavelet_intensity, axis=1) / reference_std, 'b-')

ax.set_xlabel('spec_id')


# --- STOP ---

In [None]:
raise -1

In [None]:
import koala.corrections.sky as sky
importlib.reload(sky)
sky, sigma = sky.BackgroundEstimator.linear(rss.intensity_corrected)

In [None]:
total_flux = np.nanmean(rss.intensity_corrected, axis=1)
I = rss.intensity_corrected / total_flux[:, np.newaxis]

In [None]:
total_norm = np.sqrt(np.nanmean(rss.intensity_corrected**2, axis=1))

In [None]:
sed_sky = np.nanmedian(I, axis=0)
sed_sky /= np.nanmean(sed_sky**2)

In [None]:
sky_flux = np.nanmean(sed_sky[np.newaxis, :]*rss.intensity_corrected, axis=1)

In [None]:
np.nanmedian(sky_flux)

In [None]:
fig, axes = new_figure('I-flux')

ax = axes[0, 0]
ax.plot(total_flux, sky_flux, 'k.')

In [None]:
fig, axes = new_figure('map', nrows=2)

im, cb = colour_map(fig, axes[0, 0], 'I original', rss.intensity_corrected)
im, cb = colour_map(fig, axes[1, 0], 'I norm', I)


In [None]:
fig, axes = new_figure('hist')

ax = axes[0, 0]
ax.hist(total_flux, bins=np.nanpercentile(total_flux, np.linspace(0, 100, 26)), density=True)
ax.set_yscale('log')
ax.set_xscale('log')

In [None]:
fig, axes = new_figure('spectrum', nrows=2)

ax = axes[0, 0]
ax.plot(rss_wavelength, sed_obj)

ax = axes[1, 0]
ax.plot(rss_wavelength, sed_sky)

In [None]:
fig, axes = new_figure('sky_spectrum', nrows=3)

ax = axes[0, 0]
ax.plot(rss_wavelength, np.nanmean(rss.intensity, axis=0))

ax = axes[1, 0]
ax.plot(rss_wavelength, np.nanmean(rss.intensity-sky[np.newaxis, :], axis=0))

ax = axes[2, 0]
ax.plot(rss_wavelength, sky)

## Registration

### Image Cross-correlation

The most sofisticated method to perform the registration of extended sources included in pyKOALA is based on the cross-correlation of two images.

from koala.register.registration import register_crosscorr

figures = register_crosscorr(sci_rss, plot=True, quick_cube_pix_size=1.)
for fig in figures:
    plt.show(plt.figure(fig))

### Centroid finding

A simple approach to find the offset between the different RSS is to find the center of light of the images (assuming that they contain the same sources).

from koala.register.registration import register_centroid

figures = register_centroid(sci_rss, plot=True, quick_cube_pix_size=0.2,
                            centroider='gauss',
                            #subbox=[[150, 200], [20, 100]]
                           )

for fig in figures:
    plt.show(plt.figure(fig))

### Manual

Alternatively, it is also possible to provide a manual offset for the input RSS frames

In [None]:
from koala.register.registration import register_manual

#register_manual(sci_rss, [[0, 0], [1.5, 0], [3, 1.5]], absolute=False)

For interpolating RSS data into a 3D datacube we will make use of the function *build_cube*. This method requires as input:
- A list of RSS objects. 
- The desired dimensions of the cube expressed as a 2-element tuple, corresponding to (ra, dec) in arcseconds.
- The pixel size of the cube in arcseconds.
- A list containing the ADR correction for every RSS (it can contain None) in the form: [(ADR_ra_1, ADR_dec_1), (ADR_ra_2, ADR_dec_2), (None, None)].
- Additional information to be included in *cube_info*

In [None]:
importlib.reload(cubing)
cube = cubing.build_cube(rss_set=[rss],
                  cube_size_arcsec=(90, 90),  # (dec, ra)
                  pixel_size_arcsec=.5,
                  )
cube.info['name'] = rss.info['name'].split(' ')[0]

## Sky substraction

In [None]:
import koala.corrections.sky as sky
#import importlib
importlib.reload(sky)

skymodel = sky.SkyFromObject(cube, bckgr_estimator='mad', source_mask_nsigma=3, remove_cont=False)
skycorrection = sky.SkySubsCorrection(skymodel)
cube, _ = skycorrection.apply(cube)

## Plots

### Mean and median intensity maps

In [None]:
fig, axes = new_figure('cube_projection', nrows=2, ncols=2,
                       sharex=True, sharey=True,
                       figsize=(12, 8), gridspec_kw={'hspace': .1, 'wspace': 0.5})

im, cb = colour_map(fig, axes[0, 0], 'raw mean', np.nanmean(cube.intensity, axis=0))
im, cb = colour_map(fig, axes[0, 1], 'corrected mean', np.nanmean(cube.intensity_corrected, axis=0))
im, cb = colour_map(fig, axes[1, 0], 'raw median', np.nanmedian(cube.intensity, axis=0))
im, cb = colour_map(fig, axes[1, 1], 'corrected median', np.nanmedian(cube.intensity_corrected, axis=0))


In [None]:
fig, axes = new_figure('mean_spectrum', nrows=3)

ax = axes[0, 0]
ax.plot(cube.wavelength, np.nanmean(cube.intensity, axis=(1,2)))

ax = axes[1, 0]
ax.plot(cube.wavelength, np.nanmean(cube.intensity_corrected, axis=(1,2)))
ax.set_ylim(-.005, .01)

ax = axes[2, 0]
ax.plot(skymodel.wavelength, skymodel.bckgr)
