In [1]:
import numpy as np
import gymnasium as gym
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.models.catalog import MODEL_DEFAULTS
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.env.env_context import EnvContext
from ray.rllib.utils.typing import MultiAgentDict, PolicyID
from ray.rllib.connectors.env_to_module import FlattenObservations
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
from ray.tune.registry import register_env
import ray
import torch
from pprint import pprint
from ray.rllib.core.rl_module import RLModule
import pathlib
from python_tsp.exact import solve_tsp_branch_and_bound

E0000 00:00:1747585013.112253  221195 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747585013.115834  221195 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1747585013.124880  221195 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1747585013.124888  221195 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1747585013.124889  221195 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1747585013.124890  221195 computation_placer.cc:177] computation placer already registered. Please check linka

In [2]:
class TspObsEnv(gym.Env):
    def __init__(self, config: EnvContext):
        super().__init__()
        self.n = config['n']
        self.pa = config['a'][:self.n,:self.n]
        self.a = np.array(self.pa, dtype=np.float32)
        self.action_space = gym.spaces.Discrete(self.n-1, start=1)
        self.observation_space = gym.spaces.Box(low=0.0, high=1.0, 
                                                shape=(self.n,self.n), dtype=np.float32)

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.a = np.array(self.pa, dtype=np.float32)
        return self.a, {}

    def step(self, action):     
        terminated = False
        truncated = False
        if self.a[0,action] > 0.0:
            reward = - self.a[0,action]
            self.a[:, 0] = self.a[:, action]
            self.a[0, :] = self.a[action, :]
            self.a[:, action:-1] = self.a[:, action+1:]
            self.a[action:-1, :] = self.a[action+1:, :]
            self.a[:,-1] = 0.0
            self.a[-1:] = 0.0
            if np.sum(self.a) == 0.0:
                terminated = True
        else:
            reward = -0.1
        info = {}
        
        return self.a, reward, terminated, truncated, info
# Регистрация среды
def env_creator(config):
    return TspObsEnv(config)
register_env("TSPObsEnv", env_creator)

gym.register(
    id="gymnasium_env/TspObsEnv",
    entry_point=TspObsEnv,
)

In [3]:
X = np.load('X_20x20_fixed.npy')[0]
n = 5

In [4]:
env = gym.make("gymnasium_env/TspObsEnv", config = {'n': n, 'a': X})
env.action_space.seed(1)
observation, info = env.reset(seed = 1)
episode_over = False
i = 0
path = list(range(n))
curr_i = 0
while not episode_over and i < 30:
    action = env.action_space.sample()
    if curr_i + action < n:
        v = path.pop(curr_i + action)
        curr_i += 1
        path.insert(curr_i, v)
    print(i, action, path)
    observation, reward, terminated, truncated, info = env.step(action)
    
    i += 1
    episode_over = terminated or truncated
env.close()

0 2 [0, 2, 1, 3, 4]
1 3 [0, 2, 4, 1, 3]
2 4 [0, 2, 4, 1, 3]
3 4 [0, 2, 4, 1, 3]
4 1 [0, 2, 4, 1, 3]
5 1 [0, 2, 4, 1, 3]


In [5]:
# Configure the algorithm.
config = (
    PPOConfig()
    .environment("TSPObsEnv", env_config = {'n': n, 'a': X})
    .env_runners(
        num_env_runners=7,
        # Observations are discrete (ints) -> We need to flatten (one-hot) them.
        env_to_module_connector=lambda env: FlattenObservations(),
    )
    #.evaluation(evaluation_num_env_runners=1)
)

In [6]:
algo = config.build_algo()

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2025-05-18 16:16:58,651	INFO worker.py:1879 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
[36m(pid=221971)[0m E0000 00:00:1747585019.793778  221971 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN wh

In [7]:
for i in range(5):
    algo.train()
    print(i)
    checkpoint_dir = algo.save_to_path()
    print(f"Checkpoint saved in directory {checkpoint_dir}")

0
Checkpoint saved in directory /tmp/8cf2a34c-bbd5-48cb-81d6-3c6322137b1e
1
Checkpoint saved in directory /tmp/8309267a-d8c7-4e12-8400-33306c568434
2
Checkpoint saved in directory /tmp/2b6b5fcb-a692-42a8-875a-c3c3b17a3fa5
3
Checkpoint saved in directory /tmp/99283a3a-b1df-4007-a202-7e3c20e610af
4
Checkpoint saved in directory /tmp/2821c964-2d79-4db3-97c4-4dadd370abfe


In [8]:
#rl_module = algo.get_module()
rl_module = RLModule.from_checkpoint(
    pathlib.Path(checkpoint_dir ) / "learner_group" / "learner" / "rl_module" #best_checkpoint.path
)["default_policy"]

In [9]:
env = gym.make("gymnasium_env/TspObsEnv", config = {'n': n, 'a': X})
observation, info = env.reset()
episode_over = False
i = 0
path = list(range(n))
curr_i = 0
s = 0
while not episode_over and i < 30:
    action_logits = rl_module.forward_inference({'obs': 
       torch.from_numpy(gym.spaces.flatten(env.observation_space, 
                                           observation)).unsqueeze(0)})[
        "action_dist_inputs"
    ]
    action = torch.argmax(action_logits[0]).numpy()
    observation, reward, terminated, truncated, info = env.step(action)
    if curr_i + action < n:
        v = path.pop(curr_i + action)
        curr_i += 1
        path.insert(curr_i, v)
        s += reward
    print(i, action, path, reward)
    i += 1
    episode_over = terminated or truncated
env.close()
print(s)

0 3 [0, 3, 1, 2, 4] -0.45384774
1 3 [0, 3, 4, 1, 2] -0.6959624
2 2 [0, 3, 4, 2, 1] -0.6491428
3 1 [0, 3, 4, 2, 1] -0.48258537
-2.2815385


In [10]:
sum(X[i,j] for i,j in zip(path[:-1],path[1:]))

np.float64(2.2815383570027787)

In [11]:
distance_matrix = X[:n,:n].copy()
distance_matrix[:, 0] = 0
solve_tsp_branch_and_bound(distance_matrix)

([0, 3, 4, 2, 1], np.float64(2.281538357002779))