# Corrections based on sky emission lines

Compute relative offsets in wavelength (in pixels) and flux (arbitrary units) based on the sky emission lines, detected through a wavelet filter

# 1. Initialisation

In [None]:
%matplotlib ipympl
from matplotlib import pyplot as plt
import numpy as np

In [None]:
from pykoala import __version__
from pykoala.instruments import koala_ifu, weave
from pykoala.corrections import sky
from pykoala.corrections.throughput import ThroughputCorrection
from pykoala.corrections.wavelength import WavelengthCorrection
print("pyKOALA version: ", __version__)

The following will probably disappear in the final version of the tutorial

In [None]:
import importlib
from astropy import stats
from astropy import units as u
from pykoala.plotting.utils import new_figure, colour_map
import scipy

# 2. Load the science data
This must be a Row-Stacked Spectra (RSS) file. Please choose one of the following examples

In [None]:
example = 'KOALA'
#example = 'WEAVE'

KOALA:

In [None]:
if example == 'KOALA':
    filename = f"../data/27feb20036red.fits"
    rss0 = koala_ifu.koala_rss(filename)

WEAVE:

In [None]:
if example == 'WEAVE':
    filename = f"../data/weave/single_3042890.fit"
    rss0 = weave.weave_rss(filename)

## Summary

In [None]:
print(f"Analysing object {rss0.info['name']} read from {filename}")
print('- info:')
print(rss0.info.keys())
print('- log:')
rss0.log.show()


# 3. Wavelet filter

In [None]:
importlib.reload(sky)

## First iteration

In [None]:
wavelet1 = sky.WaveletFilter(rss0)

In [None]:
throughput_corr = ThroughputCorrection(throughput=wavelet1.get_throughput_object())

In [None]:
wave_corr = WavelengthCorrection(offset=wavelet1.get_wavelength_offset())

In [None]:
rss1 = wave_corr.apply(rss0)
rss1 = throughput_corr.apply(rss1)

## Second iteration

In [None]:
wavelet2 = sky.WaveletFilter(rss1)

In [None]:
throughput_corr = ThroughputCorrection(throughput=wavelet2.get_throughput_object())

In [None]:
wave_corr = WavelengthCorrection(offset=wavelet2.get_wavelength_offset())

In [None]:
rss2 = wave_corr.apply(rss1)
rss2 = throughput_corr.apply(rss2)

In [None]:
rss2.log.show()

# 4. Quality control plots

## Wavlelet filter

In [None]:
wavelet1.qc_plots()

## Relative calibration

In [None]:
fig, axes = new_figure('sky-based relative calibration', nrows=2)

ax = axes[0, 0]
ax.set_ylabel('relative throughput')
ax.set_ylim(.4, 2.1)

ax.plot(wavelet1.fibre_throughput, 'k-', label='iter 1')
ax.plot(wavelet2.fibre_throughput, 'k-', alpha=.5, label='iter 2')

p16, p50, p84 = np.nanpercentile(wavelet1.filtered / wavelet1.sky[np.newaxis, :], [16, 50, 84], axis=1)
ax.plot(p50, 'r--', alpha=.5)
ax.fill_between(np.arange(p50.size), p16, p84, color='r', alpha=0.1)
p16, p50, p84 = np.nanpercentile(wavelet2.filtered / wavelet2.sky[np.newaxis, :], [16, 50, 84], axis=1)
ax.plot(p50, 'y--', alpha=.5)

ax.legend()


ax = axes[1, 0]
ax.set_ylabel('relative offset [pix]')

ax.plot(wavelet1.fibre_offset, 'k-', label='iter 1')
ax.plot(wavelet2.fibre_offset, 'k-', alpha=.5, label='iter 2')

ax.legend()
ax.set_xlabel('fibre')
,

## Single fibre tests

In [None]:
fibre = 42

In [None]:
fig, axes = new_figure('single throughput', nrows=3, sharex=False, sharey=False, gridspec_kw={'hspace': .2})

ax = axes[0, 0]
ax.hist(wavelet1.filtered[fibre]*wavelet1.fibre_throughput[fibre] / wavelet1.sky, bins=np.linspace(0, 7.5, 101), color='k', alpha=.2)
ax.hist(wavelet1.filtered[fibre] / wavelet1.sky, bins=np.linspace(0, 7.5, 101), color='b', alpha=.5)

