# Test utilisation Ray[rllib]

In [7]:
import ray
from ray import tune
import os

In [None]:
ray.init()

## Train our custom environment

In [37]:
from ray.tune.registry import register_env
from pathlib import Path
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
from FpfEnv_v2 import FeuParFeuEnv
from ray.rllib.algorithms.ppo import PPOConfig

In [41]:
net_file = r"..\data\networks\2way-single-intersection\single-intersection.net.xml"
route_file = r"..\data\networks\2way-single-intersection\single-intersection.rou.xml"

base_path = Path(os.path.abspath('')).parent
net_file_path = (base_path / net_file).resolve()
route_file_path = (base_path / route_file).resolve()

register_env('fpf', lambda config : ParallelPettingZooEnv((FeuParFeuEnv(
        net_file=net_file,
        route_file=route_file,
        reward_fn_name='arrived_vehicles',
        with_gui=True,
        num_seconds=3600
))))

In [None]:
num_rollout_workers = 4
rollout_fragment_length = 360
train_batch_size = num_rollout_workers * rollout_fragment_length
config = (
        PPOConfig()
        .environment(env='fpf')
        .rollouts(num_rollout_workers=num_rollout_workers, rollout_fragment_length=rollout_fragment_length)
        .training(
            train_batch_size=train_batch_size)
        .experimental(_disable_preprocessor_api=True)
        .debugging(log_level="ERROR")
        .framework(framework="torch")
        .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
    )

In [None]:
tune.run(
        "PPO",
        name="PPO",
        stop={"num_episodes": 5 if not os.environ.get("CI") else 2},
        checkpoint_freq=50,
        local_dir=r"\..\outputs\fpf",
        config=config.to_dict(),
    )

## Test for Custom env

In [None]:
net_file = r"D:\mesdocuments\paul\gertrude\travail\R&D\projets\generation_diag\travail\data\networks\2way-single-intersection\single-intersection.net.xml"
route_file = r"D:\mesdocuments\paul\gertrude\travail\R&D\projets\generation_diag\travail\data\networks\2way-single-intersection\single-intersection.rou.xml"

env = FeuParFeuEnv(
        net_file=net_file,
        route_file=route_file,
        reward_fn_name='green',
        with_gui=True,
        num_seconds=3600
)

for _ in range(1):
    obs, infos = env.reset()
    while True:
        actions = {agent: env.action_space(agent).sample(obs['action_mask']) for agent in env.agents}
        observations, rewards, terminations, truncations, infos = env.step(actions)
        if all(terminations.values()) or any(truncations.values()):
            break

In [48]:
import numpy as np

class Test():
    def __init__(self) -> None:
        print("init")
        self.possible_actions = ['g', 'y', 'r']
        self.consecutive_durations = {state : 0 for state in self.possible_actions}
        self.cumulated_durations = {state : 0 for state in self.possible_actions}  
        print(f"self.consecutive_durations : {self.consecutive_durations}") 
        print(f"self.cumulated_durations : {self.cumulated_durations}")

    def update(self, state: str):
        """"Update all the values related to this Feu such as cumulated green time etc
        Args:
            state (str): should be 'g', 'y' or 'r' 
        """
        print("----------------UPDATE-----------------")
        new = {key : value + 1 if key == state else 0 for key, value in self.consecutive_durations.items()}
        self.consecutive_durations.update(new)
        self.cumulated_durations[state] += 1
        print(f"new self.consecutive_durations : {self.consecutive_durations} ")
        print(f"new self.cumulated_durations : {self.cumulated_durations} ")

    def compute_action_mask(self, state) -> np.array:
        """Return the action mask for the Feu.
        Should be a list or array with
        1 : valid action
        0 : invalid action
        """
        print("----------------COMPUTE ACTION MASK-----------------")
        current_state = state
        print(f"current_state : {current_state}")
        print(f"self.consecutive_durations : {self.consecutive_durations}")
        if current_state == "g":
            # ctr vert min
            if self.consecutive_durations["g"] < 6:
                action_mask = np.array([1, 0, 0], dtype=np.int8)
            else:
                action_mask = np.array([1, 1, 0], dtype=np.int8)
        elif current_state == "y":
            # ctr orange 
            if self.consecutive_durations["y"] < 3:
                action_mask = np.array([0, 1, 0], dtype=np.int8)
            elif self.consecutive_durations["y"] == 3:
                action_mask = np.array([0, 0, 1], dtype=np.int8)
            else:
                raise Exception(f"Impossible d'avoir plus de 3 sec de orange.\nself.consecutive_durations : {self.consecutive_durations['y']}")
                #action_mask = np.array([0, 0, 1], dtype=np.int8)
        else:
            # ctr rouge max
            if self.consecutive_durations["r"] >= 120:
                action_mask = np.array([1, 0, 0], dtype=np.int8)
            else:
                action_mask = np.array([1, 0, 1], dtype=np.int8)
        print(f"action_mask : {action_mask}")
        return action_mask
    
    def process(self, state: str):
        self.update(state=state)
        self.compute_action_mask(state=state)


In [81]:
t = Test()

init
self.consecutive_durations : {'g': 0, 'y': 0, 'r': 0}
self.cumulated_durations : {'g': 0, 'y': 0, 'r': 0}


In [87]:
t.process(state='y')

----------------UPDATE-----------------
new self.consecutive_durations : {'g': 6, 'y': 0, 'r': 0} 
new self.cumulated_durations : {'g': 6, 'y': 0, 'r': 0} 
----------------COMPUTE ACTION MASK-----------------
current_state : g
self.consecutive_durations : {'g': 6, 'y': 0, 'r': 0}
action_mask : [1 1 0]
