In [47]:
import numpy as np
import matplotlib.pyplot as plt
import random

import gymnasium as gym

from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy

In [51]:
Nf = 6
T = 100
nv_max = 4
random.seed(2)
true_actions = np.array([np.full(4, i) for i in range(Nf)]).flatten()
true_actions = true_actions.tolist()
random.shuffle(true_actions)
true_actions

[3, 1, 0, 3, 3, 0, 1, 4, 4, 5, 3, 4, 5, 2, 4, 5, 1, 2, 2, 1, 2, 5, 0, 0]

In [59]:
true_actions == true_actions[:4]

False

In [52]:
from gymnasium.spaces import Dict, Box, Discrete

from typing import Optional
import numpy as np


class SimpleTelEnv(gym.Env):
    def __init__(self, Nf, T, target_sequence, nv_max, starting_field_id):
        self.Nf = Nf # number of fields
        self.T = 1 # 28800sec = 8hrs
        self.nv_max = nv_max
        self.target_sequence = target_sequence if target_sequence is not None else np.arange(0, Nf, 1, dtype=int)
        # "Teff_meas": Box(0, 1, shape=(Nf,), dtype=np.float32),

        #TODO
        # Initialize positions - will be set in reset()
        self._field_id = -1
        # self._t = -1
        self._nvisits = np.full(shape=(Nf,), fill_value=-1, dtype=int)
        # self._Teff_pred = np.full(shape=(Nf,), fill_value=-1, dtype=np.float32)
        self._step_num = -1
        self._sequence = []

        # Map action numbers to actual movements on the grid
        self._possible_actions = [i for i in range(Nf)]

        #TODO
        # Define what the agent can observe
        # Dict space gives us structured, human-readable observations
        self.observation_space = gym.spaces.Dict(
            {
                # "t": Box(0, T, shape=None, dtype=np.float32),
                "field_id": Discrete(n=Nf, start=0),
                "nvisits": Box(0, 4, shape=(Nf,), dtype=int),
                # "Teff_pred": Box(0, 1, shape=(Nf,), dtype=np.float32),
                    #filter
            }
        )

        self.action_space = gym.spaces.Discrete(self.Nf)

        # remove fields that have 
        
    def _get_obs(self,):
        """Convert internal state to observation format.
    
        Returns:
            dict: Observation with agent and target positions
        """
        return {
            # "t": self._t,
            "field_id": self._field_id,
            "nvisits": self._nvisits,
            ""
            # "Teff_pred": self._Teff_pred,
        }

    def _get_info(self):
        """Compute auxiliary information for debugging.

        Returns:
            
        """
        return {}

    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
        """Start a new episode.

        Args:
            seed: Random seed for reproducible episodes
            options: Additional configuration (unused in this example)

        Returns:
            tuple: (observation, info) for the initial state
        """
        # IMPORTANT: Must call this first to seed the random number generator
        super().reset(seed=seed)

        # Randomly choose initial field id and set time to 0
        self._field_id = self.starting_field_id
        # self._field_id = self.np_random.integers(0, self.Nf, size=1, dtype=int)[0]
        # self._t = np.array([0.0], dtype=np.float32)
        self._nvisits = np.full(shape=(self.Nf,), fill_value=0, dtype=int)
        # self._Teff_pred = np.linspace(0.1, .98, num=self.Nf, dtype=np.float32)
        self._step_num = 0

        self._sequence = [self._field_id]

        observation = self._get_obs()
        info = self._get_info()

        return observation, info

    def step(self, action):
        """Execute one timestep within the environment.

        Args:

        Returns:
        """
        # choose random field for next observation
        list_idx = self.np_random.integers(low=0, high=len(self._possible_actions), dtype=int)
        proposed_field = self._possible_actions[list_idx]
        self._nvisits[proposed_field] += 1
        self._field_id = proposed_field

        # if this is the last required observation for this field, remove it from possible actions list
        if self._nvisits[proposed_field] == self.nv_max:
            self._possible_actions.pop(list_idx)

        survey_complete = (len(self._sequence) == len(self.target_sequence))
        
        terminated = self._sequence == self.target_sequence
        
        if survey_complete and not terminated:
            truncated = True
        
        # Simple reward structure: +1 for reaching target, 0 otherwise
        is_exact = self._field_id == self.target_sequence[self._step_num]
        is_off_by_one = np.abs(self._field_id - self.target_sequence[self._step_num]) == 1

        if not terminated:
            if is_exact:
                
        if is_exact:
            reward = .5
        elif is_off_by_one:
            reward = .1
        else:
            reward = 0

        # automatically update possible actions dict to remove this field if it has been visited 4 times
        self.step_num += 1
        
        if (self._nvisits == self.nv_max).all()
            self.reset()


        observation = self._get_obs()
        info = self._get_info()

        return observation, reward, terminated, truncated, info

SyntaxError: ':' expected after dictionary key (2875554107.py, line 54)

In [32]:
true_actions

array([5, 0, 5, 1, 3, 4, 0, 3, 2, 1, 2, 1, 4, 0, 0, 1, 5, 4, 4, 2, 5, 3,
       3, 2])

In [29]:
# Register the environment so we can create it with gym.make()
gym.register(
    id="gymnasium_env/SimpleTel-v0",
    entry_point=SimpleTelEnv,
    max_episode_steps=300,  # Prevent infinite episodes
)

In [30]:
env = gym.make("gymnasium_env/SimpleTel-v0", Nf=Nf, T=T, target_sequence=true_actions, nv_max=nv_max)
# Create multiple environments for parallel training
# vec_env = gym.make_vec("gymnasium_env/GridWorld-v0", num_envs=3)

In [31]:
from gymnasium.utils.env_checker import check_env

# This will catch many common issues
try:
    check_env(env.unwrapped)
    print("Environment passes all checks!")
except Exception as e:
    print(f"Environment has issues: {e}")

Environment passes all checks!