ax = axes[1, 0]
ax.set_ylabel('wavelet')
ax.plot(wavelet1.wavelength, wavelet1.sky, 'b:', alpha=.5, label='sky')
ax.plot(wavelet1.wavelength, wavelet1.filtered[fibre]*wavelet1.fibre_throughput[fibre], 'k--', alpha=.2, label='data')
ax.plot(wavelet1.wavelength, wavelet1.filtered[fibre], 'k-', alpha=.5, label='throughput-corrected')
ax.plot(wavelet2.wavelength, wavelet2.filtered[fibre], 'r-', alpha=.5, label='iter 2')
ax.legend()


ax = axes[2, 0]
ax.set_ylabel('intensity')
ax.sharex(axes[1, 0])

x = rss0.wavelength
if isinstance(x, u.Quantity):
    x = x.to_value(u.AA)

ax.plot(x, rss0.intensity[fibre], 'k--', alpha=.2, label=f'fibre {fibre}')
median_sky = np.nanmedian(rss0.intensity, axis=0)
ax.plot(x, median_sky, 'c-', alpha=.5, label='median sky')
ax.plot(x, rss0.intensity[fibre] - median_sky, 'c--', alpha=.2, label='sky-subtracted')


ax.plot(x, rss1.intensity[fibre], 'k-', alpha=.5, label='corrected')
median_sky = np.nanmedian(rss1.intensity, axis=0)
ax.plot(x, median_sky, 'b:', alpha=.5, label='median sky')
ax.plot(x, rss1.intensity[fibre] - median_sky, 'c-', alpha=.5, label='sky-subtracted')

ax.plot(x, rss2.intensity[fibre], 'r-', alpha=.5, label='iter 2')
median_sky = np.nanmedian(rss2.intensity, axis=0)
ax.plot(x, rss2.intensity[fibre] - median_sky, 'r:', alpha=.5, label='sky')

ax.legend()
ax.set_xlabel('wavelength [\AA]')
,

In [None]:
fig, axes = new_figure(f'single wavelength calibration')

ax = axes[0, 0]

mid = wavelet1.wavelength.size // 2
s = wavelet1.scale
x = np.nanmedian(wavelet1.filtered, axis=0)
x[~ np.isfinite(x)] = 0
x = scipy.signal.fftconvolve(wavelet1.filtered[fibre], x[::-1], mode='same')[mid-s:mid+s+1]
idx = np.arange(x.size)

ax.plot(idx - s, x/np.max(x), 'k-', label=f'fibre {fibre}')
ax.axvline(wavelet1.fibre_offset[fibre], c='k', ls=':', label=f'offset = {wavelet1.fibre_offset[fibre]:.2f} pix')
ax.plot(2*wavelet1.fibre_offset[fibre] - (idx - s), x/np.max(x), 'k--', alpha=.25, label='reflected cross-correlation')

ax.legend()
ax.set_ylabel('cross-correlation with sky')
ax.set_xlabel('offset [pix]')
,

## Corrected intensities

In [None]:
fig, axes = new_figure('skyline-based correction',
                       nrows=2, ncols=3, sharex=False, sharey=False,
                       gridspec_kw={'left': 0.07, 'right':0.9, 'width_ratios': [1, 1, .05], 'hspace': 0.25, 'wspace': 0.25})

x = rss0.wavelength
if isinstance(x, u.Quantity):
    x = x.to_value(u.AA)

im, cb = colour_map(fig, axes[0, 0], 'intensity', rss0.intensity, x=x, ylabel='fibre', cbax=axes[0, 2])
im, cb = colour_map(fig, axes[0, 1], '.', rss1.intensity, x=x, norm=im.norm, cbax=False)

im, cb = colour_map(fig, axes[1, 0], 'median-subtracted', rss0.intensity - np.nanmedian(rss0.intensity, axis=0), x=x, cbax=axes[1, 2])
im, cb = colour_map(fig, axes[1, 1], '.', rss1.intensity - np.nanmedian(rss1.intensity, axis=0), x=x, norm=im.norm, cbax=False)

for ax in [axes[0, 1], axes[1, 0], axes[1, 1]]:
    ax.sharex(axes[0, 0])
    ax.sharey(axes[0, 0])
