In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, TensorDataset
import gymnasium as gym

from astropy.io import fits
from datetime import datetime, timezone
import json
import fitsio
import pandas as pd
import time
import pickle
import re

%reload_ext autoreload
%autoreload 2

In [2]:
import survey_ops
from survey_ops.utils import units, geometry, interpolate
from survey_ops.coreRL.offline_dataset import OfflineDECamDataset
from survey_ops.coreRL.agents import Agent
from survey_ops.algorithms import DDQN, BehaviorCloning
from survey_ops.utils.sys_utils import seed_everything
from survey_ops.coreRL.data_loading import load_raw_data_to_dataframe
from survey_ops.utils.config import Config

In [3]:
from survey_ops.utils import ephemerides
from tqdm import tqdm
from pathlib import Path

In [4]:
SEED = 10
seed_everything(SEED)
torch.set_default_dtype(torch.float32)
device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "cpu"   
)

In [5]:
from survey_ops.utils.config import Config

In [6]:
cfg = Config('../configs/template_config.json')

In [7]:
cfg.set('experiment.data.specific_months', [1])
cfg.set('experiment.data.specific_days', np.arange(10))


In [8]:
cfg.set('experiment.data.include_bin_features', True)
cfg.set('experiment.data.default_bin_features', True)
cfg.set('experiment.data.additional_bin_features', ['num_visits_hist'])

In [9]:
cfg.get('feature_names'), cfg.get('experiment.data')

