In [1]:
import gymnasium as gym
import qwertyenv
from stable_baselines3 import PPO
from stable_baselines3.ppo.policies import MultiInputPolicy
from stable_baselines3.common.evaluation import evaluate_policy

In [2]:
action = None

env =  gym.make('qwertyenv/CollectCoins-v0', pieces=['rock', 'rock'])

def another_action_taken(action_taken):
    global action
    action = action_taken

# Wrapping the original environment as to make sure a valid action will be taken.
env = qwertyenv.EnsureValidAction(
  env,
  env.check_action_valid,
  env.provide_alternative_valid_action,
  another_action_taken
)
agent_w = PPO(MultiInputPolicy, env)

In [3]:
def play(num_episodes: int = 1):
    for episode in range(num_episodes):
        print(f'{episode=}')
        print("-------")
        obs, _ = env.reset()
        env.render()
        while True:
          action, _state = agent_w.predict(obs, deterministic=True)
          # print(f'action predicted: {action}')
          obs, reward, terminated, truncated, info = env.step(action)
          done = terminated or truncated
          # print(f'action taken: {action}')
          env.render()
          print()
          print(f'{reward=}, {done=}, {info=}')
          print()
          if done:
            break
        print()

In [4]:
play(1)

episode=0
-------
---------------------------------
|wR | $ | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ | $ |bR |
---------------------------------

0/0
---------------------------------
|   |wR | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ |

  logger.warn(f"{pre} was expecting a numpy array, actual type: {type(obs)}")
  logger.warn(f"{pre} was expecting a numpy array, actual type: {type(obs)}")


---------------------------------
|   |   |   |   |   |   |   | $ |
---------------------------------
|   |   |   |   |   |   |   | $ |
---------------------------------
|   |   |wR |   |   |   |   |   |
---------------------------------
|   |   |   |   |   |   |   |   |
---------------------------------
|   |   |   |   |   |   |   |   |
---------------------------------
|   |bR |   |   |   |   |   |   |
---------------------------------
|   |   |   |   |   |   |   |   |
---------------------------------
|   |   |   |   |   |   |   |   |
---------------------------------

28/32

reward=0.0, done=False, info={}

---------------------------------
|   |   |   |   |   |   |   | $ |
---------------------------------
|   |   |   |   |   |   |   | $ |
---------------------------------
|   |   |   |   |   |   |   |   |
---------------------------------
|   |   |wR |   |   |   |   |   |
---------------------------------
|   |bR |   |   |   |   |   |   |
---------------------------------
|   |  

In [5]:
agent_w.learn(total_timesteps=int(2e5), progress_bar=True)

Output()

<stable_baselines3.ppo.ppo.PPO at 0x7f2c086f5ac0>

In [6]:
play(1)

episode=0
-------
---------------------------------
|wR | $ | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ | $ |bR |
---------------------------------

0/0
---------------------------------
|   |wR | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ | $ | $ |
---------------------------------
| $ | $ | $ | $ | $ | $ |

In [7]:
mean_reward, std_reward = evaluate_policy(agent_w, env, deterministic=True)
print(f'{mean_reward=}, {std_reward=}')



mean_reward=0.3919999934732914, std_reward=0.7381029729677139
