In [None]:
# from astropy.wcs import WCS
import matplotlib.pyplot as plt
import numpy as np
import tess_cpm
import lightkurve as lk
import timeit
from astropy.io import fits 
from matplotlib.gridspec import GridSpec
from scipy.optimize import curve_fit
from scipy.ndimage import uniform_filter1d
from scipy.ndimage import median_filter

plt.rcParams["figure.figsize"] = (7, 5)
plt.rcParams["figure.dpi"] = 400

In [None]:
# plt.rcParams["figure.figsize"] = (14, 10)
# plt.rcParams["figure.dpi"] = 300

You can first specify the path to the stack of FFI cutouts containing the source you're interested in.  
We've been using cutouts (obtained using TESScut) of size 100x100, but you can use smaller (or larger) cutouts. 
The smallest cutout we've used is 32x32.  
In this example I'll be using a 100x100 FFI cutout.

In [None]:
# sec11 = "dwarfs/tess-s0011-3-4_169.234200_-80.464300_80x80_astrocut.fits"  # TIC 395130640
# sec12 = "dwarfs/tess-s0012-3-3_169.234200_-80.464300_80x80_astrocut.fits"
fits_file = "dwarfs/tess-s0012-3-3_169.234200_-80.464300_100x100_astrocut.fits"  # TIC 395130640
# fits_file = "dwarfs/tess-s0010-2-1_162.328812_-53.319467_100x100_astrocut.fits"  # provided source
# fits_file = "dwarfs/tess-s0010-3-2_162.328812_-53.319467_100x100_astrocut.fits"  # provided source

The current main interface to the TESS CPM package is through the Source class.  
You'll initialize an instance of the Source class by passing the path to the FFI cutouts.  
The `remove_bad` keyword argument specifies whether you want to remove the data points that have been flagged by the TESS QUALITY array. 

In [None]:
dw = tess_cpm.Source(fits_file, remove_bad=True)

If you want to see the median flux image of your stack of FFI cutouts, you can just use the `plot_cutout()` method.  
It's probably a good idea to check it to see where you might be missing fluxes (e.g. the FFI cutouts are close to the edge of the detectors).

In [None]:
dw.plot_cutout();

The next thing to do is specify the set of pixels you believe your source falls on.  
If you got your cutouts using TESScut by providing the coordinates of the source, the source will roughly be at the center of the image.  

You can specify the set of pixels by using the `set_aperture` method.  
It currently only lets you choose a rectangular set of pixels, although we're hoping to eventually make a way specify any aperture.  
You can define the extent of the rectangular aperture in the `set_aperture` method using the `rowlims` and `collims` argument. For each of these arguments, just pass a list that specifies the lower and upper limits of the aperture. For example `rowlims=[50, 52]` means rows 50, 51, and 52.  

After specifying the aperture, you can visually check to see that your aperture is actually covering the pixels you're interested in using `plot_cutout` again.  
You'll just need to specify the `show_aperture=True` keyword argument. The overlayed aperture will make the pixels in the aperture look white. 
You can also pass the region you'd like to see the cutout around (instead of the entire cutout) by specifying the rows and columns in the same way you'd define the aperture.

In [None]:
dw.set_aperture(rowlims=[50, 51], collims=[50, 51])
# dw.set_aperture(rowlims=[47, 52], collims=[48, 53])
dw.plot_cutout(rowlims=[45, 55], collims=[45, 55], show_aperture=True);

After specifying the set of pixels you're interested in, you can check the pixel light curves using the `plot_pix_by_pix` method.  

In [None]:
dw.plot_pix_by_pix();  # Just calling the method on its own will plot the raw flux values

In [None]:
dw.plot_pix_by_pix(data_type="normalized_flux");  # If you specify `normalized_flux`, you'll get the zero-centered median-normalized flux.

From here you'd choose the model components you'd want to add.   
You'd definitely want to add the causal pixel model with the `add_cpm_model` method.  

In [None]:
dw.add_cpm_model?
# dw.add_cpm_model()
dw.add_cpm_model(predictor_method='similar_brightness')

CPM's main idea is to model a single pixel light curve as a linear combination of a bunch of other pixel light curves.  
The default setting uses `n=256` other pixel light curves, so the model is very prone to overfit.  
One of the ways to prevent overfitting is to constrain the flexibility of the model through regularization.  
Currently we use L2 regularization and a larger regularization value is a stronger regularization. 
We set the regularization value using `set_regs`. We still need to figure out a good way to determine the regularization values, but for this example we'll just use `0.1`.

In [None]:
dw.models[0][0].plot_model();

