In [1]:
import argparse
import logging
import os
from pathlib import Path
import ray
from graphenv.examples.tsp.graph_utils import make_complete_planar_graph
from graphenv.examples.tsp.tsp_model import TSPModel, TSPQModel
from graphenv.examples.tsp.tsp_nfp_model import TSPGNNModel
from graphenv.examples.tsp.tsp_nfp_state import TSPNFPState
from graphenv.examples.tsp.tsp_state import TSPState
from graphenv.graph_env import GraphEnv
from networkx.algorithms.approximation.traveling_salesman import greedy_tsp
from ray import tune
from ray.rllib.algorithms.a3c import A3CConfig
from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.algorithms.marwil import MARWILConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.framework import try_import_tf
from ray.tune.registry import register_env
import networkx as nx
from ray.rllib.algorithms.algorithm import Algorithm
from ray.tune import ExperimentAnalysis

In [2]:
tf1, tf, tfv = try_import_tf()

parser = argparse.ArgumentParser()
parser.add_argument(
    "--run",
    type=str,
    default="PPO",
    choices=["PPO", "DQN", "A3C", "MARWIL"],
    help="The RLlib-registered algorithm to use.",
)
parser.add_argument("--N", type=int, default=5, help="Number of nodes in TSP network")
parser.add_argument(
    "--use-gnn", action="store_true", help="use the nfp state and gnn model"
)
parser.add_argument(
    "--max-num-neighbors",
    type=int,
    default=5,
    help="Number of nearest neighbors for the gnn model",
)
parser.add_argument(
    "--seed", type=int, default=0, help="Random seed used to generate networkx graph"
)
parser.add_argument(
    "--num-workers", type=int, default=1, help="Number of rllib workers"
)
parser.add_argument("--num-gpus", type=int, default=0, help="Number of GPUs")
parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
parser.add_argument(
    "--entropy-coeff", type=float, default=0.0, help="entropy coefficient"
)
parser.add_argument(
    "--rollouts-per-worker",
    type=int,
    default=1,
    help="Number of rollouts for each worker to collect",
)
parser.add_argument(
    "--stop-iters", type=int, default=100, help="Number of iterations to train."
)
parser.add_argument(
    "--stop-timesteps", type=int, default=100000, help="Number of timesteps to train."
)
parser.add_argument(
    "--stop-reward", type=float, default=0.0, help="Reward at which we stop training."
)
parser.add_argument(
    "--local-mode",
    action="store_true",
    help="Init Ray in local mode for easier debugging.",
)
parser.add_argument("--log-level", type=str, default="INFO")


_StoreAction(option_strings=['--log-level'], dest='log_level', nargs=None, const=None, default='INFO', type=<class 'str'>, choices=None, required=False, help=None, metavar=None)

In [3]:
args = parser.parse_args(args=[])
args.use_gnn=True
#args.num_gpus=0
#args.num_workers = 2


In [4]:
print(f"Running with following CLI options: {args}")
logging.basicConfig(level=args.log_level.upper())
ray.init(local_mode=args.local_mode)

Running with following CLI options: Namespace(run='PPO', N=5, use_gnn=True, max_num_neighbors=5, seed=0, num_workers=1, num_gpus=0, lr=0.0001, entropy_coeff=0.0, rollouts_per_worker=1, stop_iters=100, stop_timesteps=100000, stop_reward=0.0, local_mode=False, log_level='INFO')


2025-05-10 18:55:29,811	INFO worker.py:1553 -- Started a local Ray instance.


0,1
Python version:,3.9.21
Ray version:,2.3.1


