In [81]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, TensorDataset
import gymnasium as gym

from astropy.io import fits
from datetime import datetime, timezone
import pandas as pd
import json
import fitsio
import time
import pickle
import re

%reload_ext autoreload
%autoreload 2

In [82]:
import survey_ops
from survey_ops.utils import units, geometry, interpolate
from survey_ops.coreRL.offline_dataset import OfflineDELVEDataset
from survey_ops.coreRL.agents import Agent
from survey_ops.algorithms import DDQN, BehaviorCloning
from survey_ops.utils.sys_utils import seed_everything
from survey_ops.coreRL.data_processing import load_raw_data_to_dataframe


In [83]:
from survey_ops.utils import ephemerides
from tqdm import tqdm
from pathlib import Path

In [84]:
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import KMeans
from collections import Counter
from scipy.stats import entropy

In [85]:
SEED = 10
seed_everything(SEED)
torch.set_default_dtype(torch.float32)
device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "cpu"   
)

# Load config and lookup files

In [86]:
results_outdir = '../experiment_results/bc-azel-grid-test/'

In [87]:
with open('../configs/global_config.json', 'r') as f:
    gcfg = json.load(f)

In [88]:
with open(results_outdir + 'config.json', 'r') as f:
    cfg = json.load(f)

In [89]:
cfg['data']['specific_months'] = [12]
cfg['data']['specific_days'] = [1, 2, 3, 4, 5]
cfg['data']['bin_space'] = 'azel'
cfg['data']['additional_bin_features'] = ['ha', 'moon_distance', 'angular_distance_to_pointing', 'airmass']

In [90]:
nside = cfg['data']['nside']

In [97]:
with open(gcfg['paths']['LOOKUP_DIR'] + '/' + gcfg['files']['FIELD2NVISITS'], 'r') as f:
    field2nvisits = json.load(f)
# with open(glob_cfg.get('paths.lookup_dir') + '/' + cfg.get('paths')['FIELD2NAME'], 'r') as f:
#     field2name = json.load(f)
# with open(glob_cfg.get('paths.lookup_dir') + '/' + cfg.get('paths')['FIELD2RADEC'], 'r') as f:
#     field2radec = json.load(f)
with open(f'../data/lookups/nside{nside}_bin2azel.json', 'r') as f:
    bin2azel = json.load(f)
with open(f'../data/lookups/nside{nside}_bin2radec.json', 'r') as f:
    bin2radec = json.load(f)


# Load in data

In [98]:
from survey_ops.coreRL.data_processing import load_raw_data_to_dataframe, drop_rows_in_DECam_data, get_zenith_features

In [99]:
df = load_raw_data_to_dataframe(fits_path=gcfg['paths']['FITS_DIR'] + gcfg['files']['DECFITS'])

In [94]:
df = drop_rows_in_DECam_data(
    df=df,
    specific_years=cfg['data']['specific_years'],
    specific_months=cfg['data']['specific_months'],
    specific_days=cfg['data']['specific_days'],
    specific_filters=None,
    objects_to_remove=["guide", "DES vvds","J0'","gwh","DESGW","Alhambra-8","cosmos","COSMOS hex","TMO","LDS","WD0","DES supernova hex","NGC","ec", "(outlier)"]
)

In [95]:
zenith_df = get_zenith_features(df)


Calculating zenith states: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 39593.81it/s]


In [100]:
d = OfflineDELVEDataset(
    df=df,
    cfg=cfg,
    gcfg=gcfg,
    specific_years=[2016],
    specific_months=None,
    specific_days=None,
    specific_filters=None,
)

Calculating zenith states: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 16244.40it/s]
Calculating sun and moon ra/dec and az/el: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1047/1047 [00:00<00:00, 1232.97it/s]
Calculating bin features for all healpix bins and timestamps: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1047/1047 [00:05<00:00, 199.70it/s]


In [101]:
d._df['night']

0      2016-12-01 00:00:00+00:00
1      2016-12-01 00:00:00+00:00
2      2016-12-01 00:00:00+00:00
3      2016-12-01 00:00:00+00:00
4      2016-12-01 00:00:00+00:00
                  ...           
1042   2016-12-05 00:00:00+00:00
1043   2016-12-05 00:00:00+00:00
1044   2016-12-05 00:00:00+00:00
1045   2016-12-05 00:00:00+00:00
1046   2016-12-05 00:00:00+00:00
Name: night, Length: 1047, dtype: datetime64[ns, UTC]

In [77]:
with open('../experiment_results/bc-azel-test-4years/val_metrics.pkl', 'rb') as f:
    val_metrics = pickle.load(f)

In [80]:
for k, v in val_metrics.items():
    print(k, len(v))

val_loss 174
logp_expert_action 174
action_margin 174
entropy 174
ang_sep 174
unique_bins 174
accuracy 174
epoch 348


In [36]:
# d = fitsio.read('../data/fits/decam-exposures-20251211.fits')
# df = pd.DataFrame(d.astype(d.dtype.newbyteorder('=')))
# df['datetime'] = pd.to_datetime(df['datetime'], utc=True, errors='coerce')
# # df['datetime'] = pd.to_datetime(df['datetime'], utc=True)
# df['night'] = (df['datetime'] - pd.Timedelta(hours=12)).dt.normalize()
# df = df[df['datetime'].dt.year != 1970]
# df = df[df['datetime'].dt.year != 1972]
# df = df[df['datetime'].dt.year != 1973]
# df = df[~df['datetime'].isna()]

