Skip to content

Commit

Permalink
tmp save
Browse files Browse the repository at this point in the history
  • Loading branch information
Ming Zhou committed Nov 8, 2023
1 parent c01e463 commit ab55fa1
Show file tree
Hide file tree
Showing 13 changed files with 335 additions and 67 deletions.
10 changes: 10 additions & 0 deletions malib/common/rollout_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@
@dataclass
class RolloutConfig:
inference_server_type: str
"""Inference server type"""

num_workers: int = 1
"""Defines how many workers will be used for executing one rollout task, default is 1"""

n_envs_per_worker: int = 1
"""Indicates how many environments will be activated for a rollout task per rollout worker, default is 1"""

timelimit: int = 256
"""Specifying how many time steps will be collected for each rollout, default is 256"""

@classmethod
def from_raw(
Expand Down
64 changes: 35 additions & 29 deletions malib/common/strategy_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from typing import Dict, Any, Tuple, Type
from argparse import Namespace

from collections import namedtuple
import numpy as np

from malib.rl.common.policy import Policy
Expand All @@ -48,9 +48,21 @@ def validate_meta_data(policy_ids: Tuple[PolicyID], meta_data: Dict[str, Any]):
assert np.isclose(sum(meta_data["prob_list"]), 1.0)


import copy

from gym import spaces


class StrategySpec:
def __init__(
self, identifier: str, policy_ids: Tuple[PolicyID], meta_data: Dict[str, Any]
self,
policy_cls: Type,
observation_space: spaces.Space,
action_space: spaces.Space,
model_config: Dict[str, Any] = None,
identifier: str = None,
policy_ids: Tuple[PolicyID] = None,
**kwargs,
) -> None:
"""Construct a strategy spec.
Expand All @@ -60,10 +72,17 @@ def __init__(
meta_data (Dict[str, Any]): Meta data, for policy construction.
"""

validate_meta_data(policy_ids, meta_data)
self.id = identifier
self.policy_ids = tuple(policy_ids)
self.meta_data = meta_data
self.id = identifier or "StrategySpec"
self.policy_ids = tuple(policy_ids) if policy_ids else ()
self.meta_data = {
"policy_cls": policy_cls,
"init_kwargs": {
"observation_space": observation_space,
"action_space": action_space,
"model_config": model_config,
**kwargs,
},
}

def __str__(self):
return f"<StrategySpec: {self.policy_ids}>"
Expand All @@ -85,7 +104,8 @@ def register_policy_id(self, policy_id: PolicyID):
policy_id (PolicyID): Policy id to register.
"""

assert policy_id not in self.policy_ids, (policy_id, self.policy_ids)
if policy_id in self.policy_ids:
raise KeyError("repected policy id detected: {}".format(policy_id))
self.policy_ids = self.policy_ids + (policy_id,)

if "prob_list" in self.meta_data:
Expand All @@ -104,10 +124,11 @@ def update_prob_list(self, policy_probs: Dict[PolicyID, float]):
for pid, prob in policy_probs.items():
idx = self.policy_ids.index(pid)
self.meta_data["prob_list"][idx] = prob
assert np.isclose(sum(self.meta_data["prob_list"]), 1.0), (
self.meta_data["prob_list"],
sum(self.meta_data["prob_list"]),
)

if not np.isclose(sum(self.meta_data["prob_list"]), 1.0):
raise ValueError(
f"Prob list is not normalized: {self.meta_data['prob_list']}"
)

def get_meta_data(self) -> Dict[str, Any]:
"""Return meta data. Keys in meta-data contains
Expand All @@ -121,7 +142,7 @@ def get_meta_data(self) -> Dict[str, Any]:
Dict[str, Any]: A dict of meta data.
"""

return self.meta_data
return copy.deepcopy(self.meta_data)

def gen_policy(self, device=None) -> Policy:
"""Generate a policy instance with the given meta data.
Expand All @@ -131,23 +152,8 @@ def gen_policy(self, device=None) -> Policy:
"""

policy_cls: Type[Policy] = self.meta_data["policy_cls"]
plist = self.meta_data["kwargs"]
plist = Namespace(**plist)

custom_config = plist.custom_config.copy()

if device is not None and "cuda" in device:
custom_config["use_cuda"] = True
else:
custom_config["use_cuda"] = False

return policy_cls(
observation_space=plist.observation_space,
action_space=plist.action_space,
model_config=plist.model_config,
custom_config=custom_config,
**plist.kwargs,
)
policy = policy_cls(**self.meta_data["init_kwargs"])
return policy.to(device)

def sample(self) -> PolicyID:
"""Sample a policy instance. Use uniform sample if there is no presetted prob list in meta data.
Expand Down
25 changes: 9 additions & 16 deletions malib/rl/common/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,25 +62,21 @@ def state_dict(self):


class Policy(metaclass=ABCMeta):
def __init__(
self, observation_space, action_space, model_config, custom_config, **kwargs
):
def __init__(self, observation_space, action_space, model_config, **kwargs):
_locals = locals()
_locals.pop("self")
self._init_args = _locals
self._observation_space = observation_space
self._action_space = action_space
self._model_config = model_config or {}
self._custom_config = custom_config or {}
self._model_config = model_config
self._custom_config = kwargs
self._state_handler_dict = {}
self._preprocessor = get_preprocessor(
observation_space,
mode=self._custom_config.get("preprocess_mode", "flatten"),
mode=kwargs.get("preprocess_mode", "flatten"),
)(observation_space)

self._device = torch.device(
"cuda" if self._custom_config.get("use_cuda") else "cpu"
)
self._device = torch.device("cuda" if kwargs.get("use_cuda") else "cpu")

self._registered_networks: Dict[str, nn.Module] = {}

Expand All @@ -95,16 +91,13 @@ def __init__(
)
)

