In [1]:
import argparse
import logging
import os
from pathlib import Path
import ray
from graphenv.examples.tsp.graph_utils import make_complete_planar_graph
from graphenv.examples.tsp.tsp_model import TSPModel, TSPQModel
from graphenv.examples.tsp.tsp_nfp_model import TSPGNNModel
from graphenv.examples.tsp.tsp_nfp_state import TSPNFPState
from graphenv.examples.tsp.tsp_state import TSPState
from graphenv.graph_env import GraphEnv
from networkx.algorithms.approximation.traveling_salesman import greedy_tsp
from ray import tune
from ray.rllib.algorithms.a3c import A3CConfig
from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.algorithms.marwil import MARWILConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.framework import try_import_tf
from ray.tune.registry import register_env
import networkx as nx
from ray.rllib.algorithms.algorithm import Algorithm
from ray.tune import ExperimentAnalysis

In [2]:
tf1, tf, tfv = try_import_tf()

parser = argparse.ArgumentParser()
parser.add_argument(
    "--run",
    type=str,
    default="PPO",
    choices=["PPO", "DQN", "A3C", "MARWIL"],
    help="The RLlib-registered algorithm to use.",
)
parser.add_argument("--N", type=int, default=5, help="Number of nodes in TSP network")
parser.add_argument(
    "--use-gnn", action="store_true", help="use the nfp state and gnn model"
)
parser.add_argument(
    "--max-num-neighbors",
    type=int,
    default=5,
    help="Number of nearest neighbors for the gnn model",
)
parser.add_argument(
    "--seed", type=int, default=0, help="Random seed used to generate networkx graph"
)
parser.add_argument(
    "--num-workers", type=int, default=1, help="Number of rllib workers"
)
parser.add_argument("--num-gpus", type=int, default=0, help="Number of GPUs")
parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
parser.add_argument(
    "--entropy-coeff", type=float, default=0.0, help="entropy coefficient"
)
parser.add_argument(
    "--rollouts-per-worker",
    type=int,
    default=1,
    help="Number of rollouts for each worker to collect",
)
parser.add_argument(
    "--stop-iters", type=int, default=100, help="Number of iterations to train."
)
parser.add_argument(
    "--stop-timesteps", type=int, default=100000, help="Number of timesteps to train."
)
parser.add_argument(
    "--stop-reward", type=float, default=0.0, help="Reward at which we stop training."
)
parser.add_argument(
    "--local-mode",
    action="store_true",
    help="Init Ray in local mode for easier debugging.",
)
parser.add_argument("--log-level", type=str, default="INFO")


_StoreAction(option_strings=['--log-level'], dest='log_level', nargs=None, const=None, default='INFO', type=<class 'str'>, choices=None, required=False, help=None, metavar=None)

In [3]:
args = parser.parse_args(args=[])

In [4]:
print(f"Running with following CLI options: {args}")
logging.basicConfig(level=args.log_level.upper())
ray.init(local_mode=args.local_mode)

Running with following CLI options: Namespace(run='PPO', N=5, use_gnn=False, max_num_neighbors=5, seed=0, num_workers=1, num_gpus=0, lr=0.0001, entropy_coeff=0.0, rollouts_per_worker=1, stop_iters=100, stop_timesteps=100000, stop_reward=0.0, local_mode=False, log_level='INFO')


2025-05-01 18:13:09,601	INFO worker.py:1553 -- Started a local Ray instance.


0,1
Python version:,3.9.21
Ray version:,2.3.1