# # # Add timestamp col
# utc = pd.to_datetime(df['datetime'], utc=True)
# timestamps = (utc.astype('int64') // 10**9).values
# df['timestamp'] = timestamps.copy()
# df.to_json('../data/fits/decam-exposures-20251211-full.json')

In [37]:
df['night']

0        2012-10-06 00:00:00+00:00
1        2012-10-24 00:00:00+00:00
2        2012-10-24 00:00:00+00:00
3        2012-10-24 00:00:00+00:00
4        2012-10-24 00:00:00+00:00
                    ...           
577139   2025-11-15 00:00:00+00:00
577140   2025-11-15 00:00:00+00:00
577141   2025-11-15 00:00:00+00:00
577142   2025-11-15 00:00:00+00:00
577143   2025-11-16 00:00:00+00:00
Name: night, Length: 576165, dtype: datetime64[ns, UTC]

In [43]:
d = fitsio.read('../data/fits/decam-exposures-20251211.fits')
df = pd.DataFrame(d.astype(d.dtype.newbyteorder('=')))

sel = (df['propid'] == '2012B-0001') & (df['exptime'] > 40) & (df['exptime'] < 100) & (~np.isnan(df['teff']))
df = df[sel].copy()
df['datetime'] = pd.to_datetime(df['datetime'], utc=True)
df['night'] = (df['datetime'] - pd.Timedelta(hours=12)).dt.normalize()
df = df[df['datetime'].dt.year != 1970]

# Add timestamp col
utc = pd.to_datetime(df['datetime'], utc=True)
timestamps = (utc.astype('int64') // 10**9).values
df['timestamp'] = timestamps.copy()

In [44]:
df['night'].dtype

datetime64[ns, UTC]

In [45]:
df['datetime'].dtype

datetime64[ns, UTC]

In [46]:
df['night'].dt

<pandas.core.indexes.accessors.DatetimeProperties object at 0x76ee38d6c0d0>

In [47]:
df['night'].dt.year.unique()

array([2013, 2014, 2015, 2016, 2017, 2018, 2019], dtype=int32)

In [28]:
# fits_path = Path(cfg.get('paths.DFITS')).resolve().parents[1] / 'data' / cfg.get('paths.DFITS')
# json_path = Path(cfg.get('paths.DJSON')).resolve().parents[1] / 'data' / cfg.get('paths.DJSON')

df = load_raw_data_to_dataframe(fits_path='../data/fits/decam-exposures-20251211.fits', json_path='../data/fits/decam-exposures-20251211.json')

# d = OfflineDELVEDataset(
#     df=df,
#     cfg=cfg,
#     gcfg=gcfg,
# )

In [29]:
df['night']

32292     1377907200000
32293     1377907200000
32294     1377907200000
32295     1377907200000
32296     1377907200000
              ...      
290012    1546992000000
290013    1546992000000
290016    1546992000000
290017    1546992000000
290019    1546992000000
Name: night, Length: 88541, dtype: int64

In [None]:
full_df = load_raw_data_to_dataframe(fits_path='../data/fits/decam-exposures-20251211-full.json', json_path=None)

In [50]:
from survey_ops.coreRL.data_processing import get_nautical_twilight

In [55]:
t_sunset = np.zeros(df.groupby('night').ngroups)
t_sunrise = np.zeros(df.groupby('night').ngroups)
midpoint = np.zeros_like(t_sunrise)

In [56]:
for i, t in enumerate(df.groupby('night').head(1)['timestamp'].values):
    t_sunset[i] = get_nautical_twilight(t+100, event_type='set')
    t_sunrise[i] = get_nautical_twilight(t+100, event_type='rise')
    if t_sunrise[i] < 1e9:
        print(i, t)
        print('datetime', df.groupby('night').head(1)['datetime'].values[i])
    midpoint[i] = t_sunset[i] + (t_sunrise[i] - t_sunset[i]) / 2
    

1972 963446833
datetime 2000-07-13T00:07:13.000000000
1973 963533683
datetime 2000-07-14T00:14:43.000000000


In [57]:
len(df)

576165

In [64]:
np.sort((df.groupby('night').head(1)['timestamp'].values - t_sunset) / 60)[:500]

array([-1.54984322e+00, -1.46207607e+00, -1.21793857e+00, -1.20910710e+00,
       -1.15206200e+00, -1.02858857e+00, -1.02105321e+00, -9.84760765e-01,
       -9.61474832e-01, -9.59827201e-01, -9.37485500e-01, -8.90390750e-01,
       -7.61441700e-01, -7.37136666e-01, -7.33257302e-01, -7.07121301e-01,
       -7.04669933e-01, -7.04279351e-01, -6.24767617e-01, -6.13130351e-01,
       -6.11540401e-01, -5.87774118e-01, -5.40898466e-01, -5.14407484e-01,
       -4.66230249e-01, -4.23370683e-01, -4.04194566e-01, -3.82380382e-01,
       -3.80967585e-01, -3.77063068e-01, -3.76669665e-01, -3.41360235e-01,
       -3.40460316e-01, -3.37230984e-01, -3.06740002e-01, -2.99863752e-01,
       -2.75543567e-01, -2.38544718e-01, -2.14874518e-01, -2.14710800e-01,
       -1.70234517e-01, -1.67954135e-01, -1.52918565e-01, -1.23448133e-01,
       -1.19494132e-01, -7.41900484e-02, -6.80438320e-02, -5.67405661e-02,
       -4.72522815e-02, -4.40113147e-02, -3.45629017e-02, -1.87800169e-02,
       -1.29898389e-03,  