# Corrections based on sky emission lines

Compute relative offsets in wavelength (in pixels) and flux (arbitrary units) based on the sky emission lines, detected through a wavelet filter

# 1. Initialisation

## Imports

In [None]:
%matplotlib ipympl
from matplotlib import pyplot as plt
import numpy as np

In [None]:
from pykoala import __version__
from pykoala.instruments import koala_ifu, weave
from pykoala.corrections import sky
from pykoala.corrections.throughput import ThroughputCorrection
from pykoala.corrections.wavelength import WavelengthCorrection
from pykoala.corrections.sky import SkySubsCorrection
print("pyKOALA version: ", __version__)

The following will probably disappear in the final version of the tutorial

In [None]:
from astropy import stats
from astropy import units as u
from astropy.stats import biweight_location, biweight_scale
from pykoala.plotting.utils import new_figure, plot_image, plot_fibres
from matplotlib.colors import LogNorm, Normalize
import scipy
from pykoala.ancillary import symmetric_background

In [None]:
%load_ext autoreload
%autoreload 2

## Load the science data
This must be a Row-Stacked Spectra (RSS) file. Please choose one of the following examples

In [None]:
example = 'KOALA'
#example = 'WEAVE'

KOALA:

In [None]:
if example == 'KOALA':
    filename = f"../data/27feb20036red.fits"
    rss0 = koala_ifu.koala_rss(filename)
    #wavelength_AA = rss0.wavelength.to_value(u.Angstrom)

WEAVE:

In [None]:
if example == 'WEAVE':
    # Old versions of the pipeline
    #filename = "../data/weave/v3/NGC5322_OB12063/L1/single_3045973.fit"
    #filename = "../data/weave/v3/NGC4290_OB11113/L1/single_3039517.fit"
    #filename = "../data/weave/v3/WA_J024019.19+321544.10/single_3042890.fit"
    #filename = "../data/weave/v0.9/OB11162/single_3063947.fit"
    #filename = "../data/weave/v0.9/OB12709/single_3058745.fit"
    # Twilight (probably old pipeline)
    #filename = "../data/weave/solar/msp_3059302.fit" # WARNING: Doesn't conform to data model (requires tweaks)
    # Pablo's data (probably old pipeline)
    #filename = "../data/weave/ws2023b2-012/L1/single_3041989.fit"
    #filename = "../data/weave/ws2023b2-012/L1/single_3041991.fit"
    #filename = "../data/weave/ws2023b2-012/L1/single_3041993.fit"

    # Latest version of the pipeline
    filename = "../data/weave/v0.91/OB11162/single_3063947.fit"
    filename = "../data/weave/v0.91/OB12709/single_3058745.fit"
    
    rss0 = weave.weave_rss(filename)
    #wavelength_AA = rss0.wavelength.to_value(u.Angstrom)

### Data summary

In [None]:
print(f"Analysing object {rss0.info['name']} read from {filename}")
print('- info:')
print(rss0.info.keys())
print('- log:')
rss0.history.show()


# 2. Sky spectrum

In [None]:
help(sky.SkyFromObject.__init__)

In [None]:
sky_model = sky.SkyFromObject(rss0, qc_plots={'show': False})

In [None]:
sky_model.qc_plots['sky_fibres']

In [None]:
sky_model.qc_plots['sky_model']

The default behaviour tries to read the list of sky fibres from the `info` attribute of the `DataContainer`, and estimates them in case they are not present.
Alternatively, one may resort to other types of `BackgroundEstimator`:

In [None]:
help(sky.BackgroundEstimator)

In [None]:
sky_sigma_clip = sky.SkyFromObject(rss0, bckgr_estimator='mad', sky_fibres='all', source_mask_nsigma=3, qc_plots={'show': False})

In [None]:
sky_sigma_clip.qc_plots.get('sky_fibres', None)

In [None]:
sky_sigma_clip.qc_plots.get('sky_model', None)

In [None]:
sky_mode = sky.SkyFromObject(rss0, bckgr_estimator='mode', sky_fibres='all', qc_plots={'show': False})

In [None]:
sky_mode.qc_plots.get('sky_fibres', None)

In [None]:
sky_mode.qc_plots.get('sky_model', None)

## -- Single wavelength test --

In [None]:
#wl = 6582
#wl = 6620
#wl = 6700
#wl = 7718
wl = 8344.5
idx = np.searchsorted(rss0.wavelength, wl << u.Angstrom) - 1

intensity = rss0.rss_intensity[:, idx].value
total_flux = np.nanmean(rss0.rss_intensity, axis=1)
intensity_norm = intensity / total_flux

sorted_by_flux = np.argsort(total_flux)
flux_mean = np.nancumsum(total_flux[sorted_by_flux]) / np.arange(1, total_flux.size+1)
half_sample = total_flux.size // 2

In [None]:
fig, axes = new_figure('single_wavelength_sky', nrows=2)

