# Corrections based on sky emission lines

Compute relative offsets in wavelength (in pixels) and flux (arbitrary units) based on the sky emission lines, detected through a wavelet filter

# 1. Initialisation

In [None]:
%matplotlib ipympl
from matplotlib import pyplot as plt
from matplotlib import colors

import numpy as np

from astropy import units as u
from astropy import constants as c

from scipy.signal import fftconvolve
from scipy.stats import linregress

In [None]:
from pykoala import __version__
from pykoala.instruments import koala_ifu, weave
from pykoala.corrections import sky
from pykoala.corrections.throughput import ThroughputCorrection
from pykoala.corrections.wavelength import WavelengthCorrection
from pykoala.plotting.utils import new_figure, plot_image, plot_fibres
print("pyKOALA version: ", __version__)

The following will probably disappear in the final version of the tutorial

In [None]:
%load_ext autoreload
%autoreload 2

# 2. Load the science data
This must be a Row-Stacked Spectra (RSS) file. Please choose one of the following examples

In [None]:
example = 'KOALA'
#example = 'WEAVE'

KOALA:

In [None]:
if example == 'KOALA':
    filename = f"../data/27feb20036red.fits"
    rss0 = koala_ifu.koala_rss(filename)
    wavelength_AA = rss0.wavelength.to_value(u.Angstrom)

WEAVE:

In [None]:
if example == 'WEAVE':
    #filename = f"../data/weave/v3/NGC5322_OB12063/L1/single_3045973.fit"
    filename = f"../data/weave/v3/NGC4290_OB11113/L1/single_3039517.fit"
    #filename = f"../data/weave/v3/WA_J024019.19+321544.10/single_3042890.fit"
    #filename = f"../data/weave/solar/msp_3059302.fit"
    rss0 = weave.weave_rss(filename)
    wavelength_AA = rss0.wavelength.to_value(u.Angstrom)
    from astropy.io import fits
    with fits.open(filename) as hdu:
        sky_L1 = hdu[3].data[0] - hdu[1].data[0]
        sensitivity_function = hdu[5].data[0]
        fibtable = hdu[6].data
        

## Summary

In [None]:
print(f"Analysing object {rss0.info['name']} read from {filename}")
print('- info:')
print(rss0.info.keys())
print('- log:')
rss0.history.show()


# 3. Wavelet filter

## First iteration

In [None]:
wavelet1 = sky.WaveletFilter(rss0)

In [None]:
throughput_corr = ThroughputCorrection(throughput=wavelet1.get_throughput_object())

In [None]:
wave_corr = WavelengthCorrection(offset=wavelet1.get_wavelength_offset())

In [None]:
rss1 = wave_corr.apply(rss0)
rss1 = throughput_corr.apply(rss1)

## Second iteration

In [None]:
wavelet2 = sky.WaveletFilter(rss1)

In [None]:
throughput_corr = ThroughputCorrection(throughput=wavelet2.get_throughput_object())

In [None]:
wave_corr = WavelengthCorrection(offset=wavelet2.get_wavelength_offset())

In [None]:
rss2 = wave_corr.apply(rss1)
rss2 = throughput_corr.apply(rss2)

In [None]:
rss2.history.show()

# 4. Sky subtraction

In [None]:
sky0_fit = sky.SkyFromObject(rss0, bckgr_estimator='fit', bckgr_params={'wavelet': wavelet1})

In [None]:
sky_fit = sky.SkyFromObject(rss1, bckgr_estimator='fit', bckgr_params={'wavelet': wavelet1})

In [None]:
sky2_fit = sky.SkyFromObject(rss2, bckgr_estimator='fit', bckgr_params={'wavelet': wavelet2})

In [None]:
sky_med = sky.SkyFromObject(rss1, bckgr_estimator='percentile')

In [None]:
sky_mad = sky.SkyFromObject(rss1, bckgr_estimator='mad', source_mask_nsigma=3)

In [None]:
fig, axes = new_figure('sky model')

ax = axes[0, 0]

