In [None]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from lightkurve import search_lightcurve
from astroquery.ipac.nexsci.nasa_exoplanet_archive import NasaExoplanetArchive

#configure the batches
BIN_POINTS = 200
BATCH_SIZE = 40
SAVE_DIR = os.path.join(os.getcwd(), "planet_batches")
os.makedirs(SAVE_DIR, exist_ok=True)

#getting the planet metadata
df = NasaExoplanetArchive.query_criteria(
    table="pscomppars",
    select="pl_name,ra,dec,pl_orbper,pl_tranmid,pl_rade,pl_bmasse",
    where="pl_orbper is not null and (pl_rade is not null or pl_bmasse is not null)"
).to_pandas()

def assign_class(row):
    r = row.get("pl_rade", np.nan)
    if np.isnan(r):
        return None
    #earth-like
    if r < 1.25:
        return 0 
    #super-earth
    elif r < 2.0:
        return 1 
    #neptune-like
    elif r < 6.0:
        return 2 
    #jupiter-like
    else:
        return 3 

def fetch_folded_binned_flux(ra, dec, period, t0=None, bin_points=BIN_POINTS, mission_hint=("TESS","Kepler")):
    try:
        for m in mission_hint:
            sr = search_lightcurve(f"{ra} {dec}", mission=m)
            if len(sr) > 0:
                lc_collection = sr
                break
        else:
            return None
        lc = lc_collection.download_all().stitch()
        lr = lc.normalize().remove_nans()
        #flatten
        try:
            lr = lr.flatten(window_length=401)
        except Exception:
            pass
        folded = lr.fold(period=period, t0=t0) if t0 and np.isfinite(t0) else lr.fold(period=period)
        binned = folded.bin(bin_points)
        flux = binned.flux.value
        if np.any(~np.isfinite(flux)):
            return None
        return flux
    except Exception as e:
        return None

#processing the batches
for start in range(0, len(df), BATCH_SIZE):
    end = min(start + BATCH_SIZE, len(df))
    batch_df = df.iloc[start:end]
    batch_id = start // BATCH_SIZE
    batch_file = os.path.join(SAVE_DIR, f"batch_{batch_id}.npz")

    if os.path.exists(batch_file):
        print(f"Skipping batch {batch_id} (already processed)")
        continue

    Xb, yb = [], []
    for _, row in tqdm(batch_df.iterrows(), total=len(batch_df), desc=f"Batch {batch_id}"):
        flux = fetch_folded_binned_flux(row["ra"], row["dec"], row["pl_orbper"], t0=row.get("pl_tranmid", None))
        lbl = assign_class(row)
        if flux is not None and lbl is not None:
            Xb.append(flux)
            yb.append(lbl)
    if len(Xb) > 0:
        Xb = np.array(Xb)
        yb = np.array(yb)
        np.savez(batch_file, X=Xb, y=yb)
        print(f"Saved batch {batch_id} with {len(Xb)} samples")
    else:
        print(f"No valid data in batch {batch_id}")

#merging all the batches into a final dataset since there's too much to process all at once
batch_files = sorted([f for f in os.listdir(SAVE_DIR) if f.endswith(".npz")])
X_list, y_list = [], []
for fname in batch_files:
    data = np.load(os.path.join(SAVE_DIR, fname))
    X_list.append(data["X"])
    y_list.append(data["y"])
X = np.vstack(X_list)
y = np.concatenate(y_list)
np.savez(os.path.join(SAVE_DIR, "final_dataset.npz"), X=X, y=y)
print("Final dataset saved:", X.shape, y.shape)
