# Twilight corrections

**WARNING**: Preliminary, potentially unstable tests based on twilight flats. Use with (more) caution (than the rest of the package ;^D).

# 1. Initialisation

In [None]:
%matplotlib ipympl
from matplotlib import pyplot as plt
import numpy as np
from scipy.signal import correlate

In [None]:
from pykoala import __version__
from pykoala.instruments import koala_ifu, weave
from pykoala.corrections.wavelength import SolarCrossCorrOffset
from pykoala.plotting.utils import new_figure, colour_map
print("pyKOALA version: ", __version__)

The following will probably disappear in the final version of the tutorial

In [None]:
%load_ext autoreload
%autoreload 2

# 2. Load the science data
This must be a Row-Stacked Spectra (RSS) file. Please choose one of the following examples

In [None]:
#example = 'KOALA'
example = 'WEAVE'

KOALA:

In [None]:
if example == 'KOALA':
    filename = f"../data/27feb20036red.fits"
    rss0 = koala_ifu.koala_rss(filename)
    wavelength_AA = rss0.wavelength

WEAVE:

In [None]:
if example == 'WEAVE':
    #filename = f"../data/weave/v3/WA_J024019.19+321544.10/single_3042890.fit"
    #rss0 = weave.weave_rss(filename)
    #wavelength_AA = rss0.wavelength.to_value(u.Angstrom)
    filename = f"../data/weave/solar/msp_3059302.fit"
    rss = weave.weave_rss(filename)
    wavelength_AA = rss.wavelength

## Summary

In [None]:
print(f"Analysing object {rss.info['name']} read from {filename}")
print('- info:')
print(rss.info.keys())
print('- log:')
rss.history.show()


# 3. Solar spectrum

In [None]:
solar_correction = SolarCrossCorrOffset.from_fits()

# 4. Transforms

## Normalisation

In [None]:
norm_range = np.array([7000., 7100.])

In [None]:
#normalised_rss = np.nanmedian(rss.intensity, axis=0).copy()
fibre = 42
normalised_rss = rss.intensity[fibre]

normalised_rss /= np.nanmedian(normalised_rss[(wavelength_AA >= norm_range[0]) & (wavelength_AA <= norm_range[1])])

In [None]:
normalised_sun = solar_correction.sun_intensity.copy()
normalised_sun /= np.nanmedian(normalised_sun[(solar_correction.sun_wavelength >= norm_range[0]) & (solar_correction.sun_wavelength <= norm_range[1])])

## Wavelet filter

In [None]:
def integrate(y, x=None):
    if x is None:
        x = np.arange(y.size)
    x_edges = np.hstack([2*x[0] - x[1], (x[1:] + x[:-1]) / 2, 2*x[-1] - x[-2]])
    dx = np.diff(x_edges)
    integral = np.hstack([0, np.nancumsum(y * dx)])
    return x_edges, integral

def mean_filter(x0, y0, x, h):
    x0_edges, integral_y0 = integrate(y0, x0)
    return (np.interp(x+h/2, x0_edges, integral_y0) - np.interp(x-h/2, x0_edges, integral_y0)) / h

def wavelet_filter(x0, y0, x, h):
    x0_edges, integral_y0 = integrate(y0, x0)
    filtered = (np.interp(x+h/2, x0_edges, integral_y0) - np.interp(x-h/2, x0_edges, integral_y0)) / h
    filtered -= (np.interp(x+1.5*h, x0_edges, integral_y0) - np.interp(x-1.5*h, x0_edges, integral_y0)) / (3*h)
    norm = np.sqrt(mean_filter(x, filtered**2, x, 30*h))
    filtered /= norm
    return filtered, norm

In [None]:
h = 3  # AA
shift = 5  # AA
rss_filtered, rss_norm = wavelet_filter(wavelength_AA, normalised_rss, wavelength_AA, h)
sun_filtered, sun_norm = wavelet_filter(solar_correction.sun_wavelength, normalised_sun, wavelength_AA, h)

## Cross-correlation

In [None]:
shift = np.linspace(-5, 5, 101)
shift_idx = np.argmin(shift**2)

