In [None]:
from pykoala.instruments.koala_ifu import koala_rss
from matplotlib import pyplot as plt
import numpy as np
from pykoala import __version__
import warnings
import importlib
# You may want to comment the following line
warnings.filterwarnings("ignore")
from pykoala.corrections.throughput import ThroughputCorrection

In [None]:
flat_rss = [koala_rss("data/combined_skyflat_red.fits")]
throughput_corr = ThroughputCorrection.from_rss(flat_rss, clear_nan=True,
                                                medfilt=100)

# apply the correction to the flat exposure
corrected_flat = throughput_corr.apply(flat_rss[0])

In [None]:
throughput_fig = throughput_corr.throughput.plot()

In [None]:
plt.figure(figsize=(15, 8))
plt.subplot(121, title='Original')
plt.imshow(flat_rss[0].intensity.value, interpolation='none', aspect='auto',
           cmap='nipy_spectral',
           vmin=np.nanpercentile(flat_rss[0].intensity, 5).value,
           vmax=np.nanpercentile(flat_rss[0].intensity, 95).value)
plt.colorbar()

plt.subplot(122, title='Throughput corrected')
plt.imshow(corrected_flat.intensity.value, interpolation='none', aspect='auto', cmap='nipy_spectral',
           vmin=np.nanpercentile(flat_rss[0].intensity, 5).value,
           vmax=np.nanpercentile(flat_rss[0].intensity, 95).value)
plt.colorbar()

# Use wavelength correction first

In [None]:
from pykoala.instruments.koala_ifu import koala_rss

from pykoala.corrections.wavelength import SolarCrossCorrOffset
from time import time
from pykoala.plotting.utils import plot_fibres
from matplotlib import pyplot as plt
import numpy as np


In [None]:
solar_correction = SolarCrossCorrOffset.from_fits()

solution = solar_correction.compute_shift_from_twilight(
    flat_rss[0], keep_features_frac=0.05,
    pix_shift_array=np.arange(-2, 2, 0.10),
    pix_std_array=np.arange(0.1, 2, 0.10),
    logspace=False, inspect_fibres=[400, 900])

In [None]:
rss_corrected = solar_correction.apply(flat_rss[0])

In [None]:
plt.imshow(rss_corrected.intensity.value, interpolation='none')


In [None]:
fig, axs = plt.subplots(ncols=2, nrows=2, constrained_layout=True, sharex=True, sharey=True,
                        figsize=(8, 4))

plot_fibres(fig, axs[0, 0], rss=flat_rss[0], data=solution['mean'][0], norm=plt.Normalize(),
          cmap='gnuplot', cblabel=r'$\Delta\lambda$ (pix)')
plot_fibres(fig, axs[0, 1], rss=flat_rss[0], data=solution['mean'][1], norm=plt.Normalize(),
          cmap='gnuplot', cblabel=r'$\sigma$ (pix)')
plot_fibres(fig, axs[1, 0], rss=flat_rss[0], data=solution['best-fit'][0], norm=plt.Normalize(),
          cmap='gnuplot', cblabel=r'$\Delta\lambda$ (pix)')
plot_fibres(fig, axs[1, 1], rss=flat_rss[0], data=solution['best-fit'][1], norm=plt.Normalize(),
          cmap='gnuplot', cblabel=r'$\sigma$ (pix)')


axs[0, 0].set_title("Mean avelength offset")
axs[1, 0].set_title("Best-fit avelength offset")
axs[0, 1].set_title("Mean LSF std")
axs[1, 1].set_title("Best-fit LSF std")

In [None]:
throughput_corr = ThroughputCorrection.from_rss([rss_corrected], clear_nan=True,
                                                medfilt=100)
new_corrected_flat = throughput_corr.apply(rss_corrected)


In [None]:
plt.figure(figsize=(15, 8))
plt.subplot(121, title='Original')
plt.imshow(rss_corrected.intensity.value, interpolation='none',
           aspect='auto', cmap='nipy_spectral',
           vmin=np.nanpercentile(rss_corrected.intensity.value, 5),
           vmax=np.nanpercentile(rss_corrected.intensity.value, 95))
plt.colorbar()

plt.subplot(122, title='Throughput corrected')
plt.imshow(new_corrected_flat.intensity.value, interpolation='none',
           aspect='auto', cmap='nipy_spectral',
           vmin=np.nanpercentile(rss_corrected.intensity.value, 5),
           vmax=np.nanpercentile(rss_corrected.intensity.value, 95))
plt.colorbar()

In [None]:
throughput_fig = throughput_corr.throughput.plot()

In [None]:
new_dispersion = np.nanstd(new_corrected_flat.intensity, axis=0) / np.nanmean(new_corrected_flat.intensity, axis=0)
old_dispersion = np.nanstd(corrected_flat.intensity, axis=0)  / np.nanmean(corrected_flat.intensity, axis=0)
fig, axs = plt.subplots(nrows=2, figsize=(12, 8), sharex=True)
ax = axs[0]
ax.plot(new_corrected_flat.wavelength, new_dispersion, alpha=1, label='New')
ax.plot(corrected_flat.wavelength, old_dispersion, alpha=0.5, label='Old')
ax.set_ylim(0, 0.1)
ax.legend()
ax = axs[1]
ax.plot(corrected_flat.wavelength, np.nanstd(new_corrected_flat.intensity, axis=0) / np.nanstd(corrected_flat.intensity, axis=0))
ax.set_ylabel("STD(new) / STD(original)")
ax.set_ylim(0.1, 1.2)