# Corrections based on sky emission lines

Compute relative offsets in wavelength (in pixels) and flux (arbitrary units) based on the sky emission lines, detected through a wavelet filter

# 1. Initialisation

In [None]:
%matplotlib ipympl
from matplotlib import pyplot as plt
import numpy as np

In [None]:
from pykoala import __version__
from pykoala.instruments import koala_ifu, weave
from pykoala.corrections import sky
from pykoala.corrections.throughput import ThroughputCorrection
print("pyKOALA version: ", __version__)

The following will probably disappear in the final version of the tutorial

In [None]:
import importlib
from time import time
from astropy import stats
from astropy import units as u
from pykoala.plotting.utils import new_figure, colour_map

# 2. Load the science data
This must be a Row-Stacked Spectra (RSS) file. Please choose one of the following examples

In [None]:
example = 'KOALA'
#example = 'WEAVE'

KOALA:

In [None]:
if example == 'KOALA':
    filename = f"../data/27feb20036red.fits"
    rss = koala_ifu.koala_rss(filename)
    rss_wavelength_AA = rss.wavelength

WEAVE:

In [None]:
if example == 'WEAVE':
    filename = f"../data/weave/single_3042890.fit"
    rss = weave.weave_rss(filename)
    rss_wavelength_AA = rss.wavelength.to_value(u.Angstrom)

## Basic information:

In [None]:
print(f"Analysing object {rss.info['name']} read from {filename}")

# 3. Wavelet filter

In [None]:
importlib.reload(sky)
wavelet = sky.WaveletFilter(rss)

In [None]:
wavelet.qc_plots()

# 4. Fibre throughput

In [None]:
fig, axes = new_figure('wavelet throughput')

reference = np.nanmedian(wavelet.filtered, axis=0)
reference_std = np.nanstd(reference)

ax = axes[0, 0]

ax.plot(wavelet.fibre_throughput, 'k-')

p16, p50, p84 = np.nanpercentile(wavelet.filtered / wavelet.sky[np.newaxis, :], [16, 50, 84], axis=1)
ax.plot(p50, 'r--', alpha=.5)
ax.fill_between(np.arange(p50.size), p16, p84, color='r', alpha=0.1)

ax.set_ylim(.4, 2.1)

Test a single fibre:

In [None]:
fibre = 140

fig, axes = new_figure('test throughput', nrows=3, sharex=False, sharey=False, gridspec_kw={'hspace': .2})

ax = axes[0, 0]
ax.hist(wavelet.filtered[fibre] / wavelet.sky, bins=np.linspace(.5, 1.5, 101))

ax = axes[1, 0]
ax.plot(wavelet.wavelength, wavelet.filtered[fibre]*wavelet.fibre_throughput[fibre], 'k--', alpha=.2)
ax.plot(wavelet.wavelength, wavelet.filtered[fibre], 'k-', alpha=.5)
ax.plot(wavelet.wavelength, wavelet.sky, 'b:', alpha=.5)

ax = axes[2, 0]
ax.sharex(axes[1, 0])
median_sky = np.nanmedian(rss.intensity, axis=0)
x = rss.wavelength
if isinstance(x, u.Quantity):
    x = x.to_value(u.AA)
ax.plot(x, rss.intensity[fibre], 'k--', alpha=.2)
ax.plot(x, rss.intensity[fibre]/wavelet.fibre_throughput[fibre], 'k-', alpha=.5)
ax.plot(x, rss.intensity[fibre]/wavelet.fibre_throughput[fibre] - median_sky, 'c-', alpha=.5)
ax.plot(x, median_sky, 'b:', alpha=.5)

,

## Create and apply throughput correction

In [None]:
throughput_corr = ThroughputCorrection(throughput=wavelet.get_throughput_object())

In [None]:
rss_out = throughput_corr.apply(rss)

In [None]:
fig, axes = new_figure('throughput correction', nrows=2, ncols=2, sharey=False, gridspec_kw={'width_ratios': [1, .02], 'hspace': 0.05, 'wspace': .1})

x = rss.wavelength
if isinstance(x, u.Quantity):
    x = x.to_value(u.AA)

im, cb = colour_map(fig, axes[0, 0], 'before', rss.intensity - np.nanmedian(rss.intensity, axis=0), x=x, ylabel='spec_id', cbax=axes[0, 1])
im, cb = colour_map(fig, axes[1, 0], 'after', rss_out.intensity - np.nanmedian(rss_out.intensity, axis=0), x=x, ylabel='spec_id', cbax=axes[1, 1])

