Skip to content

Commit

Permalink
implement new proto version
Browse files Browse the repository at this point in the history
  • Loading branch information
Belissimo-T committed May 8, 2024
1 parent 94963fc commit 88dcbe5
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 83 deletions.
15 changes: 8 additions & 7 deletions run_seekers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@ def main():
if args.nogrpc and not args.ai_files:
raise ValueError("At least one AI file must be provided if gRPC is disabled.")

parsed_config_overrides = parse_config_overrides(args.config_override or [])

config_dict = Config.from_filepath(args.config).to_properties() | parsed_config_overrides
try:
config = Config.from_properties(config_dict, raise_key_error=True)
except KeyError as e:
raise ValueError(f"Invalid config option {e.args[0]!r}.") from e
config = Config.from_filepath(args.config)
for option, value in parse_config_overrides(args.config_override or []).items():
section, key = option.split(".", maxsplit=1)

try:
config.import_option(section, key, value)
except KeyError as e:
raise ValueError(f"Invalid config option {e.args[0]!r}.") from e

logging.basicConfig(level=args.loglevel, style="{", format=f"[{{name}}] {{levelname}}: {{message}}",
stream=sys.stdout)
Expand Down
2 changes: 1 addition & 1 deletion seekers/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, local_ai_locations: typing.Iterable[str], config: Config,
debug: bool = True, print_scores: bool = True, dont_kill: bool = False):
self._logger = logging.getLogger("SeekersGame")

self._logger.debug(f"Config: {config.to_properties()}")
self._logger.debug(f"Config: {config}")

self.config = config
self.debug = debug
Expand Down
15 changes: 8 additions & 7 deletions seekers/grpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from grpc._channel import _InactiveRpcError

from .converters import *
from .stubs.org.seekers.net.seekers_pb2 import *
from .stubs.org.seekers.net.seekers_pb2_grpc import *
from .stubs.org.seekers.grpc.service.seekers_pb2 import *
from .stubs.org.seekers.grpc.service.seekers_pb2_grpc import *

import seekers.colors

Expand All @@ -30,6 +30,7 @@ class GrpcSeekersServiceWrapper:
def __init__(self, address: str = "localhost:7777"):
self.name: str | None = None
self.token: str | None = None
self.config: list[Section] | None = None

self.channel = grpc.insecure_channel(address)
self.stub = SeekersStub(self.channel)
Expand All @@ -50,8 +51,11 @@ def join(self, name: str, color: seekers.Color = None) -> str:
self._logger.info(f"Joining game as {name!r} with color {color!r}.")

try:
reply = self.stub.Join(JoinRequest(details=dict(name=name, color=color_to_grpc(color))))
reply: JoinResponse = self.stub.Join(JoinRequest(name=name, color=color_to_grpc(color)))

self.token = reply.token
self.config = reply.sections

return reply.player_id
except _InactiveRpcError as e:
if e.code() in [grpc.StatusCode.UNAUTHENTICATED, grpc.StatusCode.INVALID_ARGUMENT]:
Expand All @@ -66,9 +70,6 @@ def join(self, name: str, color: seekers.Color = None) -> str:
) from e
raise

def get_server_properties(self) -> dict[str, str]:
return self.stub.Properties(Empty()).entries

def send_commands(self, commands: list[Command]) -> CommandResponse:
if self.channel_connectivity_status != grpc.ChannelConnectivity.READY:
raise ServerUnavailableError("Channel is not ready.")
Expand Down Expand Up @@ -145,7 +146,7 @@ def run(self):

def get_config(self):
if self._server_config is None:
self._server_config = seekers.Config.from_properties(self.service_wrapper.get_server_properties())
self._server_config = config_to_seekers(self.service_wrapper.config)

return self._server_config

