Skip to content

Commit

Permalink
Merge 9fbcd07 into 8944594
Browse files Browse the repository at this point in the history
  • Loading branch information
Rodrigo Martins de Oliveira committed Jul 1, 2017
2 parents 8944594 + 9fbcd07 commit 3b490fc
Show file tree
Hide file tree
Showing 5 changed files with 372 additions and 27 deletions.
50 changes: 50 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,56 @@ Take a look at `examples/validation.py` for more information.

All validation options can be found at http://json-schema.org/latest/json-schema-validation.html

### Custom validation

By default Flasgger will use [python-jsonschema](https://python-jsonschema.readthedocs.io/en/latest/)
to perform validation.

Custom validation functions are supported as long as they take two
positional arguments: the data to be validated as:
- the first and the schema to validate against as the second argument; and
- raise any kind of exception when validation fails.

Any return value is discarded.


Providing the function to the Swagger instance will make it the default:

```python
from flasgger import Swagger

swagger = Swagger(app, validation_function=my_validation_function)
```

Providing the function as parameter of `swag_from` or `swagger.validate`
annotations or directly to the `validate` function will force it's use
over the default validation function for Swagger:

```python
from flasgger import swag_from

@swag_from('spec.yml', validation=True, validation_function=my_function)
...
```

```python
from flasgger import Swagger

swagger = Swagger(app)

@swagger.validate('Pet', validation_function=my_function)
...
```

```python
from flasgger import validate

...

validate(
request.json, 'Pet', 'defs.yml', validation_function=my_function)
```

# HTML sanitizer

By default Flasgger will try to sanitize the content in YAML definitions
Expand Down
269 changes: 269 additions & 0 deletions examples/custom_validation_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
import time
import jsonschema
try:
import simplejson as json
except ImportError:
import json
try:
from http import HTTPStatus
except ImportError:
import httplib as HTTPStatus
from flask import Flask, jsonify
from flask import request
from flasgger import Swagger


def validate(data, schema):
"""
Custom validation function which drops parameter '_id' if present
in data
"""
jsonschema.validate(data, schema)
if data.get('_id') is not None:
del data['_id']


def timestamping_validate(data, schema):
"""
Custom validation function which inserts a timestamp for when the
validation occurred
"""
jsonschema.validate(data, schema)
data['timestamp'] = str(time.time())


def special_validate(data, schema):
"""
Custom validation function which marks inserts an special flag
depending on the cat's name
"""
jsonschema.validate(data, schema)
data['special'] = str(data['name'] == 'Garfield').lower()


def regular_validate(data, schema):
"""
Regular validation function
"""
jsonschema.validate(data, schema)


app = Flask(__name__)
swag = Swagger(app, validation_function=validate)


@app.route('/cat', methods=['POST'])
@swag.validate('Cat')
def create_cat():
"""
Cat creation endpoint which drops '_id' parameters when present in
request data
---
tags:
- cat
summary: Creates a new cat
consumes:
- application/json
produces:
- application/json
parameters:
- in: body
name: body
description:
Cat object that needs to be persisted to the database
required: true
schema:
id: Cat
required:
- name
- address
properties:
name:
description: Cat's name
type: string
example: Sylvester
address:
description: Cat's house address
type: string
example: 4000 Warner Blvd., Burbank, CA 91522
responses:
200:
description: Successful operation
400:
description: Invalid input
"""
return jsonify(request.json), HTTPStatus.OK


@app.route('/timestamped/cat', methods=['POST'])
@swag.validate('Cat', validation_function=timestamping_validate)
def create_timestamped_cat():
"""
Cat creation endpoint which timestamps validated data
---
tags:
- cat
summary: Creates a new cat
consumes:
- application/json
produces:
- application/json
parameters:
- in: body
name: body
description:
Cat object that needs to be persisted to the database
required: true
schema:
$ref: '#/definitions/Cat'
responses:
200:
description: Successful operation
schema:
$ref: '#/definitions/Cat'
400:
description: Invalid input
"""
return jsonify(request.json), HTTPStatus.OK


@app.route('/special/cat', methods=['POST'])
@swag.validate('Cat', validation_function=special_validate)
def create_special_cat():
"""
Cat creation endpoint which timestamps validated data
---
tags:
- cat
summary: Creates a new cat
consumes:
- application/json
produces:
- application/json
parameters:
- in: body
name: body
description:
Cat object that needs to be persisted to the database
required: true
schema:
$ref: '#/definitions/Cat'
responses:
200:
description: Successful operation
schema:
$ref: '#/definitions/Cat'
400:
description: Invalid input
"""
return jsonify(request.json), HTTPStatus.OK


@app.route('/regular/cat', methods=['POST'])
@swag.validate('Cat', validation_function=regular_validate)
def create_regular_cat():
"""
Cat creation endpoint
---
tags:
- cat
summary: Creates a new cat
consumes:
- application/json
produces:
- application/json
parameters:
- in: body
name: body
description:
Cat object that needs to be persisted to the database
required: true
schema:
$ref: '#/definitions/Cat'
responses:
200:
description: Successful operation
schema:
$ref: '#/definitions/Cat'
400:
description: Invalid input
"""
return jsonify(request.json), HTTPStatus.OK


def test_swag(client, specs_data):
"""
This test is runs automatically in Travis CI
:param client: Flask app test client
:param specs_data: {'url': {swag_specs}} for every spec in app
"""
cat = \
"""
{
"_id": "594dba7b2879334e411f3dcc",
"name": "Tom",
"address": "MGM, 245 N. Beverly Drive, Beverly Hills, CA 90210"
}
"""
with client.post(
'/cat', data=cat, content_type='application/json') as response:
assert response.status_code == HTTPStatus.OK

sent = json.loads(cat)
received = json.loads(response.data.decode('utf-8'))
assert received.get('_id') is None
assert received.get('timestamp') is None
assert received.get('special') is None
try:
assert received.viewitems() < sent.viewitems()
except AttributeError:
assert received.items() < sent.items()

with client.post(
'/timestamped/cat', data=cat,
content_type='application/json') as response:
assert response.status_code == HTTPStatus.OK

sent = json.loads(cat)
received = json.loads(response.data.decode('utf-8'))
assert received.get('_id') == sent.get('_id')
assert received.get('timestamp') is not None
assert received.get('special') is None
try:
assert received.viewitems() > sent.viewitems()
except AttributeError:
assert received.items() > sent.items()

with client.post(
'/special/cat', data=cat,
content_type='application/json') as response:
assert response.status_code == HTTPStatus.OK

sent = json.loads(cat)
received = json.loads(response.data.decode('utf-8'))
assert received.get('_id') == sent.get('_id')
assert received.get('timestamp') is None
assert received.get('special') is not None
try:
assert received.viewitems() > sent.viewitems()
except AttributeError:
assert received.items() > sent.items()

with client.post(
'/regular/cat', data=cat,
content_type='application/json') as response:
assert response.status_code == HTTPStatus.OK

sent = json.loads(cat)
received = json.loads(response.data.decode('utf-8'))
assert received.get('_id') == sent.get('_id')
assert received.get('timestamp') is None
assert received.get('special') is None
try:
assert received.viewitems() == sent.viewitems()
except AttributeError:
assert received.items() == sent.items()

if __name__ == "__main__":
app.run(debug=True)
23 changes: 16 additions & 7 deletions flasgger/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import simplejson as json
except ImportError:
import json

from functools import wraps
from collections import defaultdict
from flask import Blueprint
Expand Down Expand Up @@ -287,9 +286,9 @@ class Swagger(object):
"specs_route": "/apidocs/"
}