[2m[36m(pid=19144)[0m I0000 00:00:1746123285.100181   19144 fork_posix.cc:75] Other threads are currently calling into gRPC, skipping fork() handlers
[2m[36m(PPO pid=19144)[0m 2025-05-01 18:14:46,748	INFO algorithm_config.py:2888 -- Executing eagerly (framework='tf2'), with eager_tracing=tf2. For production workloads, make sure to set eager_tracing=True  in order to match the speed of tf-static-graph (framework='tf'). For debugging purposes, `eager_tracing=False` is the best choice.
[2m[36m(pid=19277)[0m I0000 00:00:1746123288.468627   19277 fork_posix.cc:75] Other threads are currently calling into gRPC, skipping fork() handlers
[2m[36m(pid=19279)[0m I0000 00:00:1746123288.532544   19279 fork_posix.cc:75] Other threads are currently calling into gRPC, skipping fork() handlers
[2m[36m(RolloutWorker pid=19277)[0m 2025-05-01 18:14:51,201	INFO eager_tf_policy_v2.py:75 -- Creating TF-eager policy running on CPU.
[2m[36m(RolloutWorker pid=19277)[0m 2025-05-01 18:14:51,217	

In [5]:
N = args.N
G = make_complete_planar_graph(N=N, seed=args.seed)
nx.to_numpy_array(G)

array([[0.        , 0.1786471 , 0.14306129, 0.20869372, 0.53118409],
       [0.1786471 , 0.        , 0.20562852, 0.3842079 , 0.39536284],
       [0.14306129, 0.20562852, 0.        , 0.2462733 , 0.60040816],
       [0.20869372, 0.3842079 , 0.2462733 , 0.        , 0.73154383],
       [0.53118409, 0.39536284, 0.60040816, 0.73154383, 0.        ]])

In [6]:
tsp_approx = nx.approximation.traveling_salesman_problem
path = tsp_approx(G, cycle=True)
reward_baseline = -sum([G[path[i]][path[i + 1]]["weight"] for i in range(0, N)])
print(f"Networkx heuristic reward: {reward_baseline:1.3f}")
print(path)
path = tsp_approx(G, cycle=True, method=greedy_tsp)
reward_baseline = -sum([G[path[i]][path[i + 1]]["weight"] for i in range(0, N)])
print(f"Networkx greedy reward: {reward_baseline:1.3f}")
print(path)

Networkx heuristic reward: -1.700
[0, 4, 1, 3, 2, 0]
Networkx greedy reward: -1.996
[0, 2, 1, 3, 4, 0]


In [7]:
# Algorithm-specific config, common ones are in the main config dict below
if args.run == "PPO":
    run_config = PPOConfig()
    train_batch_size = args.rollouts_per_worker * N * args.num_workers
    sgd_minibatch_size = 16 if train_batch_size > 16 else 2
    run_config.training(entropy_coeff=args.entropy_coeff,
                        sgd_minibatch_size=sgd_minibatch_size,
                        num_sgd_iter=5,
    )
elif args.run in ["DQN"]:
    run_config = DQNConfig()
    # Update here with custom config
    run_config.training(hiddens=False,
                    dueling=False
    )
    run_config.exploration(exploration_config={"epsilon_timesteps": 250000})
elif args.run == "A3C":
    run_config = A3CConfig()
elif args.run == "MARWIL":
    run_config = MARWILConfig()
else:
    raise ValueError(f"Import agent {args.run} and try again")

In [8]:
# Define custom_model, config, and state based on GNN yes/no
if args.use_gnn:
    custom_model = "TSPGNNModel"
    custom_model_config = {"num_messages": 3, "embed_dim": 32}
    ModelCatalog.register_custom_model(custom_model, TSPGNNModel)
    _tag = "gnn"
    state = TSPNFPState(
        lambda: make_complete_planar_graph(N=N),
        max_num_neighbors=args.max_num_neighbors,
    )
else:
    custom_model_config = {"hidden_dim": 256, "embed_dim": 256, "num_nodes": N}
    custom_model = "TSPModel"
    Model = TSPQModel if args.run in ["DQN", "R2D2"] else TSPModel
    ModelCatalog.register_custom_model(custom_model, Model)
    _tag = f"basic{args.run}"
    state = TSPState(lambda: make_complete_planar_graph(N=N))

In [9]:
# Register env name with hyperparams that will help tracking experiments
# via tensorboard
env_name = f"mygraphenv-v0" #_{N}_{_tag}_lr={args.lr}
register_env(env_name, lambda config: GraphEnv(config))

In [11]:
args.num_gpus = 0
args.num_workers = 4
run_config = (
    run_config
    .resources(num_gpus=args.num_gpus) 
    .framework("tf2") #tf ?
    .rollouts(num_rollout_workers=args.num_workers, 
              # a multiple of N (collect whole episodes)
              rollout_fragment_length=N
             )
    .environment(env=env_name,
                 env_config={"state": state, 
                             "max_num_children": G.number_of_nodes()}
              )
    .training(lr=args.lr,
              train_batch_size=args.rollouts_per_worker * N * args.num_workers,
              model={"custom_model": custom_model, 
                     "custom_model_config": custom_model_config}
              )
    .evaluation(evaluation_config={"explore": False},
                evaluation_interval=1, 
                evaluation_duration=100,
              )
    .debugging(log_level=args.log_level)
    .framework(eager_tracing=True)
)

stop = {
    "training_iteration": args.stop_iters,
    "timesteps_total": args.stop_timesteps,
    "episode_reward_mean": args.stop_reward,
}

In [12]:
my_path = Path("/home/vladimir/work/graph_test/scratch/ray_results")

In [34]:
res = tune.run(
    args.run,
    config=run_config.to_dict(),
    stop=stop,
    local_dir=my_path,
    checkpoint_freq = 10,
    checkpoint_at_end=True
)

0,1
Current time:,2025-05-01 18:26:01
Running for:,00:01:34.35
Memory:,8.0/125.7 GiB

Trial name,status,loc,iter,total time (s),ts,reward,episode_reward_max,episode_reward_min,episode_len_mean
PPO_mygraphenv-v0_83b0c_00000,TERMINATED,192.168.0.126:21126,100,84.8138,2000,-2.32887,-1.14646,-3.20093,5


2025-05-01 18:24:26,900	INFO algorithm_config.py:2888 -- Executing eagerly (framework='tf2'), with eager_tracing=tf2. For production workloads, make sure to set eager_tracing=True  in order to match the speed of tf-static-graph (framework='tf'). For debugging purposes, `eager_tracing=False` is the best choice.
2025-05-01 18:24:26,906	INFO algorithm_config.py:2888 -- Executing eagerly (framework='tf2'), with eager_tracing=tf2. For production workloads, make sure to set eager_tracing=True  in order to match the speed of tf-static-graph (framework='tf'). For debugging purposes, `eager_tracing=False` is the best choice.


Trial name,agent_timesteps_total,connector_metrics,counters,custom_metrics,date,done,episode_len_mean,episode_media,episode_reward_max,episode_reward_mean,episode_reward_min,episodes_this_iter,episodes_total,evaluation,experiment_id,hostname,info,iterations_since_restore,node_ip,num_agent_steps_sampled,num_agent_steps_trained,num_env_steps_sampled,num_env_steps_sampled_this_iter,num_env_steps_trained,num_env_steps_trained_this_iter,num_faulty_episodes,num_healthy_workers,num_in_flight_async_reqs,num_remote_worker_restarts,num_steps_trained_this_iter,perf,pid,policy_reward_max,policy_reward_mean,policy_reward_min,sampler_perf,sampler_results,time_since_restore,time_this_iter_s,time_total_s,timers,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
PPO_mygraphenv-v0_83b0c_00000,2000,"{'ObsPreprocessorConnector_ms': 0.03676104545593262, 'StateBufferConnector_ms': 0.007767200469970703, 'ViewRequirementAgentConnector_ms': 0.10059976577758789}","{'num_env_steps_sampled': 2000, 'num_env_steps_trained': 2000, 'num_agent_steps_sampled': 2000, 'num_agent_steps_trained': 2000}",{},2025-05-01_18-26-01,True,5,{},-1.14646,-2.32887,-3.20093,4,400,"{'episode_reward_max': -1.0382531494150822, 'episode_reward_min': -3.741914818388316, 'episode_reward_mean': -2.2476783066819817, 'episode_len_mean': 5.0, 'episode_media': {}, 'episodes_this_iter': 100, 'policy_reward_min': {}, 'policy_reward_max': {}, 'policy_reward_mean': {}, 'custom_metrics': {}, 'hist_stats': {'episode_reward': [-2.1254908512281365, -1.2917786344157192, -2.7794127442422445, -2.3737772654242146, -2.096694781312544, -1.998214312998111, -2.5031531048397353, -1.915698498585109, -2.492899001981852, -2.7103124590930316, -2.3879840768367107, -2.9381158688073308, -1.9883952161542804, -2.3683534235219295, -2.827053691594947, -2.4857539936880473, -2.1648105267079463, -1.9267901337690134, -1.943836528223068, -1.575827524040641, -2.419939636433561, -2.817876793314558, -2.865244315516788, -2.3417808895906087, -2.064721903229385, -2.691461603935197, -3.3073526065781946, -1.3276743579954244, -3.15256238668833, -2.587069990529321, -2.543093937930438, -2.516395615074977, -2.49882346959784, -2.2483607726886032, -2.481960482672434, -1.7122295100717926, -2.6059026230999534, -3.449432646215185, -2.053056284368886, -2.6166380872115487, -2.681262272944611, -1.1027037534154638, -2.457549285413978, -1.8922633421590538, -1.9041564216870932, -2.446745009318397, -2.270821133236576, -2.656279150808281, -2.4138248192742107, -2.0067714038070976, -1.5486310102993088, -2.2365595581046227, -2.416607170455187, -2.070211251324009, -1.7638408062713098, -2.186394603007687, -2.367315016414315, -1.9641248247851797, -2.330072955996229, -1.935339928115452, -2.1108316426959535, -2.243735216693254, -2.08342640986155, -2.2640335570799333, -1.9770719114226822, -2.4504604926538853, -1.822147607052899, -2.5927899303820747, -1.9921400692703162, -1.4294876508309025, -1.7488732134226983, -2.8486038253220842, -1.8252003529119836, -2.011604199985138, -1.6943493959966518, -2.2511309430721624, -2.4694389292117567, -2.508040027721897, -2.402603738873854, -3.741914818388316, -2.383153718361382, -2.6293828271289494, -1.9496939720364428, -2.1746490763015425, -2.3377224658771896, -1.618912174975453, -1.8332610480950264, -2.940507453537559, -1.657488937451832, -2.4303967489875085, -2.4270049804930887, -2.45808027110558, -1.745605392215356, -1.9439206874240376, -1.0382531494150822, -2.6206897621060303, -1.5084642019451242, -2.24917657690128, -2.594247917273677, -1.9119311066703504], 'episode_lengths': [5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5]}, 'sampler_perf': {'mean_raw_obs_processing_ms': 0.43354376595329425, 'mean_inference_ms': 0.8059981555877301, 'mean_action_processing_ms': 0.08907841862693998, 'mean_env_wait_ms': 0.06784589974246658, 'mean_env_render_ms': 0.0}, 'num_faulty_episodes': 0, 'connector_metrics': {'ObsPreprocessorConnector_ms': 0.03167915344238281, 'StateBufferConnector_ms': 0.0060787200927734375, 'ViewRequirementAgentConnector_ms': 0.08473324775695801}, 'num_agent_steps_sampled_this_iter': 500, 'num_env_steps_sampled_this_iter': 500, 'timesteps_this_iter': 500, 'num_healthy_workers': 0, 'num_in_flight_async_reqs': 0, 'num_remote_worker_restarts': 0}",d5b151283ff14efb98ca5dda411140ae,srv5,"{'learner': {'default_policy': {'learner_stats': {'cur_kl_coeff': 3.1554436679038213e-31, 'cur_lr': 9.999999747378752e-05, 'total_loss': 0.21769394, 'policy_loss': -0.010065695, 'vf_loss': 0.22775963, 'vf_explained_var': 0.2493176, 'kl': 0.0006022395, 'entropy': 0.38405335, 'entropy_coeff': 0.0}, 'custom_metrics': {}, 'num_agent_steps_trained': 2.0, 'num_grad_updates_lifetime': 4975.5, 'diff_num_grad_updates_vs_sampler_policy': 24.5}}, 'num_env_steps_sampled': 2000, 'num_env_steps_trained': 2000, 'num_agent_steps_sampled': 2000, 'num_agent_steps_trained': 2000}",100,192.168.0.126,2000,2000,2000,20,2000,20,0,4,0,0,20,"{'cpu_util_percent': 6.4, 'ram_util_percent': 6.4}",21126,{},{},{},"{'mean_raw_obs_processing_ms': 0.5756774617410609, 'mean_inference_ms': 1.6284009398993817, 'mean_action_processing_ms': 0.10941484692081475, 'mean_env_wait_ms': 0.08529641176247782, 'mean_env_render_ms': 0.0}","{'episode_reward_max': -1.1464611892376801, 'episode_reward_min': -3.200930589558242, 'episode_reward_mean': -2.3288708628100623, 'episode_len_mean': 5.0, 'episode_media': {}, 'episodes_this_iter': 4, 'policy_reward_min': {}, 'policy_reward_max': {}, 'policy_reward_mean': {}, 'custom_metrics': {}, 'hist_stats': {'episode_reward': [-2.388984540232118, -1.2956705627551097, -3.0269021765014408, -2.871336713078322, -1.8946789544887843, -2.6087670941878898, -2.3238982518758657, -2.421402961723609, -2.6534302124778737, -2.170424256420173, -1.850432168817969, -2.872134814653089, -1.619673165283086, -1.884573291573757, -1.2478918736153988, -2.4630825977324444, -2.639806891409734, -2.631024335774482, -2.0551456059717945, -2.572691997123049, -3.0948187063354244, -2.589696414182098, -2.645068841072532, -2.7708974674344455, -1.7945784544626255, -2.6847378306814784, -2.758691704691069, -2.858921751355586, -2.8270625664793, -2.419992161290342, -3.0120276115173055, -2.0664348305856906, -1.7347696573832805, -1.518139172591099, -1.9618898065292618, -2.3408693989287443, -2.057995338647973, -1.8091140151459562, -2.171533539484203, -2.784003195736911, -2.1446224222476364, -2.325353927212247, -2.482228180553632, -1.8544583243307287, -2.8477146182633266, -2.6855138608530713, -2.551306000398193, -2.663511801666885, -2.0547490319469803, -2.607279097006721, -2.140749199445432, -1.844831633056356, -2.36156222869029, -2.8578718450433707, -2.195219611959812, -2.1840371228333746, -2.7532372986882994, -2.826235550772615, -1.1464611892376801, -2.4569316649473745, -2.5244708987019373, -1.220607916161287, -2.6590120320878188, -1.941129238033662, -2.4453914320090386, -1.844006983571282, -2.0412584849088224, -1.8138912879173614, -2.4725418200347002, -2.4529466664723256, -2.8270457775119224, -2.006987611698472, -2.9196046306312278, -2.3941057671609722, -3.200930589558242, -1.8312028062300985, -2.3424384910953187, -1.852359658782794, -2.9454913675595726, -1.8426851468945133, -2.7121143482345733, -2.2455369947903345, -2.2842135804408317, -2.7461724579675977, -2.5958119408773785, -2.20368351686906, -2.765672967055498, -1.8816303769746123, -2.258179347184148, -2.0186826363423784, -2.1278566607935687, -1.8720596347276826, -2.7538756225761896, -2.7231737799677105, -2.31942062890263, -2.1189018381162237, -2.3435131970577623, -1.8571882253364007, -2.3198228195391737, -2.7843975608457794], 'episode_lengths': [5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5]}, 'sampler_perf': {'mean_raw_obs_processing_ms': 0.5756774617410609, 'mean_inference_ms': 1.6284009398993817, 'mean_action_processing_ms': 0.10941484692081475, 'mean_env_wait_ms': 0.08529641176247782, 'mean_env_render_ms': 0.0}, 'num_faulty_episodes': 0, 'connector_metrics': {'ObsPreprocessorConnector_ms': 0.03676104545593262, 'StateBufferConnector_ms': 0.007767200469970703, 'ViewRequirementAgentConnector_ms': 0.10059976577758789}}",84.8138,0.834555,84.8138,"{'training_iteration_time_ms': 107.769, 'learn_time_ms': 92.838, 'learn_throughput': 215.428, 'synch_weights_time_ms': 3.094}",1746123961,0,2000,100,83b0c_00000,4.04257


2025-05-01 18:26:01,611	INFO tune.py:798 -- Total run time: 94.75 seconds (94.31 seconds for the tuning loop).


In [None]:
ray.shutdown()

In [35]:
# Загружаем результаты
analysis = ExperimentAnalysis(my_path)
best_trial = analysis.get_best_trial(metric="episode_reward_max", mode="max")
best_checkpoint = analysis.get_best_checkpoint(best_trial, mode = "max")

2025-05-01 18:27:09,104	INFO experiment_analysis.py:789 -- No `self.trials`. Drawing logdirs from checkpoint file. This may result in some information that is out of sync, as checkpointing is periodic.


In [36]:
env = GraphEnv({"state": state, 
          "max_num_children": G.number_of_nodes()})

In [37]:
algo = Algorithm.from_checkpoint(best_checkpoint)

2025-05-01 18:27:12,799	INFO algorithm_config.py:2888 -- Executing eagerly (framework='tf2'), with eager_tracing=tf2. For production workloads, make sure to set eager_tracing=True  in order to match the speed of tf-static-graph (framework='tf'). For debugging purposes, `eager_tracing=False` is the best choice.
2025-05-01 18:27:16,136	INFO worker_set.py:310 -- Inferred observation/action spaces from remote worker (local worker has no env): {'default_policy': (Repeated(Dict('nbr_dist': Box(0.0, 1.4142135623730951, (1,), float64), 'node_idx': Box(0, 5, (1,), int64), 'node_obs': Box(0.0, 1.0, (2,), float64), 'parent_dist': Box(0.0, 1.4142135623730951, (1,), float64)), 6), Discrete(5)), '__env__': (Repeated(Dict('nbr_dist': Box(0.0, 1.4142135623730951, (1,), float64), 'node_idx': Box(0, 5, (1,), int64), 'node_obs': Box(0.0, 1.0, (2,), float64), 'parent_dist': Box(0.0, 1.4142135623730951, (1,), float64)), 6), Discrete(5))}
2025-05-01 18:27:16,140	INFO eager_tf_policy_v2.py:75 -- Creating TF-

In [59]:
episode_reward = 0
terminated = truncated = False
obs, info = env.reset()
i = 0
path = [obs[0]['node_idx'][0]]
while not terminated and not truncated and i < 20: 
    action = algo.compute_single_action(obs, explore = False)
    obs, reward, terminated, truncated, info = env.step(action)
    print(i, action, reward, info)
    episode_reward += reward
    path.append(obs[0]['node_idx'][0])
    i += 1
print(path, episode_reward )

0 3 -0.4066376678985208 {}
1 2 -0.2750684250004176 {}
2 0 -0.1859166479415223 {}
3 0 -0.6496229391255659 {}
4 0 -0.24954859274920732 {}
[0, 4, 3, 1, 2, 0] -1.766794272715234
