In [1]:
import gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.ppo import MlpPolicy

from imitation.algorithms import bc
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper

In [2]:
class ExtractPOVAndTranspose(gym.ObservationWrapper):
    """
    Basically what it says on the tin. Extracts only the POV observation out of the `obs` dict,
    and transposes those observations to be in the (C, H, W) format used by stable_baselines and imitation
    """
    def __init__(self, env):
        super().__init__(env)
        non_transposed_shape = self.env.observation_space['pov'].shape
        self.high = np.max(self.env.observation_space['pov'].high)
        transposed_shape = (non_transposed_shape[2],
                            non_transposed_shape[0],
                            non_transposed_shape[1])
        # Note: this assumes the Box is of the form where low/high values are vector but need to be scalar
        transposed_obs_space = gym.spaces.Box(low=np.min(self.env.observation_space['pov'].low),
                                              high=np.max(self.env.observation_space['pov'].high),
                                              shape=transposed_shape,
                                              dtype=np.uint8)
        self.observation_space = transposed_obs_space

    def observation(self, observation):
        # Minecraft returns shapes in NHWC by default
        return np.swapaxes(observation['pov'], -1, -3)

In [3]:
## These functions are available in the `basalt_utils` package 
## provided in the basalt template repo 


def optional_observation_map(env, inner_obs):
    """
    If the env implements the `observation` function (i.e. if one of the
    wrappers is an ObservationWrapper), call that `observation` transformation
    on the observation produced by the inner environment
    """
    if hasattr(env, 'observation'):
        return env.observation(inner_obs)
    else:
        return inner_obs


def optional_action_map(env, inner_action):
    """
    This is doing something slightly tricky that is explained in the documentation for
    RecursiveActionWrapper (which TODO should eventually be in MineRL)
    Basically, it needs to apply `reverse_action` transformations from the inside out
    when converting the actions stored and used in a dataset

    """
    if hasattr(env, 'wrap_action'):
        return env.wrap_action(inner_action)
    else:
        return inner_action


def recursive_squeeze(dictlike):
    """
    Take a possibly-nested dictionary-like object of which all leaf elements are numpy ar
    """
    out = {}
    for k, v in dictlike.items():
        if isinstance(v, dict):
            out[k] = recursive_squeeze(v)
        else:
            out[k] = np.squeeze(v)
    return out
    
def create_data_iterator(
        wrapped_dummy_env: gym.Env,
        data_pipeline: minerl.data.DataPipeline,
        batch_size: int,
        buffer_size: int = 15000,
        num_epochs: int = None,
        num_batches: int = None,
        remove_no_ops: bool = False,
) -> dict:
    """
    Construct a data iterator that (1) loads data from disk, and (2) wraps it in the set of
    wrappers that have been applied to `wrapped_dummy_env`.

    :param wrapped_dummy_env: An environment that mimics the base environment and wrappers we'll be using for training,
    but doesn't actually call Minecraft
    :param data_pipeline: A MineRL DataPipeline object that can handle loading data from disk
    :param batch_size: The batch size we want the iterator to produce
    :param num_epochs: The number of epochs we want the underlying iterator to run for
    :param num_batches: The number of batches we want the underlying iterator to run for
    :param remove_no_ops: Whether to remove transitions with no-op demonstrator actions from batches
    as they are generated. For now, this corresponds to all-zeros.

    :yield: Wrapped observations and actions in a dict with the keys "obs", "acts", "rews",
         "next_obs", "dones".
    """
    buffered_iterator = BufferedBatchIter(data_pipeline, buffer_target_size=buffer_size)
    for current_obs, action, reward, next_obs, done in buffered_iterator.buffered_batch_iter(batch_size=batch_size,
                                                                                             num_epochs=num_epochs,
                                                                                             num_batches=num_batches):
        wrapped_obs = optional_observation_map(wrapped_dummy_env,
                                               recursive_squeeze(current_obs))
        wrapped_next_obs = optional_observation_map(wrapped_dummy_env,
                                                    recursive_squeeze(next_obs))
        wrapped_action = optional_action_map(wrapped_dummy_env,
                                             recursive_squeeze(action))

        if remove_no_ops:
            # This definitely makes assumptions about the action space, namely that all-zeros corresponds to a no-op
            not_no_op_indices = wrapped_action.sum(axis=1) != 0
            wrapped_obs = wrapped_obs[not_no_op_indices]
            wrapped_next_obs = wrapped_next_obs[not_no_op_indices]
            wrapped_action = wrapped_action[not_no_op_indices]

        return_dict = dict(obs=wrapped_obs,
                           acts=wrapped_action,
                           rews=reward,
                           next_obs=wrapped_next_obs,
                           dones=done)

        yield return_dict

NameError: name 'minerl' is not defined