In [1]:
"""
Train hierarchical RNN (minLSTM/minGRU) autoencoder on light curve time series.

Uses parallelizable RNN variants from "Were RNNs All We Needed?" (arXiv:2410.01201)
Supports training with block masking and time-aware positional encoding.
"""

import numpy as np
import argparse
import sys
from pathlib import Path
from functools import partial
from tqdm import tqdm
import h5py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt
import pickle


# Add project root to path
sys.path.append('/Users/philvanlane/Documents/lc_ae/')
from src.lcgen.models.rnn import HierarchicalRNN, RNNConfig
from src.lcgen.data.masking import dynamic_block_mask


### One time formattings to be consistent with code

In [2]:
with open("data/real_lightcurves/pk_star_sector_lc_cf.pickle", "rb") as file:
    lc_by_sector = pickle.load(file)
with open("data/real_lightcurves/pk_star_day_lc_cf_norm.pickle", "rb") as file:
    lc_by_day = pickle.load(file)
with open("data/real_lightcurves/star_sector_lc_formatted.pickle", "rb") as file:
    lc_by_sector_formatted = pickle.load(file)

In [10]:
lc_by_sector['405461319_42'].keys()

dict_keys(['TIC_ID', 'sector', 'time', 'time_adj', 'flux', 'flux_err', 'flux_norm', 'flux_mean', 'asinh_mean', 'norm_asinh_mean', 'time_norm', 'flux_norm_standard', 'flux_err_norm_standard', 'flux_norm_absTmag', 'flux_err_norm_absTmag'])

In [9]:
list(lc_by_sector.keys())

