In [None]:
import tess_cpm
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
from astropy.wcs import WCS
import lightkurve as lc

In [None]:
plt.rcParams["figure.figsize"] = (12, 8)

In [None]:
f = "exoplanets/tess-s0010-1-1_173.957400_-29.156000_100x100_astrocut.fits"

In [None]:
exo = tess_cpm.Source(f)

In [None]:
exo.plot_cutout()

In [None]:
exo.set_aperture(rowrange=[49, 50], colrange=[50, 51])
exo.plot_cutout(rowrange=[45, 55], colrange=[45, 55], show_aperture=True)

In [None]:
exo.plot_pix_by_pix()

In [None]:
exo.add_cpm_model()
exo.set_regs([0.162])
exo.holdout_fit_predict(k=10);

In [None]:
# exo.plot_pix_by_pix(split=True, data_type="raw")
# exo.plot_pix_by_pix(split=True, data_type="prediction")
# exo.plot_pix_by_pix(split=True, data_type="cpm_prediction")
# exo.plot_pix_by_pix(split=True, data_type="poly_model_prediction")
exo.plot_pix_by_pix(split=False, data_type="cpm_subtracted_lc")

In [None]:
apt_cpm = exo.get_aperture_lc(data_type="cpm_prediction")
apt_lc = exo.get_aperture_lc(data_type="cpm_subtracted_lc")
# exo.get_aperture_lc(split=False, data_type="cpm_subtracted_lc")

In [None]:
# plt.plot(exo.time, apt_cpm, ".")
plt.plot(exo.time, apt_lc, ".")
# outliers = np.abs(apt_lc) > 1.5*np.std(apt_lc)
# print(np.sum(outliers))
# plt.plot(exo.time[outliers], apt_lc[outliers], "x", c="k")
# plt.plot(exo.time[~outliers], apt_lc[~outliers], "x", c="k")

# apt_lc[pre_clip] = np.nan
# clip = np.abs(apt_lc) > 1*np.std(apt_lc[pre_clip])
# plt.plot(exo.time[clip], apt_lc[clip], "x", c="r")
# np.any(clip)

In [None]:
exo.holdout_fit_predict(k=20, mask=~outliers);

In [None]:
clipped_apt_lc = exo.get_aperture_lc(data_type="cpm_subtracted_lc")
plt.plot(exo.time, clipped_apt_lc, ".", label='Clipped Prediction')
plt.plot(exo.time, apt_lc, "-", label='no clipping')
plt.legend()

In [None]:
cpm_regs = 10.0 ** np.arange(-5, 10)

In [None]:
min_cpm_reg, cdpps = exo.calc_min_cpm_reg(cpm_regs, k=2)

In [None]:
print(min_cpm_reg)

In [None]:
# def calc_cdpp(flux):
#     return lc.TessLightCurve(flux=flux).estimate_cdpp()

# calc_cdpp(cpm_subtracted_lc)

In [None]:
split_cpm_subtracted_lc = exo.get_aperture_lc(split=True, data_type="cpm_subtracted_lc")
cpm_subtracted_lc = exo.get_aperture_lc(data_type="cpm_subtracted_lc")

for t, l in zip(exo.split_times, split_cpm_subtracted_lc):
    plt.plot(t, l)

In [None]:
k = 10
cpm_regs = 0.01 * np.arange(1, 100)
cdpps = np.zeros((cpm_regs.size, k))
for idx, creg in enumerate(cpm_regs):
        exo.set_regs([creg])
        exo.holdout_fit_predict(k)
        split_cpm_subtracted_lc = exo.get_aperture_lc(split=True, data_type="cpm_subtracted_lc", verbose=False)
        split_cdpp = [calc_cdpp(lc) for lc in split_cpm_subtracted_lc]
        cdpps[idx] = np.array(split_cdpp)
#         cdpps.append(lc.TessLightCurve(exo.time, exo.get_aperture_lc(data_type="cpm_subtracted_lc", verbose=False)).estimate_cdpp())
cdpps.shape  # (cdpp, k-th section)

In [None]:
# This shows k-th-section (x-axis) vs cdpp (y-axis)
# each line represents the cdpp values for a lightcurve fit with a given reg value.
for cpm_reg, cdpp in zip(cpm_regs, cdpps):
    plt.plot(np.arange(k)+1, cdpp, label=f"Reg {cpm_reg}")
cdpps;
plt.xlabel("k-th section of lightcurve", fontsize=20)
plt.ylabel("CDPP", fontsize=20)
# plt.title("");
# plt.legend()

