# Draft

In [None]:
from pprint import pprint
from ray.rllib.algorithms.ppo import PPOConfig
from ray import train, tune
from ray.rllib.algorithms.algorithm import Algorithm
import gymnasium as gym
import pathlib
import numpy as np
import torch
from ray.rllib.core.rl_module import RLModule

In [None]:
N = 5
rand = np.random.RandomState(1)
a = np.round(rand.normal(5, 1.5, size=(N, N)),2)
a[np.eye(N)==1] = 0
a = -(a / a.max())
a

In [None]:
class MyEnv(gym.Env):
    def __init__(self, env_config):
        self.n = env_config['n']
        self.a = env_config['a']
        self.action_space = gym.spaces.Discrete(self.n)
        self.observation_space = gym.spaces.Box(low=-2.0, 
                                                high=0.0, 
                                                shape=(self.n,), dtype=np.float32)
    def reset(self, seed=None, options=None):
        self.obs = self.a[0].copy()
        return self.obs, {}
        
    def step(self, action):
        reward = self.obs[action]
        self.obs[action] = -2.0
        terminated = False
        truncated = False
        if np.all(self.obs<=-2.0):
            terminated = True
        else:
            # переставляем
            pass
        #print(f"action = {action} reward = {reward} terminated = {terminated}")
        return self.obs, reward, terminated, truncated, {}

In [None]:
config = (
    PPOConfig()
    .api_stack(
        enable_rl_module_and_learner=True,
        enable_env_runner_and_connector_v2=True,
    )
    .training(gamma=0.9, lr=0.01, mini_batch_size_per_learner = 128, train_batch_size_per_learner=1024)
    .environment(env=MyEnv,
        env_config={"n": N, "a": a}) 
    .env_runners(num_env_runners=1)
)
algo = config.build()
for i in range(6):
    result = algo.train()
    if i % 5 == 0:
        print(i)
        checkpoint_dir = algo.save_to_path()
        print(f"Checkpoint saved in directory {checkpoint_dir}")

In [None]:
config = (
    PPOConfig()
    .api_stack(
        enable_rl_module_and_learner=True,
        enable_env_runner_and_connector_v2=True,
    )
    .environment(env=MyEnv,
        env_config={"n": N, "a": a})
    .training(
        lr=tune.grid_search([0.01, 0.001, 0.0001]),
    )
)

In [None]:
tuner = tune.Tuner(
    "PPO",
    param_space=config,
    run_config=train.RunConfig(
        stop={"num_env_steps_sampled_lifetime": 20000},
        checkpoint_config=train.CheckpointConfig(checkpoint_at_end=True),
    ),
)
results = tuner.fit()
# Get the best result based on a particular metric.
best_result = results.get_best_result(
    metric="env_runners/episode_return_mean", mode="max"
)
# Get the best checkpoint corresponding to the best result.
best_checkpoint = best_result.checkpoint

In [None]:
env = MyEnv(env_config = {"n": N, "a": a} )

In [None]:
rl_module = RLModule.from_checkpoint(
    pathlib.Path(checkpoint_dir ) / "learner_group" / "learner" / "rl_module" #best_checkpoint.path
)["default_policy"]

episode_return = 0
terminated = truncated = False
obs, info = env.reset()
i = 0
while not terminated and not truncated and i < 15:
    i += 1
    torch_obs_batch = torch.from_numpy(np.array([obs]))
    torch_obs_batch = torch_obs_batch.to(torch.float32)
    action_logits = rl_module.forward_inference({"obs": torch_obs_batch})[
        "action_dist_inputs"
    ]
    action = torch.argmax(action_logits[0]).numpy()
    obs, reward, terminated, truncated, info = env.step(action)
    episode_return += reward
    print(action, terminated, truncated)
print(f"Reached episode return of {episode_return}.")