ax = axes[0, 0]
ax.set_ylabel(f'intensity ($\\lambda={rss0.wavelength[idx]:.2f}\\ \\AA$)')
vmin = np.nanmin(intensity[sky_model.sky_fibres])
vmax = np.nanmax(intensity[sky_model.sky_fibres])
h = .1 * (vmax - vmin)
ax.set_ylim(vmin-h, vmax+2*h)

ax.plot(total_flux, intensity, 'k.', alpha=.1)
ax.plot(total_flux[sky_model.sky_fibres], intensity[sky_model.sky_fibres], 'b+', alpha=.5)
ax.axhline(sky_model.intensity[idx], c='k', ls='--')
std = np.sqrt(sky_model.variance[idx])
ax.axhline(sky_model.intensity[idx] + std, c='k', ls=':')
ax.axhline(sky_model.intensity[idx] - std, c='k', ls=':')


ax = axes[1, 0]
ax.set_ylabel(f'normalised intensity')
vmin = np.nanmin(intensity_norm[sky_model.sky_fibres])
vmax = np.nanmax(intensity_norm[sky_model.sky_fibres])
h = .1 * (vmax - vmin)
ax.set_ylim(vmin-h, vmax+2*h)

ax.plot(total_flux, intensity_norm, 'k.', alpha=.1)
ax.plot(total_flux[sky_model.sky_fibres], intensity_norm[sky_model.sky_fibres], 'b+', alpha=.5)
sky_flux = np.nanmean(sky_model.intensity.value)
ax.axhline(sky_model.intensity.value[idx]/sky_flux, c='k', ls='--')
ax.axhline((sky_model.intensity[idx] + std).value/sky_flux, c='k', ls=':')
ax.axhline((sky_model.intensity[idx] - std).value/sky_flux, c='k', ls=':')

ax.set_xlabel('total flux (mean fibre intensity)')
vmin, vmax = (flux_mean[0], flux_mean[-1])
h = .1 * (vmax - vmin)
ax.set_xlim(vmin-h, vmax+2*h)
#h = th_flux-bg_flux
#ax.set_xlim(bg_flux-2*h, th_flux+6*h)
#ax.set_xscale('log')

# -- STOP --

In [None]:
raise -1

# 4. Wavelet-based corrections

## Wavelet

In [None]:
wavelet1 = sky.WaveletFilter(rss0)

In [None]:
# 1. Estimate the FWHM of emission lines from the autocorrelation of the median (~ sky) spectrum.

x = np.where(np.isfinite(sky_mode.intensity), sky_mode.intensity, 0)
x -= np.nanmean(x)

In [None]:
x = scipy.signal.correlate(x, x, mode='same')
h = (np.count_nonzero(x > 0.5*np.nanmax(x)) + 1) // 2
# h = 0
scale = 2*h + 1
print(f'> Wavelet filter scale: {scale} pixels')

In [None]:
# 2. Apply a (mexican top hat) wavelet filter to detect features on that scale (i.e. filter out the continuum).

x = np.nancumsum(rss0.intensity, axis=1)
rss_filtered = (x[:, 2*scale:-scale] - x[:, scale:-2*scale]) / scale
rss_filtered -= (x[:, 3*scale:] - x[:, :-3*scale]) / (3*scale)

In [None]:
x = np.nancumsum(sky_mode.intensity)
sky_filtered = (x[2*scale:-scale] - x[scale:-2*scale]) / scale
sky_filtered -= (x[3*scale:] - x[:-3*scale]) / (3*scale)
sky_filtered_p25, sky_filtered_p75 = np.nanpercentile(sky_filtered, [25, 75])
sky_intense = np.where((sky_filtered < sky_filtered_p25) | (sky_filtered > sky_filtered_p75))[0]

In [None]:
wavelength_filtered = rss0.wavelength[scale+h+1 : -scale-h]
sky_error_filtered = np.sqrt(sky_mode.variance)[scale+h+1 : -scale-h]
filtered_idx = np.arange(wavelength_filtered.size)
sky_filtered_weight = (sky_filtered / sky_error_filtered)**2
sky_filtered_weight /= np.nanmean(sky_filtered_weight)

## Offset

In [None]:
def plot_fibre_wavelet_diff(ax, fibre):
    sky_sub = rss_filtered[fibre] - sky_filtered
    sky_sub_SNR = sky_sub / sky_error_filtered
    #ax.plot(wavelength_filtered, sky_filtered, 'b-', alpha=.2, label=f'sky')
    #ax.plot(wavelength_filtered, sky_sub, 'b:', alpha=.2, label=f'sky subtracted')
    #ax.fill_between(wavelength_filtered, sky_sub-sky_error_filtered, sky_sub+sky_error_filtered, color='b', alpha=.2)
    ax.plot(wavelength_filtered, sky_sub_SNR, 'k-', alpha=.5,
            label=f'SNR = {np.sqrt(np.nanmean(sky_sub_SNR**2)):.3f} {np.sqrt(np.nanmean(sky_filtered_weight*sky_sub_SNR**2)):.3f}')
    ax.legend()