def __init__(self, app=None, config=None,
sanitizer=None, template=None, template_file=None,
decorators=None):
def __init__(
self, app=None, config=None, sanitizer=None, template=None,
template_file=None, decorators=None, validation_function=None):
self._configured = False
self.endpoints = []
self.definition_models = [] # not in app, so track here
Expand All @@ -298,6 +297,7 @@ def __init__(self, app=None, config=None,
self.template = template
self.template_file = template_file
self.decorators = decorators
self.validation_function = validation_function
if app:
self.init_app(app)

Expand Down Expand Up @@ -437,7 +437,7 @@ def after_request(response): # noqa
response.headers[header] = value
return response

def validate(self, schema_id):
def validate(self, schema_id, validation_function=None):
"""
A decorator that is used to validate incoming requests data
against a schema
Expand All @@ -458,9 +458,16 @@ def post():
be the outermost annotation
:param schema_id: the id of the schema with which the data will
be validated
be validated
:param validation_function: custom validation function which
takes the positional arguments: data to be validated at
first and schema to validate against at second
"""

if validation_function is None:
validation_function = self.validation_function

def decorator(func):

@wraps(func)
Expand All @@ -486,7 +493,9 @@ def wrapper(*args, **kwargs):
if d.get('schema', {}).get('id') == schema_id:
specs = swag

validate(schema_id=schema_id, specs=specs)
validate(
schema_id=schema_id, specs=specs,
validation_function=validation_function)
return func(*args, **kwargs)

return wrapper
Expand Down
Loading

0 comments on commit 3b490fc

Please sign in to comment.