In [1]:
import numpy as np
import matplotlib.pyplot as plt
import random

import gymnasium as gym

from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy

2025-10-31 23:20:31.924191: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-10-31 23:20:31.929550: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761970831.935765   31438 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761970831.937955   31438 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1761970831.943298   31438 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

# Simulate data

In [2]:
Nf = 6
T = 100
nv_max = 2
random.seed(2)
true_sequence = np.array([np.full(nv_max, i) for i in range(Nf)]).flatten()
true_sequence = true_sequence.tolist()
random.shuffle(true_sequence)
true_sequence

[4, 5, 1, 2, 3, 3, 4, 1, 2, 5, 0, 0]

# Set up environment

In [3]:
from gymnasium.spaces import Dict, Box, Discrete

from typing import Optional
import numpy as np


class SimpleTelEnv(gym.Env):
    def __init__(self, Nf, target_sequence, nv_max):
        self.Nf = Nf # number of fields
        # self.T = 1 # 28800sec = 8hrs
        self.nv_max = nv_max
        self.target_sequence = target_sequence
        # "Teff_meas": Box(0, 1, shape=(Nf,), dtype=np.float32),

        #TODO
        # Initialize positions - will be set in reset()
        self._field_id = -1
        # self._t = -1
        self._nvisits = np.full(shape=(Nf,), fill_value=-1, dtype=np.int32)
        # self._Teff_pred = np.full(shape=(Nf,), fill_value=-1, dtype=np.float32)
        self._index = -1
        self._sequence = []

        # self._possible_actions = [i for i in range(Nf)]

        #TODO
        # Define what the agent can observe
        # Dict space gives us structured, human-readable observations
        self.observation_space = gym.spaces.Dict(
            {
                # "t": Box(0, T, shape=None, dtype=np.float32),
                "field_id": Discrete(n=Nf, start=0),
                "nvisits": Box(0, 4, shape=(Nf,), dtype=np.int32),
                "index": Discrete(n=len(self.target_sequence), start=0)
                # "Teff_pred": Box(0, 1, shape=(Nf,), dtype=np.float32),
                    #filter
            }
        )
        
        # # Map action numbers to field
        self._action_to_field_id = {i:i for i in range(Nf)}

        self.action_space = gym.spaces.Discrete(self.Nf)

        # remove fields that have 
        
    def _get_obs(self):
        """Convert internal state to observation format.
    
        Returns:
            dict: Observation with agent and target positions
        """
        return {
            # "t": self._t,
            "field_id": self._field_id,
            "nvisits": self._nvisits,
            "index": self._index
            # "Teff_pred": self._Teff_pred,
        }

    def _get_info(self, chosen_field_id=None, correct=None):
        """Compute auxiliary information for debugging.

        Returns:
            
        """
        return {'chosen_field_id': chosen_field_id, 'correct': correct}

    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
        """Start a new episode.

        Args:
            seed: Random seed for reproducible episodes
            options: Additional configuration (unused in this example)

        Returns:
            tuple: (observation, info) for the initial state
        """
        # IMPORTANT: Must call this first to seed the random number generator
        super().reset(seed=seed)

        # initialize number of visits
        self._nvisits = np.full(shape=(self.Nf,), fill_value=0, dtype=np.int32)
        
        # Randomly choose initial field id and add 1 visit to nvisits list
        # self._field_id = int(self.np_random.integers(0, self.Nf, size=1, dtype=int)[0])
        self._field_id = self.target_sequence[0]
        # self._field_id = self.np_random.integers(0, self.Nf, size=1, dtype=int).tolist()[0]
        self._nvisits[self._field_id] += 1
        self._index = 0
        # self._t = np.array([0.0], dtype=np.float32)
        # self._Teff_pred = np.linspace(0.1, .98, num=self.Nf, dtype=np.float32)
        
        observation = self._get_obs()
        info = self._get_info()

        return observation, info

    def step(self, action):
        """Execute one timestep within the environment.

        Args:

        Returns:
        """
        # # choose random field for next observation
        # list_idx = self.np_random.integers(low=0, high=len(self._possible_actions), dtype=int)
        # proposed_field = self._possible_actions[list_idx]
        # self._nvisits[proposed_field] += 1
        # self._field_id = proposed_field
        self._index += 1            
        # get current field_id from action
        self._field_id = self._action_to_field_id[action]
        # add to nvisits
        self._nvisits[self._field_id] += 1

        # Simple reward structure: +1 for reaching target, 0 otherwise
        target_field = self.target_sequence[self._index]
        correct = self._field_id == target_field
        off_by_one = np.abs(self._field_id - target_field) == 1
        if correct:
            reward = 1
        elif off_by_one:
            reward = .25
        else:
            reward = 0
            
        survey_complete = (self._index == len(self.target_sequence)-1)
        
        # end condition
        terminated = survey_complete
        truncated = False

        # get obs and info
        observation = self._get_obs()
        info = self._get_info(self._field_id, correct)

        return observation, reward, terminated, truncated, info