In [None]:
dw.set_regs([0.1])  # It needs to be a list because there are cases where you'd want to specify more than one regularization parameter.

We can now perform least squares regression to model the pixel light curve with the `holdout_fit_predict` method.
In addition to regularization, we also use a train-and-test framework to prevent overfitting. In this framework we split the lightcurve into __k__ contiguous chunks and predict the __i-th__ chunk with the parameters obtained from regressing on all the other chunks.

In [None]:
dw.holdout_fit_predict(k=100)
dw.plot_pix_by_pix(data_type="cpm_subtracted_flux", split=True);

In [None]:
plt.figure(figsize=(16, 10))
aperture_normalized_flux = dw.get_aperture_lc(data_type="normalized_flux")
aperture_cpm_prediction = dw.get_aperture_lc(data_type="cpm_prediction")
plt.plot(dw.time, aperture_normalized_flux, ".", c="k", ms=8, label="Normalized Flux")
plt.plot(dw.time, aperture_cpm_prediction, "-", lw=3, c="C3", alpha=0.8, label="CPM Prediction")
plt.xlabel("Time - 2457000 [Days]", fontsize=30)
plt.ylabel("Normalized Flux", fontsize=30)
plt.tick_params(labelsize=20)
plt.legend(fontsize=30)

In [None]:
weighted_detrended_lc = dw.get_aperture_lc(split=True, weighting="median", data_type="cpm_subtracted_flux")
# for time, lc in zip(dw.split_times, weighted_detrended_lc):
#     plt.plot(time, lc, "-")
# detrended_lc = dw.get_aperture_lc(split=True, weighting=None, data_type="cpm_subtracted_flux")
# for time, lc in zip(dw.split_times, detrended_lc):
#     plt.plot(time, lc, "--")
# plt.xlabel("Time - 2457000 [Days]", fontsize=30)
# plt.ylabel("CPM Flux", fontsize=30)
# plt.tick_params(labelsize=20)

In [None]:
cpm_lc = dw.get_aperture_lc(data_type="cpm_subtracted_flux")

In [None]:
plt.plot(dw.time, cpm_lc, "-", c="k")
plt.xlabel("Time - 2457000 [Days]", fontsize=30)
plt.ylabel("CPM Flux", fontsize=30)
plt.tick_params(labelsize=20)

In [None]:
outliers = dw.get_outliers(sigma_upper=3)

In [None]:
plt.plot(dw.time, cpm_lc, "-", c="k", label="Detrended Light curve")
plt.plot(dw.time[outliers], cpm_lc[outliers], "x", ms=10, c="C3", label="Outliers")
plt.xlabel("Time - 2457000 [Days]", fontsize=30)
plt.ylabel("CPM Flux", fontsize=30)
plt.tick_params(labelsize=20)
plt.legend(fontsize=30)

In [None]:
plt.plot(dw.time[~outliers], cpm_lc[~outliers], "-", c="k", label="Detrended Light curve")
# plt.plot(dw.time[outliers], cpm_lc[outliers], "x", ms=10, c="C3")
plt.xlabel("Time - 2457000 [Days]", fontsize=30)
plt.ylabel("CPM Flux", fontsize=30)
plt.tick_params(labelsize=20)

In [None]:
lc = lk.TessLightCurve(time=dw.time, flux=cpm_lc)
lc.plot()

In [None]:
lc = lc.remove_outliers(sigma_upper=3)
lc.plot()

In [None]:
pg = lc.to_periodogram()

In [None]:
fig, axs = plt.subplots(2, 1, figsize=(16, 16))
pg.plot(ax=axs[0], c='k')
pg.plot(ax=axs[1], c='k', view='period')
# fig.suptitle("Periodogram", fontsize=20, y=0.95)
period = pg.period_at_max_power
print(f"Max Power Period: {period}")
lc.fold(period.value).scatter()
plt.title(f"Folded Lightcurve with Period: {period:.4f}", fontsize=20)

## Periodogram for Original Light Curve

In [None]:
lc_og = lk.TessLightCurve(time=dw.time, flux=dw.get_aperture_lc("raw"))
lc_og.plot()
lc_og = lc_og.flatten()
lc_og.plot()
pg_og = lc_og.to_periodogram()

In [None]:
fig, axs = plt.subplots(2, 1, figsize=(16, 16))
pg_og.plot(ax=axs[0], c='k')
pg_og.plot(ax=axs[1], c='k', view='period')
# fig.suptitle("Periodogram", fontsize=20, y=0.95)
period_og = pg_og.period_at_max_power
print(f"Max Power Period: {period_og}")
lc.fold(period_og.value).scatter()
plt.title(f"Folded Lightcurve with Period: {period_og:.4f}", fontsize=20)

