In [1]:
import pandas as pd
# import jax.numpy as jnp
import numpy as np
# from jax import jit
from matplotlib import pyplot as plt

In [2]:
lcs = pd.read_csv(f"data/plasticc_test_lightcurves_01.csv", compression="gzip")
lcs = lcs.sort_values(["object_id", "passband", "mjd"])
mjd_min = np.floor(lcs.mjd.min() - 10)
mjd_max = np.floor(lcs.mjd.max() + 10)

In [3]:
def linear_interpolation(x, xp, fp):
    md = np.median(fp)
    xp = np.insert(xp, 0, 59572)
    xp = np.append(xp, 60684)
    fp = np.insert(fp, 0, md)
    fp = np.append(fp, md)
    indices = np.searchsorted(xp, x) - 1
    indices = np.clip(indices, 0, len(xp) - 2)

    x0, x1 = xp[indices], xp[indices + 1]
    y0, y1 = fp[indices], fp[indices + 1]

    slope = (y1 - y0) / (x1 - x0)
    y = y0 + slope * (x - x0)

    return y

In [4]:
lcs.head(2)

Unnamed: 0,object_id,mjd,passband,flux,flux_err,detected_bool
10,13,59818.274,0,1.962846,1.795587,0
11,13,59819.2541,0,-1.697929,2.433431,0


### 1. Begin

In [5]:
old_gp = lcs.groupby(["object_id", "passband"], group_keys=False)

#### Linear filling

In [6]:
t_fill = np.linspace(mjd_min + 5, mjd_max - 5, 72)
new_lcs = old_gp.apply(
    lambda x: linear_interpolation(t_fill, x.mjd.values, x.flux.values),
)

#### Into dataframes

In [7]:
# %%time
new_lc_dfs = []
for idx in new_lcs.index:
    df = pd.DataFrame({"flux": pd.DataFrame(new_lcs).loc[idx].values[0].copy()})
    df["object_id"] = idx[0]
    df["passband"] = idx[1]
    df["mjd"] = t_fill
    new_lc_dfs.append(df)
new_df = pd.concat(new_lc_dfs)[["object_id", "passband", "mjd", "flux"]]

In [8]:
del lcs, old_gp

#### Add Normalized flux

In [10]:
new_gp = new_df.groupby(["object_id", "passband"], group_keys=False)
new_df["flux_norm"] = new_gp["flux"].apply(
    lambda x: (x - x.min()) / (x.max() - x.min()), 
    # include_groups=False
)

#### Save

In [13]:
new_df.head()

Unnamed: 0,object_id,passband,mjd,flux,flux_norm
0,13,0,59577.0,0.125404,0.562753
1,13,0,59592.521127,0.243606,0.579247
2,13,0,59608.042254,0.361809,0.59574
3,13,0,59623.56338,0.480011,0.612234
4,13,0,59639.084507,0.598214,0.628728


In [14]:
new_df.to_parquet('data/lc1.parquet')

In [12]:
# old_lto_parquet.get_group(key)
# new_lc = new_gp.get_group(key)

# new_lc

In [36]:
# for key in new_lcs.index[:30]:
#     fig = plt.figure()
#     old_lc = old_gp.get_group(key)
#     plt.scatter(old_lc.mjd, old_lc.flux)

#     new_lc = new_gp.get_group(key)
#     plt.scatter(new_lc.mjd, new_lc.flux_norm)
#     plt.title(key)