ax.plot(wavelength_AA, sky_med.intensity, 'y-', alpha=.75, label='median')
std = np.sqrt(sky_med.variance)
ax.fill_between(wavelength_AA, sky_med.intensity-std, sky_med.intensity+std, color='y', alpha=.25, label='$16-84$ \\%')

ax.plot(wavelength_AA, sky_mad.intensity, 'k--', alpha=.5, label='MAD + 3-$\\sigma$ clip')

#ax.plot(wavelength_AA, sky0_fit.intensity, 'k:', alpha=.5, label='uncorrected fit')
ax.plot(wavelength_AA, sky_fit.intensity, 'k-', alpha=.75, label='linear fit')
#ax.plot(wavelength_AA, sky2_fit.intensity, 'r--', alpha=.5, label='iter 2')
if example == "WEAVE":
    ax.plot(wavelength_AA, sky_L1, 'b-', alpha=.75, label='L1')

ax.legend()
,

# 5. Quality control plots

## Wavlelet filter

In [None]:
wavelet1.qc_plots()

## Corrected intensities

In [None]:
fig, axes = new_figure('skyline-based correction',
                       nrows=2, ncols=4, sharex=False, sharey=False,
                       gridspec_kw={'left': 0.07, 'right':0.9, 'width_ratios': [1, 1, 1, .05], 'hspace': 0.25, 'wspace': 0.25})

x = rss0.wavelength
if isinstance(x, u.Quantity):
    x = x.to_value(u.AA)

im, cb = plot_image(fig, axes[0, 0], 'intensity', rss0.intensity, x=x, ylabel='fibre', cbax=axes[0, 3])
im, cb = plot_image(fig, axes[0, 1], '.', rss1.intensity, x=x, norm=im.norm, cbax=False)
im, cb = plot_image(fig, axes[0, 2], '.', rss2.intensity, x=x, norm=im.norm, cbax=False)

im, cb = plot_image(fig, axes[1, 0], 'subtracted', rss0.intensity - sky0_fit.intensity, x=x, cbax=axes[1, 3])
im, cb = plot_image(fig, axes[1, 1], '.', rss1.intensity - sky_fit.intensity, x=x, norm=im.norm, cbax=False)
im, cb = plot_image(fig, axes[1, 2], '.', rss2.intensity - sky2_fit.intensity, x=x, norm=im.norm, cbax=False)

for ax in [axes[0, 1], axes[0, 2], axes[1, 0], axes[1, 1], axes[1, 2]]:
    ax.sharex(axes[0, 0])
    ax.sharey(axes[0, 0])


## Wavelength calibration

In [None]:
wave_corr.offset.offset_data[0]

In [None]:
fig, axes = new_figure(f'relative wavelength calibration - {rss0.info['name']}',
                       nrows=2, ncols=2,
                       sharex=False, sharey=False,
                       figsize=(8, 8),
                       gridspec_kw={'left': 0.1, 'right':0.95, 'bottom':0.05, 'top':0.95,
                                    'width_ratios': [1, .05], 'hspace': 0.2, 'wspace': 0.2})

ax = axes[0, 0]
ax.set_ylabel('relative offset [pix]')
ax.set_ylim(-.55, .55)
ax.set_xlabel('fibre')

ax.plot(wavelet1.fibre_offset, 'k-', label='iter 1')
#ax.plot(wavelet2.fibre_offset, 'k-', alpha=.25, label='iter 2')

axes[0, 1].set_axis_off()

ax = axes[1, 0]
ax, pc, cb = plot_fibres(fig, axes[1, 0], cblabel='relative offset [pix]', rss=rss0, data=wavelet1.fibre_offset, cbax=axes[1, 1], cmap='Spectral')

if rss0.info['name'] == "NGC5322":
    ax.set_xlim(207.285, 207.345)
    ax.set_ylim(60.178, 60.203)
if rss0.info['name'] == "J122047.09+580537.3":
    ax.set_xlim(185.175, 185.222)
    ax.set_ylim(58.080, 58.105)
fig.savefig(f'{rss0.info['name']}_offset.png')
np.savetxt(f'{rss0.info['name']}_offset.txt', wave_corr.offset.offset_data[:, 0].value)

## Throughput correction