({'TIME_DEPENDENT_FEATURE_NAMES': ['az',
   'el',
   'ha',
   'time_fraction_since_start'],
  'CYCLICAL_FEATURE_NAMES': ['ra', 'az', 'ha'],
  'MAX_NORM_FEATURE_NAMES': ['el', 'dec'],
  'DEFAULT_BIN_FEATURE_NAMES': ['ha', 'airmass', 'ang_dist_to_moon'],
  'DEFAULT_PNTG_FEATURE_NAMES': ['ra',
   'dec',
   'az',
   'el',
   'airmass',
   'ha',
   'sun_ra',
   'sun_dec',
   'sun_az',
   'sun_el',
   'moon_ra',
   'moon_dec',
   'moon_az',
   'moon_el',
   'time_fraction_since_start']},
 {'remove_large_time_diffs': True,
  'do_cyclical_norm': True,
  'do_max_norm': True,
  'do_inverse_norm': True,
  'include_bin_features': True,
  'include_default_features': True,
  'binning_method': 'healpix',
  'nside': 16,
  'bin_space': 'radec',
  'specific_years': [2018],
  'specific_months': [1],
  'specific_days': array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
  'specific_filters': None,
  'num_bins_1d': None,
  'additional_pointing_features': [],
  'additional_bin_features': ['num_visits_hist'],
  'objects_

In [21]:
with open('../configs/global_config.json', 'r') as f:
    glob_cfg = json.load(f)

In [29]:
a = np.array(['one', 'two', 'three'])

In [33]:
np.where(a == 'one')[0]

array([0])

In [23]:
glob_cfg

{'features': {'TIME_DEPENDENT_FEATURE_NAMES': ['az',
   'el',
   'ha',
   'time_fraction_since_start'],
  'CYCLICAL_FEATURE_NAMES': ['ra', 'az', 'ha'],
  'MAX_NORM_FEATURE_NAMES': ['el', 'dec'],
  'DEFAULT_BIN_FEATURE_NAMES': ['ha', 'airmass', 'ang_dist_to_moon'],
  'DEFAULT_PNTG_FEATURE_NAMES': ['ra',
   'dec',
   'az',
   'el',
   'airmass',
   'ha',
   'sun_ra',
   'sun_dec',
   'sun_az',
   'sun_el',
   'moon_ra',
   'moon_dec',
   'moon_az',
   'moon_el',
   'time_fraction_since_start']},
 'files': {'DECFITS': 'decam-exposures-20251211.fits',
  'DECJSON': 'decam-exposures-20251211.json',
  'FIELD2RADEC': 'field2radec.json',
  'FIELD2NAME': 'field2name.json',
  'FIELD2NVISITS': 'field2nvisits.json'},
 'paths': {'LOOKUP_DIR': '/home/hurra/Projects/survey-ops/data/lookups',
  'DATA_DIR': '/home/hurra/Projects/survey-ops/data/',
  'SURVEY_OPS': '/home/hurra/Projects/survey-ops/'}}

In [10]:
with open(cfg.get('paths.lookup_dir') + '/' + cfg.get('paths')['FIELD2NVISITS'], 'r') as f:
    field2nvisits = json.load(f)
with open(cfg.get('paths.lookup_dir') + '/' + cfg.get('paths')['FIELD2NAME'], 'r') as f:
    field2name = json.load(f)
with open(cfg.get('paths.lookup_dir') + '/' + cfg.get('paths')['FIELD2RADEC'], 'r') as f:
    field2radec = json.load(f)

In [11]:
fits_path = Path(cfg.get('paths.DFITS')).resolve().parents[1] / 'data' / cfg.get('paths.DFITS')
json_path = Path(cfg.get('paths.DJSON')).resolve().parents[1] / 'data' / cfg.get('paths.DJSON')

df = load_raw_data_to_dataframe(fits_path, json_path)

In [15]:
df = df.sort_values(by='timestamp')

In [16]:
df

Unnamed: 0,expnum,ra,dec,exptime,filter,propid,program,object,teff,fwhm,datetime,az,zd,ha,airmass,qc_fwhm,qc_cloud,qc_sky,qc_teff,timestamp
59348,601499,63.980504,-27.300694,90,z,2012B-0001,survey,DES survey hex 645-278 tiling 6,0.59,0.95,1970-01-03 02:35:42,262.5920,39.63,45.419583,1.30,0.95,0.10,0.38,0.59,182142
59349,601500,65.681625,-28.315138,90,z,2012B-0001,survey,DES survey hex 662-288 tiling 6,0.62,0.94,1970-01-03 02:37:41,261.4612,38.33,44.196458,1.27,0.94,0.12,0.36,0.62,182261
59350,601501,62.035621,-28.315110,90,z,2012B-0001,survey,DES survey hex 626-288 tiling 6,0.56,0.97,1970-01-03 02:39:45,260.0598,41.89,48.367417,1.34,0.97,0.09,0.40,0.56,182385
59351,601502,63.730875,-29.329527,90,z,2012B-0001,survey,DES survey hex 643-298 tiling 6,0.69,0.89,1970-01-03 02:41:42,258.9518,40.63,47.167250,1.32,0.89,0.09,0.36,0.69,182502
59352,601503,65.456129,-30.343999,90,z,2012B-0001,survey,DES survey hex 660-308 tiling 6,0.67,0.91,1970-01-03 02:43:40,257.7724,39.35,45.928667,1.29,0.91,0.10,0.34,0.67,182620
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
88613,810366,58.930033,-42.303110,90,r,2012B-0001,des_celeb,NGC 1487,0.55,1.35,2019-01-10 04:27:33,238.7702,38.66,46.306458,1.28,1.35,0.09,-0.15,0.55,1547094453
88614,810367,58.929917,-42.303138,90,i,2012B-0001,des_celeb,NGC 1487,0.51,1.16,2019-01-10 04:29:31,238.7969,39.02,46.803625,1.29,1.16,0.08,0.19,0.51,1547094571
88615,810370,53.414579,-36.057110,45,i,2012B-0001,des_celeb,NGC 1365,0.30,1.41,2019-01-10 04:35:39,247.6877,44.83,53.840917,1.41,1.41,0.08,0.34,0.30,1547094939
88616,810371,53.389708,-36.223610,45,g,2012B-0001,des_celeb,NGC 1365,0.46,1.59,2019-01-10 04:36:53,247.3923,45.09,54.204458,1.42,1.59,0.12,-0.18,0.46,1547095013


In [15]:
train_dataset = OfflineDECamDataset(
    df,
    cfg=cfg
)

Calculating sun and moon ra/dec and az/el: 100%|█| 403/403 [00:00<00:00, 1268.85
Calculating zenith states: 100%|████████████████| 6/6 [00:00<00:00, 1134.16it/s]
Calculating bin features for all healpix bins and timestamps: 100%|█| 409/409 [0
Normalizing bin features: 100%|████████| 18432/18432 [00:00<00:00, 94109.68it/s]


In [None]:
nvisits = list(field2nvisits.values())

In [None]:
max_visited_field = str(np.argmax(nvisits))
max_visited_field

In [160]:
max_obj_df = raw_data_df[raw_data_df['object'].str.contains(field2name[max_visited_field])]

In [161]:
raw_data_df.keys()

Index(['expnum', 'ra', 'dec', 'exptime', 'filter', 'propid', 'program',
       'object', 'teff', 'fwhm', 'datetime', 'az', 'zd', 'ha', 'airmass',
       'qc_fwhm', 'qc_cloud', 'qc_sky', 'qc_teff'],
      dtype='object')

# radec

In [174]:
bin2nvisits = {i: 0 for i in range(dataset.num_actions)}
for i, field_id in enumerate(dataset._df['field_id'].values):
    if field_id == -1:
        bin2nvisits 