In [None]:
%cd ..

In [1]:
import gym
from gym.spaces.box import Box
import torch
import numpy as np
import random

from baselines import bench
from baselines.common.vec_env import VecEnvWrapper
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.common.vec_env.shmem_vec_env import ShmemVecEnv
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
from baselines.common.vec_env.vec_normalize import VecNormalize

In [2]:
seed=0
random.seed(seed)
torch.manual_seed(seed)
torch.random.manual_seed(seed)
np.random.seed(seed)

In [3]:
# import datetime
# now = datetime.datetime.now().strftime('_%d:%m_%H:%M:%S')
# import wandb
# wandb.init(name='rl_env_test'+now, 
#            project='ucl_msc_proj')

In [4]:
class TransposeImage(gym.ObservationWrapper):
    def __init__(self, env=None):
        super(TransposeImage, self).__init__(env)
        obs_shape = self.observation_space.shape
        self.observation_space = Box(self.observation_space.low[0, 0, 0],
                                     self.observation_space.high[0, 0, 0],
                                     [obs_shape[2], obs_shape[1], obs_shape[0]],
                                     dtype=self.observation_space.dtype)

    def observation(self, observation):
        # Observation is of type Tensor
        return observation.transpose(2, 0, 1)

In [5]:
class VecPyTorch(VecEnvWrapper):
    def __init__(self, venv, device):
        """
        Converts array of observations to Tensors. This makes them
        usable as input to a PyTorch policy network.

        Unsure: Return only every `skip`-th frame.        
        """
        super(VecPyTorch, self).__init__(venv)
        self.device = device

    def reset(self):
        obs = self.venv.reset()
        # convert obs to torch tensor
        obs = torch.from_numpy(obs).float().to(self.device)
        return obs

    def step_async(self, actions):
        actions = actions.squeeze(1).cpu().numpy()
        self.venv.step_async(actions)

    def step_wait(self):
        obs, reward, done, info = self.venv.step_wait()
        # convert obs to torch tensor
        obs = torch.from_numpy(obs).float().to(self.device)
        # convert reward to torch tensor
        reward = torch.from_numpy(reward).unsqueeze(dim=1).float()
        return obs, reward, done, info      

In [6]:
train_seeds=np.arange(0,100)

In [7]:
def make_env(env_id,rank):  

    def _thunk():

        if env_id.startswith('procgen'):
            env = gym.make(env_id, 
                           start_level=100, 
                           num_levels=100, 
                           distribution_mode='easy',
                           rand_seed=int(random.choice(train_seeds)))
        else: 
            raise NotImplementedError  
        

        env = bench.Monitor(env=env, 
                            filename=None, 
                            allow_early_resets=False)
        
        # If the input has shape (H,W,3), wrap for PyTorch convolutions (3,H,W)
        obs_shape = env.observation_space.shape
        if len(obs_shape) == 3 and obs_shape[2] in [1, 3]:
            env = TransposeImage(env)            
        
        return env
    return _thunk

def make_vec_envs(env_name,num_processes,gamma):
    envs = [make_env(env_id=env_name, 
                     rank=i) 
            for i in range(num_processes)]
    if len(envs) > 1:
        envs = SubprocVecEnv(envs)
    else:
        envs = DummyVecEnv(envs)           

    return envs

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

env_name = "procgen:procgen-coinrun-v0"
num_processes = 2
gamma = 0.99

envs = make_vec_envs(env_name,num_processes,gamma)
envs = VecNormalize(envs, ob=False)
envs = VecPyTorch(envs, device)

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [9]:
frame = envs.reset()

In [10]:
torch.sum(frame[0]-frame[1])

tensor(-5700.)