In [None]:
fig, axes = new_figure(f'throuput correction - {rss0.info['name']}',
                       nrows=2, ncols=2,
                       sharex=False, sharey=False,
                       figsize=(8, 8),
                       gridspec_kw={'left': 0.1, 'right':0.95, 'bottom':0.05, 'top':0.95,
                                    'width_ratios': [1, .05], 'hspace': 0.2, 'wspace': 0.2})

ax = axes[0, 0]
ax.set_ylabel('relative offset [pix]')
ax.set_ylim(.75, 1.25)
ax.set_xlabel('fibre')

ax.plot(wavelet1.fibre_throughput, 'k-', label='iter 1')
#ax.plot(wavelet2.fibre_throughput, 'k-', alpha=.25, label='iter 2')

axes[0, 1].set_axis_off()

ax = axes[1, 0]
ax, pc, cb = plot_fibres(fig, axes[1, 0], cblabel='relative throughput', rss=rss0, data=wavelet1.fibre_throughput, cbax=axes[1, 1], cmap='Spectral')

if rss0.info['name'] == "NGC5322":
    ax.set_xlim(207.285, 207.345)
    ax.set_ylim(60.178, 60.203)
if rss0.info['name'] == "J122047.09+580537.3":
    ax.set_xlim(185.175, 185.222)
    ax.set_ylim(58.080, 58.105)
fig.savefig(f'{rss0.info['name']}_throughput.png')
np.savetxt(f'{rss0.info['name']}_throughput.txt', throughput_corr.throughput.throughput_data[:, 0].value)

## Sky subtraction

In [None]:
fig, axes = new_figure(f'sky_subtraction - {rss0.info['name']}', nrows=3, sharey=True)

ax = axes[0, 0]
ax.set_ylabel('sky counts')
ax.set_yscale('log')
ax.set_ylim(50, 5e4)
#ax.plot(self.wavelength, self.sky_counts[0], 'r-', alpha=.5, label='original sky')
if example == "WEAVE":
    ax.plot(wavelength_AA, sky_L1, 'b-', alpha=.5, label='L1')
ax.plot(sky_fit.wavelength, sky_fit.intensity, 'k-', alpha=.75, label='PyKOALA')
ax.legend()

ax = axes[1, 0]
ax.set_ylabel('mean spectrum')
if example == "WEAVE":
    ax.plot(wavelength_AA, np.nanmean(rss0.intensity, axis=0).value - sky_L1, 'b-', alpha=.5, label='L1')
ax.plot(wavelength_AA, np.nanmean(rss1.intensity, axis=0) - sky_fit.intensity, 'k-', alpha=.75, label='PyKOALA')

ax = axes[2, 0]
peak = np.nanargmax(np.nanmean(rss0.intensity, axis=1))
if example == "WEAVE":
    ax.plot(wavelength_AA, rss0.intensity[peak].value - sky_L1, 'b-', alpha=.5, label='L1')
ax.plot(wavelength_AA, rss1.intensity[peak] - sky_fit.intensity, 'k-', alpha=.75, label='PyKOALA')
ax.set_ylabel(f'fibre {peak} (peak)')

ax.set_xlabel(r'wavelength [$\AA$]')
fig.savefig(f'{rss0.info['name']}_sky-subtraction.png')

Now, let's get into the gory details of the linear fit

In [None]:
flux = np.nanmean(rss1.intensity, axis=1)
mean_flux = np.nanmean(flux)
flux_cut_low = np.nanmedian(flux[flux < mean_flux])
flux_cut_hi = 2*mean_flux - flux_cut_low
flux_low = np.nanmean(flux[flux < flux_cut_low])
flux_med = np.nanmean(flux[(flux > flux_cut_low) & (flux < mean_flux)])
flux_hi = np.nanmean(flux[(flux > mean_flux) & (flux < flux_cut_hi)])

I_low = np.nanmean(rss1.intensity[flux < flux_cut_low, :], axis=0)
I_med = np.nanmean(rss1.intensity[(flux > flux_cut_low) & (flux < mean_flux), :], axis=0)
I_hi = np.nanmean(rss1.intensity[(flux > mean_flux) & (flux < flux_cut_hi), :], axis=0)
m = (I_hi - I_low) / (flux_hi - flux_low)
b = I_low - m * flux_low

