### A deterministic (reproducible) example of a custom subclass Keras model with a simple custom Gymnasium environment
Issues:
1. Config with 'tf2' mode and eager tracing enabled reports
   `input_dict["is_training"]) KeyError: 'is_training'` 
2. 'tf' mode cannot restore policy from a checkpoint
3. in 'tf' mode policy.export_model always saves a model with initial weights (not trained)

Derived from https://github.com/ray-project/ray/blob/master/rllib/examples/custom_env.py

<a href="https://colab.research.google.com/github/shmyak-ai/rllib-env-model/blob/main/rllib_env_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### Train and save

In [None]:
if 'google.colab' in str(get_ipython()):
  print('Running on CoLab')
  %pip install 'ray[default,rllib]' &>/dev/null || echo "Ray install failed!"
else:
  print('Not running on CoLab')

In [None]:
from dataclasses import dataclass

import gymnasium as gym
from gymnasium.spaces import Discrete, Box
import numpy as np
import tensorflow as tf
import ray
from tensorflow import keras
from tensorflow.keras import layers
from ray.rllib.env.env_context import EnvContext
from ray.rllib.models import ModelCatalog
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.policy.policy import Policy
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import NotProvided
from ray.tune.registry import get_trainable_cls

In [None]:
@dataclass
class Arguments:
    run: str = 'PPO'
    framework: str = 'tf'
    stop_iters: int = 10
    stop_timesteps: int = 100000
    stop_reward: float = 0.1 
    local_mode: bool = False
    num_workers: int = 0
    num_envs_per_worker: int = 1
    seed: int = 42

args = Arguments()

In [None]:
if type(args.seed) is int:
    tf.keras.utils.set_random_seed(args.seed)
    tf.config.experimental.enable_op_determinism()
    print("Tensorflow determenism is enabled.")

In [None]:
class SimpleCorridor(gym.Env):
    """Example of a custom env in which you have to walk down a corridor.
    You can configure the length of the corridor via the env config."""

    def __init__(self, config: EnvContext):
        self.end_pos = config["corridor_length"]
        self.cur_pos = 0
        self.action_space = Discrete(2)
        self.observation_space = Box(0.0, self.end_pos, shape=(1,), dtype=np.float32)
        # Set the seed. This is only used for the final (reach goal) reward.
        if isinstance(config, EnvContext):
            self.reset(seed=config["seed"] + config.worker_index + config.num_workers)
        else:
            self.reset(seed=config["seed"])

    def reset(self, *, seed=None, options=None):
        # random.seed(seed)
        self.cur_pos = 0
        return [self.cur_pos], {}

    def step(self, action):
        assert action in [0, 1], action
        if action == 0 and self.cur_pos > 0:
            self.cur_pos -= 1
        elif action == 1:
            self.cur_pos += 1
        done = truncated = self.cur_pos >= self.end_pos
        # Produce a random reward when we reach the goal.
        return (
            [self.cur_pos],
            # random.random() * 2 if done else -0.1,
            1 if done else -0.1,
            done,
            truncated,
            {},
        )

In [None]:
is_gpu = bool(tf.config.list_physical_devices('GPU'))
num_gpus = NotProvided
if is_gpu:
    print("Use GPU")
    num_gpus = 1
    # one more cpu for a driver
    ray.init(local_mode=args.local_mode, num_cpus=args.num_workers + 1, num_gpus=num_gpus)
else:
    print("Use CPU")
    ray.init(local_mode=args.local_mode, num_cpus=args.num_workers + 1)

In [None]:
config = (
    get_trainable_cls(args.run)
    .get_default_config()
    # or "corridor" if registered above
    .environment(
        SimpleCorridor, 
        env_config={
            "corridor_length": 5,
            "seed": args.seed,
            },
        )
    .framework(
        framework=args.framework,
        eager_tracing=True if args.framework == 'tf2' else False,
    )
    .rollouts(
        num_rollout_workers=args.num_workers,  # if 0 a driver will sample
        num_envs_per_worker=args.num_envs_per_worker,  # 1 is minimum
        )
    .training(
        model={
            "custom_model": "my_model",
            "custom_model_config": {
                "seed": args.seed,
            }
        }
    )

    .debugging(seed=args.seed)
    .resources(num_gpus=num_gpus)
)