In [None]:
# This shows cpm_reg (x-axis) vs cdpp (y-axis)
for idx, cdpp in enumerate(cdpps.T):
    plt.plot(cpm_regs, cdpp, label=f"{idx+1} section")
plt.xlabel("CPM Regularization Values", fontsize=20)
plt.ylabel("CDPP", fontsize=20)
plt.legend();

In [None]:
section_averaged_cdpps = np.average(cdpps, axis=1)
reg_at_min_cdpp = cpm_regs[np.argmin(section_averaged_cdpps)]
print(reg_at_min_cdpp)
plt.plot(cpm_regs, section_averaged_cdpps)
plt.xlabel("CPM Regularization Values", fontsize=20)
plt.ylabel("CDPP", fontsize=20);
plt.scatter(0.13, section_averaged_cdpps[np.where(cpm_regs == 0.13)], c="r")

In [None]:
reg_at_min_cdpp = cpm_regs[np.argmin(section_averaged_cdpps)]
print(reg_at_min_cdpp)

In [None]:
exo.set_regs([reg_at_min_cdpp])
exo.holdout_fit_predict()

In [None]:
plt.plot(cpm_regs, cdpps, "o")

In [None]:
pm = tess_cpm.PixelModel(exo.target_data, row=50, col=50)
pm.add_cpm_model()
pm.set_regs([0.1])
pm.fit()
prediction = pm.predict()
prediction = (prediction + 1) * pm.median

In [None]:
exo = tess_cpm.CPM(tess_cpm.TargetData(f))

In [None]:
exo.set_target_exclusion_predictors(50, 50)

In [None]:
exo.get_hyperparameters(transit_duration=100)

In [None]:
exo = tess_cpm.CPM(f, remove_bad=True)

In [None]:
exo.set_poly_model(1, 4, 0.5)
exo.set_target(50, 50)
exo.set_exclusion(10)
exo.set_predictor_pixels(256)

In [None]:
exo.lsq(1.0, rescale=True, polynomials=False)

In [None]:
tess_cpm.summary_plot(exo, 20, save=True)

In [None]:
exo.sigma_clip_process(2.3)

In [None]:
tess_cpm.summary_plot(exo, 20, subtract_polynomials=False)

In [None]:
plt.figure(figsize=(15, 6))
# plt.plot(exo.time, exo.rescaled_target_fluxes, ".-", color="black")
diff = exo.rescaled_target_fluxes - exo.lsq_prediction
plt.plot(exo.time[exo.valid], diff[exo.valid], ".-", color="C0");
# plt.plot(exo.time, diff, ".-", color="C3")
plt.plot(exo.time[~exo.valid], diff[~exo.valid], "x", color="gray")

In [None]:
plt.figure(figsize=(15, 6))
# plt.plot(exo.time, exo.rescaled_target_fluxes, ".-", color="black")
diff = exo.rescaled_target_fluxes - exo.lsq_prediction
# plt.plot(exo.time[exo.valid], diff[exo.valid], ".-", color="C0");
plt.plot(exo.time, diff, ".-", color="C0")
# plt.plot(exo.time[~exo.valid], diff[~exo.valid], "x", color="gray")

In [None]:
from IPython.display import HTML
import matplotlib.animation as animation
exo.entire_image(0.5, rescale=True, polynomials=True)
diff = exo.im_diff
upscaled_diff = exo.pixel_medians*exo.im_diff

fig, axes = plt.subplots(1, 3, figsize=(18, 18))

ims = []
for i in range(0, diff.shape[0], 10):
    im1 = axes[0].imshow(exo.im_fluxes[i], origin="lower", animated=True,
                        vmin=np.percentile(exo.im_fluxes[0], 10), vmax=np.percentile(exo.im_fluxes[0], 90))
    im2 = axes[1].imshow(diff[i], origin="lower", animated=True,
                   vmin=np.percentile(diff, 1), vmax=np.percentile(diff, 99));
    im3 = axes[2].imshow(upscaled_diff[i], origin="lower", animated=True,
                   vmin=np.percentile(upscaled_diff, 1), vmax=np.percentile(upscaled_diff, 99));
    ims.append([im1, im2, im3]);
fig.colorbar(im1, ax=axes[0], fraction=0.046)
fig.colorbar(im2, ax=axes[1], fraction=0.046)
fig.colorbar(im3, ax=axes[2], fraction=0.046)
    
ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True,
                                repeat_delay=1000);

HTML(ani.to_jshtml())