sky_flux_candidate = np.linspace(0, flux_cut_hi, int(100*flux_cut_hi/flux_cut_low))
sky_filtered = m[np.newaxis, :] * sky_flux_candidate[:, np.newaxis] + b[np.newaxis, :]
x = np.nancumsum(sky_filtered, axis=1)
s = wavelet1.scale
sky_filtered = (x[:, 2*s:-s] - x[:, s:-2*s]) / s
sky_filtered -= (x[:, 3*s:] - x[:, :-3*s]) / (3*s)
chi2_sky = np.nanstd(sky_filtered*wavelet1.sky_weight - wavelet1.sky, axis=1)
chi2_no_sky = np.nanstd(sky_filtered*(1 - wavelet1.sky_weight), axis=1)

sky_flux = sky_flux_candidate[np.nanargmin(chi2_no_sky)]
sky_intensity = b + m*sky_flux

In [None]:
fig, axes = new_figure('get sky flux', nrows=1)

ax = axes[0, 0]
ax.plot(sky_flux_candidate, chi2_sky, 'y-')
ax.plot(sky_flux_candidate, chi2_no_sky, 'r-')
ax.axvline(sky_flux, c='k', ls=':', label=f'sky flux = {sky_flux:.2f}')
ax.legend()
,

In [None]:
fig, axes = new_figure('fit coefficients', nrows=4)

ax = axes[0, 0]
ax.set_ylabel(r'$m_\lambda$')
ax.plot(wavelength_AA, m, 'k-', label=r'$\sum m_\lambda$ = '+f'{np.nanmean(m):.3f}')
#ax.plot(wavelet1.wavelength, wavelet1.sky_weight, 'y-')
#ax.plot(rss1.wavelength, sky_intensity / np.nanmean(sky_intensity), 'c-')
ax.legend()

ax = axes[1, 0]
ax.set_ylabel(r'$b_\lambda$')
#ax.plot(wavelength_AA, sky_fit.intensity, 'r-')
ax.plot(wavelength_AA, sky_intensity, 'c-', label=f'sky_flux = {sky_flux:.3f}')
ax.plot(wavelength_AA, b, 'k-', label=r'$\sum b_\lambda$ = '+f'{np.nanmean(m):.3f}')
#ax.set_ylim(-10, 500)
ax.legend()

ax = axes[2, 0]
s = wavelet1.scale
ax.set_ylabel(f'wavelet b ({s} pix)')

x = np.nancumsum(sky_intensity)
sky_filtered = (x[2*s:-s] - x[s:-2*s]) / s
sky_filtered -= (x[3*s:] - x[:-3*s]) / (3*s)
ax.plot(wavelet1.wavelength, sky_filtered, 'r-', alpha=.5)
ax.plot(wavelet1.wavelength, sky_filtered * wavelet1.sky_weight, 'k-', alpha=.5)
ax.plot(wavelet1.wavelength, wavelet1.sky, 'c-', alpha=.5)

ax = axes[3, 0]
ax.plot(wavelet1.wavelength, wavelet1.sky_weight)

## Flux calibration

In [None]:
def fit_line(x, y):
    valid = np.where(np.isfinite(x) & np.isfinite(y))
    xx = x[valid]
    yy = y[valid]
    slope, intercept, r_value, p_value, std_err = linregress(xx, yy)
    model = intercept + slope*xx
    abs_deviation = np.abs(yy - model)
    valid = np.where(abs_deviation < np.median(abs_deviation))
    xx = xx[valid]
    yy = yy[valid]
    slope, intercept, r_value, p_value, std_err = linregress(xx, yy)
    model = intercept + slope*xx
    abs_deviation = np.abs(yy - model)
    valid = np.where(abs_deviation < np.median(abs_deviation))
    xx = xx[valid]
    yy = yy[valid]
    slope, intercept, r_value, p_value, std_err = linregress(xx, yy)
    return slope, intercept

