# Corrections based on sky emission lines

Compute relative offsets in wavelength (in pixels) and flux (arbitrary units) based on the sky emission lines, detected through a wavelet filter

# 1. Initialisation

In [None]:
%matplotlib ipympl
from matplotlib import pyplot as plt
import numpy as np

In [None]:
from pykoala import __version__
from pykoala.instruments import koala_ifu, weave
from pykoala.corrections import sky
from pykoala.corrections.throughput import ThroughputCorrection
from pykoala.corrections.wavelength import WavelengthCorrection
print("pyKOALA version: ", __version__)

The following will probably disappear in the final version of the tutorial

In [None]:
import importlib
from astropy import stats
from astropy import units as u
from pykoala.plotting.utils import new_figure, colour_map
import scipy

# 2. Load the science data
This must be a Row-Stacked Spectra (RSS) file. Please choose one of the following examples

In [None]:
example = 'KOALA'
#example = 'WEAVE'

KOALA:

In [None]:
if example == 'KOALA':
    filename = f"../data/27feb20036red.fits"
    rss0 = koala_ifu.koala_rss(filename)
    wavelength_AA = rss0.wavelength

WEAVE:

In [None]:
if example == 'WEAVE':
    filename = f"../data/weave/single_3042890.fit"
    rss0 = weave.weave_rss(filename)
    wavelength_AA = rss0.wavelength.to_value(u.Angstrom)

## Summary

In [None]:
print(f"Analysing object {rss0.info['name']} read from {filename}")
print('- info:')
print(rss0.info.keys())
print('- log:')
rss0.history.show()


# 3. Wavelet filter

In [None]:
importlib.reload(sky)

## First iteration

In [None]:
wavelet1 = sky.WaveletFilter(rss0)

In [None]:
throughput_corr = ThroughputCorrection(throughput=wavelet1.get_throughput_object())

In [None]:
wave_corr = WavelengthCorrection(offset=wavelet1.get_wavelength_offset())

In [None]:
rss1 = wave_corr.apply(rss0)
rss1 = throughput_corr.apply(rss1)

## Second iteration

In [None]:
wavelet2 = sky.WaveletFilter(rss1)

In [None]:
throughput_corr = ThroughputCorrection(throughput=wavelet2.get_throughput_object())

In [None]:
wave_corr = WavelengthCorrection(offset=wavelet2.get_wavelength_offset())

In [None]:
rss2 = wave_corr.apply(rss1)
rss2 = throughput_corr.apply(rss2)

In [None]:
rss2.history.show()

# 4. Sky subtraction

In [None]:
importlib.reload(sky)

In [None]:
sky0_fit = sky.SkyFromObject(rss0, bckgr_estimator='fit', bckgr_params={'wavelet': wavelet1})

In [None]:
sky_fit = sky.SkyFromObject(rss1, bckgr_estimator='fit', bckgr_params={'wavelet': wavelet1})

In [None]:
sky2_fit = sky.SkyFromObject(rss2, bckgr_estimator='fit', bckgr_params={'wavelet': wavelet2})

In [None]:
sky_med = sky.SkyFromObject(rss1, bckgr_estimator='percentile')

In [None]:
sky_mad = sky.SkyFromObject(rss1, bckgr_estimator='mad', source_mask_nsigma=3)

In [None]:
fig, axes = new_figure('sky model')

ax = axes[0, 0]

ax.plot(wavelength_AA, sky_med.intensity, 'b-', alpha=.5, label='median')
std = np.sqrt(sky_med.variance)
ax.fill_between(wavelength_AA, sky_med.intensity-std, sky_med.intensity+std, color='b', alpha=.1, label='$16-84$ \\%')

ax.plot(wavelength_AA, sky_mad.intensity, 'b--', alpha=.5, label='MAD + 3-$\\sigma$ clip')

ax.plot(wavelength_AA, sky0_fit.intensity, 'k:', alpha=.5, label='uncorrected fit')
ax.plot(wavelength_AA, sky_fit.intensity, 'k-', alpha=.5, label='fit')
ax.plot(wavelength_AA, sky2_fit.intensity, 'r-', alpha=.5, label='iter 2')

ax.legend()
,

# 5. Quality control plots

## Wavlelet filter

In [None]:
wavelet1.qc_plots()

## Relative calibration