axes[1, 0].sharey(axes[0, 0])


# 5. Wavelength calibration

# -- OLD STUFF --

In [None]:
raise -1

## Wavelength calibration

In [None]:
def compute_correlation(wavelength, intensity, intensity_ref, offsets):
    '''
    Compute cross-correlation with respect to a reference spectrum.
    '''
    t0 = time()
    print('> Computing cross-correlation; please be patient...')
    
    if intensity.ndim == 2:
        normed = intensity - np.nanmean(intensity, axis=1)[:, np.newaxis]
    elif intensity.ndim == 1:
        normed = np.reshape(intensity - np.nanmean(intensity), (1, intensity.size))
    else:
        # TODO: proper error handling
        print(f'ERROR: intensity has {intensity.ndim} dimensions')
        raise -1
    if intensity_ref.ndim != 1:
        # TODO: proper error handling
        print(f'ERROR: reference intensity has {intensity_ref.ndim} dimensions')
        raise -1
    
    normed /= np.sqrt(np.nanmean(normed**2, axis=1))[:, np.newaxis]
    normed_ref = intensity_ref - np.nanmean(intensity_ref)
    normed_ref /= np.sqrt(np.nanmean(normed_ref**2))
    correlation = np.empty((offsets.size, normed.shape[0]))
    for i, offset in enumerate(offsets):
        # TODO: proper interpolation based on cumulant
        correlation[i] = np.nanmean(normed*np.interp(wavelength, wavelength-offset, normed_ref)[np.newaxis, :], axis=1)
    print(f"  Done ({time()-t0:.3g} s)")
    return correlation

In [None]:
offset = np.linspace(-10, 10, 501)
old_correlation = compute_correlation(wavelet.wavelength, wavelet.filtered, wavelet.sky, offset)

In [None]:
old_fibre_offset = np.nanmean(offset[:, np.newaxis] * old_correlation**2, axis=0) / np.nanmean(old_correlation**2, axis=0)

## Self-calibration based on sky lines

In [None]:
importlib.reload(sky)
rss_autocal = sky.SkySelfCalibration(rss)

### Lines and continuum

In [None]:
fig, axes = new_figure('lines-continuum', nrows=3)

im, cb = colour_map(fig, axes[0, 0], 'lines', rss.intensity - rss_autocal.continuum.intensity, x=rss_wavelength, xlabel='$\lambda\ [\AA]$', ylabel='spec_id')
im, cb = colour_map(fig, axes[1, 0], 'continuum', rss_autocal.continuum.intensity, x=rss_wavelength, xlabel='$\lambda\ [\AA]$', ylabel='spec_id')
im, cb = colour_map(fig, axes[2, 0], 'line emission S/N', (rss.intensity - rss_autocal.continuum.intensity) / rss_autocal.continuum.scale[:, np.newaxis], x=rss_wavelength, xlabel='$\lambda\ [\AA]$', ylabel='spec_id')

axes[0, 0].get_shared_y_axes().join(axes[0, 0], *axes[1:, 0])


### Sky lines

In [None]:
fig, axes = new_figure('sky_lines')

ax = axes[0, 0]
ax.plot(rss_wavelength, rss_autocal.sky_intensity, 'k-')
for line in rss_autocal.continuum.strong_sky_lines:
    ax.axvspan(rss_wavelength[line['left']], rss_wavelength[line['right']], color='c', alpha=.1)
    ax.axvline(line['sky_wavelength'].to_value(u.Angstrom), c='b', alpha=.5)

In [None]:
importlib.reload(sky)
lsf = sky.LSF_estimator(wavelength_range=20*u.Angstrom, resolution=0.01*u.Angstrom)
sky_lsf = lsf.find_LSF(rss_autocal.continuum.strong_sky_lines['sky_wavelength'], rss_autocal.dc.wavelength, rss_autocal.sky_intensity)
weight = sky_lsf**2
sky_offset = np.sum(weight*lsf.delta_lambda) / np.sum(weight)

print(f'Sky lines FWHM = {lsf.find_FWHM(sky_lsf):.4g}, offset = {sky_offset:.4g}')

In [None]:
fig, axes = new_figure('sky_LSF')

ax = axes[0, 0]
ax.plot(lsf.delta_lambda, sky_lsf, 'k-', label=f'FWHM = {lsf.find_FWHM(sky_lsf):.4g}')
ax.axvline(sky_offset.to_value(u.Angstrom), c='k', ls='--', label=f'offset = {sky_offset:.4g}')
ax.plot(-lsf.delta_lambda+2*sky_offset, sky_lsf, 'k:', label=f'symmetric')
ax.legend()

