In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import time
import pickle

import torch
import torch.nn.functional as F
import gymnasium as gym


%reload_ext autoreload
%autoreload 2

In [2]:
import survey_ops
from survey_ops.src.environments import ToyEnv  # to silence linting error
from survey_ops.utils import units, geometry, interpolate
from survey_ops.src.offline_dataset import TelescopeDatasetv0
from survey_ops.src.agents import Agent
from survey_ops.src.algorithms import DDQN, BehaviorCloning
from survey_ops.utils.pytorch_utils import seed_everything


# Load Data

In [3]:
import json
with open("../data/2013-09-15_gband_fields.json") as f:
    old_id2pos = json.load(f)

In [4]:
id2pos = {}
for key in old_id2pos.keys():
    id2pos[int(key)] = old_id2pos[key]

In [None]:
import pandas as pd
schedule = pd.read_csv('../data/2013-09-15_gband_schedule.csv', dtype={'next_field':'Int64'})

In [None]:
schedule_old = schedule.copy()

In [None]:
schedule_old

# Visualize schedule

In [None]:

radec = np.array([id2pos[field_id] for field_id in schedule.field_id.values])
ra_shifted = np.where(radec[:, 0] > 180, radec[:, 0] - 360, radec[:, 0])
c = plt.scatter(ra_shifted, radec[:, 1], c = np.arange(len(ra_shifted)))
plt.colorbar(c)

# Configure dataset, model, and training setup

In [None]:
SEED = 10
train_size = 1


seed_everything(SEED)

torch.set_default_dtype(torch.float32)

device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "cpu"   
)

In [None]:
dataset = TelescopeDatasetv0(schedule, id2pos, normalize_obs=True)

In [None]:
"""DESIRED Algorithm and Train setup"""

alg_config = {
    'obs_dim': dataset.obs_dim,
    'num_actions': dataset.num_actions,
    'hidden_dim': 64,
    'device': device,
    'lr': 1e-3,
    'loss_fxn': None, #torch.nn.cross_entropy,
    # 'use_double': True,
}
alg = BehaviorCloning(**alg_config)

fit_config = {
    'num_epochs' : 3000,
    'batch_size': 32,
    # 'learning_start': 
    }

env_name = 'TelescopeEnv-v0'

def create_exp_name(alg, env_name, dataset, fit_config, alg_config):

    exp_name = f"{alg.name}"
    exp_name += f"-{env_name}"
    exp_name += f"-n_unique={dataset._nfields}"
    # config_prefix += f"-use_double={alg_config['use_double']}"
    exp_name += f"-num_epochs={fit_config['num_epochs']}"
    exp_name += f"-batch_size={fit_config['batch_size']}"
    exp_name += f"-lr={alg_config['lr']}"

exp_name = create_exp_name(alg, env_name, dataset, fit_config, alg_config)

outdir = f'../results/{exp_name}/'
if not os.path.exists(outdir):
    os.makedirs(outdir)
fig_outdir = outdir + 'figures/'
if not os.path.exists(fig_outdir):
    os.makedirs(fig_outdir)


agent_config = {
    'algorithm': alg,
    'normalize_obs': dataset.normalize_obs,
    'outdir': outdir
    }

agent = Agent(**agent_config)



In [None]:
env_name = 'TelescopeEnv-v0'

gym.register(
    id=f"gymnasium_env/{env_name}",
    entry_point=ToyEnv,
    max_episode_steps=300,  # Prevent infinite episodes. Here just set to 300 even though episode will terminate when stepping to last element of sequence
)

# for eval step only
env_config = {
    'id': f"gymnasium_env/{env_name}",
    'dataset': dataset
}


env = gym.make(**env_config)
# Create multiple environments for parallel training
# vec_env = gym.make_vec("gymnasium_env/SimpleTel-v0", num_envs=5, vectorization_mode='sync', Nf=Nf, target_sequence=true_sequence, nv_max=nv_max)

from gymnasium.utils.env_checker import check_env

# This will catch many common issues
try:
    check_env(env.unwrapped)
    print("Environment passes all checks!")
except Exception as e:
    print(f"Environment has issues: {e}")

# Train

In [None]:
start_time = time.time()
"""Train"""
agent.fit(
    dataset=dataset,
    **fit_config
    )
end_time = time.time()
train_time = end_time - start_time
print(train_time)

In [None]:
with open(outdir + 'train_metrics.pkl', 'rb') as handle:
    train_metrics = pickle.load(handle)

In [None]:
fig, axs = plt.subplots(2, sharex=True, figsize=(5, 5))
axs[0].plot(train_metrics['loss_history'])
axs[0].hlines(y=0, xmin=0, xmax=5000, color='red', linestyle='--')
axs[0].set_ylabel('Loss', fontsize=12)
axs[1].plot(np.linspace(0, len(train_metrics['loss_history']), len(train_metrics['test_acc_history'])), train_metrics['test_acc_history'])
axs[1].hlines(y=1, xmin=0, xmax=5000, color='red', linestyle='--')
axs[1].set_xlabel('Train step', fontsize=12)
axs[1].set_ylabel('Accuracy', fontsize=12)
axs[1].set_xlabel('Train step', fontsize=12)
fig.tight_layout()
fig.savefig(fig_outdir + 'train_history.png')

# Evaluate

In [None]:
agent.evaluate(env=env, num_episodes=1)
with open(outdir + 'eval_metrics.pkl', 'rb') as handle:
    eval_metrics = pickle.load(handle)

In [None]:
target_sequence = dataset._schedule_field_ids[0]
eval_sequence = eval_metrics['observations']['ep-0'][:, 0]

fig, axs = plt.subplots(2)
axs[0].plot(eval_sequence, marker='o', label='pred')
axs[0].plot(target_sequence, marker='o', linestyle='dashed', label='true')
axs[0].legend()
axs[0].set_xlabel('obs index')
axs[0].set_ylabel('field id')

axs[1].plot(eval_sequence - target_sequence, marker='o')
axs[1].legend()
axs[1].set_xlabel('obs index')
axs[1].set_ylabel('pred - true')


fig.savefig(outdir + 'learned_sequence.png')