In [None]:
fig, axes = new_figure('sky-based relative calibration', nrows=2)

ax = axes[0, 0]
ax.set_ylabel('relative throughput')
ax.set_ylim(.4, 2.1)

ax.plot(wavelet1.fibre_throughput, 'k-', label='iter 1')
ax.plot(wavelet2.fibre_throughput, 'k-', alpha=.5, label='iter 2')

p16, p50, p84 = np.nanpercentile(wavelet1.filtered / wavelet1.sky[np.newaxis, :], [16, 50, 84], axis=1)
ax.plot(p50, 'r--', alpha=.5)
ax.fill_between(np.arange(p50.size), p16, p84, color='r', alpha=0.1)
p16, p50, p84 = np.nanpercentile(wavelet2.filtered / wavelet2.sky[np.newaxis, :], [16, 50, 84], axis=1)
ax.plot(p50, 'y--', alpha=.5)

ax.legend()


ax = axes[1, 0]
ax.set_ylabel('relative offset [pix]')

ax.plot(wavelet1.fibre_offset, 'k-', label='iter 1')
ax.plot(wavelet2.fibre_offset, 'k-', alpha=.5, label='iter 2')

ax.legend()
ax.set_xlabel('fibre')
,

In [None]:
from pykoala.plotting import utils
from matplotlib import colors
importlib.reload(utils)
fig, axes = new_figure('relative calibration maps',
                       nrows=2, ncols=2, sharey=False, figsize=(8, 10),
                       gridspec_kw={'left': 0.07, 'right':0.9, 'width_ratios': [1, .05], 'hspace': 0.05, 'wspace': 0.25})

im, cb = utils.fibre_map(fig, axes[0, 0], 'relative throughput', rss0, wavelet1.fibre_throughput, cbax=axes[0, 1], cmap='Spectral', norm=colors.Normalize())
im, cb = utils.fibre_map(fig, axes[1, 0], 'relative offset', rss0, wavelet1.fibre_offset, cbax=axes[1, 1], cmap='Spectral', norm=colors.Normalize())

axes[0, 0].sharey(axes[1, 0])


## Single fibre tests

In [None]:
fibre = 42

In [None]:
fig, axes = new_figure('single throughput', nrows=3, sharex=False, sharey=False, gridspec_kw={'hspace': .2})

ax = axes[0, 0]
ax.hist(wavelet1.filtered[fibre]*wavelet1.fibre_throughput[fibre] / wavelet1.sky, bins=np.linspace(0, 7.5, 101), color='k', alpha=.2)
ax.hist(wavelet1.filtered[fibre] / wavelet1.sky, bins=np.linspace(0, 7.5, 101), color='b', alpha=.5)

ax = axes[1, 0]
ax.set_ylabel('wavelet')
ax.plot(wavelet1.wavelength, wavelet1.sky, 'b:', alpha=.5, label='sky')
ax.plot(wavelet1.wavelength, wavelet1.filtered[fibre]*wavelet1.fibre_throughput[fibre], 'k--', alpha=.2, label='data')
ax.plot(wavelet1.wavelength, wavelet1.filtered[fibre], 'k-', alpha=.5, label='throughput-corrected')
ax.plot(wavelet2.wavelength, wavelet2.filtered[fibre], 'r-', alpha=.5, label='iter 2')
ax.legend()


ax = axes[2, 0]
ax.set_ylabel('intensity')
ax.sharex(axes[1, 0])

x = rss0.wavelength
if isinstance(x, u.Quantity):
    x = x.to_value(u.AA)

ax.plot(x, rss0.intensity[fibre], 'k--', alpha=.2, label=f'fibre {fibre}')
ax.plot(x, rss1.intensity[fibre], 'k-', alpha=.5, label='corrected')
ax.plot(x, rss1.intensity[fibre] - sky_fit.intensity, 'c-', alpha=.5, label='sky-subtracted')
ax.plot(x, rss2.intensity[fibre] - sky2_fit.intensity, 'c--', alpha=.5, label='iter 2')

ax.legend()
ax.set_xlabel('wavelength [\AA]')


In [None]:
fig, axes = new_figure(f'single wavelength calibration')

ax = axes[0, 0]

