In [2]:
import minigrid
from minigrid.wrappers import ImgObsWrapper
from stable_baselines3 import PPO
import gymnasium as gym
import matplotlib.pyplot as plt
from features_extractor import MinigridFeaturesExtractor
import os

## Prepare Workspace

In [3]:
os.makedirs("./models", exist_ok=True)
os.makedirs("./results", exist_ok=True)
CONTROL_MODEL_PATH = "./models/ppo_control"
TRANSFER_MODEL_PATH = "./models/ppo_transfer"

## Train Control Model

In [4]:
policy_kwargs = dict(
    features_extractor_class=MinigridFeaturesExtractor,
    features_extractor_kwargs=dict(features_dim=128),
)
env = gym.make("MiniGrid-UnlockPickup-v0", render_mode="rgb_array")
env = ImgObsWrapper(env)
model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs)
model.learn(1e5, progress_bar=True)
model.save(CONTROL_MODEL_PATH)

Output()

## Train Transfer Model

In [5]:
policy_kwargs = dict(
    features_extractor_class=MinigridFeaturesExtractor,
    features_extractor_kwargs=dict(features_dim=128),
)

# pretraining on simpler env
env = gym.make("MiniGrid-Unlock-v0", render_mode="rgb_array")
env = ImgObsWrapper(env)
model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs)
model.learn(1e2)
model.save(TRANSFER_MODEL_PATH)

# train on final env
env = gym.make("MiniGrid-UnlockPickup-v0", render_mode="rgb_array")
env = ImgObsWrapper(env)
model = PPO.load(TRANSFER_MODEL_PATH, env=env)
model.learn(1e5)
model.save(TRANSFER_MODEL_PATH)

## Evalutation Tool

In [7]:
def eval(
    env: ImgObsWrapper,
    model: PPO,
    eps: int,
    graph_label: str,
    color: str
) -> None:
    rewards = [0]
    total_rewards = 0
    for i in range(eps):
        steps = 0
        obs, _ = env.reset()
        while True:
            action, _state = model.predict(obs)
            obs, _reward, terminated, trunctated, info = env.step(action)
            steps += 1
            if terminated or trunctated:
                total_rewards += _reward
                rewards.append(total_rewards)
                #print(f'eval_ep: {i+1}/{eps} | current_reward: {_reward} | total_rewards: {total_rewards}', end='\r')
                obs = env.reset()
                break
    print(f'\nfinished evaluating {eps} eps. total_rewards: {total_rewards}')
    plt.figure(figsize=(15,5))
    plt.xlim(0, eps)
    plt.plot(rewards, linestyle="-", color=color)
    plt.title(graph_label, fontname='sans-serif', fontsize=14, fontstyle='italic')
    plt.xlabel("Episode", fontname='sans-serif', fontweight="semibold")
    plt.ylabel("Reward", fontname='sans-serif', fontweight="semibold")
    plt.savefig(f'./results/{graph_label.replace(" ", "_")}_eval_rewards.png')

## Evaluate

In [9]:
env = gym.make("MiniGrid-UnlockPickup-v0", render_mode="human")
env = ImgObsWrapper(env)
model_ctrl = PPO.load(CONTROL_MODEL_PATH, env=env)
_, obs = env.reset()
eval(env, model, 100, "PPO Control", "red")

env = gym.make("MiniGrid-UnlockPickup-v0", render_mode="rgb_array")
env = ImgObsWrapper(env)
model_transfer = PPO.load(TRANSFER_MODEL_PATH, env=env)
_, obs = env.reset()
eval(env, model, 100, "PPO Control", "red")

KeyboardInterrupt: 