In [None]:
# cpm_regs = 10.0 ** np.arange(-9, 9)
# min_cdpp_reg, cdpps = dw.calc_min_cpm_reg(cpm_regs, k=5)

In [None]:
# print(min_cdpp_reg)
# # dw.set_regs([min_cdpp_reg])
# dw.set_regs([min_cdpp_reg])

# dw.holdout_fit_predict(k=10);

In [None]:
# outliers = dw.get_outliers()

In [None]:
# plt.plot(dw.time, dw.get_aperture_lc(split=False, data_type="normalized_flux"), "-", c="k")
# plt.plot(dw.time, dw.get_aperture_lc(split=False, data_type="cpm_prediction"), "-", c="r")
# plt.plot(dw.time[~outliers], dw.get_aperture_lc(split=False, data_type="cpm_subtracted_flux")[~outliers], "-", c="k")

In [None]:
# def cpm_periodogram(fits_file, t_row=50, t_col=50):
#     cpm = tess_cpm.CPM(fits_file, remove_bad=True)
#     cpm.set_target(t_row, t_col)
#     cpm.set_exclusion(10)
#     cpm.set_predictor_pixels(256, method='cosine_similarity')
#     cpm.lsq(0.1, rescale=True, polynomials=False)
#     tess_cpm.summary_plot(cpm, 10)
#     aperture_lc, lc_matrix = cpm.get_aperture_lc(box=1, show_pixel_lc=True, show_aperture_lc=True)
#     lc = lk.LightCurve(time=cpm.time, flux=aperture_lc)
#     pg = lc.to_periodogram(oversample_factor=100)
#     fig, axs = plt.subplots(2, 1, figsize=(15, 8))
#     pg.plot(ax=axs[0], c='k')
#     pg.plot(ax=axs[1], c='k', view='period')
#     fig.suptitle("Periodogram", fontsize=20, y=0.95)
#     period = pg.period_at_max_power
#     print(f"Max Power Period: {period}")
#     lc.fold(period.value*4).scatter()
#     plt.title(f"Folded Lightcurve with Period: {period:.4f}", fontsize=20)
#     return cpm

## TIC 395130640

In [None]:
## FFI Data

sec11_FFI_data_file = "dwarfs/tess-s0011-3-4_169.234200_-80.464300_80x80_astrocut.fits"  # TIC 395130640
with fits.open(sec11_FFI_data_file, mode="readonly") as hdu:
    time = hdu[1].data["TIME"]
    flux = hdu[1].data["FLUX"]
    err = hdu[1].data["FLUX_ERR"]
    quality = hdu[1].data["QUALITY"]

In [None]:
# if removing flagged points
flagged = quality != 0
time = time[~flagged]
flux = flux[~flagged]

In [None]:
plt.imshow(flux[0], origin="lower")
plt.imshow(flux[0, 35:45, 35:45], origin="lower")

In [None]:
flux[:, 41:42, 39:40].sum((1,2))

In [None]:
ffi_apt_lc = flux[:, 40:42, 40:43].sum((1,2))
# ffi_apt_lc += flux[:, 41:42, 39:40].sum((1,2))

ffi_apt_lc.shape

In [None]:
plt.plot(time, ffi_apt_lc)

In [None]:
## 2-minute Data

tpf_search_result = lk.search_targetpixelfile(target="169.2342 -80.4643", mission="TESS")
tpf11 = tpf_search_result[0].download()

In [None]:
tpf11_lc = tpf11.to_lightcurve()
tpf11_lc.plot()

In [None]:
binned_tpf11_lc = tpf11_lc.bin(time_bin_size=0.02083333333)

In [None]:
plt.plot(tpf11_lc.time.value, tpf11_lc.flux / np.nanmedian(tpf11_lc.flux), label="2-minute data")
plt.plot(binned_tpf11_lc.time.value, binned_tpf11_lc.flux / np.nanmedian(binned_tpf11_lc.flux), label="Binned 2-minute data")
plt.plot(time, ffi_apt_lc / np.nanmedian(ffi_apt_lc) + 0.013, label="FFI")
plt.legend()
plt.xlim(1603.5,1610)
plt.ylim(0.97,1.05)

## Central Pixel

In [None]:
plt.imshow(flux[0], origin="lower")

In [None]:
central_tpf_pixel = lk.TessLightCurve(time=tpf11.time.value, flux=tpf11.flux[:,5,5])
binned_central_tpf_pixel = central_tpf_pixel.bin(time_bin_size=0.02083333333)