In [None]:
if example == "WEAVE":
    l_ref_r = 6201.20 * u.Angstrom
    zero_point_r = (3631*u.Jy * c.c/(l_ref_r)**2).to_value(u.erg/u.s/u.cm**2/u.Angstrom)
    i0, i1 = np.interp([5391.11, 7038.08], wavelength_AA, np.arange(wavelength_AA.size)).astype(int)
    r_L1 = np.nanmean((rss0.intensity.value - sky_L1)[:, i0:i1], axis=1)
    r_fit = np.nanmean((rss1.intensity.value - sky_fit.intensity.value)[:, i0:i1], axis=1)
    r_iter2 = np.nanmean((rss2.intensity.value - sky2_fit.intensity.value)[:, i0:i1], axis=1)

    ref_flux_r = zero_point_r * np.power(10, -.4*(fibtable['MAG_R']))
    m0_r = 1/np.interp(l_ref_r, rss0.wavelength, sensitivity_function)

    m_r_counts, sky_r_counts = fit_line(ref_flux_r, r_L1)
    m_r_new, sky_r_new = fit_line(ref_flux_r, r_fit)

In [None]:
if example == "WEAVE":
    l_ref_i = 7534.96 * u.Angstrom
    zero_point_i = (3631*u.Jy * c.c/(l_ref_i)**2).to_value(u.erg/u.s/u.cm**2/u.Angstrom)

    zero_point_r = (3631*u.Jy * c.c/(l_ref_r)**2).to_value(u.erg/u.s/u.cm**2/u.Angstrom)
    i0, i1 = np.interp([5391.11, 7038.08], wavelength_AA, np.arange(wavelength_AA.size)).astype(int)
    i_L1 = np.nanmean((rss0.intensity.value - sky_L1)[:, i0:i1], axis=1)
    i_fit = np.nanmean((rss1.intensity.value - sky_fit.intensity.value)[:, i0:i1], axis=1)
    i_iter2 = np.nanmean((rss2.intensity.value - sky2_fit.intensity.value)[:, i0:i1], axis=1)

    ref_flux_r = zero_point_r * np.power(10, -.4*(fibtable['MAG_R']))
    ref_flux_i = zero_point_i * np.power(10, -.4*(fibtable['MAG_I']))
    m0_i = 1/np.interp(l_ref_i, rss0.wavelength, sensitivity_function)

    m_i_counts, sky_i_counts = fit_line(ref_flux_r, r_L1)
    m_i_new, sky_i_new = fit_line(ref_flux_r, r_fit)

In [None]:
if example == "WEAVE":
    fig, axes = new_figure(f'flux calibration - {rss0.info['name']}', nrows=1, ncols=2, sharey=True)

    xlim_mag = np.array([15.5, 25.5])
    xlim_flux = zero_point_r * np.power(10, -.4*xlim_mag)
    xx_mag = np.linspace(xlim_mag[0], xlim_mag[1], 101)
    xx_r = zero_point_r * np.power(10, -.4*xx_mag)
    xx_i = zero_point_i * np.power(10, -.4*xx_mag)


    ax = axes[0, 0]
    ax.set_ylabel('counts [ADU/s/pix]')
    ax.set_yscale('log')
    ax.set_ylim(.3, 3e3)
    #ax.set_ylim(30, 3e4)

    ax.errorbar(fibtable['MAG_R'], r_L1, fmt='b.', alpha=.1, label='L1')
    ax.errorbar(fibtable['MAG_R'], r_fit, fmt='k.', alpha=.1, label='PyKOALA')
    #ax.errorbar(fibtable['MAG_R'], r_iter2, fmt='r.', alpha=.1, label='iter 2')

    ax.plot(xx_mag, m0_r*xx_r, 'r--', label='expected $r$ sensitivity')
    #ax.plot(xx, m_r_counts*xx + sky_r_counts, 'r-', label=f'sens x {m0_r/m_r_counts:.2f}, {sky_r_counts:+.2f} sky')
    #ax.plot(xx, m_r_new*xx + sky_r_new, 'k-', label=f'sens x {m0_r/m_r_new:.2f}, {sky_r_new:+.2f} sky')
    ax.legend()


    ax = axes[0, 1]
    ax.set_yscale('log')

    ax.errorbar(fibtable['MAG_I'], i_L1, fmt='b.', alpha=.1, label='L1')
    ax.errorbar(fibtable['MAG_I'], i_fit, fmt='k.', alpha=.1, label='PyKOALA')
    #ax.errorbar(fibtable['MAG_I'], i_iter2, fmt='r.', alpha=.1, label='iter 2')

    ax.plot(xx_mag, m0_i*xx_i, 'y--', label='expected $i$ sensitivity')
    #ax.plot(xx, m_i_counts*xx + sky_i_counts, 'r-', label=f'sens x {m0_r/m_r_counts:.2f}, {sky_r_counts:+.2f} sky')
    #ax.plot(xx, m_i_new*xx + sky_i_new, 'k-', label=f'sens x {m0_r/m_r_new:.2f}, {sky_r_new:+.2f} sky')
    ax.legend()


    for ax in axes[-1, :]:
        #ax.set_xscale('log')
        ax.set_xlim(xlim_mag)
        #ax.set_xlim(-1e-17, 1.1e-16)
        #ax.set_ylim(-50, 550)

    axes[-1, 0].set_xlabel(r'Pan-STARRS $r$ [erg/s/cm$^2/\AA$]')
    axes[-1, 1].set_xlabel(r'Pan-STARRS $i$ [erg/s/cm$^2/\AA$]')

    fig.savefig(f'{rss0.info['name']}_flux_calibration.png')


