Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
4 changes: 2 additions & 2 deletions codegen/parser/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_raw_definition(self) -> str:
required=self.required,
schema_data=self.body_schema,
)
return prop.get_param_defination()
return prop.get_param_definition()

def get_endpoint_definition(self) -> str:
prop = Property(
Expand All @@ -95,7 +95,7 @@ def get_endpoint_definition(self) -> str:
required=not bool(self.allowed_models),
schema_data=self.body_schema,
)
return prop.get_param_defination()
return prop.get_param_definition()


@dataclass(kw_only=True)
Expand Down
172 changes: 162 additions & 10 deletions codegen/parser/schemas/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,22 @@ class SchemaData:
_type_string: ClassVar[str] = "Any"

def get_type_string(self, include_constraints: bool = True) -> str:
"""Get schema typing string in any place"""
"""Get schema typing string in any place.

Args:
include_constraints (bool):
whether to include field constraints by Annotated.
"""
if include_constraints and (args := self._get_field_args()):
return f"Annotated[{self._type_string}, {self._get_field_string(args)}]"
return self._type_string

def get_param_type_string(self) -> str:
"""Get type string used by client codegen"""
"""Get type string used by client request codegen"""
return self._type_string

def get_response_type_string(self) -> str:
"""Get type string used by client response codegen"""
return self._type_string

def get_model_imports(self) -> set[str]:
Expand All @@ -40,7 +49,11 @@ def get_param_imports(self) -> set[str]:
return set()

def get_using_imports(self) -> set[str]:
"""Get schema needed imports for client request codegen"""
"""Get schema needed imports for client request body codegen"""
return set()

def get_response_imports(self) -> set[str]:
"""Get schema needed imports for client response codegen"""
return set()

def _get_field_string(self, args: dict[str, str]) -> str:
Expand All @@ -66,7 +79,7 @@ class Property:
required: bool
schema_data: SchemaData