In [None]:
plt.plot(central_tpf_pixel.time.value, central_tpf_pixel.flux, label="2-minute data")
plt.plot(binned_central_tpf_pixel.time.value, binned_central_tpf_pixel.flux, label="Binned 2-minute data")
# plt.plot(tpf11.time.value, tpf11.flux[:,5,5], label="2-minute data")

plt.plot(time, flux[:,40,40]-110, label="FFI-110")
plt.plot(time, flux[:,40,40], label="FFI")

plt.legend()
plt.xlim(1600,1610)
# plt.ylim(700,800)

In [None]:
# plt.plot(central_tpf_pixel.time.value, central_tpf_pixel.flux / np.nanmedian(central_tpf_pixel.flux), label="2-minute data")
plt.plot(binned_central_tpf_pixel.time.value, binned_central_tpf_pixel.flux / np.nanmedian(binned_central_tpf_pixel.flux), 
         ".-", label="Binned 2-minute data")
# plt.plot(tpf11.time.value, tpf11.flux[:,5,5], label="2-minute data")

plt.plot(time, (flux[:,40,40]-110) / np.nanmedian(flux[:,40,40]-110) + 0.009, ".-", label="FFI-110")
plt.plot(time, flux[:,40,40] / np.nanmedian(flux[:,40,40]) + 0.006, ".-", label="FFI")
plt.legend()
plt.xlim(1603,1608)
plt.ylim(0.96,1.03)

In [None]:
tpf_bool = (central_tpf_pixel.time.value > 1603) * (central_tpf_pixel.time.value < 1608)
binned_tpf_bool = (binned_central_tpf_pixel.time.value > 1603) * (binned_central_tpf_pixel.time.value < 1608)
ffi_bool = (time > 1603) * (time < 1608)

In [None]:
plt.plot(central_tpf_pixel.time.value[tpf_bool], central_tpf_pixel.flux[tpf_bool] / np.nanmedian(central_tpf_pixel.flux[tpf_bool]), label="2-minute data")
plt.plot(binned_central_tpf_pixel.time.value[binned_tpf_bool], binned_central_tpf_pixel.flux[binned_tpf_bool] / np.nanmedian(binned_central_tpf_pixel.flux[binned_tpf_bool]), 
         ".-", label="Binned 2-minute data")
# plt.plot(tpf11.time.value, tpf11.flux[:,5,5], label="2-minute data")

# plt.plot(time[ffi_bool], (flux[:,40,40][ffi_bool]-110) / np.nanmedian(flux[:,40,40][ffi_bool]-110), ".-", label="FFI-110")
plt.plot(time[ffi_bool], flux[:,40,40][ffi_bool] / np.nanmedian(flux[:,40,40][ffi_bool]), ".-", label="FFI")
plt.legend()
plt.xlim(1603,1608)
plt.ylim(0.96,1.03)

In [None]:
folded_binned_tpf = lk.TessLightCurve(time=binned_central_tpf_pixel.time.value[binned_tpf_bool], 
                                      flux=binned_central_tpf_pixel.flux[binned_tpf_bool] / np.nanmedian(binned_central_tpf_pixel.flux[binned_tpf_bool]
                                                                                                        )).fold(period=0.413)
offset = 0
folded_ffi = lk.TessLightCurve(time=time[ffi_bool], flux=(flux[:,40,40][ffi_bool]+offset) / np.nanmedian(flux[:,40,40][ffi_bool]+offset)).fold(period=0.413)

In [None]:
plt.plot(folded_binned_tpf.phase.value, folded_binned_tpf.flux, ".")
plt.plot(folded_ffi.phase.value, folded_ffi.flux, ".")

## Other Pixels

In [None]:
row_offset_from_center = 1
col_offset_from_center = 0

tpf_pixel = lk.TessLightCurve(time=tpf11.time.value, flux=tpf11.flux[:,5+row_offset_from_center,5+col_offset_from_center])
binned_tpf_pixel = tpf_pixel.bin(time_bin_size=0.02083333333)

In [None]:
plt.plot(tpf_pixel.time.value, tpf_pixel.flux, label="2-minute data")
plt.plot(binned_tpf_pixel.time.value, binned_tpf_pixel.flux, label="Binned 2-minute data")
# plt.plot(tpf11.time.value, tpf11.flux[:,5,5], label="2-minute data")

plt.plot(time, flux[:,40+row_offset_from_center,40+col_offset_from_center]-110, label="FFI-110")
plt.plot(time, flux[:,40+row_offset_from_center,40+col_offset_from_center], label="FFI")

plt.legend()
# plt.xlim(1600,1610)
# plt.ylim(1600,2100)
# plt.ylim(300,400)