Skip to content

Commit

Permalink
🐛 Fix: webhook parsing when having nested union types (#57)
Browse files Browse the repository at this point in the history
Co-authored-by: Ju4tCode <42488585+yanyongyu@users.noreply.github.com>
  • Loading branch information
frankie567 and yanyongyu authored Dec 28, 2023
1 parent 9449b1a commit e650ec3
Show file tree
Hide file tree
Showing 81 changed files with 176 additions and 58 deletions.
49 changes: 16 additions & 33 deletions codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .config import Config
from .source import get_source
from .log import logger as logger
from .parser.schemas import ModelSchema
from .parser.schemas import ModelSchema, UnionSchema
from .parser import (
WebhookData,
EndpointData,
Expand All @@ -26,24 +26,17 @@
lstrip_blocks=True,
extensions=["jinja2.ext.loopcontrols"],
)
env.globals.update(
{
"repr": repr,
"sanitize": sanitize,
"snake_case": snake_case,
"pascal_case": pascal_case,
"kebab_case": kebab_case,
}
)
env.filters.update(
{
"repr": repr,
"sanitize": sanitize,
"snake_case": snake_case,
"pascal_case": pascal_case,
"kebab_case": kebab_case,
}
)

_funcs = {
"repr": repr,
"sanitize": sanitize,
"snake_case": snake_case,
"pascal_case": pascal_case,
"kebab_case": kebab_case,
"is_union_schema": lambda x: isinstance(x, UnionSchema),
}
env.globals.update(_funcs)
env.filters.update(_funcs)


def load_config() -> Config:
Expand Down Expand Up @@ -198,9 +191,7 @@ def build_legacy_rest_models(
logger.info("Successfully generated legacy rest models!")


def build_versions(
dir: Path, output_module: str, versions: dict[str, str], latest_version: str
):
def build_versions(dir: Path, versions: dict[str, str], latest_version: str):
logger.info("Start generating versions...")

# build __init__.py
Expand All @@ -216,23 +207,15 @@ def build_versions(
rest_template = env.get_template("versions/rest.py.jinja")
rest_path = dir / "rest.py"
rest_path.write_text(
rest_template.render(
output_module=output_module,
versions=versions,
latest_version=latest_version,
)
rest_template.render(versions=versions, latest_version=latest_version)
)

# build webhooks.py
logger.info("Building versions webhooks.py...")
webhooks_template = env.get_template("versions/webhooks.py.jinja")
webhooks_path = dir / "webhooks.py"
webhooks_path.write_text(
webhooks_template.render(
output_module=output_module,
versions=versions,
latest_version=latest_version,
)
webhooks_template.render(versions=versions, latest_version=latest_version)
)

logger.info("Successfully generated versions!")
Expand Down Expand Up @@ -332,7 +315,7 @@ def build():
latest_model_names,
latest_event_names,
)
build_versions(config.output_dir, output_module, versions, latest_version)
build_versions(config.output_dir, versions, latest_version)
build_legacy_rest_models(
config.legacy_rest_models,
output_module,
Expand Down
2 changes: 1 addition & 1 deletion codegen/templates/versions/rest.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class RestVersionSwitcher(_VersionProxy):
def __call__(self, version: VERSION_TYPE = LATEST_VERSION) -> Any:
if version in self._cached_namespaces:
return self._cached_namespaces[version]
module = importlib.import_module(f"{{ output_module }}.{VERSIONS[version]}.rest", __name__)
module = importlib.import_module(f".{VERSIONS[version]}.rest", __package__)
namespace = module.RestNamespace(self._github)
self._cached_namespaces[version] = namespace
return namespace
2 changes: 1 addition & 1 deletion codegen/templates/versions/webhooks.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class WebhooksVersionSwitcher(_VersionProxy):
def __call__(self, version: VERSION_TYPE = LATEST_VERSION) -> Any:
if version in self._cached_namespaces:
return self._cached_namespaces[version]
module = importlib.import_module(f"{{ output_module }}.{VERSIONS[version]}.webhooks", __name__)
module = importlib.import_module(f".{VERSIONS[version]}.webhooks", __package__)
namespace = module.WebhookNamespace()
self._cached_namespaces[version] = namespace
return namespace
4 changes: 2 additions & 2 deletions codegen/templates/webhooks/_namespace.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class WebhookNamespace:
"""
if name not in VALID_EVENT_NAMES:
raise WebhookTypeNotFound(name)
module = importlib.import_module(f".{name}", __name__)
module = importlib.import_module(f".{name}", __package__)
Event = getattr(module, "Event")
return type_validate_json(Event, payload)

Expand Down Expand Up @@ -106,7 +106,7 @@ class WebhookNamespace:

if name not in VALID_EVENT_NAMES:
raise WebhookTypeNotFound(name)
module = importlib.import_module(f".{name}", __name__)
module = importlib.import_module(f".{name}", __package__)
Event = getattr(module, "Event")
return type_validate_python(Event, payload)

Expand Down
12 changes: 12 additions & 0 deletions codegen/templates/webhooks/event.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ from typing_extensions import TypeAlias, Annotated

from pydantic import Field

from githubkit.utils import TaggedUnion
from githubkit.compat import GitHubModel

{% for webhook in webhooks %}
Expand All @@ -18,7 +19,18 @@ from githubkit.compat import GitHubModel
{% if webhooks | length > 1 %}
Event: TypeAlias = Annotated[Union[
{% for webhook in webhooks %}
{% if is_union_schema(webhook.event_schema) %}
Annotated[
{{ webhook.event_schema.get_type_string() }},
TaggedUnion(
{{ webhook.event_schema.get_type_string() }},
"action",
"{{ webhook.action }}"
)
],
{% else %}
{{ webhook.event_schema.get_type_string() }},
{% endif %}
{% endfor %}
], Field(discriminator="action")]
{% else %}
Expand Down
10 changes: 5 additions & 5 deletions githubkit/lazy_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from importlib.machinery import ModuleSpec, PathFinder, SourceFileLoader

LAZY_MODULES = (
r"githubkit\.rest",
r"githubkit\.versions\.v[^.]+\.webhooks",
r"githubkit\.versions\.latest\.models",
r"githubkit\.versions\.latest\.types",
r"githubkit\.versions\.latest\.webhooks",
r"^githubkit\.rest$",
r"^githubkit\.versions\.v[^.]+\.webhooks$",
r"^githubkit\.versions\.latest\.models$",
r"^githubkit\.versions\.latest\.types$",
r"^githubkit\.versions\.latest\.webhooks$",
)


Expand Down
40 changes: 38 additions & 2 deletions githubkit/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import inspect
from enum import Enum
from typing import Any, Dict, Literal, TypeVar, final
from typing import Any, Dict, Type, Generic, Literal, TypeVar, final

from pydantic import BaseModel
from pydantic_core import to_jsonable_python

from .compat import custom_validation
from .compat import PYDANTIC_V2, custom_validation, type_validate_python

if PYDANTIC_V2:
from pydantic_core import core_schema
from pydantic import GetCoreSchemaHandler

T = TypeVar("T")

Expand Down Expand Up @@ -75,3 +80,34 @@ def obj_to_jsonable(obj: Any) -> Any:
return obj

return to_jsonable_python(obj)


class TaggedUnion(Generic[T]):
__slots__ = ("type_", "discriminator", "tag")

def __init__(self, type_: Type[T], discriminator: str, tag: str) -> None:
self.type_ = type_
self.discriminator = discriminator
self.tag = tag

def _validate(self, value: Any) -> T:
return type_validate_python(self.type_, value)

if PYDANTIC_V2:

def __get_pydantic_core_schema__(
self, _source_type: Any, _handler: "GetCoreSchemaHandler"
) -> "core_schema.CoreSchema":
return core_schema.no_info_before_validator_function(
self._validate,
core_schema.model_schema(
BaseModel,
schema=core_schema.model_fields_schema(
{
self.discriminator: core_schema.model_field(
core_schema.literal_schema([self.tag])
)
}
),
),
)
4 changes: 1 addition & 3 deletions githubkit/versions/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ def __call__(self) -> "V20221128RestNamespace":
def __call__(self, version: VERSION_TYPE = LATEST_VERSION) -> Any:
if version in self._cached_namespaces:
return self._cached_namespaces[version]
module = importlib.import_module(
f"githubkit.versions.{VERSIONS[version]}.rest", __name__
)
module = importlib.import_module(f".{VERSIONS[version]}.rest", __package__)
namespace = module.RestNamespace(self._github)
self._cached_namespaces[version] = namespace
return namespace
4 changes: 2 additions & 2 deletions githubkit/versions/v2022_11_28/webhooks/_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ def parse(name: EventNameType, payload: Union[str, bytes]) -> "WebhookEvent":
"""
if name not in VALID_EVENT_NAMES:
raise WebhookTypeNotFound(name)
module = importlib.import_module(f".{name}", __name__)
module = importlib.import_module(f".{name}", __package__)
Event = getattr(module, "Event")
return type_validate_json(Event, payload)

Expand Down Expand Up @@ -1223,7 +1223,7 @@ def parse_obj(name: EventNameType, payload: Dict[str, Any]) -> "WebhookEvent":

if name not in VALID_EVENT_NAMES:
raise WebhookTypeNotFound(name)
module = importlib.import_module(f".{name}", __name__)
module = importlib.import_module(f".{name}", __package__)
Event = getattr(module, "Event")
return type_validate_python(Event, payload)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from pydantic import Field

from githubkit.utils import TaggedUnion
from githubkit.compat import GitHubModel

from ..models import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from pydantic import Field

from githubkit.utils import TaggedUnion
from githubkit.compat import GitHubModel

from ..models import (
Expand Down
1 change: 1 addition & 0 deletions githubkit/versions/v2022_11_28/webhooks/check_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from pydantic import Field

from githubkit.utils import TaggedUnion
from githubkit.compat import GitHubModel

from ..models import (
Expand Down
1 change: 1 addition & 0 deletions githubkit/versions/v2022_11_28/webhooks/check_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from pydantic import Field

from githubkit.utils import TaggedUnion
from githubkit.compat import GitHubModel

from ..models import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from pydantic import Field

from githubkit.utils import TaggedUnion
from githubkit.compat import GitHubModel

from ..models import (
Expand Down
1 change: 1 addition & 0 deletions githubkit/versions/v2022_11_28/webhooks/commit_comment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from pydantic import Field

from githubkit.utils import TaggedUnion
from githubkit.compat import GitHubModel

from ..models import WebhookCommitCommentCreated
Expand Down
1 change: 1 addition & 0 deletions githubkit/versions/v2022_11_28/webhooks/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from pydantic import Field

from githubkit.utils import TaggedUnion
from githubkit.compat import GitHubModel

from ..models import WebhookCreate
Expand Down
1 change: 1 addition & 0 deletions githubkit/versions/v2022_11_28/webhooks/custom_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from pydantic import Field

from githubkit.utils import TaggedUnion
from githubkit.compat import GitHubModel

from ..models import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from pydantic import Field

from githubkit.utils import TaggedUnion
from githubkit.compat import GitHubModel

from ..models import WebhookCustomPropertyValuesUpdated
Expand Down
1 change: 1 addition & 0 deletions githubkit/versions/v2022_11_28/webhooks/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from pydantic import Field

from githubkit.utils import TaggedUnion
from githubkit.compat import GitHubModel

from ..models import WebhookDelete
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from pydantic import Field

from githubkit.utils import TaggedUnion
from githubkit.compat import GitHubModel

from ..models import (
Expand Down
1 change: 1 addition & 0 deletions githubkit/versions/v2022_11_28/webhooks/deploy_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from pydantic import Field

from githubkit.utils import TaggedUnion
from githubkit.compat import GitHubModel

from ..models import WebhookDeployKeyCreated, WebhookDeployKeyDeleted
Expand Down
1 change: 1 addition & 0 deletions githubkit/versions/v2022_11_28/webhooks/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from pydantic import Field

from githubkit.utils import TaggedUnion
from githubkit.compat import GitHubModel

from ..models import WebhookDeploymentCreated
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from pydantic import Field

from githubkit.utils import TaggedUnion
from githubkit.compat import GitHubModel

from ..models import WebhookDeploymentProtectionRuleRequested
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from pydantic import Field

from githubkit.utils import TaggedUnion
from githubkit.compat import GitHubModel

from ..models import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from pydantic import Field

from githubkit.utils import TaggedUnion
from githubkit.compat import GitHubModel

from ..models import WebhookDeploymentStatusCreated
Expand Down
1 change: 1 addition & 0 deletions githubkit/versions/v2022_11_28/webhooks/discussion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from pydantic import Field

from githubkit.utils import TaggedUnion
from githubkit.compat import GitHubModel

from ..models import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from pydantic import Field

from githubkit.utils import TaggedUnion
from githubkit.compat import GitHubModel

from ..models import (
Expand Down
Loading

0 comments on commit e650ec3

Please sign in to comment.