Skip to content

Commit

Permalink
test passed for dynamic dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Ming Zhou committed Nov 14, 2023
1 parent 8ddc763 commit 9c3af38
Show file tree
Hide file tree
Showing 20 changed files with 512 additions and 628 deletions.
53 changes: 21 additions & 32 deletions malib/backend/dataset_server/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from typing import Type, Any

import socket
import threading
import grpc
import socket

from concurrent import futures
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import Dataset

from malib.utils.general import find_free_port
from malib.backend.dataset_server.utils import service_wrapper

from .service import DatasetServer
from . import data_pb2_grpc
from .feature import BaseFeature


Expand All @@ -23,43 +21,34 @@ def __init__(
self,
grpc_thread_num_workers: int,
max_message_length: int,
feature_handler_caller: Type,
feature_handler_cls: Type[BaseFeature],
**feature_handler_kwargs,
) -> None:
super().__init__()

# start a service as thread
self.feature_handler: BaseFeature = feature_handler_caller()
self.server = self._start_servicer(
self.feature_handler: BaseFeature = feature_handler_cls(
**feature_handler_kwargs
)
self.server_port = find_free_port()
self.server = service_wrapper(
grpc_thread_num_workers,
max_message_length,
find_free_port(),
)
self.host = socket.gethostbyname(socket.gethostbyname())

def _start_servicer(
self, max_workers: int, max_message_length: int, grpc_port: int
):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=max_workers),
options=[
("grpc.max_send_message_length", max_message_length),
("grpc.max_receive_message_length", max_message_length),
],
)
servicer = DatasetServer(self.feature_handler)
data_pb2_grpc.add_SendDataServicer_to_server(servicer, server)

server.add_insecure_port(f"[::]:{grpc_port}")
server.start()

return server
self.server_port,
)(self.feature_handler)
self.server.start()
self.host = socket.gethostbyname(socket.gethostname())

@property
def entrypoint(self) -> str:
return f"{self.host}:{self.server._state.port}"
return f"{self.host}:{self.server_port}"

@property
def readable_block_size(self) -> str:
return len(self.feature_handler)

def __len__(self):
return self.feature_handler_caller.block_size
return self.feature_handler.block_size

def __getitem__(self, index) -> Any:
if index >= len(self):
Expand All @@ -71,4 +60,4 @@ def __getitem__(self, index) -> Any:
return self.feature_handler.safe_get(index)

def close(self):
self.server.stop()
self.server.wait_for_termination(3)
107 changes: 91 additions & 16 deletions malib/backend/dataset_server/feature.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,105 @@
from typing import Any
from typing import Any, Dict
from abc import ABC, abstractmethod

import copy
import numpy as np
import torch

from gym import spaces
from readerwriterlock import rwlock


class BaseFeature:
def __init__(self) -> None:
numpy_to_torch_dtype_dict = {
np.bool_: torch.bool,
np.uint8: torch.uint8,
np.int8: torch.int8,
np.int16: torch.int16,
np.int32: torch.int32,
np.int64: torch.int64,
np.float16: torch.float16,
np.float32: torch.float32,
np.float64: torch.float64,
np.complex64: torch.complex64,
np.complex128: torch.complex128,
}


class BaseFeature(ABC):
def __init__(
self,
spaces: Dict[str, spaces.Space],
np_memory: Dict[str, np.ndarray],
block_size: int = None,
device: str = "cpu",
) -> None:
self.rw_lock = rwlock.RWLockFair()
self._readable_index = []
self._writable_index = []
self._device = device
self._spaces = spaces
self._block_size = block_size or list(np_memory.values())[0].shape[0]
self._available_size = 0
self._flag = 0
self._shared_memory = {
k: torch.from_numpy(v).to(device).share_memory_()
for k, v in np_memory.items()
}

def get(self, index: int):
"""Get data from this feature.
Args:
index (int): Index of the data.
Returns:
Any: Data
"""
data = {}
for k, v in self._shared_memory.items():
data[k] = v[index]
return data

def write(self, data: Dict[str, Any], start: int, end: int):
for k, v in data.items():
self._shared_memory[k][start:end] = torch.as_tensor(v).to(
self._device, dtype=self._shared_memory[k].dtype
)

def generate_timestep(self) -> Dict[str, np.ndarray]:
return {k: space.sample() for k, space in self.spaces.items()}

