In [1]:
import gymnasium as gym
import numpy as np

class pushButton(gym.Env):
    def __init__(self,config=None):
        self.current_steps = 0
        self.grid_size = 4
        self.action_space = gym.spaces.Discrete(5)
        self.observation_space = gym.spaces.Discrete(5)
        self.seed()
        self.reset()

    def seed(self, seed=None):
        self.np_random, seed = gym.utils.seeding.np_random(seed)
        return [seed]
    
    def reset(self, *, seed=None, options=None):
        self._max_episode_steps = 100
        self.current_steps = 0
        self.activated_button = 0
        return self.activated_button,{}
    
    def step(self,action):
        self.current_steps+=1
        reward = 1 if action==self.activated_button else 0
        self.activated_button = np.random.randint(0,5)
        return self.activated_button,reward,self.current_steps==self._max_episode_steps,False,{}
        
        
        


In [2]:
from ray.tune.registry import register_env
register_env("pb", pushButton)

In [None]:
import ray
from ray import tune,air
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.algorithms.callbacks import MemoryTrackingCallbacks
from ray.tune.logger import LoggerCallback
from ray.air import session


class RewardLoggerCallback(LoggerCallback):
    def __init__(self) -> None:
        self.results = []

    def log_trial_result(self, iteration: int, trial, result):
        reward = result  # assuming 'reward' is your reward key
        self.results.append(reward)

callback = RewardLoggerCallback()

stopping_criterion = {"training_iteration": 50}
ray.init(ignore_reinit_error=True, log_to_driver=False)
tuner = tune.Tuner(
    "PPO",
    param_space={
        "env": "pb",
        "framework": "torch",
        "num_gpus": 0,
        "num_workers":8,
        "should_checkpoint": True,
    },
    run_config=air.RunConfig(
        stop=stopping_criterion,callbacks=[callback],
        checkpoint_config=air.CheckpointConfig(checkpoint_frequency=1),
        ),
)
results = tuner.fit()
ray.shutdown()

2023-10-15 03:06:16,811	INFO worker.py:1642 -- Started a local Ray instance.


In [None]:
import matplotlib.pyplot as plt
rewards = [i["episode_reward_mean"] for i in callback.results]

plt.plot(rewards)
plt.xlabel("Iteration")
plt.ylabel("Mean Reward")
plt.title("Reward vs. Iteration")
plt.show()

In [18]:
# Get the best result
best_result = results.get_best_result()

# Get the best checkpoint
best_checkpoint = best_result.checkpoint

In [22]:
best_checkpoint

Checkpoint(local_path=/Users/spacefarers/ray_results/PPO/PPO_pb_8767a_00000_0_2023-08-06_01-07-16/checkpoint_000050)

In [19]:
from ray.rllib.algorithms.algorithm import Algorithm
register_env("pb", pushButton)
algo = Algorithm.from_checkpoint(best_checkpoint)

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2023-08-06 01:09:26,567	INFO algorithm.py:536 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.
2023-08-06 01:09:28,401	INFO worker.py:1636 -- Started a local Ray instance.


In [32]:
from time import sleep
env = pushButton()
obs = env.reset()[0]
progression = []
done = False

In [33]:
obs

0

In [34]:
from IPython.display import clear_output
lastObs=obs
while not done:
    action = algo.compute_single_action(obs)
    obs, reward, done, info,ads = env.step(action)
    print(f"action: {action}")
    print(f"obs: {obs}")
    print(f"reward: {reward}")
    if lastObs==action:
        print("WIN")
    lastObs=obs
    sleep(1)
    clear_output(True)
    # progression.append(obs)

action: 4
obs: 1
reward: 1
WIN


KeyboardInterrupt: 