fig, axes = new_figure('individual_wavelets', nrows=3, ncols=1, sharex=True, sharey=True)
#fig.supylabel('correlation with sky')
fig.supxlabel('wavelength')

plot_fibre_wavelet_diff(axes[0, 0], 42)
#plot_fibre_wavelet_diff(axes[0, 0], np.nanargmin(fibre_offset))
plot_fibre_wavelet_diff(axes[1, 0], np.nanargmax(np.nanmean(rss0.intensity, axis=1)))
#plot_fibre_wavelet_diff(axes[2, 0], np.nanargmax(fibre_offset))
plot_fibre_wavelet_diff(axes[2, 0], 69)

axes[0, 0].set_ylim(-2.50, 2.5)

In [None]:
def sky_sub_residual(fibre, offset):
    sky_sub_SNR = np.interp(filtered_idx-offset, filtered_idx, rss_filtered[fibre])
    sky_sub_SNR -= sky_filtered
    sky_sub_SNR /= sky_error_filtered
    return np.sqrt(np.nanmean(sky_sub_SNR**2)), np.sqrt(np.nanmean(sky_filtered_weight*sky_sub_SNR**2))

for off in np.linspace(-1.5, 1.5, 31):
    print(f'{off:.2f} {sky_sub_residual(42, off)}')

## Old stuff

In [None]:
# 5. Estimate wavelength offset of each fibre (in pixels) from cross-correlation with the sky.

mid = sky_filtered.size // 2
x = scipy.signal.fftconvolve(
    rss_filtered, sky_filtered[np.newaxis, ::-1], mode='same', axes=1)[:, mid-scale:mid+scale+1]
idx = np.arange(x.shape[1])
weight = np.where(x > 0, x, 0)
fibre_offset = np.nansum(
    (idx - scale)[np.newaxis, :] * weight, axis=1) / np.nansum(weight, axis=1)

In [None]:
def plot_fibre_sky_corr(ax, fibre):
    ax.plot(x[fibre], 'b-', alpha=.2)
    ax.axvline(scale, c='b', ls=':', alpha=.2)
    ax.axvline(scale+fibre_offset[fibre], c='k', ls='-', label=f'fibre {fibre} ({fibre_offset[fibre]:.3f})')
    ax.legend()


fig, axes = new_figure('individual_wavelet_correlation', nrows=3)
fig.supylabel('correlation with sky')
fig.supxlabel(f'scale ({scale} pix) + offset')

plot_fibre_sky_corr(axes[0, 0], np.nanargmin(fibre_offset))
plot_fibre_sky_corr(axes[1, 0], np.nanargmax(np.nanmean(rss0.intensity, axis=1)))
plot_fibre_sky_corr(axes[2, 0], np.nanargmax(fibre_offset))

In [None]:

def rescale(base, target):
    base_norm = base / np.sqrt(np.nanmean(base**2))
    return np.nanmean(target*base_norm) * base_norm

def plot_fibre_wavelet(ax, fibre):
    ax.plot(wavelength_filtered, sky_filtered, 'b--', alpha=.2, label=f'sky')
    #ax.fill_between(idx, sky_filtered-sky_error_filtered, sky_filtered+sky_error_filtered, color='b', alpha=.2)
    #sky_rescaled = rescale(sky_filtered, rss_filtered[fibre])
    #ax.plot(sky_rescaled, 'r--', alpha=.2, label=f'rescaled sky')
    ax.plot(wavelength_filtered, rss_filtered[fibre], 'k-', alpha=.5, label=f'fibre {fibre} ({fibre_offset[fibre]:.3f})')
    #corrected = np.interp(idx+fibre_offset[fibre], idx, rss_filtered[fibre])
    #ax.plot(corrected, 'k-', alpha=.5, label=f'corrected')
    sky_sub = rss_filtered[fibre] - sky_filtered
    #ax.plot(wavelength_filtered, sky_sub, 'b:', alpha=.2, label=f'subtracted')
    #ax.fill_between(wavelength_filtered, sky_sub-sky_error_filtered, sky_sub+sky_error_filtered, color='b', alpha=.2)
    #ax.plot(rss_filtered[fibre] - sky_rescaled, 'r-', alpha=.2, label=f'rescaled sky')
    ax.legend()

def diff(x, y):
    y_norm = y / np.sqrt(np.nanmean(y**2))
    coeff = np.nanmean(x*y_norm)
    return x - coeff*y_norm

def plot_fibre_wavelet_diff(ax, fibre):
    ax.plot(diff(rss_filtered[fibre], sky_filtered), 'b--', alpha=.2, label=f'orginal')
    corrected = np.interp(idx+fibre_offset[fibre], idx, rss_filtered[fibre])
    ax.plot(diff(corrected, sky_filtered), 'r-', alpha=.5, label=f'corrected')
    ax.legend()

