In [2]:
from gym_minigrid.wrappers import *
from gym_minigrid.minigrid import *
import gym

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.evaluation import evaluate_policy

import numpy as np
import torch
import torch.nn as nn

import matplotlib.pyplot as plt
%matplotlib notebook

# Create custom enviroment

Environment where ball and agent are randomly put into square grid. Ball color and grid size can be changed.

In [3]:
class EmptyBallRoom(MiniGridEnv):
    def __init__(self, size=5, ball_color="green"):
        self.ball_color = ball_color
        super().__init__(
            grid_size=size,
            max_steps=4*size*size,
            # Set this to True for maximum speed
            see_through_walls=True
        )
    def _gen_grid(self, width, height):
        # Create an empty grid
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)

        # Place the agent
        # self.agent_pos = (1, 1)
        # self.agent_dir=0
        self.place_agent()
        
        # Place object
        self.ball = Ball(self.ball_color)
        # self.put_obj(self.ball, width-2, height-2)
        self.place_obj(self.ball)
        self.mission = f"Pick up {self.ball_color} {self.ball.type}"
        
    def step(self, action):
        obs, reward, done, info = super().step(action)

        if action == self.actions.pickup:
            if self.carrying and self.carrying == self.ball:
                reward = self._reward()
                done = True

        return obs, reward, done, info

In [17]:
# Function to create the environment (needed for vectorized environment object in training process)
def make_env(ball_color, room_size, rank, seed=0):
    """
    Utility function for multiprocessed env.

    :param env_id: (str) the environment ID
    :param num_env: (int) the number of environments you wish to have in subprocesses
    :param seed: (int) the inital seed for RNG
    :param rank: (int) index of the subprocess
    """
    def _init():
        env = ImgObsWrapper(RGBImgPartialObsWrapper(EmptyBallRoom(size=16, ball_color=ball_color)))
        #env = gym.make(env_id)
        env.seed(seed + rank)
        return env
    set_random_seed(seed)
    return _init

In [4]:
BALL_COLOR = "green"
ROOM_SIZE = 16

# Reinforcment learning

### Training first agent (not essential to use the trained agent)

In [13]:
# Number of "cores" to use
N_CPU = 12

In [20]:
# Create vectorized environment to make learning faster
env = SubprocVecEnv([make_env(ball_color = BALL_COLOR, room_size=ROOM_SIZE, rank=i*2) for i in range(N_CPU)])

In [11]:
# Activation function to be used in CNN layers
policy_kwargs = dict(activation_fn=nn.ReLU)

In [21]:
# https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html?highlight=PPO
model = PPO('CnnPolicy', env, policy_kwargs=policy_kwargs, verbose=1)

Using cuda device
Wrapping the env in a VecTransposeImage.


In [8]:
model.learn(total_timesteps=2_500_000)

------------------------------
| time/              |       |
|    fps             | 1072  |
|    iterations      | 1     |
|    time_elapsed    | 22    |
|    total_timesteps | 24576 |
------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 563         |
|    iterations           | 2           |
|    time_elapsed         | 87          |
|    total_timesteps      | 49152       |
| train/                  |             |
|    approx_kl            | 0.009247157 |
|    clip_fraction        | 0.0423      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.94       |
|    explained_variance   | -0.224      |
|    learning_rate        | 0.0003      |
|    loss                 | 0.00111     |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.00531    |
|    value_loss           | 0.00283     |
-----------------------------------------
---------------------------

----------------------------------------
| time/                   |            |
|    fps                  | 402        |
|    iterations           | 13         |
|    time_elapsed         | 793        |
|    total_timesteps      | 319488     |
| train/                  |            |
|    approx_kl            | 0.01858892 |
|    clip_fraction        | 0.194      |
|    clip_range           | 0.2        |
|    entropy_loss         | -1.42      |
|    explained_variance   | 0.695      |
|    learning_rate        | 0.0003     |
|    loss                 | 0.00242    |
|    n_updates            | 120        |
|    policy_gradient_loss | -0.0127    |
|    value_loss           | 0.0212     |
----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 400         |
|    iterations           | 14          |
|    time_elapsed         | 858         |
|    total_timesteps      | 344064      |
| train/  

-----------------------------------------
| time/                   |             |
|    fps                  | 391         |
|    iterations           | 24          |
|    time_elapsed         | 1505        |
|    total_timesteps      | 589824      |
| train/                  |             |
|    approx_kl            | 0.019860806 |
|    clip_fraction        | 0.148       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.9        |
|    explained_variance   | 0.643       |
|    learning_rate        | 0.0003      |
|    loss                 | 0.00384     |
|    n_updates            | 230         |
|    policy_gradient_loss | -0.00879    |
|    value_loss           | 0.00962     |
-----------------------------------------
----------------------------------------
| time/                   |            |
|    fps                  | 390        |
|    iterations           | 25         |
|    time_elapsed         | 1572       |
|    total_timesteps      | 614400     

-----------------------------------------
| time/                   |             |
|    fps                  | 387         |
|    iterations           | 35          |
|    time_elapsed         | 2221        |
|    total_timesteps      | 860160      |
| train/                  |             |
|    approx_kl            | 0.018244931 |
|    clip_fraction        | 0.106       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.499      |
|    explained_variance   | 0.607       |
|    learning_rate        | 0.0003      |
|    loss                 | 0.019       |
|    n_updates            | 340         |
|    policy_gradient_loss | -0.00792    |
|    value_loss           | 0.00634     |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 386         |
|    iterations           | 36          |
|    time_elapsed         | 2286        |
|    total_timesteps      | 884736