Expand Down
55 changes: 45 additions & 10 deletions seekers/grpc/converters.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Functions that convert between the gRPC types and the internal types."""
from .stubs.org.seekers.game.camp_pb2 import Camp
from .stubs.org.seekers.game.goal_pb2 import Goal
from .stubs.org.seekers.game.physical_pb2 import Physical
from .stubs.org.seekers.game.player_pb2 import Player
from .stubs.org.seekers.game.seeker_pb2 import Seeker
from .stubs.org.seekers.game.vector2d_pb2 import Vector2D
import dataclasses
from collections import defaultdict

from .stubs.org.seekers.grpc.game.camp_pb2 import Camp
from .stubs.org.seekers.grpc.game.goal_pb2 import Goal
from .stubs.org.seekers.grpc.game.physical_pb2 import Physical
from .stubs.org.seekers.grpc.game.player_pb2 import Player
from .stubs.org.seekers.grpc.game.seeker_pb2 import Seeker
from .stubs.org.seekers.grpc.game.vector2d_pb2 import Vector2D
from .stubs.org.seekers.grpc.service.seekers_pb2 import Section

from .. import seekers_types as seekers

Expand Down Expand Up @@ -98,11 +102,10 @@ def color_to_grpc(color: tuple[int, int, int]):
def player_to_seekers(player: Player) -> seekers.Player:
out = seekers.Player(
id=player.id,
name=player.name,
name=str(player.id),
score=player.score,
seekers={}
)
out.color = color_to_seekers(player.color),

return out

Expand All @@ -111,9 +114,9 @@ def player_to_grpc(player: seekers.Player) -> Player:
return Player(
id=player.id,
seeker_ids=[seeker.id for seeker in player.seekers.values()],
name=player.name,
# name=player.name,
camp_id=player.camp.id,
color=color_to_grpc(player.color),
# color=color_to_grpc(player.color),
score=player.score,
)

Expand All @@ -138,3 +141,35 @@ def camp_to_grpc(camp: seekers.Camp) -> Camp:
width=camp.width,
height=camp.height
)


def config_to_grpc(config: seekers.Config) -> list[Section]:
out = defaultdict(dict)

for attribute_name, value in dataclasses.asdict(config).items():
section, key = config.get_section_and_key(attribute_name)

out[section][key] = config.value_to_str(value)

return [Section(name=section, entries=data) for section, data in out.items()]


def config_to_seekers(config: list[Section], ignore_missing: bool = True) -> seekers.Config:
config_field_types = {field.name: field.type for field in dataclasses.fields(seekers.Config) if field.init}

all_fields_as_none = {k: None for k in config_field_types}

kwargs = {}
for section in config:
for key, value in section.entries.items():
field_name = seekers.Config.get_attribute_name(section.name, key)

if field_name not in config_field_types:
if ignore_missing:
raise KeyError(section.name)
else:
continue

kwargs[field_name] = seekers.Config.value_from_str(value, config_field_types[field_name])

return seekers.Config(**(all_fields_as_none | kwargs))
43 changes: 16 additions & 27 deletions seekers/grpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
from concurrent.futures import ThreadPoolExecutor

from .converters import *
from .stubs.org.seekers.net.seekers_pb2 import *
from .stubs.org.seekers.net.seekers_pb2_grpc import *
from .stubs.org.seekers.grpc.service.seekers_pb2 import *
from .stubs.org.seekers.grpc.service.seekers_pb2_grpc import *

from .. import colors
from .. import game


Expand All @@ -25,9 +24,6 @@ def __init__(self, seekers_game: game.SeekersGame, game_start_event: threading.E
self.next_game_tick_event = threading.Event()
self.tokens: set[str] = set()

def Properties(self, request: Empty, context) -> PropertiesResponse:
return PropertiesResponse(entries=self.game.config.to_properties())

def new_tick(self):
"""Invalidate the cached game status. Called by SeekersGame."""
# self._logger.debug("New tick!")
Expand Down Expand Up @@ -91,7 +87,7 @@ def Command(self, request: CommandRequest, context: grpc.ServicerContext) -> Com
self.generate_status()
return self.current_status

def join_game(self, name: str, color: seekers.Color) -> tuple[str, str]:
def join_game(self, name: str, color: seekers.Color | None) -> tuple[str, str]:
# add the player with a new name if the requested name is already taken
_requested_name = name
i = 2
Expand All @@ -117,41 +113,34 @@ def join_game(self, name: str, color: seekers.Color) -> tuple[str, str]:
return new_token, player.id

def Join(self, request: JoinRequest, context) -> JoinResponse | None:
self._logger.debug(f"Received JoinRequest: {request.details!r}")
self._logger.debug(f"Received JoinRequest: {request.name=} {request.color=}")

# validate requested name
try:
requested_name = request.details["name"].strip()
except KeyError:
context.abort(grpc.StatusCode.INVALID_ARGUMENT,
"No 'name' key was provided in JoinRequest.details.")
return
if request.name is None:
requested_name = "Player"
else:
requested_name = request.name.strip()

if not requested_name:
context.abort(grpc.StatusCode.INVALID_ARGUMENT,
f"Requested name must not be empty or only consist of whitespace.")
return
if not requested_name:
context.abort(grpc.StatusCode.INVALID_ARGUMENT,
f"Requested name must not be empty or only consist of whitespace.")
return

color = (
colors.string_hash_color(requested_name)
if request.details.get("color") is None
else color_to_seekers(request.details["color"])
)
color = color_to_seekers(request.color) if request.color is not None else None

# add player to game
try:
new_token, player_id = self.join_game(requested_name, color)
except seekers.game.GameFullError:
except game.GameFullError:
context.abort(grpc.StatusCode.RESOURCE_EXHAUSTED, "Game is full.")
return

return JoinResponse(token=new_token, player_id=player_id)
return JoinResponse(token=new_token, player_id=player_id, sections=config_to_grpc(self.game.config))


class GrpcSeekersServer:
"""A wrapper around the GrpcSeekersServicer that handles the gRPC server."""

def __init__(self, seekers_game: seekers.game.SeekersGame, address: str = "localhost:7777"):
def __init__(self, seekers_game: game.SeekersGame, address: str = "localhost:7777"):
self._logger = logging.getLogger(self.__class__.__name__)
self.game_start_event = threading.Event()

Expand Down
48 changes: 18 additions & 30 deletions seekers/seekers_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def from_filepath(cls, filepath: str) -> "Config":
return cls.from_file(f)

@staticmethod
def _dump_value(value: typing.Any) -> str:
def value_to_str(value: bool | float | int | str) -> str:
if isinstance(value, bool):
return str(value).lower()
elif isinstance(value, float):
Expand All @@ -139,7 +139,7 @@ def _dump_value(value: typing.Any) -> str:
return str(value)

@staticmethod
def _load_value(value: str, type_: str):
def value_from_str(value: str, type_: typing.Literal["bool", "float", "int", "str"]) -> bool | float | int | str:
if type_ == "bool":
return value.lower() == "true"
elif type_ == "float":
Expand All @@ -149,40 +149,28 @@ def _load_value(value: str, type_: str):
else:
return value

def to_properties(self) -> dict[str, str]:
self_dict = dataclasses.asdict(self)

def convert_specifier(specifier: str) -> str:
specifier = specifier.replace("_", ".", 1)
specifier = specifier.replace("_", "-")
return specifier

return {convert_specifier(k): self._dump_value(v) for k, v in self_dict.items()}

@classmethod
def from_properties(cls, properties: dict[str, str], raise_key_error: bool = False) -> "Config":
"""Converts a dictionary of properties, as received by a gRPC client, to a Config object."""
all_kwargs = {field.name: field.type for field in dataclasses.fields(Config) if field.init}
@staticmethod
def get_section_and_key(attribute_name: str) -> tuple[str, str]:
"""Split an attribute name into the config header name and the key name."""

all_fields_as_none = {k: None for k in all_kwargs}
section, key = attribute_name.split("_", 1)

kwargs = {}
for key, value in properties.items():
# field.name-example -> field_name_example
field_name = key.replace(".", "_").replace("-", "_")
return section, key.replace("_", "-")

if field_name not in all_kwargs:
if raise_key_error:
raise KeyError(key)
else:
continue
@staticmethod
def get_attribute_name(section: str, key: str) -> str:
return f"{section}_{key.replace('-', '_')}"

# convert the value to the correct type
kwargs[field_name] = cls._load_value(value, all_kwargs[field_name])
@classmethod
def get_field_type(cls, field_name: str) -> typing.Literal["bool", "float", "int", "str"]:
field_types = {f.name: f.type for f in dataclasses.fields(cls)}
return field_types[field_name]

kwargs = all_fields_as_none | kwargs
def import_option(self, section: str, key: str, value: str):
field_name = self.get_attribute_name(section, key)
field_type = self.get_field_type(field_name)

return cls(**kwargs)
setattr(self, field_name, self.value_from_str(value, field_type))


class Vector:
Expand Down

0 comments on commit 88dcbe5

Please sign in to comment.