In [23]:
import gymnasium
import numpy as np
from gymnasium.spaces import Discrete, Box

class GraphTraversalEnv(gymnasium.Env):
    def __init__(self, config):
        super(GraphTraversalEnv, self).__init__()

        self.coordinates = config["coordinates"]
        self.user_simulator = config["user_simulator"]
        self.max_steps = len(self.coordinates)
        self.num_nodes = len(self.coordinates)

        self.action_space = Discrete(self.num_nodes)  # Propose a next node
        self.observation_space = Box(low=0, high=1, shape=(self.num_nodes + self.max_steps,), dtype=np.float32)

        self.reset()

    def reset(self):
        self.current_node = random.randint(0,self.num_nodes)
        self.visited = set([self.current_node])
        self.path = [self.current_node]
        self.step_count = 0
        return self._get_obs()

    def _get_obs(self):
        # Observations can include:
        # - Current node (one-hot)
        # - Path history (binary flags)
        obs = np.zeros(self.num_nodes + self.max_steps, dtype=np.float32)
        obs[self.current_node] = 1.0
        obs[self.num_nodes:self.num_nodes+len(self.path)] = [float(i) / self.num_nodes for i in self.path]
        return obs

    def step(self, action):
        reward = 0.0
        done = False

        if action in self.visited:
            reward = -1.0  # Penalty for revisiting
        else:
            self.current_node = action
            self.visited.add(action)
            self.path.append(action)

            # Simulate user preference
            reward = self.user_simulator.evaluate(self.current_node)

        self.step_count += 1
        done = self.step_count >= self.max_steps

        return self._get_obs(), reward, done, {"path": self.path}
    
class UserSimulator:
    def __init__(self, liked_indices):
        self.liked = set(liked_indices)

    def evaluate(self, node_index):
        return 1.0 if node_index in self.liked else -0.5


In [24]:
import pickle
with open('../data/processed/df_01.pkl', 'rb') as f:
    coordinates=pickle.load(f)['geo']

user_likes = [1,3,5,10]
config = {
    "coordinates": coordinates,
    "user_simulator": UserSimulator(liked_indices=user_likes)
}

g_env = GraphTraversalEnv(config)

In [25]:
import json
import random

def generate_offline_dataset(env, episodes=100):
    dataset = []

    for _ in range(episodes):
        obs = env.reset()
        done = False

        while not done:
            action = random.choice([
                i for i in range(env.num_nodes) if i not in env.visited
            ])
            next_obs, reward, done, info = env.step(action)
            
            dataset.append({
                "obs": obs.tolist(),
                "actions": action,
                "rewards": reward,
                "new_obs": next_obs.tolist(),
                "dones": done,
            })

            obs = next_obs
            if len(env.visited) == env.num_nodes:
                break

    # Save as JSON
    with open("offline_graph_dataset.json", "w") as f:
        for record in dataset:
            f.write(json.dumps(record) + "\n")

In [27]:
generate_offline_dataset(g_env)

In [28]:
g_env.reset()
g_env.step(1)

(array([0.        , 1.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.8333333 , 0.08333334, 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        ], dtype=float32),
 1.0,
 False,
 {'path': [10, 1]})