fig, axes = new_figure('individual_wavelets', nrows=3, ncols=1, sharex=True, sharey=True)
fig.supylabel('correlation with sky')
fig.supxlabel('wavelength')

plot_fibre_wavelet(axes[0, 0], np.nanargmin(fibre_offset))
plot_fibre_wavelet(axes[1, 0], np.nanargmax(np.nanmean(rss0.intensity, axis=1)))
plot_fibre_wavelet(axes[2, 0], np.nanargmax(fibre_offset))

#plot_fibre_wavelet_diff(axes[0, 1], np.nanargmin(fibre_offset))
#plot_fibre_wavelet_diff(axes[1, 1], np.nanargmax(np.nanmean(rss0.intensity, axis=1)))
#plot_fibre_wavelet_diff(axes[2, 1], np.nanargmax(fibre_offset))

axes[0, 0].set_ylim(-250, 250)

In [None]:
fig, axes = new_figure('offset comparison', nrows=1)
ax = axes[0, 0]
ax.plot(fibre_offset, wavelet1.fibre_offset, 'b.', alpha=.1)
ax.plot([-0.6, 0.8], [-0.6, 0.8])

# 4. Identify sky fibres

## Wavelet

In [None]:
wavelet1 = sky.WaveletFilter(rss0)

In [None]:
# 1. Estimate the FWHM of emission lines from the autocorrelation of the median (~ sky) spectrum.

x = np.where(np.isfinite(sky_mode.intensity), sky_mode.intensity, 0)
x -= np.nanmean(x)

In [None]:
x = scipy.signal.correlate(x, x, mode='same')
h = (np.count_nonzero(x > 0.5*np.nanmax(x)) + 1) // 2
# h = 0
scale = 2*h + 1
print(f'> Wavelet filter scale: {scale} pixels')

In [None]:
# 2. Apply a (mexican top hat) wavelet filter to detect features on that scale (i.e. filter out the continuum).

x = np.nancumsum(rss0.intensity, axis=1)
rss_filtered = (x[:, 2*scale:-scale] - x[:, scale:-2*scale]) / scale
rss_filtered -= (x[:, 3*scale:] - x[:, :-3*scale]) / (3*scale)

In [None]:
x = np.nancumsum(sky_mode.intensity)
sky_filtered = (x[2*scale:-scale] - x[scale:-2*scale]) / scale
sky_filtered -= (x[3*scale:] - x[:-3*scale]) / (3*scale)
sky_filtered_p25, sky_filtered_p75 = np.nanpercentile(sky_filtered, [25, 75])
sky_intense = np.where((sky_filtered < sky_filtered_p25) | (sky_filtered > sky_filtered_p75))[0]

In [None]:
# 5. Estimate wavelength offset of each fibre (in pixels) from cross-correlation with the sky.

mid = sky_filtered.size // 2
x = scipy.signal.fftconvolve(
    rss_filtered, sky_filtered[np.newaxis, ::-1], mode='same', axes=1)[:, mid-scale:mid+scale+1]
idx = np.arange(x.shape[1])
weight = np.where(x > 0, x, 0)
fibre_offset = np.nansum(
    (idx - scale)[np.newaxis, :] * weight, axis=1) / np.nansum(weight, axis=1)

In [None]:
def plot_fibre_sky_corr(ax, fibre):
    ax.plot(x[fibre], 'b-', alpha=.1)
    ax.axvline(scale, c='b', ls=':', alpha=.1)
    ax.axvline(scale+fibre_offset[fibre], c='k', ls='-', label=f'fibre {fibre} ({fibre_offset[fibre]:.3f})')
    ax.legend()


fig, axes = new_figure('kk', nrows=3)
fig.supylabel('intensity')
fig.supxlabel('wavelength')

plot_fibre_sky_corr(axes[0, 0], np.nanargmin(fibre_offset))
plot_fibre_sky_corr(axes[1, 0], np.nanargmax(np.nanmean(rss0.intensity, axis=1)))
plot_fibre_sky_corr(axes[2, 0], np.nanargmax(fibre_offset))



ax = axes[0, 0]
ax.legend()

In [None]:
fig, axes = new_figure('offset comparison', nrows=1)
ax = axes[0, 0]
ax.plot(fibre_offset, wavelet1.fibre_offset, 'b.', alpha=.1)
ax.plot([-0.6, 0.8], [-0.6, 0.8])

In [None]:
fibre_throughput = np.nanmedian(rss_filtered[:, sky_intense] / sky_filtered[np.newaxis, sky_intense], axis=1)
#fibre_throughput /= np.nanmedian(fibre_throughput)
median_throughput = np.nanmedian(fibre_throughput)
fibre_throughput /= median_throughput
sky_filtered *= median_throughput

In [None]:
fig, axes = new_figure('kk')
wl = rss0.wavelength[scale + h + 1: - scale - h]
axes[0, 0].plot(wl, sky_filtered)
axes[0, 0].plot(wl[sky_intense], sky_filtered[sky_intense], 'k+')

