From 63d5abf145d45fb45eb88f3af6e7cca90b1eb325 Mon Sep 17 00:00:00 2001 From: Ezeudoh Tochukwu Date: Fri, 18 Oct 2024 11:30:41 +0100 Subject: [PATCH] Added support for route resolvers to supply custom alias to be used as resolving alias or name --- ellar/__init__.py | 1 - ellar/common/params/resolvers/base.py | 21 +++- ellar/common/params/resolvers/parameter.py | 120 +++++++++++++-------- 3 files changed, 91 insertions(+), 51 deletions(-) diff --git a/ellar/__init__.py b/ellar/__init__.py index 55204455..616fc2d6 100644 --- a/ellar/__init__.py +++ b/ellar/__init__.py @@ -1,4 +1,3 @@ """Ellar - Python ASGI web framework for building fast, efficient, and scalable RESTful APIs and server-side applications.""" __version__ = "0.8.4" - diff --git a/ellar/common/params/resolvers/base.py b/ellar/common/params/resolvers/base.py index 70639883..5c9829a0 100644 --- a/ellar/common/params/resolvers/base.py +++ b/ellar/common/params/resolvers/base.py @@ -44,12 +44,15 @@ async def resolve(self, *args: t.Any, **kwargs: t.Any) -> ResolverResult: @abstractmethod @t.no_type_check - def create_raw_data(self, data: t.Any) -> t.Dict: + def create_raw_data( + self, data: t.Any, field_name: t.Optional[str] = None + ) -> t.Dict: """ Creates the raw data for the parameter. Args: data: The resolved value of the parameter. + field_name: The name of the field. Returns: `dict`: A dictionary containing the raw data. @@ -62,8 +65,10 @@ def __init__(self, model_field: ModelField, *args: t.Any, **kwargs: t.Any) -> No RouteParameterModelField, model_field ) - def create_raw_data(self, data: t.Any) -> t.Dict: - return {self.model_field.name: data} + def create_raw_data( + self, data: t.Any, field_name: t.Optional[str] = None + ) -> t.Dict: + return {field_name or self.model_field.name: data} def assert_field_info(self) -> None: """ @@ -91,13 +96,21 @@ async def resolve(self, *args: t.Any, **kwargs: t.Any) -> ResolverResult: @abstractmethod @t.no_type_check - async def resolve_handle(self, *args: t.Any, **kwargs: t.Any) -> ResolverResult: + async def resolve_handle( + self, + *args: t.Any, + alias: t.Optional[str] = None, + name: t.Optional[str] = None, + **kwargs: t.Any, + ) -> ResolverResult: """ Resolves the value of the parameter during request processing. Args: *args: Additional positional arguments. **kwargs: Additional keyword arguments. + alias: The alias of the parameter. Optional. + name: The name of the parameter. Optional. Returns: `ResolverResult`: A named tuple containing the resolved value, any errors, and the raw data. diff --git a/ellar/common/params/resolvers/parameter.py b/ellar/common/params/resolvers/parameter.py index aecbc9eb..f2ff81c4 100644 --- a/ellar/common/params/resolvers/parameter.py +++ b/ellar/common/params/resolvers/parameter.py @@ -43,43 +43,47 @@ def get_received_parameter( return connection.headers async def resolve_handle( - self, ctx: IExecutionContext, *args: t.Any, **kwargs: t.Any + self, + ctx: IExecutionContext, + *args: t.Any, + alias: t.Optional[str] = None, + name: t.Optional[str] = None, + **kwargs: t.Any, ) -> ResolverResult: + alias = alias or self.model_field.alias + name = name or self.model_field.name request_logger.debug( f"Resolving Header Parameters - '{self.__class__.__name__}'" ) received_params = self.get_received_parameter(ctx=ctx) if is_sequence_field(self.model_field): - value = ( - received_params.getlist(self.model_field.alias) - or self.model_field.default - ) + value = received_params.getlist(alias) or self.model_field.default else: - value = received_params.get(self.model_field.alias) + value = received_params.get(alias) self.assert_field_info() field_info = self.model_field.field_info values = {} if value is None: if self.model_field.required: - errors = [ - self.create_error( - loc=(field_info.in_.value, self.model_field.alias) - ) - ] - return ResolverResult({}, errors, raw_data=self.create_raw_data(value)) + errors = [self.create_error(loc=(field_info.in_.value, alias))] + return ResolverResult( + {}, errors, raw_data=self.create_raw_data(value, field_name=name) + ) else: value = copy.deepcopy(self.model_field.default) - values[self.model_field.name] = value - return ResolverResult(values, [], raw_data=self.create_raw_data(value)) + values[name] = value + return ResolverResult( + values, [], raw_data=self.create_raw_data(value, field_name=name) + ) v_, errors_ = self.model_field.validate( - value, values, loc=(field_info.in_.value, self.model_field.alias) + value, values, loc=(field_info.in_.value, alias) ) return ResolverResult( - data={self.model_field.name: v_}, + data={name: v_}, errors=self.validate_error_sequence(errors_), - raw_data=self.create_raw_data(value), + raw_data=self.create_raw_data(value, field_name=name), ) @@ -99,22 +103,28 @@ def get_received_parameter(cls, ctx: IExecutionContext) -> t.Mapping[str, t.Any] return connection.path_params async def resolve_handle( - self, ctx: IExecutionContext, **kwargs: t.Any + self, + ctx: IExecutionContext, + alias: t.Optional[str] = None, + name: t.Optional[str] = None, + **kwargs: t.Any, ) -> ResolverResult: + alias = alias or self.model_field.alias + name = name or self.model_field.name request_logger.debug(f"Resolving Path Parameters - '{self.__class__.__name__}'") received_params = self.get_received_parameter(ctx=ctx) - value = received_params.get(str(self.model_field.alias)) + value = received_params.get(str(alias)) self.assert_field_info() v_, errors_ = self.model_field.validate( value, {}, - loc=(self.model_field.field_info.in_.value, self.model_field.alias), + loc=(self.model_field.field_info.in_.value, alias), ) return ResolverResult( - data={self.model_field.name: v_}, + data={name: v_}, errors=self.validate_error_sequence(errors_), - raw_data=self.create_raw_data(value), + raw_data=self.create_raw_data(value, field_name=name), ) @@ -127,44 +137,54 @@ def get_received_parameter(cls, ctx: IExecutionContext) -> t.Mapping[str, t.Any] class WsBodyParameterResolver(BaseRouteParameterResolver): async def resolve_handle( - self, ctx: IExecutionContext, *args: t.Any, body: t.Any, **kwargs: t.Any + self, + ctx: IExecutionContext, + *args: t.Any, + body: t.Any, + alias: t.Optional[str] = None, + name: t.Optional[str] = None, + **kwargs: t.Any, ) -> t.Tuple: + alias = alias or self.model_field.alias + name = name or self.model_field.name request_logger.debug( f"Resolving Websocket Body Parameters - '{self.__class__.__name__}'" ) embed = getattr(self.model_field.field_info, "embed", False) - received_body = {self.model_field.alias: body} + received_body = {alias: body} loc = ("body",) if embed: received_body = body - loc = ("body", self.model_field.alias) # type:ignore + loc = ("body", alias) # type:ignore try: - value = received_body.get(self.model_field.alias) + value = received_body.get(alias) if value is None: if self.model_field.required: return ResolverResult( None, [self.create_error(loc=loc)], - raw_data=self.create_raw_data(value), + raw_data=self.create_raw_data(value, field_name=name), ) else: value = copy.deepcopy(self.model_field.default) return ResolverResult( - {self.model_field.name: value}, + {name: value}, [], - raw_data=self.create_raw_data(value), + raw_data=self.create_raw_data(value, field_name=name), ) v_, errors_ = self.model_field.validate(value, {}, loc=loc) return ResolverResult( - data={self.model_field.name: v_}, + data={name: v_}, errors=self.validate_error_sequence(errors_), - raw_data=self.create_raw_data(value), + raw_data=self.create_raw_data(value, field_name=name), ) except AttributeError: errors = [self.create_error(loc=loc)] - return ResolverResult(None, errors, raw_data=self.create_raw_data(None)) + return ResolverResult( + None, errors, raw_data=self.create_raw_data(None, field_name=name) + ) class BodyParameterResolver(WsBodyParameterResolver): @@ -228,10 +248,10 @@ async def resolve_handle( class FormParameterResolver(BodyParameterResolver): async def process_and_validate( - self, *, values: t.Dict, value: t.Any, loc: t.Tuple + self, *, values: t.Dict, value: t.Any, loc: t.Tuple, field_name: str ) -> t.Tuple: v_, errors_ = self.model_field.validate(value, values, loc=loc) - values[self.model_field.name] = v_ + values[field_name] = v_ return ResolverResult( data=values, errors=self.validate_error_sequence(errors_), @@ -257,22 +277,26 @@ async def resolve_handle( ctx: IExecutionContext, *args: t.Any, body: t.Optional[t.Any] = None, + alias: t.Optional[str] = None, + name: t.Optional[str] = None, **kwargs: t.Any, ) -> t.Tuple: + alias = alias or self.model_field.alias + name = name or self.model_field.name _body = body or await self.get_request_body(ctx) embed = getattr(self.model_field.field_info, "embed", False) - received_body = {self.model_field.alias: _body} + received_body = {alias: _body} loc = ("body",) if embed: received_body = _body - loc = ("body", self.model_field.alias) # type:ignore + loc = ("body", alias) # type:ignore if is_sequence_field(self.model_field) and isinstance(_body, FormData): - loc = ("body", self.model_field.alias) # type: ignore - value = _body.getlist(self.model_field.alias) + loc = ("body", alias) # type: ignore + value = _body.getlist(alias) else: - value = received_body.get(self.model_field.alias) # type: ignore + value = received_body.get(alias) # type: ignore if ( value is None @@ -281,17 +305,21 @@ async def resolve_handle( ): if self.model_field.required: return ResolverResult( - None, [self.create_error(loc=loc)], self.create_raw_data(value) + None, + [self.create_error(loc=loc)], + self.create_raw_data(value, field_name=name), ) else: value = copy.deepcopy(self.model_field.default) return ResolverResult( - {self.model_field.name: value}, + {name: value}, [], - raw_data=self.create_raw_data(value), + raw_data=self.create_raw_data(value, field_name=name), ) - return await self.process_and_validate(values={}, value=value, loc=loc) + return await self.process_and_validate( + values={}, value=value, loc=loc, field_name=name + ) class FileParameterResolver(FormParameterResolver): @@ -302,7 +330,7 @@ def __init__(self, *args: t.Any, **kwargs: t.Any): self._is_byte_list = is_bytes_sequence_annotation(self.model_field.type_) async def process_and_validate( - self, *, values: t.Dict, value: t.Any, loc: t.Tuple + self, *, values: t.Dict, value: t.Any, loc: t.Tuple, field_name: str ) -> t.Tuple: if self._is_byte and isinstance(value, StarletteUploadFile): value = await value.read() @@ -321,10 +349,10 @@ async def process_fn( value = serialize_sequence_value(field=self.model_field, value=results) v_, errors_ = self.model_field.validate(value, values, loc=loc) - values[self.model_field.name] = v_ + values[field_name] = v_ return ResolverResult( data=values, errors=self.validate_error_sequence(errors_), - raw_data=self.create_raw_data(value), + raw_data=self.create_raw_data(value, field_name=field_name), )