-----------------------------------------
| time/                   |             |
|    fps                  | 384         |
|    iterations           | 46          |
|    time_elapsed         | 2939        |
|    total_timesteps      | 1130496     |
| train/                  |             |
|    approx_kl            | 0.013262154 |
|    clip_fraction        | 0.0765      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.356      |
|    explained_variance   | 0.626       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0135     |
|    n_updates            | 450         |
|    policy_gradient_loss | -0.00477    |
|    value_loss           | 0.00516     |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 384         |
|    iterations           | 47          |
|    time_elapsed         | 3005        |
|    total_timesteps      | 115507

-----------------------------------------
| time/                   |             |
|    fps                  | 381         |
|    iterations           | 57          |
|    time_elapsed         | 3670        |
|    total_timesteps      | 1400832     |
| train/                  |             |
|    approx_kl            | 0.016204856 |
|    clip_fraction        | 0.0919      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.285      |
|    explained_variance   | 0.651       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0165     |
|    n_updates            | 560         |
|    policy_gradient_loss | -0.00618    |
|    value_loss           | 0.00321     |
-----------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 377          |
|    iterations           | 58           |
|    time_elapsed         | 3778         |
|    total_timesteps      | 1

-----------------------------------------
| time/                   |             |
|    fps                  | 343         |
|    iterations           | 68          |
|    time_elapsed         | 4868        |
|    total_timesteps      | 1671168     |
| train/                  |             |
|    approx_kl            | 0.023581317 |
|    clip_fraction        | 0.126       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.252      |
|    explained_variance   | 0.673       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.00179    |
|    n_updates            | 670         |
|    policy_gradient_loss | -0.00963    |
|    value_loss           | 0.00304     |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 340         |
|    iterations           | 69          |
|    time_elapsed         | 4979        |
|    total_timesteps      | 169574

----------------------------------------
| time/                   |            |
|    fps                  | 320        |
|    iterations           | 79         |
|    time_elapsed         | 6056       |
|    total_timesteps      | 1941504    |
| train/                  |            |
|    approx_kl            | 0.16083813 |
|    clip_fraction        | 0.2        |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.319     |
|    explained_variance   | 0.669      |
|    learning_rate        | 0.0003     |
|    loss                 | 0.00115    |
|    n_updates            | 780        |
|    policy_gradient_loss | 0.00258    |
|    value_loss           | 0.00282    |
----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 318         |
|    iterations           | 80          |
|    time_elapsed         | 6164        |
|    total_timesteps      | 1966080     |
| train/  

----------------------------------------
| time/                   |            |
|    fps                  | 315        |
|    iterations           | 90         |
|    time_elapsed         | 7010       |
|    total_timesteps      | 2211840    |
| train/                  |            |
|    approx_kl            | 0.03021534 |
|    clip_fraction        | 0.113      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.287     |
|    explained_variance   | 0.659      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.0106    |
|    n_updates            | 890        |
|    policy_gradient_loss | -0.00857   |
|    value_loss           | 0.0035     |
----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 316         |
|    iterations           | 91          |
|    time_elapsed         | 7075        |
|    total_timesteps      | 2236416     |
| train/  

-----------------------------------------
| time/                   |             |
|    fps                  | 320         |
|    iterations           | 101         |
|    time_elapsed         | 7734        |
|    total_timesteps      | 2482176     |
| train/                  |             |
|    approx_kl            | 0.028017094 |
|    clip_fraction        | 0.0791      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.227      |
|    explained_variance   | 0.666       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.00708    |
|    n_updates            | 1000        |
|    policy_gradient_loss | -0.00154    |
|    value_loss           | 0.00296     |
-----------------------------------------
----------------------------------------
| time/                   |            |
|    fps                  | 321        |
|    iterations           | 102        |
|    time_elapsed         | 7800       |
|    total_timesteps      | 2506752    

<stable_baselines3.ppo.ppo.PPO at 0x27076feb370>

In [9]:
model.save("rgb_ball_agent")

### Loading trained agent

In [23]:
# To load zip file just enter the file name without the file type in the end
model = PPO.load("rgb_ball_agent", device="cpu")

# Displaying the agent in the enivronment

In [10]:
def display_agent(env,  pov="agent", n_eval=10):
    # Two picturs. Left side = agent's view, right side = whoel environment
    fig,(ax1, ax2) = plt.subplots(1,2, figsize=(10, 5));

    for i in range(n_eval):
        state = env.reset()
        if i == 0:
            image1 = ax1.imshow(state)              
            image2 = ax2.imshow(env.render("rgb_array"))
            
        else:
            image1.set_data(state)
            image2.set_data(env.render("rgb_array"))
        while True:
            action, _ = model.predict(state)
            state, r, done, _ = env.step(action)
            image1.set_data(state)
            image2.set_data(eval_env.render("rgb_array"))
    
            ax1.set_title(f"Action = {action}")
            ax2.set_title(f"Action = {action}")
            fig.canvas.draw()
            if done:
                break

In [18]:
# Make testing environment. Don't use previous vectorized environment (for me it freezed the notebook)
test_env = make_env(ball_color = "green", room_size=20, rank=5)()

In [38]:
display_agent(eval_env)

<IPython.core.display.Javascript object>

# 2nd agent

This requires to change our environment. Because stable_baselines3 gets the action indeices from the environment. Look at:
https://github.com/maximecb/gym-minigrid/blob/1d4edcfd585f6b627ebf1cfd54ea5e70260fa1ee/gym_minigrid/minigrid.py#L629

https://github.com/maximecb/gym-minigrid/blob/1d4edcfd585f6b627ebf1cfd54ea5e70260fa1ee/gym_minigrid/minigrid.py#L1099
I think redefining the Actions class and step function should be sufficient.
Possible way is to inherit EmptyBallRoom class because it has correct reward function.