mid = wavelet1.wavelength.size // 2
s = wavelet1.scale
x = np.nanmedian(wavelet1.filtered, axis=0)
x[~ np.isfinite(x)] = 0
x = scipy.signal.fftconvolve(wavelet1.filtered[fibre], x[::-1], mode='same')[mid-s:mid+s+1]
idx = np.arange(x.size)

ax.plot(idx - s, x/np.max(x), 'k-', label=f'fibre {fibre}')
ax.axvline(wavelet1.fibre_offset[fibre], c='k', ls=':', label=f'offset = {wavelet1.fibre_offset[fibre]:.2f} pix')
ax.plot(2*wavelet1.fibre_offset[fibre] - (idx - s), x/np.max(x), 'k--', alpha=.25, label='reflected cross-correlation')

ax.legend()
ax.set_ylabel('cross-correlation with sky')
ax.set_xlabel('offset [pix]')
,

## Corrected intensities

In [None]:
fig, axes = new_figure('skyline-based correction',
                       nrows=2, ncols=3, sharex=False, sharey=False,
                       gridspec_kw={'left': 0.07, 'right':0.9, 'width_ratios': [1, 1, .05], 'hspace': 0.25, 'wspace': 0.25})

x = rss0.wavelength
if isinstance(x, u.Quantity):
    x = x.to_value(u.AA)

im, cb = colour_map(fig, axes[0, 0], 'intensity', rss0.intensity, x=x, ylabel='fibre', cbax=axes[0, 2])
im, cb = colour_map(fig, axes[0, 1], '.', rss1.intensity, x=x, norm=im.norm, cbax=False)

im, cb = colour_map(fig, axes[1, 0], 'subtracted', rss0.intensity - sky0_fit.intensity, x=x, cbax=axes[1, 2])
im, cb = colour_map(fig, axes[1, 1], '.', rss1.intensity - sky_fit.intensity, x=x, norm=im.norm, cbax=False)

for ax in [axes[0, 1], axes[1, 0], axes[1, 1]]:
    ax.sharex(axes[0, 0])
    ax.sharey(axes[0, 0])


## Sky subtraction

In [None]:
flux = np.nanmean(rss1.intensity, axis=1)
mean_flux = np.nanmean(flux)
flux_cut_low = np.nanmedian(flux[flux < mean_flux])
flux_cut_hi = 2*mean_flux - flux_cut_low
flux_low = np.nanmean(flux[flux < flux_cut_low])
flux_med = np.nanmean(flux[(flux > flux_cut_low) & (flux < mean_flux)])
flux_hi = np.nanmean(flux[(flux > mean_flux) & (flux < flux_cut_hi)])

I_low = np.nanmean(rss1.intensity[flux < flux_cut_low, :], axis=0)
I_med = np.nanmean(rss1.intensity[(flux > flux_cut_low) & (flux < mean_flux), :], axis=0)
I_hi = np.nanmean(rss1.intensity[(flux > mean_flux) & (flux < flux_cut_hi), :], axis=0)
m = (I_hi - I_low) / (flux_hi - flux_low)
b = I_low - m * flux_low

sky_flux_candidate = np.arange(0, flux_cut_hi, .01*np.min(flux))
sky_filtered = m[np.newaxis, :] * sky_flux_candidate[:, np.newaxis] + b[np.newaxis, :]
x = np.nancumsum(sky_filtered, axis=1)
s = wavelet1.scale
sky_filtered = (x[:, 2*s:-s] - x[:, s:-2*s]) / s
sky_filtered -= (x[:, 3*s:] - x[:, :-3*s]) / (3*s)
chi2_sky = np.nanstd(sky_filtered*wavelet1.sky_weight - wavelet1.sky, axis=1)
chi2_no_sky = np.nanstd(sky_filtered*(1 - wavelet1.sky_weight), axis=1)

sky_flux = sky_flux_candidate[np.nanargmin(chi2_no_sky)]
sky_intensity = b + m*sky_flux

In [None]:
fig, axes = new_figure('get sky flux', nrows=1)

ax = axes[0, 0]
ax.plot(sky_flux_candidate, chi2_sky, 'b-')
ax.plot(sky_flux_candidate, chi2_no_sky, 'r-')
ax.axvline(sky_flux, c='k', ls=':', label=f'sky flux = {sky_flux:.2f}')
ax.legend()
,

In [None]:
fig, axes = new_figure('fit coefficients', nrows=4)

