In [433]:
from gym_minigrid.wrappers import *
from gym_minigrid.minigrid import *
import gym

from stable_baselines3 import A2C, PPO
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.evaluation import evaluate_policy

import numpy as np
import torch
import torch.nn as nn

import matplotlib.pyplot as plt
%matplotlib inline

# Define possible colors

In [372]:
COLORS = {
    'red'   : np.array([255, 0, 0]),
    'green' : np.array([0, 255, 0]),
    'blue'  : np.array([0, 0, 255]),
    'purple': np.array([112, 39, 195]),
    'yellow': np.array([255, 255, 0]),
    'grey'  : np.array([100, 100, 100])
}

# Custom environment for first agent

In [175]:
class EmptyRandomEnv(MiniGridEnv):
    """
    Empty grid environment, no obstacles, sparse reward
    """

    def __init__(self, size=16):
       
        super().__init__(
            grid_size=size,
            max_steps=4*size*size,
            # Set this to True for maximum speed
            see_through_walls=True
        )

    def _gen_grid(self, width, height):
        # Create an empty grid
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)

        # Place the agent
        self.place_agent()
        # Place object
        goal  = Goal()
        self.place_obj(goal)
        self.mission = "get to the green goal square"

# Wrapper to change goal color in observation and to render it

In [487]:
class RGBCustomColorPartialObsWrapper(gym.core.ObservationWrapper):
    def  __init__(self, env, goal_color="green", tile_size=16):
        super().__init__(env)
        self.goal_color = goal_color
    
    def observation(self, obs):
        # Goal natural color is [76, 255, 76] if highlighted
        color = COLORS[self.goal_color]
        highlight_color = color + 0.3 * (np.array((255, 255, 255), dtype=np.uint8) - color)
        highlight_color = highlight_color.clip(0, 255).astype(np.uint8)
        
        env = self.unwrapped
        rgb_img_partial = env.get_obs_render(
                obs['image'],
                tile_size=20
            )

        X, Y = np.where(np.all(rgb_img_partial==[76,255,76], axis=-1))
        rgb_img_partial[X, Y] = highlight_color

        return rgb_img_partial
    
    def render_img(self):
        color = COLORS[self.goal_color]
        
        img = self.render(mode="rgb_array")
        X1, Y1 = np.where(np.all(img==[0,255,0], axis=-1))
        X2, Y2 = np.where(np.all(img==[76,255,76], axis=-1))
        img[X1, Y1] = color
        highlight_color = color + 0.3 * (np.array((255, 255, 255), dtype=np.uint8) - color)
        highlight_color = highlight_color.clip(0, 255).astype(np.uint8)
        img[X2, Y2] = highlight_color
        return img

In [489]:
env = RGBCustomColorPartialObsWrapper(EmptyRandomEnv(size=5), goal_color="purple")