Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion ellar/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
"""Ellar - Python ASGI web framework for building fast, efficient, and scalable RESTful APIs and server-side applications."""

__version__ = "0.8.4"

21 changes: 17 additions & 4 deletions ellar/common/params/resolvers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,15 @@ async def resolve(self, *args: t.Any, **kwargs: t.Any) -> ResolverResult:

@abstractmethod
@t.no_type_check
def create_raw_data(self, data: t.Any) -> t.Dict:
def create_raw_data(
self, data: t.Any, field_name: t.Optional[str] = None
) -> t.Dict:
"""
Creates the raw data for the parameter.

Args:
data: The resolved value of the parameter.
field_name: The name of the field.

Returns:
`dict`: A dictionary containing the raw data.
Expand All @@ -62,8 +65,10 @@ def __init__(self, model_field: ModelField, *args: t.Any, **kwargs: t.Any) -> No
RouteParameterModelField, model_field
)

def create_raw_data(self, data: t.Any) -> t.Dict:
return {self.model_field.name: data}
def create_raw_data(
self, data: t.Any, field_name: t.Optional[str] = None
) -> t.Dict:
return {field_name or self.model_field.name: data}

def assert_field_info(self) -> None:
"""
Expand Down Expand Up @@ -91,13 +96,21 @@ async def resolve(self, *args: t.Any, **kwargs: t.Any) -> ResolverResult:

