# Using WEAVE data with pyKOALA

## Initialisation

In [None]:
%matplotlib ipympl
from matplotlib import pyplot as plt
from matplotlib import colors
from matplotlib.ticker import AutoMinorLocator

import numpy as np
import os
import importlib
from time import time
from astropy import units as u
from astropy.table import Table

import warnings
# You may want to comment the following line
#warnings.filterwarnings("ignore")

In [None]:
from koala import __version__
from koala.instruments import weave
from koala import cubing

# pyKOALA version
print("pyKOALA version: ", __version__)

### Utility functions for plotting

In [None]:
def new_figure(fig_name, figsize=None, nrows=1, ncols=1, sharex='col', sharey='row', gridspec_kw={'hspace': 0, 'wspace': 0}):
    if figsize is None:
        figsize = (9 + ncols, 4 + nrows)
        
    plt.close(fig_name)
    fig = plt.figure(fig_name, figsize=figsize)
    axes = fig.subplots(nrows=nrows, ncols=ncols, squeeze=False,
                        sharex=sharex, sharey=sharey,
                        gridspec_kw=gridspec_kw
                       )
    #fig.set_tight_layout(True)
    for ax in axes.flat:
        ax.xaxis.set_minor_locator(AutoMinorLocator())
        ax.yaxis.set_minor_locator(AutoMinorLocator())
        ax.tick_params(which='both', bottom=True, top=True, left=True, right=True)
        ax.tick_params(which='major', direction='inout', length=8, grid_alpha=.3)
        ax.tick_params(which='minor', direction='in', length=2, grid_alpha=.1)
        ax.grid(True, which='both')

    #fig.suptitle(f'{self.filename} {fig_name}')
    fig.suptitle(fig_name)
    
    return fig, axes


In [None]:
default_cmap = plt.get_cmap("gist_earth").copy()
default_cmap.set_bad('gray')


def colour_map(fig, ax, cblabel, data, cmap=default_cmap, norm=None, xlabel=None, x=None, ylabel=None, y=None):
    
    if norm is None:
        percentiles = np.array([1, 16, 50, 84, 99])
        ticks = np.nanpercentile(data, percentiles)
        linthresh = np.median(data[data > 0])
        norm = colors.SymLogNorm(vmin=ticks[0], vmax=ticks[-1], linthresh=linthresh)
    else:
        ticks = None
    if y is None:
        y = np.arange(data.shape[0])
    if x is None:
        x = np.arange(data.shape[1])

    im = ax.imshow(data,
                   extent=(x[0]-(x[1]-x[0])/2, x[-1]+(x[-1]-x[-2])/2, y[0]-(y[1]-y[0])/2, y[-1]+(y[-1]-y[-2])/2),
                   interpolation='nearest', origin='lower',
                   cmap=cmap,
                   norm=norm,
                  )
    ax.set_aspect('auto')
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)

    cb = fig.colorbar(im, ax=ax, orientation='vertical', shrink=.9)
    cb.ax.set_ylabel(cblabel)
    if ticks is not None:
        cb.ax.set_yticks(ticks=ticks, labels=[f'{value:.3g} ({percent})\%)' for value, percent in zip(ticks, percentiles)])
    cb.ax.tick_params(labelsize='small')
    
    return im, cb


### Load the science data

In [None]:
importlib.reload(weave)

filename = f"../tests/weave/single_3042890.fit"
rss = weave.weave_rss(filename)
print(f"File {filename} corresponds to object {rss.info['name']}")


In [None]:
rss_wavelength = rss.wavelength.to_value(u.Angstrom)

## Fibre throughput

### Lines and continuum

In [None]:
# TODO: guess min_separation

def find_continuum(x, y, min_separation=10):
    '''
    Fit lower envelope of a single spectrum:
    1) Find local minima, with a minimum separation `min_separation`.
    2) Interpolate linearly between them.
    3) Add "typical" (~median) offset.
    '''
    valleys = []
    y[np.isnan(y)] = np.inf
    for i in range(min_separation, y.size-min_separation-1):
        if np.argmin(y[i-min_separation:i+min_separation+1]) == min_separation:
            valleys.append(i)
    y[~np.isfinite(y)] = np.nan

    continuum = np.fmin(y, np.interp(x, x[valleys], y[valleys]))

    offset = y - continuum
    offset = np.nanpercentile(offset[offset > 0], np.linspace(1, 50, 51))
    density = (np.arange(offset.size) + 1) / offset
    offset = np.median(offset[density > np.max(density)/2])

    return continuum+offset, offset

min_separation = 10


In [None]:
n_spectra = rss.intensity.shape[0]
print(f"> Find continuum for {n_spectra} spectra:")
t0 = time()

