Skip to content

Commit

Permalink
Concept for parsing all args in a request
Browse files Browse the repository at this point in the history
  • Loading branch information
toonalbers committed Jul 22, 2019
1 parent de061e0 commit b433ea5
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 15 deletions.
66 changes: 51 additions & 15 deletions src/webargs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,24 @@ def parse_arg(self, name, field, req, locations=None):
return value
return missing

def _parse_argdict(self, argdict, schema, req, locations):
parsed = {}
for argname, field_obj in iteritems(argdict):
if MARSHMALLOW_VERSION_INFO[0] < 3:
parsed_value = self.parse_arg(argname, field_obj, req, locations)
# If load_from is specified on the field, try to parse from that key
if parsed_value is missing and field_obj.load_from:
parsed_value = self.parse_arg(
field_obj.load_from, field_obj, req, locations
)
argname = field_obj.load_from
else:
argname = field_obj.data_key or argname
parsed_value = self.parse_arg(argname, field_obj, req, locations)
if parsed_value is not missing:
parsed[argname] = parsed_value
return parsed

def _parse_request(self, schema, req, locations):
"""Return a parsed arguments dictionary for the current request."""
if schema.many:
Expand All @@ -264,21 +282,26 @@ def _parse_request(self, schema, req, locations):
parsed = []
else:
argdict = schema.fields
parsed = {}
for argname, field_obj in iteritems(argdict):
if MARSHMALLOW_VERSION_INFO[0] < 3:
parsed_value = self.parse_arg(argname, field_obj, req, locations)
# If load_from is specified on the field, try to parse from that key
if parsed_value is missing and field_obj.load_from:
parsed_value = self.parse_arg(
field_obj.load_from, field_obj, req, locations
)
argname = field_obj.load_from
else:
argname = field_obj.data_key or argname
parsed_value = self.parse_arg(argname, field_obj, req, locations)
if parsed_value is not missing:
parsed[argname] = parsed_value

parsed = self._parse_argdict(argdict, schema, req, locations)
new_parsed = {}
# From version 3, Marshmallow contains logic for handling missing fields
# Could also check a bool flag for backwards compatibility
if MARSHMALLOW_VERSION_INFO[0] >= 3:
for location, argnames in iteritems(
self.get_available_fields(req, locations)
):
newargs = {}
for argname in argnames:
if argname not in parsed.keys():
newargs[argname] = ma.fields.Raw()

location_parsed = self._parse_argdict(
newargs, schema, req, (location,)
)
new_parsed.update(location_parsed)
parsed = new_parsed.update(parsed)

return parsed

def _on_validation_error(
Expand Down Expand Up @@ -542,6 +565,19 @@ def handle_error(error, req, schema, status_code, headers):

# Abstract Methods

def get_available_fields(self, req, locations):
"""Pull the names of all fields from the request, separated by their
location, and only for the provided locations
:param req: The request object to parse.
:param tuple locations: Where on the request to search for fields.
Can include one or more of ``('json', 'querystring', 'form',
'headers', 'cookies', 'files')``.
:return: a dict with all provided fields in a request, keyed by the
location in which they occur
"""
return {}

def parse_json(self, req, name, arg):
"""Pull a JSON value from a request object or return `missing` if the
value cannot be found.
Expand Down
33 changes: 33 additions & 0 deletions src/webargs/flaskparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,39 @@ def parse_view_args(self, req, name, field):
"""Pull a value from the request's ``view_args``."""
return core.get_value(req.view_args, name, field)

def get_available_fields(self, req, locations):
"""Pull the names of all fields from the request, separated by their
location, and only for the provided locations
:param req: The request object to parse.
:param tuple locations: Where on the request to search for fields.
Can include one or more of ``('json', 'querystring', 'form',
'headers', 'cookies', 'files')``.
:return: a dict with all provided fields in a request, keyed by the
location in which they occur
"""
fields = dict()
if "json" in locations:
# TODO copied from parse_json, refactor to separate method
json_data = self._cache.get("json")
if json_data is None:
# We decode the json manually here instead of
# using req.get_json() so that we can handle
# JSONDecodeErrors consistently
data = req.get_data(cache=True)
try:
self._cache["json"] = json_data = core.parse_json(data)
except json.JSONDecodeError as e:
if e.doc == "":
return core.missing
else:
return self.handle_invalid_json_error(e, req)
# Now get all keys from the json
json_keys = json_data.keys()
fields["json"] = json_keys
# TODO etc for querystring, form, headers, cookies, files
return fields

def parse_json(self, req, name, field):
"""Pull a json value from the request."""
json_data = self._cache.get("json")
Expand Down

0 comments on commit b433ea5

Please sign in to comment.