In [4]:
# Register the environment so we can create it with gym.make()
gym.register(
    id="gymnasium_env/SimpleTel-v0",
    entry_point=SimpleTelEnv,
    max_episode_steps=300,  # Prevent infinite episodes
)

In [5]:
env = gym.make("gymnasium_env/SimpleTel-v0", Nf=Nf, target_sequence=true_sequence, nv_max=nv_max)
# Create multiple environments for parallel training
# vec_env = gym.make_vec("gymnasium_env/GridWorld-v0", num_envs=3)

In [6]:
from gymnasium.utils.env_checker import check_env

# This will catch many common issues
try:
    check_env(env.unwrapped)
    print("Environment passes all checks!")
except Exception as e:
    print(f"Environment has issues: {e}")

Environment passes all checks!


In [7]:
# quick check to make sure environment is doing what's intended

# observation_list = []
# reward_list = []
# terminated_list = []
# truncated_list = []
# info_list = []

# observation, info = env.reset()
# observation_list.append(observation)
# reward_list.append(reward)

# for i in range(20):
#     try:
#         observation, reward, terminated, truncated, info = env.step(true_actions[i+1])
#         observation_list.append(observation)
#         reward_list.append(reward)
#         terminated_list.append(terminated)
#         truncated_list.append(truncated)
#         info_list.append(info)
#     except:
#         continue
# [observation_list[i]['field_id'] for i in range(len(observation_list))]

# Learn with RL model

In [19]:
model = DQN("MultiInputPolicy", env, verbose=1)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


## How does model do with no training?

In [20]:
seed = 16

In [21]:
observation_list = []
reward_list = []
terminated_list = []
truncated_list = []
info_list = []

obs, info = env.reset(seed=seed)
observation_list.append(obs)
info_list.append(info)
for i in range(100):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = env.step(action.tolist())
    observation_list.append(obs)
    reward_list.append(reward)
    terminated_list.append(terminated)
    truncated_list.append(truncated)
    info_list.append(info)
    if terminated or truncated:
        print(i)
        break

10


In [22]:
proposed_survey = [observation_list[i]['field_id'] for i in range(len(observation_list))]
proposed_survey

[4, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

In [23]:
true_sequence

[4, 5, 1, 2, 3, 3, 4, 1, 2, 5, 0, 0]

## Now train and predict

In [24]:
model.learn(total_timesteps=10000, log_interval=4)
model.save("simpleTel-v0-model")

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 11       |
|    ep_rew_mean      | 3.56     |
|    exploration_rate | 0.958    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 21953    |
|    time_elapsed     | 0        |
|    total_timesteps  | 44       |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 11       |
|    ep_rew_mean      | 2.66     |
|    exploration_rate | 0.916    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 21937    |
|    time_elapsed     | 0        |
|    total_timesteps  | 88       |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 11       |
|    ep_rew_mean      | 2.85     |
|    exploration_rate | 0.875    |
| time/               |          |
|    episodes       

In [25]:
observation_list = []
reward_list = []
terminated_list = []
truncated_list = []
info_list = []

obs, info = env.reset()
observation_list.append(obs)
info_list.append(info)
for i in range(100):
    action, _states = model.predict(obs, deterministic=False)
    obs, reward, terminated, truncated, info = env.step(action.tolist())
    observation_list.append(obs)
    reward_list.append(reward)
    terminated_list.append(terminated)
    truncated_list.append(truncated)
    info_list.append(info)
    if terminated or truncated:
        print(i)
        break

10


In [26]:
proposed_survey = [observation_list[i]['field_id'] for i in range(len(observation_list))]
proposed_survey

[4, 5, 1, 2, 3, 3, 5, 1, 2, 5, 0, 0]

In [27]:
true_sequence

[4, 5, 1, 2, 3, 3, 4, 1, 2, 5, 0, 0]