Skip to content

Commit

Permalink
tmp save
Browse files Browse the repository at this point in the history
  • Loading branch information
Ming Zhou committed Dec 13, 2023
1 parent 6c60dc5 commit c9840a8
Show file tree
Hide file tree
Showing 28 changed files with 309 additions and 209 deletions.
2 changes: 1 addition & 1 deletion examples/sarl/ppo_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np

from malib.utils.episode import Episode
from malib.rollout.episode import Episode
from malib.learner import IndependentAgent
from malib.scenarios import sarl_scenario
from malib.rl.config import Algorithm
Expand Down
2 changes: 2 additions & 0 deletions malib/backend/dataset_server/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class EmptyError(Exception):
pass


# TODO(ming): considering to determine the `max_message_length`
# by a FeatureHandler, as it is convinient for it to know the size of data.
class DynamicDataset(Dataset):
def __init__(
self,
Expand Down
26 changes: 4 additions & 22 deletions malib/backend/dataset_server/feature.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,13 @@
from typing import Any, Dict
from abc import ABC, abstractmethod
from abc import ABC

import copy
import numpy as np
import torch

from gym import spaces
from readerwriterlock import rwlock


numpy_to_torch_dtype_dict = {
np.bool_: torch.bool,
np.uint8: torch.uint8,
np.int8: torch.int8,
np.int16: torch.int16,
np.int32: torch.int32,
np.int64: torch.int64,
np.float16: torch.float16,
np.float32: torch.float32,
np.float64: torch.float64,
np.complex64: torch.complex64,
np.complex128: torch.complex128,
}
from malib.utils.data import numpy_to_torch_dtype_dict


class BaseFeature(ABC):
Expand All @@ -35,15 +21,11 @@ def __init__(
self.rw_lock = rwlock.RWLockFair()
self._device = device
self._spaces = spaces
self._block_size = (
block_size
if block_size is not None
else list(np_memory.values())[0].shape[0]
)
self._block_size = min(block_size or np.iinfo(np.longlong).max, list(np_memory.values())[0].shape[0])
self._available_size = 0
self._flag = 0
self._shared_memory = {
k: torch.from_numpy(v).to(device).share_memory_()
k: torch.from_numpy(v[:self._block_size]).to(device).share_memory_()
for k, v in np_memory.items()
}

Expand Down
5 changes: 4 additions & 1 deletion malib/backend/dataset_server/service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Dict

import threading
import traceback
import pickle
import numpy as np

from . import data_pb2_grpc
from . import data_pb2
Expand All @@ -19,7 +22,7 @@ def __init__(

def Collect(self, request, context):
try:
data = pickle.loads(request.data)
data: Dict[str, np.ndarray] = pickle.loads(request.data)
batch_size = len(list(data.values())[0])
self.feature_handler.safe_put(data, batch_size)
message = "success"
Expand Down
23 changes: 6 additions & 17 deletions malib/learner/indepdent_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,15 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import Dict, Tuple, Any, List, Union
from typing import Dict, Any

import torch

from malib.utils.typing import AgentID
from malib.utils.tianshou_batch import Batch
from malib.utils.data import to_torch
from malib.learner.learner import Learner


class IndependentAgent(Learner):
def multiagent_post_process(
self,
batch_info: Union[
Dict[AgentID, Tuple[Batch, List[int]]], Tuple[Batch, List[int]]
],
) -> Dict[str, Any]:
if not isinstance(batch_info, Tuple):
raise TypeError(
"IndependentAgent support only a tuple of batch info as input."
)

batch = batch_info[0]
batch.to_torch(device=self.device)

return batch
def multiagent_post_process(self, batch: Dict[AgentID, Dict[str, torch.Tensor]]) -> Dict[str, Any]:
return to_torch(batch, device=self.device)
57 changes: 32 additions & 25 deletions malib/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# SOFTWARE.


from typing import Dict, Any, Tuple, Callable, List, Union, Type
from typing import Dict, Any, Tuple, Callable, List, Union
from abc import ABC, abstractmethod

import time
Expand All @@ -50,6 +50,7 @@
from malib.rl.config import Algorithm


# TODO(ming): better to use a feature handler to determine the max_message_length
MAX_MESSAGE_LENGTH = 7309898


Expand All @@ -63,7 +64,6 @@ def __init__(
observation_space: spaces.Space,
action_space: spaces.Space,
algorithm: Algorithm,
agent_mapping_func: Callable[[AgentID], str],
governed_agents: Tuple[AgentID],
custom_config: Dict[str, Any] = None,
dataset: DynamicDataset = None,
Expand Down Expand Up @@ -106,7 +106,6 @@ def __init__(
self._algorithm = algorithm
self._governed_agents = governed_agents
self._strategy_spec = strategy_spec
self._agent_mapping_func = agent_mapping_func
self._custom_config = custom_config
self._policy = strategy_spec.gen_policy(device=device)

Expand Down Expand Up @@ -144,14 +143,12 @@ def __init__(
@abstractmethod
def multiagent_post_process(
self,
batch_info: Union[
Dict[AgentID, Tuple[Batch, List[int]]], Tuple[Batch, List[int]]
],
batch: Dict[AgentID, Dict[str, torch.Tensor]],
) -> Dict[str, Any]:
"""Merge agent buffer here and return the merged buffer.
Args:
batch_info (Union[Dict[AgentID, Tuple[Batch, List[int]]], Tuple[Batch, List[int]]]): Batch info, could be a dict of agent batch info or a tuple.
batch (Dict[AgentID, Dict[str, torch.Tensor]]): A dict of agent batch.
Returns:
Dict[str, Any]: A merged buffer dict.
Expand Down Expand Up @@ -218,6 +215,33 @@ def get_interface_state(self) -> Dict[str, Any]:
"total_epoch": self._total_epoch,
"policy_num": len(self._strategy_spec),
}

def step(self, prints: bool = False):
while (
self.data_loader.dataset.readable_block_size
< self.data_loader.batch_size
):
time.sleep(1)

for data in self.data_loader:
batch_dict = self.multiagent_post_process(data)
batch = Batch(batch_dict)
# call trainer for one update step, and return training info
# since some algorithm may run multistep for one batch,
# then the returned training_info is a list of dict.
step_info_list = self.trainer(batch)
for step_info in step_info_list:
self._total_step += 1
write_to_tensorboard(
self._summary_writer,
info=step_info,
global_step=self._total_step,
prefix=f"Learner/{self._runtime_id}",
)
if prints:
print(self._total_step, step_info)

self._total_epoch += 1

def train(self, task: OptimizationTask) -> Dict[str, Any]:
"""Executes a optimization task and returns the final interface state.
Expand All @@ -233,25 +257,8 @@ def train(self, task: OptimizationTask) -> Dict[str, Any]:
self.set_running(True)

try:
while (
self.data_loader.dataset.readable_block_size
< self.data_loader.batch_size
):
time.sleep(1)

while self.is_running():
for data in self.data_loader:
batch_info = self.multiagent_post_process(data)
step_info_list = self.trainer(batch_info)
for step_info in step_info_list:
self._total_step += 1
write_to_tensorboard(
self._summary_writer,
info=step_info,
global_step=self._total_step,
prefix=f"Learner/{self._runtime_id}",
)
self._total_epoch += 1
self.step()
except Exception as e:
Logger.warning(
f"training pipe is terminated. caused by: {traceback.format_exc()}"
Expand Down
2 changes: 1 addition & 1 deletion malib/rl/coma/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from torch import nn
from gym import spaces

from malib.utils.episode import Episode
from malib.rollout.episode import Episode
from malib.utils.tianshou_batch import Batch
from malib.models.torch import make_net

Expand Down
2 changes: 1 addition & 1 deletion malib/rl/coma/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from malib.utils.typing import AgentID
from malib.utils.tianshou_batch import Batch
from malib.utils.data import Postprocessor
from malib.utils.episode import Episode
from malib.rollout.episode import Episode
from malib.rl.common import misc
from malib.rl.common.trainer import Trainer
from malib.rl.common.policy import Policy
Expand Down
3 changes: 2 additions & 1 deletion malib/rl/pg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

from .policy import PGPolicy
from .trainer import PGTrainer
from .config import DEFAULT_CONFIG
from .config import Config

POLICY = PGPolicy
TRAINER = PGTrainer
DEFAULT_CONFIG = Config
17 changes: 10 additions & 7 deletions malib/rl/pg/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,22 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

DEFAULT_CONFIG = {
"training_config": {

class Config:

TRAINING_CONFIG = {
"optimizer": "Adam",
"lr": 1e-4,
"reward_norm": None,
"n_repeat": 2,
"minibatch": 2,
"batch_size": 32,
"gamma": 0.99,
},
"model_config": {
}

CUSTOM_CONFIG = {}

MODEL_CONFIG = {
"preprocess_net": {"net_type": None, "config": {"hidden_sizes": [64]}},
"hidden_sizes": [64],
},
"custom_config": {},
}
}
6 changes: 3 additions & 3 deletions malib/rl/pg/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from malib.models.config import ModelConfig
from malib.rl.common import misc
from malib.rl.common.policy import Policy, PolicyReturn
from .config import DEFAULT_CONFIG
from .config import Config as DEFAULT_CONFIG


class PGPolicy(Policy):
Expand All @@ -60,9 +60,9 @@ def __init__(

# update model_config with default ones
model_config = merge_dicts(
DEFAULT_CONFIG["model_config"].copy(), model_config or {}
DEFAULT_CONFIG.MODEL_CONFIG.copy(), model_config or {}
)
kwargs = merge_dicts(DEFAULT_CONFIG["custom_config"].copy(), kwargs)
kwargs = merge_dicts(DEFAULT_CONFIG.CUSTOM_CONFIG.copy(), kwargs)

super().__init__(observation_space, action_space, model_config, **kwargs)

Expand Down
4 changes: 2 additions & 2 deletions malib/rl/pg/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@
from malib.utils.general import merge_dicts
from malib.utils.typing import AgentID
from malib.utils.tianshou_batch import Batch
from .config import DEFAULT_CONFIG
from .config import Config


class PGTrainer(Trainer):
def __init__(self, training_config: Dict[str, Any], policy_instance: Policy = None):
# merge from default
training_config = merge_dicts(
DEFAULT_CONFIG["training_config"], training_config or {}
Config.TRAINING_CONFIG, training_config or {}
)
super().__init__(training_config, policy_instance)

Expand Down
8 changes: 7 additions & 1 deletion malib/rl/random/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from .policy import RandomPolicy
from .random_trainer import RandomTrainer
from .config import DEFAULT_CONFIG
from .config import Config

Policy = RandomPolicy
Trainer = RandomTrainer
DEFAULT_CONFIG = Config

__all__ = ["Policy", "Trainer", "DEFAULT_CONFIG"]
13 changes: 7 additions & 6 deletions malib/rl/random/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
DEFAULT_CONFIG = {
"training_config": {
class Config:

TRAINING_CONFIG = {
"optimizer": "Adam",
"lr": 1e-4,
"reward_norm": None,
Expand All @@ -12,9 +13,9 @@
"entropy_coef": 1e-3,
"grad_norm": 5.0,
"use_cuda": False,
},
"model_config": {
}

MODEL_CONFIG = {
"preprocess_net": {"net_type": None, "config": {"hidden_sizes": [64]}},
"hidden_sizes": [64],
},
}
}
15 changes: 14 additions & 1 deletion malib/rl/random/random_trainer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,30 @@
from typing import Any, Dict, Type
from typing import Any, Dict, Sequence, Type

import random
import time
import torch

from torch import optim

from malib.rl.common.policy import Policy
from malib.rl.pg.trainer import PGTrainer
from malib.utils.tianshou_batch import Batch
from malib.utils.typing import AgentID


class RandomTrainer(PGTrainer):
def __init__(self, training_config: Dict[str, Any], policy_instance: Policy = None):
super().__init__(training_config, policy_instance)

def post_process(self, batch: Batch, agent_filter: Sequence[AgentID]) -> Batch:
return batch

def train(self, batch: Batch) -> Dict[str, Any]:
time.sleep(random.random())
return {
"loss": random.random()
}

def setup(self):
self.optimizer: Type[optim.Optimizer] = getattr(
optim, self.training_config["optimizer"]
Expand Down
2 changes: 1 addition & 1 deletion malib/rollout/envs/vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
PolicyID,
)
from malib.rollout.envs.env import Environment
from malib.utils.episode import Episode
from malib.rollout.episode import Episode


EnvironmentType = Type[Environment]
Expand Down
Loading

0 comments on commit c9840a8

Please sign in to comment.