['405461319_42',
 '405461319_43',
 '177990792_18',
 '177990792_58',
 '177990792_85',
 '456945485_43',
 '456945485_44',
 '456945485_70',
 '456945485_71',
 '118677684_43',
 '118677684_44',
 '118681038_43',
 '118681038_44',
 '118681038_70',
 '118681038_71',
 '456946694_43',
 '456946694_44',
 '456946694_70',
 '456946694_71',
 '118521932_70',
 '118521932_71',
 '118521871_70',
 '118521871_71',
 '61230756_70',
 '61230756_71',
 '118521730_70',
 '118521730_71',
 '118555907_70',
 '118555907_71',
 '118555867_70',
 '118555867_71',
 '118520413_43',
 '118520413_44',
 '118520413_70',
 '118520413_71',
 '17336696_43',
 '17336696_44',
 '17336696_71',
 '17455866_70',
 '17455866_71',
 '17339419_70',
 '17339419_71',
 '17339381_70',
 '17339381_71',
 '18310799_44',
 '18310799_70',
 '18310799_71',
 '397285184_70',
 '397285184_71',
 '118446665_43',
 '118446665_44',
 '268324394_43',
 '268324394_44',
 '268218180_43',
 '268218180_44',
 '397357734_43',
 '397357734_44',
 '238408266_70',
 '238408266_71',
 '118678954

#### Lightcurves by star and sector

In [None]:
# Use sample light curve and sector
sample = lc_by_sector['405461319_42']

# Pre-process time
time = sample['time']
sector_time = (time - time[0])

# Pre-process flux
flux = np.array(sample['flux'].data)
mask = (np.isfinite(flux) & np.isfinite(time))
flux = flux[mask]
sector_time = sector_time[mask]
flux_norm = (flux - np.nanmean(flux)) / np.nanstd(flux)

# Pre-process flux error
flux_err = np.array(sample['flux_err'].data)
med_flux_error = np.nanmedian(flux_err)
flux_err = flux_err[mask]
flux_err = np.nan_to_num(flux_err, nan=med_flux_error, posinf=med_flux_error, neginf=med_flux_error)


# Metadata
metadata = {
    'tic': sample['TIC_ID'],
    'sector': sample['sector'],
    'duration': sector_time[-1] - sector_time[0],
    'med_flux_error': med_flux_error,
    'n_points': len(flux),
    'mean_flux': np.mean(flux),
    'std_flux': np.std(flux),
}

In [None]:
times = []
fluxes = []
flux_errs = []
metadatas = []

for i,k in enumerate(lc_by_sector.keys()):
    if i % 100 == 0:
        print(f"Processing light curve {i}")
    # Use sample light curve and sector
    sample = lc_by_sector[k]

    # Pre-process time
    time = sample['time']
    sector_time = (time - time[0])

    # Pre-process flux
    flux = np.array(sample['flux'].data)
    mask = (np.isfinite(flux) & np.isfinite(time))
    flux = flux[mask]
    sector_time = sector_time[mask]
    flux_norm = (flux - np.nanmean(flux)) / np.nanstd(flux)

    # Pre-process flux error
    flux_err = np.array(sample['flux_err'].data)
    med_flux_error = np.nanmedian(flux_err)
    flux_err = flux_err[mask]
    flux_err = np.nan_to_num(flux_err, nan=med_flux_error, posinf=med_flux_error, neginf=med_flux_error)


    # Metadata
    metadata = {
        'tic': sample['TIC_ID'],
        'sector': sample['sector'],
        'duration': sector_time[-1] - sector_time[0],
        'med_flux_error': med_flux_error,
        'n_points': len(flux),
        'mean_flux': np.nanmean(flux),
        'std_flux': np.nanstd(flux),
    }
    times.append(sector_time)
    fluxes.append(flux_norm)
    flux_errs.append(flux_err)
    metadatas.append(metadata)

In [None]:
dict = {
    'time': times,
    'flux': fluxes,
    'flux_err': flux_errs,
    'metadatas': metadatas,
}
with open("data/real_lightcurves/star_sector_lc_formatted.pickle", "wb") as file:
    pickle.dump(dict, file)

#### Light curves by star and day

In [11]:
list(lc_by_day.keys())[0]

'405461319_2447'

In [None]:
sample = lc_by_sector['405461319_42']

# Get all days in the sample light curve
days = np.unique(np.round(sample['time'])).astype(int)

# Pre-processing
time = sample['time']
flux = np.array(sample['flux'].data)
mask = (np.isfinite(flux) & np.isfinite(time))
flux = flux[mask]
time = time[mask]

# Pre-process flux error
flux_err = np.array(sample['flux_err'].data)
flux_err = flux_err[mask]
med_flux_error = np.nanmedian(flux_err)
flux_err = np.nan_to_num(flux_err, nan=med_flux_error, posinf=med_flux_error, neginf=med_flux_error)

for d in days:
    day_mask = (np.floor(time).astype(int) == d)
    day_time = time[day_mask]
    day_time = (day_time - day_time[0])
    day_flux = flux[day_mask]
    day_flux_err = flux_err[day_mask]
    day_flux_norm = (day_flux - np.nanmean(day_flux)) / np.nanstd(day_flux)

    # Metadata
    metadata = {
        'tic': sample['TIC_ID'],
        'sector': sample['sector'],
        'day': d,
        'duration': day_time[-1] - day_time[0],
        'med_sector_flux_error': med_flux_error,
        'n_points': len(day_flux_norm),
        'mean_flux': np.mean(day_flux),
        'std_flux': np.std(day_flux),
    }

In [16]:
times = []
raw_fluxes = []
raw_flux_errs = []
fluxes = []
flux_errs = []
metadatas = []

for i,k in enumerate(lc_by_sector.keys()):
    if i % 100 == 0:
        print(f"Processing light curve {i}")
    # Use sample light curve and sector
    sample = lc_by_sector[k]

    # Pre-processing
    time = sample['time']
    flux = np.array(sample['flux'].data)
    mask = (np.isfinite(flux) & np.isfinite(time))
    flux = flux[mask]
    time = time[mask]

    # Pre-process flux error
    flux_err = np.array(sample['flux_err'].data)
    flux_err = flux_err[mask]
    med_flux_error = np.nanmedian(flux_err)
    flux_err = np.nan_to_num(flux_err, nan=med_flux_error, posinf=med_flux_error, neginf=med_flux_error)

    # Get all days in the sample light curve
    days = np.unique(np.floor(time)).astype(int)

    for d in days:
        day_mask = (np.floor(time).astype(int) == d)
        day_time = time[day_mask]
        day_time = (day_time - day_time[0])
        day_flux = flux[day_mask]
        day_flux_err = flux_err[day_mask]
        day_flux_norm = (day_flux - np.nanmean(day_flux)) / np.nanstd(day_flux)
        day_flux_err_norm = day_flux_err / np.nanstd(day_flux)

        # Metadata
        metadata = {
            'tic': sample['TIC_ID'],
            'sector': sample['sector'],
            'day': d,
            'duration': day_time[-1] - day_time[0],
            'med_sector_flux_error': med_flux_error,
            'n_points': len(day_flux_norm),
            'mean_flux': np.mean(day_flux),
            'std_flux': np.std(day_flux),
        }
        times.append(day_time)
        fluxes.append(day_flux_norm)
        raw_fluxes.append(day_flux)
        raw_flux_errs.append(day_flux_err)
        flux_errs.append(day_flux_err)
        metadatas.append(metadata)

Processing light curve 0
Processing light curve 100


  day_flux_norm = (day_flux - np.nanmean(day_flux)) / np.nanstd(day_flux)
  day_flux_err_norm = day_flux_err / np.nanstd(day_flux)


Processing light curve 200
Processing light curve 300
Processing light curve 400
Processing light curve 500
Processing light curve 600
Processing light curve 700
Processing light curve 800
Processing light curve 900
Processing light curve 1000
Processing light curve 1100
Processing light curve 1200
Processing light curve 1300


  day_flux_norm = (day_flux - np.nanmean(day_flux)) / np.nanstd(day_flux)
  day_flux_err_norm = day_flux_err / np.nanstd(day_flux)


Processing light curve 1400
Processing light curve 1500
Processing light curve 1600
Processing light curve 1700
Processing light curve 1800
Processing light curve 1900
Processing light curve 2000
Processing light curve 2100
Processing light curve 2200
Processing light curve 2300
Processing light curve 2400
Processing light curve 2500
Processing light curve 2600
Processing light curve 2700
Processing light curve 2800
Processing light curve 2900
Processing light curve 3000
Processing light curve 3100
Processing light curve 3200
Processing light curve 3300
Processing light curve 3400
Processing light curve 3500
Processing light curve 3600
Processing light curve 3700
Processing light curve 3800
Processing light curve 3900
Processing light curve 4000
Processing light curve 4100
Processing light curve 4200
Processing light curve 4300
Processing light curve 4400
Processing light curve 4500
Processing light curve 4600
Processing light curve 4700
Processing light curve 4800
Processing light cur

In [20]:
dict['metadatas'][0]

{'tic': 405461319,
 'sector': 42,
 'day': 2447,
 'duration': 0.30836485772033484,
 'med_sector_flux_error': 11.362072,
 'n_points': 222,
 'mean_flux': 5152.1094,
 'std_flux': 12.59021}

In [17]:
dict = {
    'time': times,
    'flux': fluxes,
    'flux_raw': raw_fluxes,
    'flux_err': flux_errs,
    'flux_err_raw': raw_flux_errs,
    'metadatas': metadatas,
}
with open("data/real_lightcurves/star_day_lc_formatted_withraw.pickle", "wb") as file:
    pickle.dump(dict, file)

### One time formatting of ps/acf/fstat to match code

In [1]:
with open("data/real_lightcurves/star_sector_lc_formatted.pickle", "rb") as file:
    data = pickle.load(file)
with open("pk_star_sector_f_p_fs.pickle", "rb") as file:
    data_ps = pickle.load(file)

NameError: name 'pickle' is not defined

In [79]:
len(dict['fluxes'][19])

15150

In [66]:
flux_err.shape

(13081,)

In [48]:
sample.keys()



dict_keys(['TIC_ID', 'sector', 'time', 'time_adj', 'flux', 'flux_err', 'flux_norm', 'flux_mean', 'asinh_mean', 'norm_asinh_mean', 'time_norm', 'flux_norm_standard', 'flux_err_norm_standard', 'flux_norm_absTmag', 'flux_err_norm_absTmag'])

In [None]:
metadata = {
    'tic': sample['TIC_ID'],
    'sector': sample['sector'],

405461319

In [12]:
lc['405461319_42'].keys()

dict_keys(['TIC_ID', 'sector', 'time', 'time_adj', 'flux', 'flux_err', 'flux_norm', 'flux_mean', 'asinh_mean', 'norm_asinh_mean', 'time_norm', 'flux_norm_standard', 'flux_err_norm_standard', 'flux_norm_absTmag', 'flux_err_norm_absTmag'])

In [2]:
with open('data/mock_lightcurves/mock_lightcurves.pkl', 'rb') as f:
    mock_lc = pickle.load(f)

In [67]:
mock_lc.keys()

dict_keys(['times', 'fluxes', 'flux_errs', 'metadatas'])

In [55]:
mock_lc['times'][0]

array([0.00000000e+00, 3.25398265e-01, 6.50796530e-01, 9.76194795e-01,
       1.30159306e+00, 1.62699132e+00, 1.95238959e+00, 2.60318612e+00,
       3.25398265e+00, 3.57938091e+00, 3.90477918e+00, 4.23017744e+00,
       4.55557571e+00, 4.88097397e+00, 5.20637224e+00, 5.85716877e+00,
       6.18256703e+00, 6.50796530e+00, 6.83336356e+00, 7.48416009e+00,
       8.13495662e+00, 8.46035489e+00, 8.78575315e+00, 9.43654968e+00,
       9.76194795e+00, 1.00873462e+01, 1.04127445e+01, 1.07381427e+01,
       1.10635410e+01, 1.13889393e+01, 1.17143375e+01, 1.20397358e+01,
       1.23651341e+01, 1.26905323e+01, 1.30159306e+01, 1.33413289e+01,
       1.36667271e+01, 1.39921254e+01, 1.43175237e+01, 1.46429219e+01,
       1.49683202e+01, 1.52937185e+01, 1.59445150e+01, 1.62699132e+01,
       1.65953115e+01, 1.69207098e+01, 1.75715063e+01, 1.78969046e+01,
       1.82223028e+01, 1.85477011e+01, 1.91984976e+01, 1.95238959e+01,
       1.98492942e+01, 2.01746924e+01, 2.05000907e+01, 2.08254890e+01,
      

In [19]:
with open('data/real_lightcurves/star_day_timeseries.pickle', 'rb') as f:
    lc = pickle.load(f)