## Single fibre tests

In [None]:
fibre = 42

In [None]:
fig, axes = new_figure('single throughput', nrows=3, sharex=False, sharey=False, gridspec_kw={'hspace': .2})

ax = axes[0, 0]
ax.hist(wavelet1.filtered[fibre]*wavelet1.fibre_throughput[fibre] / wavelet1.sky, bins=np.linspace(0, 7.5, 101), color='k', alpha=.2)
ax.hist(wavelet1.filtered[fibre] / wavelet1.sky, bins=np.linspace(0, 7.5, 101), color='y', alpha=.5)

ax = axes[1, 0]
ax.set_ylabel('wavelet')
ax.plot(wavelet1.wavelength, wavelet1.sky, 'y:', alpha=.5, label='sky')
ax.plot(wavelet1.wavelength, wavelet1.filtered[fibre]*wavelet1.fibre_throughput[fibre], 'k--', alpha=.2, label='data')
ax.plot(wavelet1.wavelength, wavelet1.filtered[fibre], 'k-', alpha=.5, label='throughput-corrected')
ax.plot(wavelet2.wavelength, wavelet2.filtered[fibre], 'r-', alpha=.5, label='iter 2')
ax.legend()


ax = axes[2, 0]
ax.set_ylabel('intensity')
ax.sharex(axes[1, 0])

x = rss0.wavelength
if isinstance(x, u.Quantity):
    x = x.to_value(u.AA)

ax.plot(x, rss0.intensity[fibre], 'k--', alpha=.2, label=f'fibre {fibre}')
ax.plot(x, rss1.intensity[fibre], 'k-', alpha=.5, label='corrected')
ax.plot(x, rss1.intensity[fibre] - sky_fit.intensity, 'c-', alpha=.5, label='sky-subtracted')
ax.plot(x, rss2.intensity[fibre] - sky2_fit.intensity, 'c--', alpha=.5, label='iter 2')

ax.legend()
ax.set_xlabel('wavelength [\AA]')


In [None]:
fig, axes = new_figure(f'single wavelength calibration')

ax = axes[0, 0]

mid = wavelet1.wavelength.size // 2
s = wavelet1.scale
x = np.nanmedian(wavelet1.filtered, axis=0)
x[~ np.isfinite(x)] = 0
x = fftconvolve(wavelet1.filtered[fibre], x[::-1], mode='same')[mid-s:mid+s+1]
idx = np.arange(x.size)

