Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions openapi_spec_validator/decorators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""OpenAPI spec validator decorators module."""
from functools import wraps
import logging

from openapi_spec_validator.managers import VisitingManager
Expand Down Expand Up @@ -43,3 +44,21 @@ def _attach_scope(self, instance):
return

instance['x-scope'] = list(self.instance_resolver._scopes_stack)


class ValidationErrorWrapper(object):

def __init__(self, error_class):
self.error_class = error_class

def __call__(self, f):
@wraps(f)
def wrapper(*args, **kwds):
errors = f(*args, **kwds)
for err in errors:
if not isinstance(err, self.error_class):
# wrap other exceptions with library specific version
yield self.error_class.create_from(err)
else:
yield err
return wrapper
14 changes: 14 additions & 0 deletions openapi_spec_validator/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@

from openapi_spec_validator.exceptions import (
ParameterDuplicateError, ExtraParametersError, UnresolvableParameterError,
OpenAPIValidationError
)
from openapi_spec_validator.decorators import ValidationErrorWrapper
from openapi_spec_validator.factories import Draft4ExtendedValidatorFactory
from openapi_spec_validator.managers import ResolverManager

log = logging.getLogger(__name__)

wraps_errors = ValidationErrorWrapper(OpenAPIValidationError)


def is_ref(spec):
return isinstance(spec, dict) and '$ref' in spec
Expand Down Expand Up @@ -43,6 +47,7 @@ def validate(self, spec, spec_url=''):
for err in self.iter_errors(spec, spec_url=spec_url):
raise err

@wraps_errors
def iter_errors(self, spec, spec_url=''):
spec_resolver = self._get_resolver(spec_url, spec)
dereferencer = self._get_dereferencer(spec_resolver)
Expand Down Expand Up @@ -81,6 +86,7 @@ class ComponentsValidator(object):
def __init__(self, dereferencer):
self.dereferencer = dereferencer

@wraps_errors
def iter_errors(self, components):
components_deref = self.dereferencer.dereference(components)

Expand All @@ -97,6 +103,7 @@ class SchemasValidator(object):
def __init__(self, dereferencer):
self.dereferencer = dereferencer

@wraps_errors
def iter_errors(self, schemas):
schemas_deref = self.dereferencer.dereference(schemas)
for name, schema in iteritems(schemas_deref):
Expand All @@ -112,6 +119,7 @@ class SchemaValidator(object):
def __init__(self, dereferencer):
self.dereferencer = dereferencer

@wraps_errors
def iter_errors(self, schema, require_properties=True):
schema_deref = self.dereferencer.dereference(schema)

Expand Down Expand Up @@ -152,6 +160,7 @@ class PathsValidator(object):
def __init__(self, dereferencer):
self.dereferencer = dereferencer

@wraps_errors
def iter_errors(self, paths):
paths_deref = self.dereferencer.dereference(paths)
for url, path_item in iteritems(paths_deref):
Expand All @@ -167,6 +176,7 @@ class PathValidator(object):
def __init__(self, dereferencer):
self.dereferencer = dereferencer

@wraps_errors
def iter_errors(self, url, path_item):
path_item_deref = self.dereferencer.dereference(path_item)

Expand All @@ -186,6 +196,7 @@ class PathItemValidator(object):
def __init__(self, dereferencer):
self.dereferencer = dereferencer

@wraps_errors
def iter_errors(self, url, path_item):
path_item_deref = self.dereferencer.dereference(path_item)

Expand Down Expand Up @@ -214,6 +225,7 @@ class OperationValidator(object):
def __init__(self, dereferencer):
self.dereferencer = dereferencer

@wraps_errors
def iter_errors(self, url, name, operation, path_parameters=None):
path_parameters = path_parameters or []
operation_deref = self.dereferencer.dereference(operation)
Expand Down Expand Up @@ -255,6 +267,7 @@ class ParametersValidator(object):
def __init__(self, dereferencer):
self.dereferencer = dereferencer

@wraps_errors
def iter_errors(self, parameters):
seen = set()
for parameter in parameters:
Expand All @@ -278,6 +291,7 @@ class ParameterValidator(object):
def __init__(self, dereferencer):
self.dereferencer = dereferencer

@wraps_errors
def iter_errors(self, parameter):
if 'schema' in parameter:
schema = parameter['schema']
Expand Down
11 changes: 5 additions & 6 deletions tests/integration/test_shortcuts.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import pytest

from jsonschema.exceptions import ValidationError

from openapi_spec_validator import validate_spec, validate_spec_url
from openapi_spec_validator import validate_v2_spec, validate_v2_spec_url
from openapi_spec_validator.exceptions import OpenAPIValidationError


class BaseTestValidValidteV2Spec:
Expand All @@ -15,7 +14,7 @@ def test_valid(self, spec):
class BaseTestFaliedValidateV2Spec:

def test_failed(self, spec):
with pytest.raises(ValidationError):
with pytest.raises(OpenAPIValidationError):
validate_v2_spec(spec)


Expand All @@ -28,7 +27,7 @@ def test_valid(self, spec):
class BaseTestFaliedValidateSpec:

def test_failed(self, spec):
with pytest.raises(ValidationError):
with pytest.raises(OpenAPIValidationError):
validate_spec(spec)


Expand All @@ -41,7 +40,7 @@ def test_valid(self, spec_url):
class BaseTestFaliedValidateV2SpecUrl:

def test_failed(self, spec_url):
with pytest.raises(ValidationError):
with pytest.raises(OpenAPIValidationError):
validate_v2_spec_url(spec_url)


Expand All @@ -54,7 +53,7 @@ def test_valid(self, spec_url):
class BaseTestFaliedValidateSpecUrl:

def test_failed(self, spec_url):
with pytest.raises(ValidationError):
with pytest.raises(OpenAPIValidationError):
validate_spec_url(spec_url)


Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_validate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from jsonschema.exceptions import ValidationError
from openapi_spec_validator.exceptions import OpenAPIValidationError


class BaseTestValidOpeAPIv3Validator(object):
Expand All @@ -20,7 +20,7 @@ def spec_url(self):
return ''

def test_failed(self, validator, spec, spec_url):
with pytest.raises(ValidationError):
with pytest.raises(OpenAPIValidationError):
validator.validate(spec, spec_url=spec_url)


Expand Down
16 changes: 7 additions & 9 deletions tests/integration/test_validators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from jsonschema.exceptions import ValidationError

from openapi_spec_validator.exceptions import (
ExtraParametersError, UnresolvableParameterError,
ExtraParametersError, UnresolvableParameterError, OpenAPIValidationError,
)


Expand All @@ -13,11 +11,11 @@ def test_empty(self, validator):
errors = validator.iter_errors(spec)

errors_list = list(errors)
assert errors_list[0].__class__ == ValidationError
assert errors_list[0].__class__ == OpenAPIValidationError
assert errors_list[0].message == "'openapi' is a required property"
assert errors_list[1].__class__ == ValidationError
assert errors_list[1].__class__ == OpenAPIValidationError
assert errors_list[1].message == "'info' is a required property"
assert errors_list[2].__class__ == ValidationError
assert errors_list[2].__class__ == OpenAPIValidationError
assert errors_list[2].message == "'paths' is a required property"

def test_info_empty(self, validator):
Expand All @@ -30,7 +28,7 @@ def test_info_empty(self, validator):
errors = validator.iter_errors(spec)

errors_list = list(errors)
assert errors_list[0].__class__ == ValidationError
assert errors_list[0].__class__ == OpenAPIValidationError
assert errors_list[0].message == "'title' is a required property"

def test_minimalistic(self, validator):
Expand Down Expand Up @@ -200,7 +198,7 @@ def test_default_value_wrong_type(self, validator):

errors_list = list(errors)
assert len(errors_list) == 1
assert errors_list[0].__class__ == ValidationError
assert errors_list[0].__class__ == OpenAPIValidationError
assert errors_list[0].message == (
"'invaldtype' is not of type 'integer'"
)
Expand Down Expand Up @@ -235,7 +233,7 @@ def test_parameter_default_value_wrong_type(self, validator):

errors_list = list(errors)
assert len(errors_list) == 1
assert errors_list[0].__class__ == ValidationError
assert errors_list[0].__class__ == OpenAPIValidationError
assert errors_list[0].message == (
"'invaldtype' is not of type 'integer'"
)