In [None]:
#fibre = 42
fibre = np.nanargmax(np.nanmean(rss0.intensity, axis=1))

In [None]:
def plot_fibre_vs_sky(ax, fibre):
    ax.plot(sky_filtered, rss_filtered[fibre] - fibre_throughput[fibre]*sky_filtered, 'b.', alpha=.1,
            label=f'fibre {fibre} ({fibre_throughput[fibre]:.4f})')
    ax.axhline(0, c='k', ls=':')
    ax.legend()


fig, axes = new_figure('fibre vs sky', nrows=3)
fig.supylabel('filetered intensity')
fig.supxlabel('filetered sky')

plot_fibre_vs_sky(axes[0, 0], np.nanargmin(fibre_throughput))
plot_fibre_vs_sky(axes[1, 0], np.nanargmax(np.nanmean(rss0.intensity, axis=1)))
plot_fibre_vs_sky(axes[2, 0], np.nanargmax(fibre_throughput))

In [None]:
def plot_fibre_sky_sub(ax, fibre):
    ax.plot(rss0.wavelength, rss0.intensity[fibre], 'b--', alpha=.1, label=f'fibre {fibre} ({fibre_throughput[fibre]:.4f})')
    ax.plot(rss0.wavelength, rss0.intensity[fibre]*fibre_throughput[fibre] - sky_mode.intensity*median_throughput, 'k-', alpha=.5)
    if example ==  "WEAVE":
        ax.plot(rss0.wavelength, rss0.intensity[fibre] - rss0.info['sky_CASU'], 'r-', alpha=.5)
    ax.legend()


fig, axes = new_figure('fibre sky subtraction', nrows=3)
fig.supylabel('intensity')
fig.supxlabel('wavelength')

plot_fibre_sky_sub(axes[0, 0], np.nanargmin(fibre_throughput))
plot_fibre_sky_sub(axes[1, 0], np.nanargmax(np.nanmean(rss0.intensity, axis=1)))
plot_fibre_sky_sub(axes[2, 0], np.nanargmax(fibre_throughput))

In [None]:
fig, axes = new_figure('throughput comparison', nrows=1)
fig.suptitle(f'fibre {fibre}')

ax = axes[0, 0]
ax.plot(fibre_throughput, wavelet1.fibre_throughput, 'b.', alpha=.1)
ax.plot([0.6, 1.4], [0.6, 1.4])

# -- STOP --

In [None]:
raise -1

# 3. Wavelet filter

## First iteration

In [None]:
wavelet1 = sky.WaveletFilter(rss0)

In [None]:
throughput_corr = ThroughputCorrection(throughput=wavelet1.get_throughput_object())

In [None]:
wave_corr = WavelengthCorrection(offset=wavelet1.get_wavelength_offset())

In [None]:
rss1 = wave_corr.apply(rss0)
rss1 = throughput_corr.apply(rss1)

## Second iteration

In [None]:
#wavelet2 = sky.WaveletFilter(rss1)

In [None]:
#throughput_corr = ThroughputCorrection(throughput=wavelet2.get_throughput_object())

In [None]:
#wave_corr = WavelengthCorrection(offset=wavelet2.get_wavelength_offset())

In [None]:
#rss2 = wave_corr.apply(rss1)
#rss2 = throughput_corr.apply(rss2)

In [None]:
#rss2.history.show()

# 5. Create datacube

In [None]:
sky_corr = SkySubsCorrection(sky_mode)
rss1 = sky_corr.apply(rss1)[0]

In [None]:
from pykoala.cubing import CubeInterpolator
interpolator = CubeInterpolator(rss_set=[rss1], spatial_pix_size=2*u.arcsec, spectra_pix_size=1*u.Angstrom)

In [None]:
cube = interpolator.build_cube()

In [None]:
cube.to_fits('cube.fits', overwrite=True)

# 6. Quality control plots

## Wavlelet filter

In [None]:
wavelet1.qc_plots()

## Relative calibration

In [None]:
fig, axes = new_figure('sky-based relative calibration', nrows=2)

ax = axes[0, 0]
ax.set_ylabel('relative throughput')
#ax.set_ylim(.4, 2.1)

ax.plot(wavelet1.fibre_throughput, 'k-', label='iter 1')
ax.plot(wavelet2.fibre_throughput, 'k-', alpha=.25, label='iter 2')

#p16, p50, p84 = np.nanpercentile(wavelet1.filtered / wavelet1.sky[np.newaxis, :], [16, 50, 84], axis=1)
#ax.plot(p50, 'r--', alpha=.5)
#ax.fill_between(np.arange(p50.size), p16, p84, color='r', alpha=0.1)
#p16, p50, p84 = np.nanpercentile(wavelet2.filtered / wavelet2.sky[np.newaxis, :], [16, 50, 84], axis=1)
#ax.plot(p50, 'y--', alpha=.5)

ax.legend()


ax = axes[1, 0]
ax.set_ylabel('relative offset [pix]')