@abstractmethod
@t.no_type_check
async def resolve_handle(self, *args: t.Any, **kwargs: t.Any) -> ResolverResult:
async def resolve_handle(
self,
*args: t.Any,
alias: t.Optional[str] = None,
name: t.Optional[str] = None,
**kwargs: t.Any,
) -> ResolverResult:
"""
Resolves the value of the parameter during request processing.

Args:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
alias: The alias of the parameter. Optional.
name: The name of the parameter. Optional.

Returns:
`ResolverResult`: A named tuple containing the resolved value, any errors, and the raw data.
Expand Down
120 changes: 74 additions & 46 deletions ellar/common/params/resolvers/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,43 +43,47 @@ def get_received_parameter(
return connection.headers

async def resolve_handle(
self, ctx: IExecutionContext, *args: t.Any, **kwargs: t.Any
self,
ctx: IExecutionContext,
*args: t.Any,
alias: t.Optional[str] = None,
name: t.Optional[str] = None,
**kwargs: t.Any,
) -> ResolverResult:
alias = alias or self.model_field.alias
name = name or self.model_field.name
request_logger.debug(
f"Resolving Header Parameters - '{self.__class__.__name__}'"
)
received_params = self.get_received_parameter(ctx=ctx)
if is_sequence_field(self.model_field):
value = (
received_params.getlist(self.model_field.alias)
or self.model_field.default
)
value = received_params.getlist(alias) or self.model_field.default
else:
value = received_params.get(self.model_field.alias)
value = received_params.get(alias)
self.assert_field_info()
field_info = self.model_field.field_info
values = {}
if value is None:
if self.model_field.required:
errors = [
self.create_error(
loc=(field_info.in_.value, self.model_field.alias)
)
]
return ResolverResult({}, errors, raw_data=self.create_raw_data(value))
errors = [self.create_error(loc=(field_info.in_.value, alias))]
return ResolverResult(
{}, errors, raw_data=self.create_raw_data(value, field_name=name)
)
else:
value = copy.deepcopy(self.model_field.default)
values[self.model_field.name] = value
return ResolverResult(values, [], raw_data=self.create_raw_data(value))
values[name] = value
return ResolverResult(
values, [], raw_data=self.create_raw_data(value, field_name=name)
)

v_, errors_ = self.model_field.validate(
value, values, loc=(field_info.in_.value, self.model_field.alias)
value, values, loc=(field_info.in_.value, alias)
)

return ResolverResult(
data={self.model_field.name: v_},
data={name: v_},
errors=self.validate_error_sequence(errors_),
raw_data=self.create_raw_data(value),
raw_data=self.create_raw_data(value, field_name=name),
)


Expand All @@ -99,22 +103,28 @@ def get_received_parameter(cls, ctx: IExecutionContext) -> t.Mapping[str, t.Any]
return connection.path_params

async def resolve_handle(
self, ctx: IExecutionContext, **kwargs: t.Any
self,
ctx: IExecutionContext,
alias: t.Optional[str] = None,
name: t.Optional[str] = None,
**kwargs: t.Any,
) -> ResolverResult:
alias = alias or self.model_field.alias
name = name or self.model_field.name
request_logger.debug(f"Resolving Path Parameters - '{self.__class__.__name__}'")
received_params = self.get_received_parameter(ctx=ctx)
value = received_params.get(str(self.model_field.alias))
value = received_params.get(str(alias))
self.assert_field_info()

v_, errors_ = self.model_field.validate(
value,
{},
loc=(self.model_field.field_info.in_.value, self.model_field.alias),
loc=(self.model_field.field_info.in_.value, alias),
)
return ResolverResult(
data={self.model_field.name: v_},
data={name: v_},
errors=self.validate_error_sequence(errors_),
raw_data=self.create_raw_data(value),
raw_data=self.create_raw_data(value, field_name=name),
)


Expand All @@ -127,44 +137,54 @@ def get_received_parameter(cls, ctx: IExecutionContext) -> t.Mapping[str, t.Any]

class WsBodyParameterResolver(BaseRouteParameterResolver):
async def resolve_handle(
self, ctx: IExecutionContext, *args: t.Any, body: t.Any, **kwargs: t.Any
self,
ctx: IExecutionContext,
*args: t.Any,
body: t.Any,
alias: t.Optional[str] = None,
name: t.Optional[str] = None,
**kwargs: t.Any,
) -> t.Tuple:
alias = alias or self.model_field.alias
name = name or self.model_field.name
request_logger.debug(
f"Resolving Websocket Body Parameters - '{self.__class__.__name__}'"
)
embed = getattr(self.model_field.field_info, "embed", False)
received_body = {self.model_field.alias: body}
received_body = {alias: body}
loc = ("body",)
if embed:
received_body = body
loc = ("body", self.model_field.alias) # type:ignore
loc = ("body", alias) # type:ignore
try:
value = received_body.get(self.model_field.alias)
value = received_body.get(alias)

if value is None:
if self.model_field.required:
return ResolverResult(
None,
[self.create_error(loc=loc)],
raw_data=self.create_raw_data(value),
raw_data=self.create_raw_data(value, field_name=name),
)
else:
value = copy.deepcopy(self.model_field.default)
return ResolverResult(
{self.model_field.name: value},
{name: value},
[],
raw_data=self.create_raw_data(value),
raw_data=self.create_raw_data(value, field_name=name),
)

v_, errors_ = self.model_field.validate(value, {}, loc=loc)
return ResolverResult(
data={self.model_field.name: v_},
data={name: v_},
errors=self.validate_error_sequence(errors_),
raw_data=self.create_raw_data(value),
raw_data=self.create_raw_data(value, field_name=name),
)
except AttributeError:
errors = [self.create_error(loc=loc)]
return ResolverResult(None, errors, raw_data=self.create_raw_data(None))
return ResolverResult(
None, errors, raw_data=self.create_raw_data(None, field_name=name)
)


class BodyParameterResolver(WsBodyParameterResolver):
Expand Down Expand Up @@ -228,10 +248,10 @@ async def resolve_handle(

class FormParameterResolver(BodyParameterResolver):
async def process_and_validate(
self, *, values: t.Dict, value: t.Any, loc: t.Tuple
self, *, values: t.Dict, value: t.Any, loc: t.Tuple, field_name: str
) -> t.Tuple:
v_, errors_ = self.model_field.validate(value, values, loc=loc)
values[self.model_field.name] = v_
values[field_name] = v_
return ResolverResult(
data=values,
errors=self.validate_error_sequence(errors_),
Expand All @@ -257,22 +277,26 @@ async def resolve_handle(
ctx: IExecutionContext,
*args: t.Any,
body: t.Optional[t.Any] = None,
alias: t.Optional[str] = None,
name: t.Optional[str] = None,
**kwargs: t.Any,
) -> t.Tuple:
alias = alias or self.model_field.alias
name = name or self.model_field.name
_body = body or await self.get_request_body(ctx)
embed = getattr(self.model_field.field_info, "embed", False)
received_body = {self.model_field.alias: _body}
received_body = {alias: _body}
loc = ("body",)

if embed:
received_body = _body
loc = ("body", self.model_field.alias) # type:ignore
loc = ("body", alias) # type:ignore

if is_sequence_field(self.model_field) and isinstance(_body, FormData):
loc = ("body", self.model_field.alias) # type: ignore
value = _body.getlist(self.model_field.alias)
loc = ("body", alias) # type: ignore
value = _body.getlist(alias)
else:
value = received_body.get(self.model_field.alias) # type: ignore
value = received_body.get(alias) # type: ignore

if (
value is None
Expand All @@ -281,17 +305,21 @@ async def resolve_handle(
):
if self.model_field.required:
return ResolverResult(
None, [self.create_error(loc=loc)], self.create_raw_data(value)
None,
[self.create_error(loc=loc)],
self.create_raw_data(value, field_name=name),
)
else:
value = copy.deepcopy(self.model_field.default)
return ResolverResult(
{self.model_field.name: value},
{name: value},
[],
raw_data=self.create_raw_data(value),
raw_data=self.create_raw_data(value, field_name=name),
)

return await self.process_and_validate(values={}, value=value, loc=loc)
return await self.process_and_validate(
values={}, value=value, loc=loc, field_name=name
)


class FileParameterResolver(FormParameterResolver):
Expand All @@ -302,7 +330,7 @@ def __init__(self, *args: t.Any, **kwargs: t.Any):
self._is_byte_list = is_bytes_sequence_annotation(self.model_field.type_)

async def process_and_validate(
self, *, values: t.Dict, value: t.Any, loc: t.Tuple
self, *, values: t.Dict, value: t.Any, loc: t.Tuple, field_name: str
) -> t.Tuple:
if self._is_byte and isinstance(value, StarletteUploadFile):
value = await value.read()
Expand All @@ -321,10 +349,10 @@ async def process_fn(
value = serialize_sequence_value(field=self.model_field, value=results)

v_, errors_ = self.model_field.validate(value, values, loc=loc)
values[self.model_field.name] = v_
values[field_name] = v_

return ResolverResult(
data=values,
errors=self.validate_error_sequence(errors_),
raw_data=self.create_raw_data(value),
raw_data=self.create_raw_data(value, field_name=field_name),
)
Loading