cross_correlation = np.empty((shift.size, wavelength_AA.size))
for i in range(shift.size):
    cross_correlation[i] = mean_filter(
        wavelength_AA,
        rss_filtered * np.interp(wavelength_AA, wavelength_AA+shift[i], sun_filtered),
        wavelength_AA+shift[i], h)



In [None]:
shift = np.linspace(-5, 5, 101)
shift_idx = np.argmin(shift**2)

self_correlation = np.empty((shift.size, wavelength_AA.size))
for i in range(shift.size):
    self_correlation[i] = mean_filter(
        wavelength_AA,
        rss_filtered * np.interp(wavelength_AA, wavelength_AA+shift[i], rss_filtered),
        wavelength_AA+shift[i], h)



In [None]:
fig, axes = new_figure('cross-correlation', nrows=3, ncols=2, sharey=False,
                       gridspec_kw={'width_ratios': [1, .02], 'hspace': 0., 'wspace': .1})

ax = axes[0, 0]
ax.set_ylabel('intensidy [arbitrary units]')
ax.set_yscale('log')
ax.set_ylim(.03, 30)

ax.plot(wavelength_AA, normalised_rss, '-', label='rss')
ax.plot(solar_correction.sun_wavelength + shift[shift_idx], normalised_sun, '-', label='sun')

ax.legend()


ax = axes[1, 0]
ax.set_ylabel('wavelet filter')

ax.plot(wavelength_AA, rss_filtered, '-', label='rss')
'''
ax.plot(wavelength_AA, rss_filtered, '-', label='rss')
ax.plot(wavelength_AA, sun_filtered, '-', label='sun')
#ax.plot(wavelength_AA, rss_filtered*sun_filtered, 'w:', label='product')
ax.plot(wavelength_AA, cross_correlation[shift_idx], 'w--', label=f'shift {shift[shift_idx]}')
'''
ax.plot(wavelength_AA, self_correlation[shift_idx], 'w:', label=f'self {shift[shift_idx]}')

ax.legend()

ax = axes[2, 0]
im, cb = colour_map(fig, ax, 'coeff', self_correlation,
                    x=wavelength_AA, xlabel='wavelength [$\AA$]',
                    y=shift, ylabel='shift [$\AA$]', cbax=axes[2, 1])
#best = np.nanargmax(cross_correlation, axis=0)
best = np.nansum(cross_correlation**2 * np.arange(shift.size)[:, np.newaxis], axis=0)
best = (best / np.nansum(cross_correlation**2, axis=0)).astype(int).clip(0, shift.size-1)
ax.plot(wavelength_AA, shift[best], color='k', alpha=.5)
ax.set_xlim(wavelength_AA[0], wavelength_AA[-1])

axes[0, 1].axis('off')
axes[1, 1].axis('off')


In [None]:
x = np.where(np.isfinite(rss_filtered), rss_filtered, 0)
x /= np.std(x)
x = correlate(x, x)
np.count_nonzero(x > 0.5*np.nanmax(x)), np.nanmax(x)

In [None]:
x = np.where(np.isfinite(rss_filtered), rss_filtered, 0)
np.std(x)

In [None]:
fig, axes = new_figure('scipy_cross-corr', nrows=3)

x = np.where(np.isfinite(rss_filtered), rss_filtered, 0)
ax = axes[0, 0]
#ax.plot(correlate(x, x) / wavelength_AA.size)
def y(x):
    norm = (x - np.mean(x)) / np.std(x)
    return correlate(norm, norm) / x.size
ax.plot(y(normalised_rss), '+-')
ax.plot(y(normalised_sun))
ax = axes[1, 0]
ax.plot(correlate(x, x) / wavelength_AA.size, alpha=.5)
ax.plot(correlate(sun_filtered, sun_filtered) / wavelength_AA.size, alpha=.5)
ax.plot(correlate(x, sun_filtered) / wavelength_AA.size)
ax = axes[2, 0]
ax.plot(x)
ax.plot(sun_filtered)

In [None]:
fig, axes = new_figure('LSF')

ax = axes[0, 0]
ax.plot(shift, np.nanmean(cross_correlation, axis=1), '+-')

In [None]:
shift