### Single spectrum

In [None]:
from astropy import stats

spec_id = 42
line_wavelength, line_intensity = rss_autocal.measure_lines(spec_id)
line_wavelength = line_wavelength.to_value(u.Angstrom)

fig, axes = new_figure('single_calibration', nrows=2, ncols=2)

ax = axes[0, 0]
ax.set_ylabel('relative $\Delta\lambda\ [\AA]$')
y = line_wavelength - rss_autocal.continuum.strong_sky_lines['sky_wavelength'].to_value(u.Angstrom)
location = stats.biweight.biweight_location(y)
scale = stats.biweight.biweight_scale(y)
ax.plot(line_wavelength, y, 'k.', label=f'sky lines in spec_id={spec_id}')
ax.axhline(location, c='r', ls='-')
ax.fill_between(line_wavelength, location-scale, location+scale, color='r', alpha=.1)#, label=f'${location:.2f}\pm{scale:.2f}$')
ax.legend()

ax = axes[0, 1]
ax.plot(rss_autocal.continuum.strong_sky_lines['sky_intensity'], y, 'k.')
xx = np.nanpercentile(rss_autocal.continuum.strong_sky_lines['sky_intensity'], np.linspace(0, 100, 101))
ax.axhline(location, c='r', ls='-')
ax.fill_between(xx, location-scale, location+scale, color='r', alpha=.1, label=f'${location:.2f}\pm{scale:.2f}$')
ax.legend()


ax = axes[1, 0]
ax.set_ylabel('relative throughput')
y = line_intensity/rss_autocal.continuum.strong_sky_lines['sky_intensity']
location = stats.biweight.biweight_location(y)
scale = stats.biweight.biweight_scale(y)
ax.plot(line_wavelength, y, 'k.')
ax.axhline(location, c='r', ls='-')
ax.fill_between(line_wavelength, location-scale, location+scale, color='r', alpha=.1, label=f'${location:.2f}\pm{scale:.2f}$')

ax.set_xlabel('$\lambda_{line}\ [\AA]$')

ax = axes[1, 1]
ax.plot(rss_autocal.continuum.strong_sky_lines['sky_intensity'], y, 'k.')
xx = np.nanpercentile(rss_autocal.continuum.strong_sky_lines['sky_intensity'], np.linspace(0, 100, 101))
ax.axhline(location, c='r', ls='-')
ax.fill_between(xx, location-scale, location+scale, color='r', alpha=.1, label=f'${location:.2f}\pm{scale:.2f}$')
ax.legend()

ax.set_xlabel('Line intensity')
ax.set_xscale('log')


In [None]:
fig, axes = new_figure('sky_self_calibration', nrows=2)
spec_id = np.arange(rss_autocal.wavelength_offset.size)


ax = axes[0, 0]
ax.set_ylabel('relative $\Delta\lambda\ [\AA]$')

ax.plot(spec_id, rss_autocal.wavelength_offset.to_value(u.Angstrom), 'k--', alpha=.5)
ax.fill_between(spec_id,
                (rss_autocal.wavelength_offset-rss_autocal.wavelength_offset_err).to_value(u.Angstrom),
                (rss_autocal.wavelength_offset+rss_autocal.wavelength_offset_err).to_value(u.Angstrom),
                color='k', alpha=.1)
ax.plot(spec_id, -fibre_offset, 'b-', alpha=.7)


ax = axes[1, 0]
ax.set_ylabel('relative throughput')

#ax.plot(spec_id, rss_autocal.relative_throughput, 'k--', alpha=.5)
ax.plot(spec_id, rss_autocal.relative_throughput / np.nanmedian(rss_autocal.relative_throughput), 'k--', alpha=.5)
ax.fill_between(spec_id, rss_autocal.relative_throughput-rss_autocal.relative_throughput_err, rss_autocal.relative_throughput+rss_autocal.relative_throughput_err, color='k', alpha=.1)
#ax.plot(.97*np.nanmedian(wavelet_intensity / reference[np.newaxis, :], axis=1), 'r--', alpha=.5)
#ax.plot(np.nanstd(wavelet_intensity, axis=1) / reference_std, 'g--', alpha=.5)
ax.plot(spec_id, fibre_throughput / np.nanmedian(fibre_throughput), 'b-', alpha=.7)

ax.set_xlabel('spec_id')