In [None]:
class DenseBlock(layers.Layer):
    """
    A keras dense block.
    """
    def __init__(self, n_features, n_layers, num_outputs, seed, **kwargs):
        super().__init__(**kwargs)

        initializer = keras.initializers.VarianceScaling
        self._dense = [layers.Dense(
            n_features,
            activation=tf.nn.silu,
            kernel_initializer=initializer(
                scale=2.0, 
                mode='fan_in', 
                distribution='truncated_normal', 
                seed=seed+i if type(seed) is int else None
            )
        ) for i in range(n_layers)]
        self._out = layers.Dense(
            num_outputs,
            activation=None,
            kernel_initializer=initializer(
                scale=2.0, 
                mode='fan_in', 
                distribution='truncated_normal', 
                seed=seed+n_layers if type(seed) is int else None
            )
        )

    def call(self, input_tensor, *args, **kwargs):
        x = input_tensor
        for dense in self._dense:
            x = dense(x)
        return self._out(x)


class DenseNet(keras.Model):
    """
    A keras dense net.
    """
    def __init__(self, num_outputs, seed=None, **kwargs):
        super().__init__(**kwargs)

        self._n_features = 256
        self._n_layers = 2
        self._actor = DenseBlock(self._n_features, self._n_layers, num_outputs, seed)
        self._critic = DenseBlock(self._n_features, self._n_layers, 1, seed)

    def call(self, input_tensor, training=False):
        logits = self._actor(input_tensor, training)
        value = self._critic(input_tensor, training)
        return logits, value


class CustomModel(TFModelV2):
    """Example of a keras custom model that just delegates to an fc-net."""

    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        super(CustomModel, self).__init__(
            obs_space, action_space, num_outputs, model_config, name
        )
        self.base_model = DenseNet(num_outputs, seed=model_config['custom_model_config']['seed'])
        
        fake_obs = tf.random.uniform(shape=obs_space.shape)
        fake_obs = tf.expand_dims(fake_obs, 0)
        _, self._value = self.base_model(fake_obs, False)

    def forward(self, input_dict, state, seq_lens):
        logits, self._value = self.base_model(input_dict["obs"])
        # logits, self._value = self.base_model(input_dict["obs"], input_dict["is_training"])
        return logits, state

    def value_function(self):
        return tf.reshape(self._value, [-1])

In [None]:
ModelCatalog.register_custom_model("my_model", CustomModel)
stop = {
    "training_iteration": args.stop_iters,
    "timesteps_total": args.stop_timesteps,
    "episode_reward_mean": args.stop_reward,
}

# manual training with train loop using PPO and fixed learning rate
if args.run != "PPO":
    raise ValueError("Only support --run PPO with --no-tune.")
print("Running manual train loop without Ray Tune.")
# use fixed learning rate instead of grid search (needs tune)
config.lr = 1e-3
algo = config.build()

In [None]:
# run manual training loop and print results after each iteration
checkpoint_paths = []
for _ in range(args.stop_iters):
    result = algo.train()
    print(f"Timesteps total: {result['agent_timesteps_total']}")
    print(f"Episode reward: {result['episode_reward_mean']}")
    path_to_checkpoint = algo.save()
    checkpoint_paths.append(path_to_checkpoint)
    print(
        "An Algorithm checkpoint has been created inside directory: "
        f"'{path_to_checkpoint}'."
    )
    policy = algo.get_policy()
    policy.export_model(path_to_checkpoint + "/keras_model")
    # stop training of the target train steps or reward are reached
    if (
        result["timesteps_total"] >= args.stop_timesteps
        or result["episode_reward_mean"] >= args.stop_reward
    ):
        break
algo.stop()

In [None]:
ray.shutdown()

#### Try to restore a policy or a model

In [None]:
path_to_checkpoint_1, path_to_checkpoint_2 = checkpoint_paths

In [None]:
# ModelCatalog.register_custom_model("my_model", CustomModel)

In [None]:
# rllib_algorithm = Algorithm.from_checkpoint(path_to_checkpoint)
# rllib_policy = rllib_algorithm.get_policy("default_policy")
# del rllib_algorithm
# rllib_policy = Policy.from_checkpoint(path_to_checkpoint + "/policies/default_policy")
keras_model_1 = tf.saved_model.load(path_to_checkpoint_1 + "/keras_model/")
keras_model_2 = tf.saved_model.load(path_to_checkpoint_2 + "/keras_model/")

In [None]:
# policy_weights = rllib_policy.get_weights()
# for key in policy_weights.keys():
#     print(key)

In [None]:
# policy_weights['default_policy/dense_net/dense_block/dense/kernel'].shape

In [None]:
# for var in keras_model.trainable_variables:
#     print(var.name)

In [None]:
# keras_model.trainable_variables[6].numpy().shape

In [None]:
# policy_weights['default_policy/dense_net/dense_block/dense/kernel']

In [None]:
keras_model_1.trainable_variables[0].numpy()

In [None]:
keras_model_2.trainable_variables[0].numpy()

In [None]:
# env = SimpleCorridor(config.env_config)
# obs, _ = env.reset()
# rllib_policy.compute_single_action(obs)