ax.plot(wavelet1.fibre_offset, 'k-', label='iter 1')
ax.plot(wavelet2.fibre_offset, 'k-', alpha=.25, label='iter 2')

ax.legend()
ax.set_xlabel('fibre')
,

In [None]:
fig, axes = new_figure('relative calibration maps',
                       nrows=2, ncols=2, sharey=False, figsize=(8, 10),
                       gridspec_kw={'left': 0.07, 'right':0.9, 'width_ratios': [1, .05], 'hspace': 0.05, 'wspace': 0.25})

plot_fibres(fig, axes[0, 0], rss0, data=wavelet1.fibre_throughput, cbax=axes[0, 1], cmap='Spectral', cblabel='relative throughput')
plot_fibres(fig, axes[1, 0], rss0, data=wavelet1.fibre_offset, cbax=axes[1, 1], cmap='Spectral', cblabel='relative offset')

axes[0, 0].sharey(axes[1, 0])

## Single fibre tests

In [None]:
fibre = 42
#fibre = np.nanargmax(np.nanmean(rss1.intensity, axis=1))

In [None]:
fig, axes = new_figure('single throughput', nrows=3, sharex=False, sharey=False, gridspec_kw={'hspace': .2})

ax = axes[0, 0]
ax.hist(wavelet1.filtered[fibre]*wavelet1.fibre_throughput[fibre] / wavelet1.sky, bins=np.linspace(0, 7.5, 101), color='k', alpha=.2)
ax.hist(wavelet1.filtered[fibre] / wavelet1.sky, bins=np.linspace(0, 7.5, 101), color='b', alpha=.5)

ax = axes[1, 0]
ax.set_ylabel('wavelet')
ax.plot(wavelet1.wavelength, wavelet1.sky, 'b:', alpha=.5, label='sky')
ax.plot(wavelet1.wavelength, wavelet1.filtered[fibre]*wavelet1.fibre_throughput[fibre], 'k--', alpha=.2, label='data')
ax.plot(wavelet1.wavelength, wavelet1.filtered[fibre], 'k-', alpha=.5, label='throughput-corrected')
ax.plot(wavelet2.wavelength, wavelet2.filtered[fibre], 'r-', alpha=.5, label='iter 2')
ax.legend()


ax = axes[2, 0]
ax.set_ylabel('intensity')
ax.sharex(axes[1, 0])

x = rss0.wavelength
if isinstance(x, u.Quantity):
    x = x.to_value(u.AA)

ax.plot(x, rss0.intensity[fibre], 'k--', alpha=.2, label=f'fibre {fibre}')
ax.plot(x, rss1.intensity[fibre], 'k-', alpha=.5, label='corrected')
ax.plot(x, rss1.intensity[fibre] - sky_fit.intensity, 'c-', alpha=.5, label='sky-subtracted')
ax.plot(x, rss2.intensity[fibre] - sky2_fit.intensity, 'c--', alpha=.5, label='iter 2')
ax.plot(x, rss0.intensity[fibre] - sky_mode.intensity, 'k--', alpha=.5, label='mode')

ax.legend()
ax.set_xlabel(r'wavelength [$\AA$]')


In [None]:
fig, axes = new_figure(f'single wavelength calibration')

ax = axes[0, 0]

mid = wavelet1.wavelength.size // 2
s = wavelet1.scale
x = np.nanmedian(wavelet1.filtered, axis=0)
x[~ np.isfinite(x)] = 0
x = scipy.signal.fftconvolve(wavelet1.filtered[fibre], x[::-1], mode='same')[mid-s:mid+s+1]
idx = np.arange(x.size)

ax.plot(idx - s, x/np.max(x), 'k-', label=f'fibre {fibre}')
ax.axvline(wavelet1.fibre_offset[fibre], c='k', ls=':', label=f'offset = {wavelet1.fibre_offset[fibre]:.2f} pix')
ax.plot(2*wavelet1.fibre_offset[fibre] - (idx - s), x/np.max(x), 'k--', alpha=.25, label='reflected cross-correlation')

ax.legend()
ax.set_ylabel('cross-correlation with sky')
ax.set_xlabel('offset [pix]')
,

## Corrected intensities

In [None]:
fig, axes = new_figure('skyline-based correction',
                       nrows=2, ncols=3, sharex=False, sharey=False,
                       gridspec_kw={'left': 0.07, 'right':0.9, 'width_ratios': [1, 1, .05], 'hspace': 0.25, 'wspace': 0.25})

x = rss0.wavelength
if isinstance(x, u.Quantity):
    x = x.to_value(u.AA)

im, cb = plot_image(fig, axes[0, 0], 'intensity', rss0.intensity, x=x, ylabel='fibre', cbax=axes[0, 2])
im, cb = plot_image(fig, axes[0, 1], '.', rss1.intensity, x=x, norm=im.norm, cbax=False)

