In [37]:
from mlagents_envs.envs.unity_gym_env import UnityToGymWrapper
from mlagents_envs.environment import ActionTuple, UnityEnvironment
from mlagents_envs.side_channel.engine_configuration_channel import \
    EngineConfigurationChannel
from mlagents_envs.exception import (
    UnityEnvironmentException,
    UnityCommunicationException,
    UnityCommunicatorStoppedException,
)
import numpy as np
import matplotlib.pyplot as plt

In [107]:
try:
    env.close()
except Exception:
    pass

env = UnityEnvironment(file_name=None)

In [108]:
# Unity環境のリセット
env.reset()

# BehaviorNameのリストの取得
behavior_names = list(env.behavior_specs.keys())
print('behavior_names:', behavior_names)

# BehaviorSpecの取得
behavior_spec = env.behavior_specs[behavior_names[0]]

# BehaviorSpecの情報の確認
print('\n== BehaviorSpecの情報の確認 ==')
print('observation_specs:', behavior_spec.observation_specs)
print('action_spec:', behavior_spec.action_spec)

behavior_names: ['GridFoodCollector?team=0']

== BehaviorSpecの情報の確認 ==
observation_specs: [ObservationSpec(shape=(40, 40, 5), dimension_property=(<DimensionProperty.TRANSLATIONAL_EQUIVARIANCE: 2>, <DimensionProperty.TRANSLATIONAL_EQUIVARIANCE: 2>, <DimensionProperty.NONE: 1>), observation_type=<ObservationType.DEFAULT: 0>, name='GridSensor-OneHot')]
action_spec: Continuous: 3, Discrete: (2,)


In [40]:
behavior_spec

BehaviorSpec(observation_specs=[ObservationSpec(shape=(40, 40, 5), dimension_property=(<DimensionProperty.TRANSLATIONAL_EQUIVARIANCE: 2>, <DimensionProperty.TRANSLATIONAL_EQUIVARIANCE: 2>, <DimensionProperty.NONE: 1>), observation_type=<ObservationType.DEFAULT: 0>, name='GridSensor-OneHot')], action_spec=ActionSpec(continuous_size=3, discrete_branches=(2,)))

In [41]:
# 現在のステップの情報の取得
decision_steps, terminal_steps = env.get_steps(behavior_names[0])

# DecisionStepsの情報の確認
print('\n== DecisionStepsの情報の確認 ==')
print('obj:', decision_steps.obs)
print('reward:', decision_steps.reward)
print('agent_id:', decision_steps.agent_id)
print('action_mask:', decision_steps.action_mask)

# TerminalStepsの情報の確認
print('\n== TerminalStepsの情報の確認 ==')
print('obs:', terminal_steps.obs)
print('reward:', terminal_steps.reward)
print('agent_id:', terminal_steps.agent_id)
print('interrupted:', terminal_steps.interrupted)


