diff --git a/openapi_spec_validator/shortcuts.py b/openapi_spec_validator/shortcuts.py index 77ee0d8..9921d9d 100644 --- a/openapi_spec_validator/shortcuts.py +++ b/openapi_spec_validator/shortcuts.py @@ -10,21 +10,29 @@ from openapi_spec_validator.validation import OpenAPIV2SpecValidator from openapi_spec_validator.validation import OpenAPIV30SpecValidator from openapi_spec_validator.validation import OpenAPIV31SpecValidator -from openapi_spec_validator.validation.finders import SpecFinder -from openapi_spec_validator.validation.finders import SpecVersion +from openapi_spec_validator.validation.exceptions import ValidatorDetectError from openapi_spec_validator.validation.protocols import SupportsValidation from openapi_spec_validator.validation.types import SpecValidatorType from openapi_spec_validator.validation.validators import SpecValidator - -SPECS: Mapping[SpecVersion, SpecValidatorType] = { - SpecVersion("swagger", "2.0"): OpenAPIV2SpecValidator, - SpecVersion("openapi", "3.0"): OpenAPIV30SpecValidator, - SpecVersion("openapi", "3.1"): OpenAPIV31SpecValidator, +from openapi_spec_validator.versions import consts as versions +from openapi_spec_validator.versions.datatypes import SpecVersion +from openapi_spec_validator.versions.exceptions import OpenAPIVersionNotFound +from openapi_spec_validator.versions.shortcuts import get_spec_version + +SPEC2VALIDATOR: Mapping[SpecVersion, SpecValidatorType] = { + versions.OPENAPIV2: OpenAPIV2SpecValidator, + versions.OPENAPIV30: OpenAPIV30SpecValidator, + versions.OPENAPIV31: OpenAPIV31SpecValidator, } def get_validator_cls(spec: Schema) -> SpecValidatorType: - return SpecFinder(SPECS).find(spec) + try: + spec_version = get_spec_version(spec) + # backward compatibility + except OpenAPIVersionNotFound: + raise ValidatorDetectError + return SPEC2VALIDATOR[spec_version] def validate_spec( diff --git a/openapi_spec_validator/validation/finders.py b/openapi_spec_validator/validation/finders.py deleted file mode 100644 index 74d2573..0000000 --- a/openapi_spec_validator/validation/finders.py +++ /dev/null @@ -1,23 +0,0 @@ -from typing import Mapping -from typing import NamedTuple - -from jsonschema_spec.typing import Schema - -from openapi_spec_validator.validation.exceptions import ValidatorDetectError -from openapi_spec_validator.validation.types import SpecValidatorType - - -class SpecVersion(NamedTuple): - name: str - version: str - - -class SpecFinder: - def __init__(self, specs: Mapping[SpecVersion, SpecValidatorType]) -> None: - self.specs = specs - - def find(self, spec: Schema) -> SpecValidatorType: - for v, classes in self.specs.items(): - if v.name in spec and spec[v.name].startswith(v.version): - return classes - raise ValidatorDetectError("Spec schema version not detected") diff --git a/openapi_spec_validator/versions/__init__.py b/openapi_spec_validator/versions/__init__.py new file mode 100644 index 0000000..2203413 --- /dev/null +++ b/openapi_spec_validator/versions/__init__.py @@ -0,0 +1,13 @@ +from openapi_spec_validator.versions.consts import OPENAPIV2 +from openapi_spec_validator.versions.consts import OPENAPIV30 +from openapi_spec_validator.versions.consts import OPENAPIV31 +from openapi_spec_validator.versions.datatypes import SpecVersion +from openapi_spec_validator.versions.shortcuts import get_spec_version + +__all__ = [ + "OPENAPIV2", + "OPENAPIV30", + "OPENAPIV31", + "SpecVersion", + "get_spec_version", +] diff --git a/openapi_spec_validator/versions/consts.py b/openapi_spec_validator/versions/consts.py new file mode 100644 index 0000000..6b5ea7d --- /dev/null +++ b/openapi_spec_validator/versions/consts.py @@ -0,0 +1,23 @@ +from typing import List + +from openapi_spec_validator.versions.datatypes import SpecVersion + +OPENAPIV2 = SpecVersion( + keyword="swagger", + major="2", + minor="0", +) + +OPENAPIV30 = SpecVersion( + keyword="openapi", + major="3", + minor="0", +) + +OPENAPIV31 = SpecVersion( + keyword="openapi", + major="3", + minor="1", +) + +VERSIONS: List[SpecVersion] = [OPENAPIV2, OPENAPIV30, OPENAPIV31] diff --git a/openapi_spec_validator/versions/datatypes.py b/openapi_spec_validator/versions/datatypes.py new file mode 100644 index 0000000..a869992 --- /dev/null +++ b/openapi_spec_validator/versions/datatypes.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass + + +@dataclass(frozen=True) +class SpecVersion: + """ + Spec version designates the OAS feature set. + """ + + keyword: str + major: str + minor: str + + def __str__(self) -> str: + return f"OpenAPIV{self.major}.{self.minor}" diff --git a/openapi_spec_validator/versions/exceptions.py b/openapi_spec_validator/versions/exceptions.py new file mode 100644 index 0000000..91a5cbb --- /dev/null +++ b/openapi_spec_validator/versions/exceptions.py @@ -0,0 +1,6 @@ +from openapi_spec_validator.exceptions import OpenAPIError + + +class OpenAPIVersionNotFound(OpenAPIError): + def __str__(self) -> str: + return "Specification version not found" diff --git a/openapi_spec_validator/versions/finders.py b/openapi_spec_validator/versions/finders.py new file mode 100644 index 0000000..f8b1ddc --- /dev/null +++ b/openapi_spec_validator/versions/finders.py @@ -0,0 +1,26 @@ +from re import compile +from typing import List + +from jsonschema_spec.typing import Schema + +from openapi_spec_validator.versions.datatypes import SpecVersion +from openapi_spec_validator.versions.exceptions import OpenAPIVersionNotFound + + +class SpecVersionFinder: + pattern = compile(r"(?P\d+)\.(?P\d+)(\..*)?") + + def __init__(self, versions: List[SpecVersion]) -> None: + self.versions = versions + + def find(self, spec: Schema) -> SpecVersion: + for v in self.versions: + if v.keyword in spec: + version_str = spec[v.keyword] + m = self.pattern.match(version_str) + if m: + version = SpecVersion(**m.groupdict(), keyword=v.keyword) + if v == version: + return v + + raise OpenAPIVersionNotFound diff --git a/openapi_spec_validator/versions/shortcuts.py b/openapi_spec_validator/versions/shortcuts.py new file mode 100644 index 0000000..73841e9 --- /dev/null +++ b/openapi_spec_validator/versions/shortcuts.py @@ -0,0 +1,10 @@ +from jsonschema_spec.typing import Schema + +from openapi_spec_validator.versions.consts import VERSIONS +from openapi_spec_validator.versions.datatypes import SpecVersion +from openapi_spec_validator.versions.finders import SpecVersionFinder + + +def get_spec_version(spec: Schema) -> SpecVersion: + finder = SpecVersionFinder(VERSIONS) + return finder.find(spec) diff --git a/tests/integration/test_versions.py b/tests/integration/test_versions.py new file mode 100644 index 0000000..891dd50 --- /dev/null +++ b/tests/integration/test_versions.py @@ -0,0 +1,43 @@ +import pytest + +from openapi_spec_validator.versions import consts as versions +from openapi_spec_validator.versions.exceptions import OpenAPIVersionNotFound +from openapi_spec_validator.versions.shortcuts import get_spec_version + + +class TestGetSpecVersion: + def test_no_keyword(self): + spec = {} + + with pytest.raises(OpenAPIVersionNotFound): + get_spec_version(spec) + + @pytest.mark.parametrize("keyword", ["swagger", "openapi"]) + @pytest.mark.parametrize("version", ["x.y.z", "xyz2.0.0", "2.xyz0.0"]) + def test_invalid(self, keyword, version): + spec = { + keyword: version, + } + + with pytest.raises(OpenAPIVersionNotFound): + get_spec_version(spec) + + @pytest.mark.parametrize( + "keyword,version,expected", + [ + ("swagger", "2.0", versions.OPENAPIV2), + ("openapi", "3.0.0", versions.OPENAPIV30), + ("openapi", "3.0.1", versions.OPENAPIV30), + ("openapi", "3.0.2", versions.OPENAPIV30), + ("openapi", "3.0.3", versions.OPENAPIV30), + ("openapi", "3.1.0", versions.OPENAPIV31), + ], + ) + def test_valid(self, keyword, version, expected): + spec = { + keyword: version, + } + + result = get_spec_version(spec) + + assert result == expected