In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, TensorDataset
import gymnasium as gym

from astropy.io import fits
from datetime import datetime, timezone
import json
import fitsio
import pandas as pd
import time
import pickle
import re

%reload_ext autoreload
%autoreload 2

In [2]:
import survey_ops
from survey_ops.utils import units, geometry, interpolate
from survey_ops.src.offline_dataset import OfflineDECamDataset
from survey_ops.src.agents import Agent
from survey_ops.src.algorithms import DDQN, BehaviorCloning
from survey_ops.utils.pytorch_utils import seed_everything


In [3]:
SEED = 10
seed_everything(SEED)
torch.set_default_dtype(torch.float32)
device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "cpu"   
)

In [9]:
from survey_ops.utils.config import Config

In [131]:
cfg = Config('../experiment_results/default_config.json')

In [154]:
cfg.set('experiment.data.specific_months', [1, 2, 3])

In [155]:
with open(cfg.get('paths')['FIELD2NVISITS'], 'r') as f:
    field2nvisits = json.load(f)
with open(cfg.get('paths')['FIELD2NAME'], 'r') as f:
    field2name = json.load(f)
with open(cfg.get('paths')['FIELD2RADEC'], 'r') as f:
    field2radec = json.load(f)

In [156]:
from survey_ops.utils.script_utils import load_raw_data_to_dataframe

In [157]:
raw_data_df = load_raw_data_to_dataframe(cfg.get('paths.DFITS'), cfg.get('paths.DJSON'))


In [158]:
nvisits = list(field2nvisits.values())

In [159]:
max_visited_field = str(np.argmax(nvisits))
max_visited_field

'6381'

In [160]:
max_obj_df = raw_data_df[raw_data_df['object'].str.contains(field2name[max_visited_field])]

In [161]:
raw_data_df.keys()

Index(['expnum', 'ra', 'dec', 'exptime', 'filter', 'propid', 'program',
       'object', 'teff', 'fwhm', 'datetime', 'az', 'zd', 'ha', 'airmass',
       'qc_fwhm', 'qc_cloud', 'qc_sky', 'qc_teff'],
      dtype='object')

In [162]:
cfg.get('experiment.data')

{'remove_large_time_diffs': True,
 'do_cyclical_norm': True,
 'do_max_norm': True,
 'do_inverse_norm': True,
 'include_bin_features': False,
 'include_default_features': True,
 'binning_method': 'healpix',
 'nside': 16,
 'bin_space': 'radec',
 'specific_years': [2018],
 'specific_months': [1, 2, 3],
 'specific_days': None,
 'specific_filters': None,
 'num_bins_1d': None,
 'additional_pointing_features': [],
 'additional_bin_features': [],
 'objects_to_remove': ['guide',
  'DES vvds',
  "J0'",
  'gwh',
  'DESGW',
  'Alhambra-8',
  'cosmos',
  'COSMOS hex',
  'TMO',
  'LDS',
  'WD0',
  'DES supernova hex',
  'NGC',
  'ec']}

In [163]:
dataset = OfflineDECamDataset(
    df=raw_data_df,
    cfg=cfg
    )

Calculating sun and moon ra/dec and az/el: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3605/3605 [00:07<00:00, 494.80it/s]
Calculating zenith states: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 42/42 [00:00<00:00, 396.09it/s]


In [164]:
dataset._df.keys()

Index(['expnum', 'ra', 'dec', 'exptime', 'filter', 'propid', 'program',
       'object', 'teff', 'fwhm', 'datetime', 'az', 'zd', 'ha', 'airmass',
       'qc_fwhm', 'qc_cloud', 'qc_sky', 'qc_teff', 'night', 'timestamp',
       'sun_ra', 'sun_dec', 'sun_az', 'sun_el', 'moon_ra', 'moon_dec',
       'moon_az', 'moon_el', 'time_fraction_since_start', 'ra_cos', 'ra_sin',
       'az_cos', 'az_sin', 'ha_cos', 'ha_sin', 'sun_ra_cos', 'sun_ra_sin',
       'sun_az_cos', 'sun_az_sin', 'moon_ra_cos', 'moon_ra_sin', 'moon_az_cos',
       'moon_az_sin', 'el', 'bin', 'field_id'],
      dtype='object')

In [167]:
cfg.get('experiment.data')

{'remove_large_time_diffs': True,
 'do_cyclical_norm': True,
 'do_max_norm': True,
 'do_inverse_norm': True,
 'include_bin_features': False,
 'include_default_features': True,
 'binning_method': 'healpix',
 'nside': 16,
 'bin_space': 'radec',
 'specific_years': [2018],
 'specific_months': [1, 2, 3],
 'specific_days': None,
 'specific_filters': None,
 'num_bins_1d': None,
 'additional_pointing_features': [],
 'additional_bin_features': [],
 'objects_to_remove': ['guide',
  'DES vvds',
  "J0'",
  'gwh',
  'DESGW',
  'Alhambra-8',
  'cosmos',
  'COSMOS hex',
  'TMO',
  'LDS',
  'WD0',
  'DES supernova hex',
  'NGC',
  'ec']}

# radec

In [176]:
a = ['awlek', 'aweif']

In [177]:
'awlek' in a

True

In [174]:
bin2nvisits = {i: 0 for i in range(dataset.num_actions)}
for i, field_id in enumerate(dataset._df['field_id'].values):
    if field_id == -1:
        bin2nvisits 