def generate_batch(self, batch_size: int = 1) -> Dict[str, np.ndarray]:
batch = {}
for k, space in self.spaces.items():
data = np.stack(
[space.sample() for _ in range(batch_size)], dtype=space.dtype
)
batch[k] = data
return batch

@property
def spaces(self) -> Dict[str, spaces.Space]:
return copy.deepcopy(self._spaces)

@property
def block_size(self) -> int:
raise NotImplementedError
return self._block_size

def __len__(self):
return len(self._readable_index)

def _get(self, index: int):
raise NotImplementedError
return self._available_size

def safe_get(self, index: int):
with self.rw_lock.gen_rlock():
return self._get(index)

def _write(self, data: Any):
raise NotImplementedError
if len(self) == 0:
raise IndexError(f"index:{index} exceeds for available size is 0")
elif index >= len(self):
# re-sampling
index = index % self._available_size
return self.get(index)

def safe_put(self, data: Any):
def safe_put(self, data: Any, batch_size: int):
with self.rw_lock.gen_wlock():
self._write(data)
# request segment asscessment
self.write(data, self._flag, self._flag + batch_size)
self._flag = (self._flag + batch_size) % self._block_size
self._available_size = min(
self._available_size + batch_size, self._block_size
)
4 changes: 3 additions & 1 deletion malib/backend/dataset_server/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ def __init__(
def Collect(self, request, context):
try:
data = pickle.loads(request.data)
self.feature_handler.safe_put(data)
batch_size = len(list(data.values())[0])
self.feature_handler.safe_put(data, batch_size)
message = "success"
except Exception as e:
message = traceback.format_exc()
print(message)

return data_pb2.Reply(message=message)
36 changes: 34 additions & 2 deletions malib/backend/dataset_server/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Any, Union
from concurrent import futures

import pickle
import grpc
import sys
import os
import pickle
import grpc

sys.path.append(os.path.dirname(__file__))

from .service import DatasetServer
from . import data_pb2
from . import data_pb2_grpc

Expand All @@ -25,3 +27,33 @@ def send_data(data: Any, host: str = None, port: int = None, entrypoint: str = N
reply = stub.Collect(data_pb2.Data(data=data))

return reply.message


def service_wrapper(max_workers: int, max_message_length: int, grpc_port: int):
def func(feature_handler):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=max_workers),
options=[
("grpc.max_send_message_length", max_message_length),
("grpc.max_receive_message_length", max_message_length),
],
)
servicer = DatasetServer(feature_handler)
data_pb2_grpc.add_SendDataServicer_to_server(servicer, server)

server.add_insecure_port(f"[::]:{grpc_port}")
return server

return func


def start_server(
max_workers: int, max_message_length: int, grpc_port: int, feature_handler
):
server = service_wrapper(
max_workers=max_workers,
max_message_length=max_message_length,
grpc_port=grpc_port,
)(feature_handler)
server.start()
server.wait_for_termination()
2 changes: 1 addition & 1 deletion malib/common/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Task:
class RolloutTask(Task):
strategy_specs: Dict[str, Any] = field(default_factory=dict())
stopping_conditions: Dict[str, Any] = field(default_factory=dict())
data_entrypoint_mapping: Dict[str, Any] = field(default_factory=dict())
data_entrypoints: Dict[str, Any] = field(default_factory=dict())

@classmethod
def from_raw(
Expand Down
6 changes: 4 additions & 2 deletions malib/common/training_config.py → malib/learner/config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import Dict, Any, Union
from typing import Dict, Any, Union, Type

from dataclasses import dataclass, field

from malib.learner.learner import Learner


# TODO(ming): rename it as LearnerConfig
@dataclass
class TrainingConfig:
trainer_config: Dict[str, Any]
learner_type: str
learner_type: Type[Learner]
custom_config: Dict[str, Any] = field(default_factory=dict())

@classmethod
Expand Down
7 changes: 1 addition & 6 deletions malib/learner/indepdent_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import Dict, Tuple, Any, Callable, List, Type, Union

import gym

from gym import spaces
from malib.backend.dataset_server.data_loader import DynamicDataset
from typing import Dict, Tuple, Any, List, Union

from malib.utils.typing import AgentID
from malib.utils.tianshou_batch import Batch
Expand Down
Loading

0 comments on commit 9c3af38

Please sign in to comment.