ax.plot(idx - s, x/np.max(x), 'k-', label=f'fibre {fibre}')
oo = wavelet1.fibre_offset[fibre].to_value(u.pix)
ax.axvline(oo, c='k', ls=':', label=f'offset = {wavelet1.fibre_offset[fibre]:.2f}')
ax.plot(2*oo - (idx - s), x/np.max(x), 'k--', alpha=.5, label='reflected cross-correlation')

ax.legend()
ax.set_ylabel('cross-correlation with sky')
ax.set_xlabel('offset [pix]')
,

## Single wavelength

In [None]:
wl = 6700
idx = np.searchsorted(wavelength_AA, wl)

In [None]:
fig, axes = new_figure('linear fit')

ax = axes[0, 0]

ax.plot(flux, rss1.intensity[:, idx], 'k.')
ax.axvline(mean_flux, c='r', ls='--', label=f'mean flux: {mean_flux:.2f}')
ax.axvline(flux_cut_low, c='r', ls=':', label=f'flux low: {flux_cut_low:.2f}')
ax.axvline(flux_cut_hi, c='r', ls=':', label=f'flux high: {flux_cut_hi:.2f}')

x = np.nanpercentile(flux, np.linspace(0, 100, 101))
ax.plot(x, m[idx]*x + b[idx], 'y:', label=f'm={m[idx]:.2f} b={b[idx]:.2f}')
ax.plot(u.Quantity([flux_low, flux_med, flux_hi]), u.Quantity([I_low[idx], I_med[idx], I_hi[idx]]), 'ro')

ax.plot(sky_flux, sky_intensity[idx], 'co', label=f'Sky level = {sky_intensity[idx]:.2f}')

ax.legend()
ax.set_xlabel('mean fibre intensity')
ax.set_ylabel(f'intensity at $\\lambda={wavelength_AA[idx]:.2f}$ \\AA')
ax.set_xscale('log')
ax.set_yscale('log')

In [None]:
fig, axes = new_figure('monochrome hist')

ax = axes[0, 0]

x = np.linspace(b[idx], 3*sky_intensity[idx] - 2*b[idx], 101)

ax.hist(rss1.intensity[:, idx], bins=x, alpha=.5)
ax.axvline(sky_intensity[idx], c='k', ls='--')

ax.legend()
ax.set_ylabel('number of fibres')
ax.set_xlabel(f'intensity at $\\lambda={wavelength_AA[idx]:.2f}$ \\AA')
#ax.set_xscale('log')
ax.set_yscale('log')

In [None]:
fig, axes = new_figure('correction maps',
                       nrows=4, ncols=2, sharey=False, figsize=(8, 12),
                       gridspec_kw={'left': 0.07, 'right':0.9, 'width_ratios': [1, .05], 'hspace': 0.05, 'wspace': 0.25})

x = rss0.intensity[:, idx].value
norm = colors.LogNorm(vmin=.05*np.nanmedian(sky_intensity).value, vmax=np.nanpercentile(x, 99))
plot_fibres(fig, axes[0, 0], cblabel=f'original $I_\\lambda({wavelength_AA[idx]:.2f} \\AA)$', rss=rss0, data=x, cbax=axes[0, 1], norm=norm)
plot_fibres(fig, axes[1, 0], cblabel='throughput corrected', rss=rss1, data=rss1.intensity[:, idx], cbax=axes[1, 1], norm=norm)
plot_fibres(fig, axes[2, 0], cblabel='sky subtracted', rss=rss1, data=rss1.intensity[:, idx] - sky_fit.intensity[idx], cbax=axes[2, 1], norm=norm)
plot_fibres(fig, axes[3, 0], cblabel='iter2', rss=rss2, data=rss2.intensity[:, idx] - sky2_fit.intensity[idx], cbax=axes[3, 1], norm=norm)

if example == "WEAVE":
    if rss0.info['name'] == "NGC5322":
        axes[0, 0].set_xlim(207.285, 207.345)
        axes[0, 0].set_ylim(60.178, 60.203)
    if rss0.info['name'] == "J122047.09+580537.3":
        axes[0, 0].set_xlim(185.175, 185.222)
        axes[0, 0].set_ylim(58.080, 58.105)

for ax in axes[1:, 0]:
    ax.sharey(axes[0, 0])
