In [5]:
import gym
import gym_gridverse
import torch
import numpy as np
from stable_baselines3 import PPO
from sb3_contrib import RecurrentPPO
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env import DummyVecEnv

#Setup environment

class FlattenObservationWrapper(gym.ObservationWrapper):
	def __init__(self, env):
		super().__init__(env)
		total_size = sum(np.prod(env.observation_space.spaces[key].shape) for key in env.observation_space.spaces)
		self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(total_size,), dtype=np.float32)

	def observation(self, observation):
		# Flatten each part of the observation and then concatenate
		flattened_obs = np.concatenate([observation[key].flatten() for key in observation])
		return flattened_obs
		
def make_env():
	env= gym.make("GV-FourRooms-9x9-v0")
	env = FlattenObservationWrapper(env)
	return env

	


#num_envs = 4 
#env = SubprocVecEnv([make_env for _ in range(num_envs)],start_method='spawn')
#num_envs = 8  # Number of parallel environments
#env = DummyVecEnv([make_env for i in range(num_envs)])

env = make_env()
#env = FlattenObservationWrapper(env)
#model = PPO("MultiInputPolicy", env,verbose=1)




In [7]:
#Training function
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
#model.learn(total_timesteps=1800000)

def train_model(env, total_timesteps):
    model = RecurrentPPO("MlpLstmPolicy", env, verbose=2, device='cuda',n_steps=2048,tensorboard_log="./ppo_tensorboard/")

    # Save the model every 100k steps
    checkpoint_callback = CheckpointCallback(save_freq=10000, save_path='./models/', name_prefix='ppo_model_9x9')

    # Evaluation and logging
    #eval_callback = EvalCallback(env, best_model_save_path='./models/', log_path='./logs/', eval_freq=50000)

    model.learn(total_timesteps=total_timesteps, callback=[checkpoint_callback])

    return model

model = train_model(env, total_timesteps=1600000)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to ./ppo_tensorboard/RecurrentPPO_3


---------------------------------
| rollout/           |          |
|    ep_len_mean     | 453      |
|    ep_rew_mean     | -16.6    |
| time/              |          |
|    fps             | 288      |
|    iterations      | 1        |
|    time_elapsed    | 7        |
|    total_timesteps | 2048     |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 299          |
|    ep_rew_mean          | -8.89        |
| time/                   |              |
|    fps                  | 179          |
|    iterations           | 2            |
|    time_elapsed         | 22           |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0074336547 |
|    clip_fraction        | 0.0128       |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.79        |
|    explained_variance   | 0.00259      |
|    learning_r

KeyboardInterrupt: 

In [11]:
#Test and save full trained model
from stable_baselines3.common.evaluation import evaluate_policy
mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=20)
print(f"Mean reward: {mean_reward}, Std reward: {std_reward}")
model.save("ppo_gridworld_9x9_raytracing_1_8M")

Mean reward: 5.3600001353770494, Std reward: 0.3907044812480032


In [9]:
#Visualize agent
env= gym.make("GV-FourRooms-9x9-v0")
env = FlattenObservationWrapper(env)
obs = env.reset()
lstm_states = None
num_envs = 1
# Episode start signals are used to reset the lstm states
episode_starts = np.ones((num_envs,), dtype=bool)
while True:
    action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True)

    obs, rewards, dones, info = env.step(action)
    episode_starts = dones
    env.render()
    if dones:
        obs = env.reset()

KeyboardInterrupt: 

In [12]:
model.policy

RecurrentActorCriticPolicy(
  (features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (pi_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (vf_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (mlp_extractor): MlpExtractor(
    (policy_net): Sequential(
      (0): Linear(in_features=256, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
    )
    (value_net): Sequential(
      (0): Linear(in_features=256, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
    )
  )
  (action_net): Linear(in_features=64, out_features=6, bias=True)
  (value_net): Linear(in_features=64, out_features=1, bias=True)
  (lstm_actor): LSTM(123, 256)
  (lstm_critic): LSTM(123, 256)
)