[2m[36m(pid=42601)[0m I0000 00:00:1746903367.000684   42601 fork_posix.cc:75] Other threads are currently calling into gRPC, skipping fork() handlers
[2m[36m(PPO pid=42601)[0m 2025-05-10 18:56:08,645	INFO algorithm_config.py:2888 -- Executing eagerly (framework='tf2'), with eager_tracing=tf2. For production workloads, make sure to set eager_tracing=True  in order to match the speed of tf-static-graph (framework='tf'). For debugging purposes, `eager_tracing=False` is the best choice.
[2m[36m(pid=42734)[0m I0000 00:00:1746903370.370971   42734 fork_posix.cc:75] Other threads are currently calling into gRPC, skipping fork() handlers
[2m[36m(pid=42735)[0m I0000 00:00:1746903370.402369   42735 fork_posix.cc:75] Other threads are currently calling into gRPC, skipping fork() handlers
[2m[36m(pid=42736)[0m I0000 00:00:1746903370.390215   42736 fork_posix.cc:75] Other threads are currently calling into gRPC, skipping fork() handlers
[2m[36m(pid=42733)[0m I0000 00:00:1746903370

In [5]:
N = args.N
G = make_complete_planar_graph(N=N, seed=args.seed)

In [6]:
tsp_approx = nx.approximation.traveling_salesman_problem
path = tsp_approx(G, cycle=True)
reward_baseline = -sum([G[path[i]][path[i + 1]]["weight"] for i in range(0, N)])
print(f"Networkx heuristic reward: {reward_baseline:1.3f}")
print(path)
path = tsp_approx(G, cycle=True, method=greedy_tsp)
reward_baseline = -sum([G[path[i]][path[i + 1]]["weight"] for i in range(0, N)])
print(f"Networkx greedy reward: {reward_baseline:1.3f}")
print(path)

Networkx heuristic reward: -1.700
[0, 4, 1, 3, 2, 0]
Networkx greedy reward: -1.996
[0, 2, 1, 3, 4, 0]


In [7]:
# Algorithm-specific config, common ones are in the main config dict below
if args.run == "PPO":
    run_config = PPOConfig()
    train_batch_size = args.rollouts_per_worker * N * args.num_workers
    sgd_minibatch_size = 16 if train_batch_size > 16 else 2
    run_config.training(entropy_coeff=args.entropy_coeff,
                        sgd_minibatch_size=sgd_minibatch_size,
                        num_sgd_iter=5,
    )
elif args.run in ["DQN"]:
    run_config = DQNConfig()
    # Update here with custom config
    run_config.training(hiddens=False,
                    dueling=False
    )
    run_config.exploration(exploration_config={"epsilon_timesteps": 250000})
elif args.run == "A3C":
    run_config = A3CConfig()
elif args.run == "MARWIL":
    run_config = MARWILConfig()
else:
    raise ValueError(f"Import agent {args.run} and try again")

In [8]:
# Define custom_model, config, and state based on GNN yes/no
if args.use_gnn:
    custom_model = "TSPGNNModel"
    custom_model_config = {"num_messages": 3, "embed_dim": 32}
    print('use_gnn')
    ModelCatalog.register_custom_model(custom_model, TSPGNNModel)
    _tag = "gnn"
    state = TSPNFPState(
        lambda: G,
        max_num_neighbors=args.max_num_neighbors,
    )
else:
    custom_model_config = {"hidden_dim": 256, "embed_dim": 256, "num_nodes": N}
    custom_model = "TSPModel"
    Model = TSPQModel if args.run in ["DQN", "R2D2"] else TSPModel
    ModelCatalog.register_custom_model(custom_model, Model)
    _tag = f"basic{args.run}"
    state = TSPState(lambda: G)

use_gnn


In [9]:
# Register env name with hyperparams that will help tracking experiments
# via tensorboard
env_name = f"mygraphenv-v0" #_{N}_{_tag}_lr={args.lr}
register_env(env_name, lambda config: GraphEnv(config))

In [10]:
args.num_gpus = 0
args.num_workers = 4
run_config = (
    run_config
    .resources(num_gpus=args.num_gpus) 
    .framework("tf2") #tf ?
    .rollouts(num_rollout_workers=args.num_workers, 
              # a multiple of N (collect whole episodes)
              rollout_fragment_length=N
             )
    .environment(env=env_name,
                 env_config={"state": state, 
                             "max_num_children": G.number_of_nodes()}
              )
    .training(lr=args.lr,
              train_batch_size=args.rollouts_per_worker * N * args.num_workers,
              model={"custom_model": custom_model, 
                     "custom_model_config": custom_model_config}
              )
    .evaluation(evaluation_config={"explore": False},
                evaluation_interval=1, 
                evaluation_duration=100,
              )
    .debugging(log_level=args.log_level)
    .framework(eager_tracing=True)
)

stop = {
    "training_iteration": args.stop_iters,
    "timesteps_total": args.stop_timesteps,
    "episode_reward_mean": args.stop_reward,
}

In [11]:
my_path = Path("/home/vladimir/work/graph_test/scratch/ray_results")

In [12]:
res = tune.run(
    args.run,
    config=run_config.to_dict(),
    stop=stop,
    local_dir=my_path,
    checkpoint_freq = 10,
    checkpoint_at_end=True
)

0,1
Current time:,2025-05-10 18:58:23
Running for:,00:02:18.25
Memory:,5.0/125.7 GiB

Trial name,status,loc,iter,total time (s),ts,reward,episode_reward_max,episode_reward_min,episode_len_mean
PPO_mygraphenv-v0_6d1f5_00000,TERMINATED,192.168.0.126:42601,100,124.098,2000,-1.71081,-1.58714,-1.94072,5


2025-05-10 18:56:05,584	INFO algorithm_config.py:2888 -- Executing eagerly (framework='tf2'), with eager_tracing=tf2. For production workloads, make sure to set eager_tracing=True  in order to match the speed of tf-static-graph (framework='tf'). For debugging purposes, `eager_tracing=False` is the best choice.
2025-05-10 18:56:05,585	INFO algorithm_config.py:2888 -- Executing eagerly (framework='tf2'), with eager_tracing=tf2. For production workloads, make sure to set eager_tracing=True  in order to match the speed of tf-static-graph (framework='tf'). For debugging purposes, `eager_tracing=False` is the best choice.


Trial name,agent_timesteps_total,connector_metrics,counters,custom_metrics,date,done,episode_len_mean,episode_media,episode_reward_max,episode_reward_mean,episode_reward_min,episodes_this_iter,episodes_total,evaluation,experiment_id,hostname,info,iterations_since_restore,node_ip,num_agent_steps_sampled,num_agent_steps_trained,num_env_steps_sampled,num_env_steps_sampled_this_iter,num_env_steps_trained,num_env_steps_trained_this_iter,num_faulty_episodes,num_healthy_workers,num_in_flight_async_reqs,num_remote_worker_restarts,num_steps_trained_this_iter,perf,pid,policy_reward_max,policy_reward_mean,policy_reward_min,sampler_perf,sampler_results,time_since_restore,time_this_iter_s,time_total_s,timers,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
PPO_mygraphenv-v0_6d1f5_00000,2000,"{'ObsPreprocessorConnector_ms': 0.05004119873046875, 'StateBufferConnector_ms': 0.007983207702636719, 'ViewRequirementAgentConnector_ms': 0.10598039627075195}","{'num_env_steps_sampled': 2000, 'num_env_steps_trained': 2000, 'num_agent_steps_sampled': 2000, 'num_agent_steps_trained': 2000}",{},2025-05-10_18-58-23,True,5,{},-1.58714,-1.71081,-1.94072,4,400,"{'episode_reward_max': -1.5871424799869809, 'episode_reward_min': -1.5871424799869809, 'episode_reward_mean': -1.587142479986981, 'episode_len_mean': 5.0, 'episode_media': {}, 'episodes_this_iter': 100, 'policy_reward_min': {}, 'policy_reward_max': {}, 'policy_reward_mean': {}, 'custom_metrics': {}, 'hist_stats': {'episode_reward': [-1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809], 'episode_lengths': [5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5]}, 'sampler_perf': {'mean_raw_obs_processing_ms': 0.4983578717021242, 'mean_inference_ms': 1.2904544708330687, 'mean_action_processing_ms': 0.08943228670979561, 'mean_env_wait_ms': 0.06638082303356374, 'mean_env_render_ms': 0.0}, 'num_faulty_episodes': 0, 'connector_metrics': {'ObsPreprocessorConnector_ms': 0.04060626029968262, 'StateBufferConnector_ms': 0.00610804557800293, 'ViewRequirementAgentConnector_ms': 0.08330035209655762}, 'num_agent_steps_sampled_this_iter': 500, 'num_env_steps_sampled_this_iter': 500, 'timesteps_this_iter': 500, 'num_healthy_workers': 0, 'num_in_flight_async_reqs': 0, 'num_remote_worker_restarts': 0}",a7ceabf2b41c4819b8a6fc26ffe882af,srv5,"{'learner': {'default_policy': {'learner_stats': {'cur_kl_coeff': 1.0463794097859136e-07, 'cur_lr': 9.999999747378752e-05, 'total_loss': 0.1437203, 'policy_loss': 7.5178745e-05, 'vf_loss': 0.14364511, 'vf_explained_var': 0.31599823, 'kl': 0.0024756247, 'entropy': 0.23174913, 'entropy_coeff': 0.0}, 'custom_metrics': {}, 'num_agent_steps_trained': 2.0, 'num_grad_updates_lifetime': 4975.5, 'diff_num_grad_updates_vs_sampler_policy': 24.5}}, 'num_env_steps_sampled': 2000, 'num_env_steps_trained': 2000, 'num_agent_steps_sampled': 2000, 'num_agent_steps_trained': 2000}",100,192.168.0.126,2000,2000,2000,20,2000,20,0,4,0,0,20,"{'cpu_util_percent': 6.85, 'ram_util_percent': 4.0}",42601,{},{},{},"{'mean_raw_obs_processing_ms': 0.6515747097674068, 'mean_inference_ms': 3.7383319936540165, 'mean_action_processing_ms': 0.1107845738046373, 'mean_env_wait_ms': 0.08886962517409602, 'mean_env_render_ms': 0.0}","{'episode_reward_max': -1.5871424799869809, 'episode_reward_min': -1.9407205546191035, 'episode_reward_mean': -1.7108066392581531, 'episode_len_mean': 5.0, 'episode_media': {}, 'episodes_this_iter': 4, 'policy_reward_min': {}, 'policy_reward_max': {}, 'policy_reward_mean': {}, 'custom_metrics': {}, 'hist_stats': {'episode_reward': [-1.893276848138505, -1.7000894302075489, -1.893276848138505, -1.7000894302075489, -1.7000894302075489, -1.7000894302075489, -1.7000894302075489, -1.7000894302075489, -1.893276848138505, -1.893276848138505, -1.7000894302075489, -1.9407205546191035, -1.6842902103837378, -1.6842902103837378, -1.6842902103837378, -1.893276848138505, -1.5871424799869809, -1.6842902103837378, -1.6842902103837378, -1.7000894302075489, -1.5871424799869809, -1.7000894302075489, -1.7000894302075489, -1.5871424799869809, -1.7000894302075489, -1.893276848138505, -1.893276848138505, -1.7000894302075489, -1.5871424799869809, -1.7000894302075489, -1.893276848138505, -1.9407205546191035, -1.7000894302075489, -1.924921334795293, -1.7000894302075489, -1.893276848138505, -1.7000894302075489, -1.7000894302075489, -1.7000894302075489, -1.7000894302075489, -1.7000894302075489, -1.7000894302075489, -1.7000894302075489, -1.9407205546191035, -1.893276848138505, -1.7000894302075489, -1.893276848138505, -1.5871424799869809, -1.7000894302075489, -1.893276848138505, -1.7000894302075489, -1.7000894302075489, -1.7000894302075489, -1.5871424799869809, -1.7000894302075489, -1.7000894302075489, -1.5871424799869809, -1.7000894302075489, -1.7000894302075489, -1.893276848138505, -1.7000894302075489, -1.5871424799869809, -1.7000894302075489, -1.7000894302075489, -1.7000894302075489, -1.7000894302075489, -1.7000894302075489, -1.7000894302075489, -1.7000894302075489, -1.893276848138505, -1.5871424799869809, -1.5871424799869809, -1.5871424799869809, -1.7000894302075489, -1.5871424799869809, -1.6842902103837378, -1.7000894302075489, -1.7000894302075489, -1.7000894302075489, -1.7000894302075489, -1.5871424799869809, -1.7000894302075489, -1.7000894302075489, -1.5871424799869809, -1.7000894302075489, -1.5871424799869809, -1.7000894302075489, -1.5871424799869809, -1.5871424799869809, -1.7000894302075489, -1.5871424799869809, -1.7000894302075489, -1.7000894302075489, -1.5871424799869809, -1.5871424799869809, -1.7000894302075489, -1.7000894302075489, -1.5871424799869809, -1.7000894302075489, -1.5871424799869809], 'episode_lengths': [5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5]}, 'sampler_perf': {'mean_raw_obs_processing_ms': 0.6515747097674068, 'mean_inference_ms': 3.7383319936540165, 'mean_action_processing_ms': 0.1107845738046373, 'mean_env_wait_ms': 0.08886962517409602, 'mean_env_render_ms': 0.0}, 'num_faulty_episodes': 0, 'connector_metrics': {'ObsPreprocessorConnector_ms': 0.05004119873046875, 'StateBufferConnector_ms': 0.007983207702636719, 'ViewRequirementAgentConnector_ms': 0.10598039627075195}}",124.098,1.17975,124.098,"{'training_iteration_time_ms': 186.545, 'learn_time_ms': 162.604, 'learn_throughput': 122.998, 'synch_weights_time_ms': 7.3}",1746903503,0,2000,100,6d1f5_00000,8.70584


2025-05-10 18:58:24,159	INFO tune.py:798 -- Total run time: 138.66 seconds (138.21 seconds for the tuning loop).


In [13]:
# Загружаем результаты
analysis = ExperimentAnalysis(my_path)
best_trial = analysis.get_best_trial(metric="episode_reward_mean", mode="max")
best_checkpoint = analysis.get_best_checkpoint(best_trial, mode = "max")
#best_checkpoint = analysis.get_last_checkpoint()

2025-05-10 18:58:24,187	INFO experiment_analysis.py:789 -- No `self.trials`. Drawing logdirs from checkpoint file. This may result in some information that is out of sync, as checkpointing is periodic.


In [14]:
env = GraphEnv({"state": state, 
          "max_num_children": G.number_of_nodes()})

In [15]:
algo = Algorithm.from_checkpoint(best_checkpoint)

2025-05-10 18:58:39,865	INFO algorithm_config.py:2888 -- Executing eagerly (framework='tf2'), with eager_tracing=tf2. For production workloads, make sure to set eager_tracing=True  in order to match the speed of tf-static-graph (framework='tf'). For debugging purposes, `eager_tracing=False` is the best choice.
2025-05-10 18:58:44,379	INFO worker_set.py:310 -- Inferred observation/action spaces from remote worker (local worker has no env): {'default_policy': (Repeated(Dict('connectivity': Box(0, 5, (20, 2), int64), 'current_node': Box(0, 5, (), int64), 'distance': Box(0.0, 1.4142135623730951, (), float64), 'edge_weights': Box(0.0, 1.4142135623730951, (20,), float64), 'node_visited': Box(0, 2, (5,), int64)), 6), Discrete(5)), '__env__': (Repeated(Dict('connectivity': Box(0, 5, (20, 2), int64), 'current_node': Box(0, 5, (), int64), 'distance': Box(0.0, 1.4142135623730951, (), float64), 'edge_weights': Box(0.0, 1.4142135623730951, (20,), float64), 'node_visited': Box(0, 2, (5,), int64)), 6

In [16]:
episode_reward = 0
terminated = truncated = False
obs, info = env.reset(G=G)
i = 0
path = [] #[obs[0]['node_idx'][0]]
nn = {ob['current_node'] for ob in obs}
print(nn)
while not terminated and not truncated and i < 20: 
    action = algo.compute_single_action(obs, explore = False)
    obs, reward, terminated, truncated, info = env.step(action)
    nn_new = {ob['current_node'] for ob in obs}
    cc = list(nn - nn_new)[0]
    print(i, action, reward, cc)
    episode_reward += reward
    path.append(cc)
 #   path.append(obs[0]['node_idx'][0])
    i += 1
    nn = nn_new
print( episode_reward, path)

{0, 1, 2, 3, 4}
0 3 -0.5311840924120236 0
1 0 -0.3953628417186544 4
2 0 -0.20562852490057235 1
3 0 -0.24627330250392146 2
4 0 -0.20869371845180915 3
-1.5871424799869809 [0, 4, 1, 2, 3]


In [17]:
env.state.tour

[0, 4, 1, 2, 3, 0]

In [None]:
from python_tsp.exact import solve_tsp_dynamic_programming

In [None]:
solve_tsp_dynamic_programming(nx.to_numpy_array(G))

In [None]:
ray.shutdown()