self.use_cuda = self._custom_config.get("use_cuda", False)
self.use_cuda = kwargs.get("use_cuda", False)
self.dist_fn: Distribution = make_proba_distribution(
action_space=action_space,
use_sde=custom_config.get("use_sde", False),
dist_kwargs=custom_config.get("dist_kwargs", None),
use_sde=kwargs.get("use_sde", False),
dist_kwargs=kwargs.get("dist_kwargs", None),
)
if kwargs.get("model_client"):
self.model = kwargs["model_client"]
else:
self.model = self.create_model()
self.model = kwargs.get("model_client", self.create_model())

def create_model(self):
raise NotImplementedError
Expand Down
15 changes: 5 additions & 10 deletions malib/rl/pg/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
model_config: Dict[str, Any],
custom_config: Dict[str, Any],
model_config: Dict[str, Any] = None,
**kwargs
):
"""Build a REINFORCE policy whose input and output dims are determined by observation_space and action_space, respectively.
Expand All @@ -52,23 +51,19 @@ def __init__(
observation_space (spaces.Space): The observation space.
action_space (spaces.Space): The action space.
model_config (Dict[str, Any]): The model configuration dict.
custom_config (Dict[str, Any]): The custom configuration dict.
is_fixed (bool, optional): Indicates fixed policy or trainable policy. Defaults to False.
Raises:
NotImplementedError: Does not support other action space type settings except Box and Discrete.
TypeError: Unexpected action space.
"""

# update model_config with default ones
model_config = merge_dicts(DEFAULT_CONFIG["model_config"].copy(), model_config)
custom_config = merge_dicts(
DEFAULT_CONFIG["custom_config"].copy(), custom_config
model_config = merge_dicts(
DEFAULT_CONFIG["model_config"].copy(), model_config or {}
)
kwargs = merge_dicts(DEFAULT_CONFIG["custom_config"].copy(), kwargs)

super().__init__(
observation_space, action_space, model_config, custom_config, **kwargs
)
super().__init__(observation_space, action_space, model_config, **kwargs)

def create_model(self):
# update model preprocess_net config here
Expand Down
5 changes: 1 addition & 4 deletions malib/rl/random/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@ def __init__(
observation_space: spaces.Space,
action_space: spaces.Space,
model_config: Dict[str, Any],
custom_config: Dict[str, Any],
**kwargs
):
super().__init__(
observation_space, action_space, model_config, custom_config, **kwargs
)
super().__init__(observation_space, action_space, model_config, **kwargs)
9 changes: 9 additions & 0 deletions malib/rollout/envs/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self, **configs):
self._configs = configs
self._current_players = []
self._state: Dict[str, np.ndarray] = None
self._deactivated = True

def record_episode_info_step(
self,
Expand Down Expand Up @@ -103,11 +104,19 @@ def action_spaces(self) -> Dict[AgentID, gym.Space]:

raise NotImplementedError

@property
def is_deactivated(self) -> bool:
return self._deactivated

def deactivate(self):
self._deactivated = True

def reset(self, max_step: int = None) -> Union[None, Sequence[Dict[AgentID, Any]]]:
"""Reset environment and the episode info handler here."""

self.max_step = max_step or self.max_step
self.cnt = 0
self._deactivated = False

self.episode_metrics = {
"env_step": 0,
Expand Down
4 changes: 2 additions & 2 deletions malib/rollout/envs/mdp/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, **configs):
)

self.env = scenarios[env_id]().to_env()
self._possible_agents = ["agent"]
self._possible_agents = ["default"]

@property
def possible_agents(self) -> List[AgentID]:
Expand All @@ -57,7 +57,7 @@ def time_step(
Dict[AgentID, bool],
Dict[AgentID, Any],
]:
obs, rew, done, info = self.env._step(actions["agent"])
obs, rew, done, info = self.env._step(actions["default"])

obs = dict.fromkeys(self.possible_agents, obs)
rew = dict.fromkeys(self.possible_agents, rew)
Expand Down
Loading

0 comments on commit ab55fa1

Please sign in to comment.