im, cb = plot_image(fig, axes[1, 0], 'subtracted', rss0.intensity - sky0_fit.intensity, x=x, cbax=axes[1, 2])
#im, cb = plot_image(fig, axes[1, 1], '.', rss1.intensity - sky_fit.intensity, x=x, norm=im.norm, cbax=False)
im, cb = plot_image(fig, axes[1, 1], '.', rss1.intensity - sky_mode.intensity, x=x, norm=im.norm, cbax=False)

for ax in [axes[0, 1], axes[1, 0], axes[1, 1]]:
    ax.sharex(axes[0, 0])
    ax.sharey(axes[0, 0])


## Sky subtraction

In [None]:
flux = np.nanmean(rss1.intensity, axis=1)
mean_flux = np.nanmean(flux)
flux_cut_low = np.nanmedian(flux[flux < mean_flux])
flux_cut_hi = 2*mean_flux - flux_cut_low
flux_low = np.nanmean(flux[flux < flux_cut_low])
flux_med = np.nanmean(flux[(flux > flux_cut_low) & (flux < mean_flux)])
flux_hi = np.nanmean(flux[(flux > mean_flux) & (flux < flux_cut_hi)])

I_low = np.nanmean(rss1.intensity[flux < flux_cut_low, :], axis=0)
I_med = np.nanmean(rss1.intensity[(flux > flux_cut_low) & (flux < mean_flux), :], axis=0)
I_hi = np.nanmean(rss1.intensity[(flux > mean_flux) & (flux < flux_cut_hi), :], axis=0)
m = (I_hi - I_low) / (flux_hi - flux_low)
b = I_low - m * flux_low

sky_flux_candidate = np.arange(0, flux_cut_hi.value, .01*np.min(flux.value)) << flux.unit
sky_filtered = m[np.newaxis, :] * sky_flux_candidate[:, np.newaxis] + b[np.newaxis, :]
x = np.nancumsum(sky_filtered, axis=1)
s = wavelet1.scale
sky_filtered = (x[:, 2*s:-s] - x[:, s:-2*s]) / s
sky_filtered -= (x[:, 3*s:] - x[:, :-3*s]) / (3*s)
chi2_sky = np.nanstd(sky_filtered*wavelet1.sky_weight - wavelet1.sky, axis=1)
chi2_no_sky = np.nanstd(sky_filtered*(1 - wavelet1.sky_weight), axis=1)

sky_flux = sky_flux_candidate[np.nanargmin(chi2_no_sky)]
#sky_flux = 280 * u.adu
#sky_flux = np.mean(np.nanmedian(rss1.intensity, axis=0))
#sky_flux = background * u.adu
sky_intensity = b + m*sky_flux

In [None]:
background, threshold = symmetric_background(flux.value, fig_name='mean intensity')

In [None]:
fig, axes = new_figure('intensity hist')

ax = axes[0, 0]

x = np.linspace(0, 3*sky_flux, 101)

ax.hist(flux, bins=x, alpha=.5)
ax.axvline(background, c='k', ls='--', label=f'Mode = {background:.2f}')
ax.axvline(threshold, c='k', ls=':', label=f'th = {threshold:.2f}')

ax.legend()
ax.set_ylabel('number of fibres')
ax.set_xlabel(f'mean fibre intensity')
#ax.set_xscale('log')
ax.set_yscale('log')

In [None]:
fig, axes = new_figure('get sky flux', nrows=1)

ax = axes[0, 0]
ax.plot(sky_flux_candidate, chi2_sky, 'b-')
ax.plot(sky_flux_candidate, chi2_no_sky, 'r-')
ax.axvline(sky_flux, c='k', ls=':', label=f'sky flux = {sky_flux:.2f}')
ax.legend()
,

In [None]:
fig, axes = new_figure('fit coefficients', nrows=4)

ax = axes[0, 0]
ax.set_ylabel('m')
ax.plot(rss0.wavelength, m, 'k-')
#ax.plot(wavelet1.wavelength, wavelet1.sky_weight, 'b-')
#ax.plot(rss1.wavelength, sky_intensity / np.nanmean(sky_intensity), 'c-')


ax = axes[1, 0]
ax.set_ylabel('b')
ax.plot(rss0.wavelength, sky_fit.intensity, 'r-')
ax.plot(rss0.wavelength, sky_intensity, 'c-')
ax.plot(rss0.wavelength, b, 'k-')
ax.plot(rss0.wavelength, sky_mad.intensity, 'r:')
ax.plot(rss0.wavelength, sky_mode.intensity, 'k--')
#ax.set_ylim(-10, 500)


ax = axes[2, 0]
s = wavelet1.scale
ax.set_ylabel(f'wavelet b ({s} pix)')

x = np.nancumsum(sky_intensity)
sky_filtered = (x[2*s:-s] - x[s:-2*s]) / s
sky_filtered -= (x[3*s:] - x[:-3*s]) / (3*s)
ax.plot(wavelet1.wavelength, sky_filtered, 'r-', alpha=.5)
ax.plot(wavelet1.wavelength, sky_filtered * wavelet1.sky_weight, 'k-', alpha=.5)
ax.plot(wavelet1.wavelength, wavelet1.sky, 'c-', alpha=.5)