== DecisionStepsの情報の確認 ==
obj: [array([[[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         ...,
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         ...,
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         ...,
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        ...,

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         ...,
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         ...,
         [0., 0., 0., 0., 0.],
       

In [42]:
len(decision_steps.obs)

1

In [43]:
decision_steps.obs[0].shape

(5, 40, 40, 5)

In [44]:
decision_steps.obs[0].shape

(5, 40, 40, 5)

In [45]:
np.concatenate([decision_steps.obs[0], terminal_steps.obs[0]], axis=0).shape

(5, 40, 40, 5)

In [46]:
terminal_steps.obs[0].shape

(0, 40, 40, 5)

In [47]:
N_AGENTS = len(decision_steps.obs[0])

In [48]:
decision_steps.reward.shape

(5,)

In [49]:
decision_steps.reward

array([0., 0., 0., 0., 0.], dtype=float32)

In [50]:
decision_steps.agent_id

array([0, 1, 2, 3, 4], dtype=int32)

In [51]:
decision_steps.action_mask

[array([[False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False]])]

In [52]:
terminal_steps.reward

array([], dtype=float32)

In [53]:
terminal_steps.agent_id

array([], dtype=int32)

In [54]:
terminal_steps.obs

[array([], shape=(0, 40, 40, 5), dtype=float32)]

In [55]:
behavior_spec.action_spec

ActionSpec(continuous_size=3, discrete_branches=(2,))

In [56]:
behavior_spec.action_spec.discrete_branches

(2,)

In [57]:
len(behavior_spec.action_spec.discrete_branches)

1

In [58]:
behavior_spec.action_spec.continuous_size

3

In [31]:
# env.set_action_for_agent(behavior_names[0], 0, ActionTuple(continuous=np.zeros(3).reshape(1, 3), discrete=np.array([0]).reshape(1, 1)))

In [132]:
class EnvWrapper:

    def __init__(self, env):
        self._env = env
        behavior_names = list(env.behavior_specs.keys())
        behavior_spec = env.behavior_specs[behavior_names[0]]
        self._behavior_name = list(env.behavior_specs.keys())[0]
        self._action_continuous_dim = behavior_spec.action_spec.continuous_size
        self._action_discrete_dim = len(behavior_spec.action_spec.discrete_branches)
        self._n_discrete_actions = behavior_spec.action_spec.discrete_branches
        self._OBSERVATION_IDX = 0
        decision_steps, terminal_steps = env.get_steps(behavior_names[0])
        self._n_agents = len(decision_steps.obs[0])

    def set_action(self, agent_id, continuous_action: np.ndarray = None, discrete_action: np.ndarray = None):
        if continuous_action is None and discrete_action is None:
            raise ValueError('either continuous_action or discrete_action must be specified')
        action_tuple = self.format_action(continuous_action, discrete_action)
        self._env.set_action_for_agent(self._behavior_name, agent_id, action_tuple)

    def format_action(self, continuous_action: np.ndarray, discrete_action: np.ndarray):
        if continuous_action is not None:
            assert continuous_action.ndim == 1
            assert len(continuous_action) == self._action_continuous_dim
        if discrete_action is not None:
            assert discrete_action.ndim == 1
            assert len(discrete_action) == self._action_discrete_dim
        return ActionTuple(
            continuous=continuous_action.reshape(1, self._action_continuous_dim),
            discrete=discrete_action.reshape(1, self._action_discrete_dim)
        )

    def get_state(self):
        decision_steps, terminal_steps = self._env.get_steps(self._behavior_name)
        
        agent_id_diff = list(set(decision_steps.agent_id)-set(terminal_steps.agent_id))
        agent_id_filt = [ai in agent_id_diff for ai in decision_steps.agent_id]

        states = np.concatenate([
            decision_steps.obs[self._OBSERVATION_IDX][agent_id_filt],
            terminal_steps.obs[self._OBSERVATION_IDX]
        ], axis=0)
        agent_ids = np.concatenate([decision_steps.agent_id[agent_id_filt], terminal_steps.agent_id], axis=0)
        rewards = np.concatenate([decision_steps.reward[agent_id_filt], terminal_steps.reward], axis=0)
        dones = [False]*len(decision_steps.reward[agent_id_filt]) + [True]*len(terminal_steps.reward)

        assert states.shape[0] == self._n_agents
        assert len(rewards) == self._n_agents
        assert len(dones) == self._n_agents

        return agent_ids, states, rewards, dones

    def step(self):
        self._env.step()
        agent_ids, states, rewards, dones = self.get_state()
        return agent_ids, states, rewards, dones
    
    def reset(self):
        self._env.reset()
        agent_ids, states, rewards, dones = self.get_state()
        return states

    def random_action(self):
        return (
            2.0*np.random.rand(self._action_continuous_dim) - 1.0 if self._action_continuous_dim > 0 else None,
            np.array([np.random.randint(0, n_act) for n_act in self._n_discrete_actions]) if self._action_discrete_dim > 0 else None,
        )

In [133]:
class Agent:
    
    def __init__(self, w_env, agent_id):
        self._w_env = w_env
        self._agent_id = agent_id

    def set_action(self, continuous_action: np.ndarray = None, discrete_action: np.ndarray = None):
        self._w_env.set_action(self._agent_id, continuous_action, discrete_action)

In [134]:
w_env = EnvWrapper(env)
agents = [Agent(w_env=w_env, agent_id=ai) for ai in range(N_AGENTS)]

In [135]:
c_acts, d_acts = w_env.random_action()

In [136]:
agents[0].set_action(c_acts, d_acts)

In [137]:
agent_ids, states, rewards, dones = w_env.step()

In [138]:
agent_ids

array([0, 1, 2, 3, 4], dtype=int32)

In [81]:
states.shape

(5, 40, 40, 5)

In [82]:
rewards

array([0., 0., 0., 0., 0.], dtype=float32)

In [83]:
dones

[False, False, False, False, False]

In [139]:
import time
import warnings
from tqdm import tqdm

warnings.simplefilter('ignore')

In [140]:
for i in tqdm(range(5000)):
    for agent in agents:
        c_acts, d_acts = w_env.random_action()
        agent.set_action(c_acts, d_acts)
    agent_ids, states, rewards, dones = w_env.step()
    if any(dones):
        print(dones)
        w_env.reset()
    # time.sleep(0.1)

 20%|█████████████████████▌                                                                                      | 1000/5000 [01:39<07:40,  8.69it/s]

[True, True, True, True, True]


 40%|███████████████████████████████████████████▏                                                                | 2001/5000 [03:19<05:44,  8.70it/s]

[True, True, True, True, True]


 60%|████████████████████████████████████████████████████████████████▊                                           | 3002/5000 [04:59<04:12,  7.90it/s]

[True, True, True, True, True]


 80%|██████████████████████████████████████████████████████████████████████████████████████▍                     | 4004/5000 [06:39<01:42,  9.72it/s]

[True, True, True, True, True]


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [08:19<00:00, 10.01it/s]


In [104]:
env.close()