ax = axes[0, 0]
ax.set_ylabel('m')
ax.plot(wavelength_AA, m, 'k-')
#ax.plot(wavelet1.wavelength, wavelet1.sky_weight, 'b-')
#ax.plot(rss1.wavelength, sky_intensity / np.nanmean(sky_intensity), 'c-')


ax = axes[1, 0]
ax.set_ylabel('b')
ax.plot(wavelength_AA, sky_fit.intensity, 'r-')
ax.plot(wavelength_AA, sky_intensity, 'c-')
ax.plot(wavelength_AA, b, 'k-')
#ax.set_ylim(-10, 500)


ax = axes[2, 0]
s = wavelet1.scale
ax.set_ylabel(f'wavelet b ({s} pix)')

x = np.nancumsum(sky_intensity)
sky_filtered = (x[2*s:-s] - x[s:-2*s]) / s
sky_filtered -= (x[3*s:] - x[:-3*s]) / (3*s)
ax.plot(wavelet1.wavelength, sky_filtered, 'r-', alpha=.5)
ax.plot(wavelet1.wavelength, sky_filtered * wavelet1.sky_weight, 'k-', alpha=.5)
ax.plot(wavelet1.wavelength, wavelet1.sky, 'c-', alpha=.5)

ax = axes[3, 0]
ax.plot(wavelet1.wavelength, wavelet1.sky_weight)

## Single wavelength

In [None]:
wl = 6700
idx = np.searchsorted(wavelength_AA, wl)

In [None]:
fig, axes = new_figure('linear fit')

ax = axes[0, 0]

ax.plot(flux, rss1.intensity[:, idx], 'k.')
ax.axvline(mean_flux, c='r', ls='--', label=f'mean flux: {mean_flux:.2f}')
ax.axvline(flux_cut_low, c='r', ls=':', label=f'flux low: {flux_cut_low:.2f}')
ax.axvline(flux_cut_hi, c='r', ls=':', label=f'flux high: {flux_cut_hi:.2f}')

x = np.nanpercentile(flux, np.linspace(0, 100, 101))
ax.plot(x, m[idx]*x + b[idx], 'b:', label=f'm={m[idx]:.2f} b={b[idx]:.2f}')
ax.plot([flux_low, flux_med, flux_hi], [I_low[idx], I_med[idx], I_hi[idx]], 'ro-')

ax.plot(sky_flux, sky_intensity[idx], 'co', label=f'Sky level = {sky_intensity[idx]:.2f}')

ax.legend()
ax.set_xlabel('mean fibre intensity')
ax.set_ylabel(f'intensity at $\\lambda={wavelength_AA[idx]:.2f}$ \\AA')
ax.set_xscale('log')
ax.set_yscale('log')

In [None]:
fig, axes = new_figure('monochrome hist')

ax = axes[0, 0]

x = np.linspace(b[idx], 3*sky_intensity[idx] - 2*b[idx], 101)

ax.hist(rss1.intensity[:, idx], bins=x, alpha=.5)
ax.axvline(sky_intensity[idx], c='k', ls='--')

ax.legend()
ax.set_ylabel('number of fibres')
ax.set_xlabel(f'intensity at $\\lambda={wavelength_AA[idx]:.2f}$ \\AA')
#ax.set_xscale('log')
ax.set_yscale('log')

In [None]:
from pykoala.plotting import utils
from matplotlib import colors
importlib.reload(utils)
fig, axes = new_figure('correction maps',
                       nrows=3, ncols=2, sharey=False, figsize=(8, 12),
                       gridspec_kw={'left': 0.07, 'right':0.9, 'width_ratios': [1, .05], 'hspace': 0.05, 'wspace': 0.25})

im, cb = utils.fibre_map(fig, axes[0, 0], f'original $I_\\lambda({wavelength_AA[idx]:.2f} \\AA)$', rss0, rss0.intensity[:, idx], cbax=axes[0, 1])
im, cb = utils.fibre_map(fig, axes[1, 0], 'throughput corrected', rss1, rss1.intensity[:, idx], cbax=axes[1, 1])
im, cb = utils.fibre_map(fig, axes[2, 0], 'sky subtracted', rss1, rss1.intensity[:, idx] - sky_fit.intensity[idx], cbax=axes[2, 1])

for ax in axes[1:, 0]:
    ax.sharey(axes[0, 0])