continuum = np.zeros_like(rss.intensity)
continuum_err = np.zeros(n_spectra)
for i in range(n_spectra):
    continuum[i], continuum_err[i] = find_continuum(rss_wavelength, rss.intensity[i], min_separation)

print(f"  Done ({time()-t0:.3g} s)")

In [None]:
fig, axes = new_figure('lines-continuum', nrows=3)

im, cb = colour_map(fig, axes[0, 0], 'lines', rss.intensity - continuum, x=rss_wavelength)
im, cb = colour_map(fig, axes[1, 0], 'continuum', continuum, x=rss_wavelength)
im, cb = colour_map(fig, axes[2, 0], 'S/N', (rss.intensity - continuum) / continuum_err[:, np.newaxis], x=rss_wavelength)


### Strong sky lines

In [None]:
SN = (rss.intensity - continuum) / continuum_err[:, np.newaxis]
SN_p16, SN_p50, SN_p84 = np.nanpercentile(SN, [16, 50, 84], axis=0)

line_mask = 2*SN_p16 - SN_p84 > 0
line_mask[0] = False
line_mask[-1] = False

line_left = np.where(~line_mask[:-1] & line_mask[1:])[0]
line_right = np.where(line_mask[:-1] & ~line_mask[1:])[0]
line_right += 1
print(f'  {line_left.size} strong sky lines ({np.count_nonzero(line_mask)} out of {rss_wavelength.size} wavelengths)')

In [None]:
strong_sky_lines = Table((line_left, line_right), names=('left', 'right'))

# To be filled later:
strong_sky_lines.add_column(0., name='sky_wavelength')
strong_sky_lines.add_column(0., name='sky_intensity')

In [None]:
fig, axes = new_figure('strong_lines')

ax = axes[0, 0]
ax.plot(rss_wavelength, SN_p50, 'k-')
ax.fill_between(rss_wavelength, SN_p16, SN_p84, color='r', alpha=.1)
ax.plot(rss_wavelength, 2*SN_p16-SN_p84, 'k:')
for line in strong_sky_lines:
    ax.axvspan(rss_wavelength[line['left']], rss_wavelength[line['right']], color='c', alpha=.1)

### Reference wavelength and flux

In [None]:
from astropy import stats

sky = stats.biweight.biweight_location(rss.intensity, axis=0)

sky_cont, sky_err = find_continuum(rss_wavelength, sky)#, min_separation)
sky_lines = sky - sky_cont


In [None]:
for line in strong_sky_lines:
    wavelength = rss_wavelength[line['left']:line['right']]
    spectrum = sky_lines[line['left']:line['right']]
    weight = spectrum
    line['sky_wavelength'] = np.nansum(weight*wavelength) / np.nansum(weight)
    line['sky_intensity'] = np.nanmean(spectrum)


In [None]:
fig, axes = new_figure('sky')

ax = axes[0, 0]
ax.plot(rss_wavelength, sky, 'k-')
ax.plot(rss_wavelength, sky_cont, 'r-')
ax.plot(rss_wavelength, sky_cont + sky_err, 'r--')
ax.plot(rss_wavelength, sky_cont - sky_err, 'r-')
for line in strong_sky_lines:
    ax.axvspan(rss_wavelength[line['left']], rss_wavelength[line['right']], color='c', alpha=.1)
ax.set_ylim(0, 1000)

In [None]:
mean_flux = np.nanmean(rss.intensity, axis=1)
spec_id = np.nanargmax(mean_flux)
spec_id_cont, spec_id_err = find_continuum(rss_wavelength, rss.intensity[spec_id])#, min_separation)

In [None]:
fig, axes = new_figure('single_continuum-line', nrows=2, ncols=2)

ax = axes[0, 0]
ax.plot(rss_wavelength, rss.intensity[spec_id], 'k-', label=f'spec_id {spec_id}')
#ax.plot(rss_wavelength, rss.intensity[spec_id] - sky_lines, 'c-')
ax.fill_between(rss_wavelength, spec_id_cont-spec_id_err, spec_id_cont+spec_id_err, color='r', alpha=.1, label='continuum')
ax.plot(rss_wavelength, continuum[spec_id] + sky_lines, 'c-', label='continuum+sky lines')
ax.plot(rss_wavelength, spec_id_cont, 'r-')
for line in strong_sky_lines:
    ax.axvspan(rss_wavelength[line['left']], rss_wavelength[line['right']], color='c', alpha=.1)
ax.legend()