ax = axes[3, 0]
ax.plot(wavelet1.wavelength, wavelet1.sky_weight)

## Single wavelength

In [None]:
#wl = 6581
wl = 6620
#wl = 6700
idx = np.searchsorted(rss0.wavelength, wl)

In [None]:
bg, th = symmetric_background(rss1.intensity[:, idx].value, fig_name='monochrome mode')

In [None]:
fig, axes = new_figure('monochrome hist')

ax = axes[0, 0]

x = np.linspace(b[idx], 3*sky_intensity[idx] - 2*b[idx], 101)

ax.hist(rss1.intensity[:, idx], bins=x, alpha=.5)
ax.axvline(sky_intensity[idx], c='k', ls='--', label=f'Sky level = {sky_intensity[idx]:.2f}')
ax.axvline(bg, c='k', ls=':', label=f'mode = {bg:.2f}')

ax.legend()
ax.set_ylabel('number of fibres')
ax.set_xlabel(f'intensity at $\\lambda={rss0.wavelength[idx]:.2f}$ \\AA')
#ax.set_xscale('log')
ax.set_yscale('log')

In [None]:
fig, axes = new_figure('linear fit')

ax = axes[0, 0]

ax.plot(flux, rss1.intensity[:, idx], 'k.', alpha=.1)
ax.axvline(mean_flux, c='r', ls='--', label=f'mean flux: {mean_flux:.2f}')
ax.axvline(flux_cut_low, c='r', ls=':', label=f'flux low: {flux_cut_low:.2f}')
ax.axvline(flux_cut_hi, c='r', ls=':', label=f'flux high: {flux_cut_hi:.2f}')

ax.axhline(bg, c='k', ls=':', label=f'mode: {bg:.2f}')
ax.axhline(th, c='k', ls='--', label=f'th: {th:.2f}')
ax.axvline(background, c='k', ls=':', label=f'mode fl: {background:.2f}')
ax.axvline(threshold, c='k', ls='--', label=f'th fl: {threshold:.2f}')

x = np.nanpercentile(flux, np.linspace(0, 100, 101))
ax.plot(x, m[idx]*x + b[idx], 'b:', label=f'm={m[idx]:.2f} b={b[idx]:.2f}')
ax.plot(u.Quantity([flux_low, flux_med, flux_hi]), u.Quantity([I_low[idx], I_med[idx], I_hi[idx]]), 'ro-')

ax.plot(sky_flux, sky_intensity[idx], 'co', label=f'Sky level = {sky_intensity[idx]:.2f}')

ax.legend()
ax.set_xlabel('mean fibre intensity')
ax.set_ylabel(f'intensity at $\\lambda={rss0.wavelength[idx]:.2f}\\ \\AA$')
ax.set_xscale('log')
ax.set_yscale('log')

In [None]:
from pykoala.plotting import utils
from matplotlib import colors
fig, axes = new_figure('correction maps',
                       nrows=4, ncols=2, sharey=False, figsize=(8, 12),
                       gridspec_kw={'left': 0.07, 'right':0.9, 'width_ratios': [1, .05], 'hspace': 0.05, 'wspace': 0.25})

#(fig, ax, rss=None, x=None, y=None,
#                fibre_diam=None, data=None, 
#                patch_args={}, use_wcs=False, fix_limits=True,
#                cmap=DEFAULT_CMAP, norm=None, cbax=None, cblabel=None, 
#                norm_interval=visualization.MinMaxInterval, interval_args={},
#                stretch=visualization.LinearStretch, stretch_args={})
norm = LogNorm(vmin=1, vmax=np.nanpercentile(rss1.intensity[:, idx].value, 90))
utils.plot_fibres(fig, axes[0, 0], rss0, data=rss0.intensity[:, idx], cbax=axes[0, 1], cblabel=f'original $I_\\lambda({rss0.wavelength[idx]:.2f} \\AA)$', norm=norm)
utils.plot_fibres(fig, axes[1, 0], rss1, data=rss1.intensity[:, idx], cbax=axes[1, 1], cblabel='throughput corrected', norm=norm)
#utils.plot_fibres(fig, axes[2, 0], rss1, data=rss1.intensity[:, idx] - sky_fit.intensity[idx], cbax=axes[2, 1], cblabel='sky subtracted', norm=norm)
utils.plot_fibres(fig, axes[2, 0], rss1, data=rss1.intensity[:, idx] - sky_biweight[idx], cbax=axes[2, 1], cblabel='biweight sky', norm=norm)
utils.plot_fibres(fig, axes[3, 0], rss1, data=rss1.intensity[:, idx] - sky_mode.intensity[idx], cbax=axes[3, 1], cblabel='mode sky', norm=norm)

for ax in axes[1:, 0]:
    ax.sharey(axes[0, 0])