def get_type_string(self, include_constraints: bool = True) -> str:
def get_type_string(self, include_constraints: bool = False) -> str:
"""Get schema typing string in any place"""
type_string = self.schema_data.get_type_string(
include_constraints=include_constraints
Expand All @@ -78,24 +91,35 @@ def get_param_type_string(self) -> str:
type_string = self.schema_data.get_param_type_string()
return type_string if self.required else f"Missing[{type_string}]"

def get_model_defination(self) -> str:
"""Get defination used by model codegen"""
def get_response_type_string(self) -> str:
type_string = self.schema_data.get_response_type_string()
return type_string if self.required else f"Missing[{type_string}]"

def get_model_definition(self) -> str:
"""Get definition used by model codegen"""
# extract the outermost type constraints to the field
type_ = self.get_type_string(include_constraints=False)
args = self.schema_data._get_field_args()
args.update(self._get_field_args())
default = self._get_field_string(args)
return f"{self.prop_name}: {type_} = {default}"

def get_type_defination(self) -> str:
"""Get defination used by types codegen"""
def get_type_definition(self) -> str:
"""Get definition used by types codegen"""
type_ = self.schema_data.get_param_type_string()
return (
f"{self.prop_name}: {type_ if self.required else f'NotRequired[{type_}]'}"
)

def get_param_defination(self) -> str:
"""Get defination used by client codegen"""
def get_response_type_definition(self) -> str:
"""Get definition usede by response types codegen"""
type_ = self.schema_data.get_response_type_string()
return (
f"{self.prop_name}: {type_ if self.required else f'NotRequired[{type_}]'}"
)

def get_param_definition(self) -> str:
"""Get definition used by client codegen"""
type_ = self.get_param_type_string()
return (
(
Expand Down Expand Up @@ -177,6 +201,12 @@ def get_using_imports(self) -> set[str]:
imports.add("from typing import Any")
return imports

@override
def get_response_imports(self) -> set[str]:
imports = super().get_response_imports()
imports.add("from typing import Any")
return imports


@dataclass(kw_only=True)
class NoneSchema(SchemaData):
Expand Down Expand Up @@ -264,6 +294,12 @@ def _get_field_args(self) -> dict[str, str]:
class DateTimeSchema(SchemaData):
_type_string: ClassVar[str] = "datetime"

@override
def get_response_type_string(self) -> str:
# datetime field is ISO string in response
# https://github.com/yanyongyu/githubkit/issues/246
return "str"

@override
def get_model_imports(self) -> set[str]:
imports = super().get_model_imports()
Expand All @@ -288,11 +324,23 @@ def get_using_imports(self) -> set[str]:
imports.add("from datetime import datetime")
return imports

@override
def get_response_imports(self) -> set[str]:
imports = super().get_response_imports()
imports.add("from datetime import datetime")
return imports


@dataclass(kw_only=True)
class DateSchema(SchemaData):
_type_string: ClassVar[str] = "date"

@override
def get_response_type_string(self) -> str:
# date field is ISO string in response
# https://github.com/yanyongyu/githubkit/issues/246
return "str"

@override
def get_model_imports(self) -> set[str]:
imports = super().get_model_imports()
Expand All @@ -317,6 +365,12 @@ def get_using_imports(self) -> set[str]:
imports.add("from datetime import date")
return imports

@override
def get_response_imports(self) -> set[str]:
imports = super().get_response_imports()
imports.add("from datetime import date")
return imports


@dataclass(kw_only=True)
class FileSchema(SchemaData):
Expand Down Expand Up @@ -346,6 +400,12 @@ def get_using_imports(self) -> set[str]:
imports.add("from githubkit.typing import FileTypes")
return imports

@override
def get_response_imports(self) -> set[str]:
imports = super().get_response_imports()
imports.add("from githubkit.typing import FileTypes")
return imports


@dataclass(kw_only=True)
class ListSchema(SchemaData):
Expand All @@ -366,6 +426,10 @@ def get_type_string(self, include_constraints: bool = True) -> str:
def get_param_type_string(self) -> str:
return f"list[{self.item_schema.get_param_type_string()}]"

@override
def get_response_type_string(self) -> str:
return f"list[{self.item_schema.get_response_type_string()}]"

@override
def get_model_imports(self) -> set[str]:
imports = super().get_model_imports()
Expand All @@ -392,6 +456,13 @@ def get_using_imports(self) -> set[str]:
imports.update(self.item_schema.get_using_imports())
return imports

@override
def get_response_imports(self) -> set[str]:
imports = super().get_response_imports()
imports.add("from githubkit.compat import PYDANTIC_V2")
imports.update(self.item_schema.get_response_imports())
return imports

@override
def _get_field_args(self) -> dict[str, str]:
args = super()._get_field_args()
Expand Down Expand Up @@ -433,6 +504,10 @@ def get_type_string(self, include_constraints: bool = True) -> str:
def get_param_type_string(self) -> str:
return f"UniqueList[{self.item_schema.get_param_type_string()}]"

@override
def get_response_type_string(self) -> str:
return f"UniqueList[{self.item_schema.get_response_type_string()}]"

@override
def get_model_imports(self) -> set[str]:
imports = super().get_model_imports()
Expand Down Expand Up @@ -462,6 +537,13 @@ def get_using_imports(self) -> set[str]:
imports.update(self.item_schema.get_using_imports())
return imports

@override
def get_response_imports(self) -> set[str]:
# imports = super().get_response_imports()
imports = {"from githubkit.typing import UniqueList"}
imports.update(self.item_schema.get_response_imports())
return imports

@override
def _get_field_args(self) -> dict[str, str]:
args = super()._get_field_args()
Expand Down Expand Up @@ -511,6 +593,10 @@ def get_type_string(self, include_constraints: bool = True) -> str:
def get_param_type_string(self) -> str:
return f"Literal[{', '.join(repr(value) for value in self.values)}]"

@override
def get_response_type_string(self) -> str:
return self.get_param_type_string()

@override
def get_model_imports(self) -> set[str]:
imports = super().get_model_imports()
Expand All @@ -535,6 +621,12 @@ def get_using_imports(self) -> set[str]:
imports.add("from typing import Literal")
return imports

@override
def get_response_imports(self) -> set[str]:
imports = super().get_response_imports()
imports.add("from typing import Literal")
return imports


@dataclass(kw_only=True)
class ModelSchema(SchemaData):
Expand All @@ -552,8 +644,46 @@ def get_type_string(self, include_constraints: bool = True) -> str:

@override
def get_param_type_string(self) -> str:
"""Get type string used by model type class name and client request codegen.

Example:

```python
class ModelType(TypedDict):
...

class Client:
def create_xxx(
*,
data: ModelType,
) -> Response[Model, ModelResponseType]:
...
```
"""
return f"{self.class_name}Type"

@override
def get_response_type_string(self) -> str:
"""Get type string used by model resposne type class name
and client response codegen.

Example:

```python
class ModelResponseType(TypedDict):
...

class Client:
def create_xxx(
*,
data: ModelType,
) -> Response[Model, ModelResponseType]:
...
```
"""
# `XXXResponseType` has name conflicts in definition
return f"{self.class_name}TypeForResponse"

@override
def get_model_imports(self) -> set[str]:
imports = super().get_model_imports()
Expand Down Expand Up @@ -583,6 +713,10 @@ def get_param_imports(self) -> set[str]:
def get_using_imports(self) -> set[str]:
return {f"from ..models import {self.class_name}"}

@override
def get_response_imports(self) -> set[str]:
return {f"from ..types import {self.get_response_type_string()}"}

@override
def get_model_dependencies(self) -> list["ModelSchema"]:
result: list[ModelSchema] = []
Expand Down Expand Up @@ -624,6 +758,15 @@ def get_param_type_string(self) -> str:
types = ", ".join(schema.get_param_type_string() for schema in self.schemas)
return f"Union[{types}]"

@override
def get_response_type_string(self) -> str:
if len(self.schemas) == 0:
return "Any"
elif len(self.schemas) == 1:
return self.schemas[0].get_response_type_string()
types = ", ".join(schema.get_response_type_string() for schema in self.schemas)
return f"Union[{types}]"

@override
def get_model_imports(self) -> set[str]:
imports = super().get_model_imports()
Expand Down Expand Up @@ -656,6 +799,15 @@ def get_using_imports(self) -> set[str]:
imports.update(schema.get_using_imports())
return imports

@override
def get_response_imports(self) -> set[str]:
imports = super().get_response_imports()
imports.add("from typing import Union")
for schema in self.schemas:
imports.update(schema.get_response_imports())
return imports

@override
def _get_field_args(self) -> dict[str, str]:
args = super()._get_field_args()
if self.discriminator:
Expand Down
2 changes: 1 addition & 1 deletion codegen/templates/models/group.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class {{ model.class_name }}({{ "ExtraGitHubModel" if model.allow_extra else "Gi
{{ build_model_docstring(model) | indent(4) }}

{% for prop in model.properties %}
{{ prop.get_model_defination() }}
{{ prop.get_model_definition() }}
{% endfor %}

{% endfor %}
Expand Down
Loading
Loading