ax = axes[1, 0]
ax.plot(rss_wavelength, rss.intensity[spec_id] - spec_id_cont, 'c-')
ax.plot(rss_wavelength, rss.intensity[spec_id] - spec_id_cont - sky_lines, 'k-')
error = np.sqrt(sky_err**2 + spec_id_err**2)
ax.fill_between(rss_wavelength, -error, error, color='r', alpha=.1)
for line in strong_sky_lines:
    ax.axvspan(rss_wavelength[line['left']], rss_wavelength[line['right']], color='c', alpha=.1)
    ax.plot([rss_wavelength[line['left']], rss_wavelength[line['right']]], 2*[line['sky_intensity']], 'c-')
    ax.axvline(line['sky_wavelength'], c='c', ls='--')
    
    wavelength = rss_wavelength[line['left']:line['right']]
    spectrum = (rss.intensity[spec_id] - spec_id_cont)[line['left']:line['right']]
    weight = spectrum
    wavelength = np.nansum(weight*wavelength) / np.nansum(weight)
    intensity = np.nanmean(spectrum)
    ax.plot([rss_wavelength[line['left']], rss_wavelength[line['right']]], 2*[intensity], 'k:')
    ax.plot(wavelength, intensity, 'k.')
    #ax.axvline(wavelength, c='k', ls=':')
    


In [None]:
def measure_lines(lines, wavelength, intensity, continuum=None):
    if continuum is None:
        continuum, err = find_continuum(wavelength, intensity) #, min_separation)
    line_wavelength = np.zeros(len(lines))
    line_intensity = np.zeros(len(lines))
    for i, line in enumerate(strong_sky_lines):
        section_wavelength = wavelength[line['left']:line['right']]
        section_intensity = (intensity - continuum)[line['left']:line['right']]
        weight = section_intensity
        line_wavelength[i] = np.nansum(weight*section_wavelength) / np.nansum(weight)
        line_intensity[i] = np.nanmean(section_intensity)
    return line_wavelength, line_intensity

In [None]:
spec_id = 9
line_wavelength, line_intensity = measure_lines(strong_sky_lines, rss_wavelength, rss.intensity[spec_id])

fig, axes = new_figure('single_calibration', nrows=2)

ax = axes[0, 0]
ax.set_ylabel('relative $\Delta\lambda\ [\AA]$')
y = line_wavelength - strong_sky_lines['sky_wavelength']
location = stats.biweight.biweight_location(y)
scale = stats.biweight.biweight_scale(y)
ax.plot(line_wavelength, y, 'k.', label=f'lines in spec_id={spec_id}')
ax.axhline(location, c='r', ls='-')
ax.fill_between(line_wavelength, location-scale, location+scale, color='r', alpha=.1, label=f'${location:.2f}\pm{scale:.2f}$')
ax.legend()

ax = axes[1, 0]
ax.set_ylabel('relative throughput')
y = line_intensity/strong_sky_lines['sky_intensity']
location = stats.biweight.biweight_location(y)
scale = stats.biweight.biweight_scale(y)
ax.plot(line_wavelength, y, 'k.')
ax.axhline(location, c='r', ls='-')
ax.fill_between(line_wavelength, location-scale, location+scale, color='r', alpha=.1, label=f'${location:.2f}\pm{scale:.2f}$')
ax.legend()

ax.set_xlabel('$\lambda_{line}\ [\AA]$')



fig, axes = new_figure('single_throughput')

ax = axes[0, 0]
ax.set_ylabel(f'Intensity in spec_id={spec_id}')
ax.set_yscale('log')

ax.plot(strong_sky_lines['sky_intensity'], line_intensity, 'k.', label=f'lines in spec_id={spec_id}')
xx = np.nanpercentile(strong_sky_lines['sky_intensity'], np.linspace(0, 100, 101))
ax.plot(xx, location*xx, 'r-')
ax.fill_between(xx, (location-scale)*xx, (location+scale)*xx, color='r', alpha=.1, label=f'${location:.2f}\pm{scale:.2f}$')
ax.legend()

ax.set_xlabel('Intensity in sky spectrum')
ax.set_xscale('log')


# --- STOP ---

In [None]:
raise -1

In [None]:
import koala.corrections.sky as sky
importlib.reload(sky)
sky, sigma = sky.BackgroundEstimator.linear(rss.intensity_corrected)

In [None]:
total_flux = np.nanmean(rss.intensity_corrected, axis=1)
I = rss.intensity_corrected / total_flux[:, np.newaxis]

In [None]:
total_norm = np.sqrt(np.nanmean(rss.intensity_corrected**2, axis=1))

In [None]:
sed_sky = np.nanmedian(I, axis=0)
sed_sky /= np.nanmean(sed_sky**2)

In [None]:
sky_flux = np.nanmean(sed_sky[np.newaxis, :]*rss.intensity_corrected, axis=1)

In [None]:
np.nanmedian(sky_flux)

In [None]:
fig, axes = new_figure('I-flux')

ax = axes[0, 0]
ax.plot(total_flux, sky_flux, 'k.')

In [None]:
fig, axes = new_figure('map', nrows=2)

im, cb = colour_map(fig, axes[0, 0], 'I original', rss.intensity_corrected)
im, cb = colour_map(fig, axes[1, 0], 'I norm', I)


In [None]:
fig, axes = new_figure('hist')

ax = axes[0, 0]
ax.hist(total_flux, bins=np.nanpercentile(total_flux, np.linspace(0, 100, 26)), density=True)
ax.set_yscale('log')
ax.set_xscale('log')

In [None]:
fig, axes = new_figure('spectrum', nrows=2)

ax = axes[0, 0]
ax.plot(rss_wavelength, sed_obj)

ax = axes[1, 0]
ax.plot(rss_wavelength, sed_sky)

In [None]:
fig, axes = new_figure('sky_spectrum', nrows=3)

ax = axes[0, 0]
ax.plot(rss_wavelength, np.nanmean(rss.intensity, axis=0))

ax = axes[1, 0]
ax.plot(rss_wavelength, np.nanmean(rss.intensity-sky[np.newaxis, :], axis=0))

ax = axes[2, 0]
ax.plot(rss_wavelength, sky)

## Registration

### Image Cross-correlation

The most sofisticated method to perform the registration of extended sources included in pyKOALA is based on the cross-correlation of two images.

from koala.register.registration import register_crosscorr

figures = register_crosscorr(sci_rss, plot=True, quick_cube_pix_size=1.)
for fig in figures:
    plt.show(plt.figure(fig))

### Centroid finding

A simple approach to find the offset between the different RSS is to find the center of light of the images (assuming that they contain the same sources).

from koala.register.registration import register_centroid

figures = register_centroid(sci_rss, plot=True, quick_cube_pix_size=0.2,
                            centroider='gauss',
                            #subbox=[[150, 200], [20, 100]]
                           )

for fig in figures:
    plt.show(plt.figure(fig))

### Manual

Alternatively, it is also possible to provide a manual offset for the input RSS frames

In [None]:
from koala.register.registration import register_manual

#register_manual(sci_rss, [[0, 0], [1.5, 0], [3, 1.5]], absolute=False)

For interpolating RSS data into a 3D datacube we will make use of the function *build_cube*. This method requires as input:
- A list of RSS objects. 
- The desired dimensions of the cube expressed as a 2-element tuple, corresponding to (ra, dec) in arcseconds.
- The pixel size of the cube in arcseconds.
- A list containing the ADR correction for every RSS (it can contain None) in the form: [(ADR_ra_1, ADR_dec_1), (ADR_ra_2, ADR_dec_2), (None, None)].
- Additional information to be included in *cube_info*

In [None]:
importlib.reload(cubing)
cube = cubing.build_cube(rss_set=[rss],
                  cube_size_arcsec=(90, 90),  # (dec, ra)
                  pixel_size_arcsec=.5,
                  )
cube.info['name'] = rss.info['name'].split(' ')[0]

## Sky substraction

In [None]:
import koala.corrections.sky as sky
#import importlib
importlib.reload(sky)

skymodel = sky.SkyFromObject(cube, bckgr_estimator='mad', source_mask_nsigma=3, remove_cont=False)
skycorrection = sky.SkySubsCorrection(skymodel)
cube, _ = skycorrection.apply(cube)

## Plots

### Mean and median intensity maps

In [None]:
fig, axes = new_figure('cube_projection', nrows=2, ncols=2,
                       sharex=True, sharey=True,
                       figsize=(12, 8), gridspec_kw={'hspace': .1, 'wspace': 0.5})

im, cb = colour_map(fig, axes[0, 0], 'raw mean', np.nanmean(cube.intensity, axis=0))
im, cb = colour_map(fig, axes[0, 1], 'corrected mean', np.nanmean(cube.intensity_corrected, axis=0))
im, cb = colour_map(fig, axes[1, 0], 'raw median', np.nanmedian(cube.intensity, axis=0))
im, cb = colour_map(fig, axes[1, 1], 'corrected median', np.nanmedian(cube.intensity_corrected, axis=0))


In [None]:
fig, axes = new_figure('mean_spectrum', nrows=3)

ax = axes[0, 0]
ax.plot(cube.wavelength, np.nanmean(cube.intensity, axis=(1,2)))

ax = axes[1, 0]
ax.plot(cube.wavelength, np.nanmean(cube.intensity_corrected, axis=(1,2)))
ax.set_ylim(-.005, .01)

ax = axes[2, 0]
ax.plot(skymodel.wavelength, skymodel.bckgr)
