Skip to content

Commit

Permalink
🎨 Fix type detection of select results in PyCharm
Browse files Browse the repository at this point in the history
  • Loading branch information
tiangolo committed Aug 25, 2021
1 parent af03df8 commit ce60193
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions sqlmodel/orm/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,49 +10,53 @@
from ..engine.result import Result, ScalarResult
from ..sql.base import Executable

_T = TypeVar("_T")
_TSelectParam = TypeVar("_TSelectParam")


class Session(_Session):
@overload
def exec(
self,
statement: Select[_T],
statement: Select[_TSelectParam],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> Union[Result[_T]]:
) -> Result[_TSelectParam]:
...

@overload
def exec(
self,
statement: SelectOfScalar[_T],
statement: SelectOfScalar[_TSelectParam],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> Union[ScalarResult[_T]]:
) -> ScalarResult[_TSelectParam]:
...

def exec(
self,
statement: Union[Select[_T], SelectOfScalar[_T], Executable[_T]],
statement: Union[
Select[_TSelectParam],
SelectOfScalar[_TSelectParam],
Executable[_TSelectParam],
],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> Union[Result[_T], ScalarResult[_T]]:
) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]:
results = super().execute(
statement,
params=params,
Expand Down Expand Up @@ -118,13 +122,13 @@ def query(self, *entities: Any, **kwargs: Any) -> "_Query[Any]":

def get(
self,
entity: Type[_T],
entity: Type[_TSelectParam],
ident: Any,
options: Optional[Sequence[Any]] = None,
populate_existing: bool = False,
with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None,
identity_token: Optional[Any] = None,
) -> Optional[_T]:
) -> Optional[_TSelectParam]:
return super().get(
entity,
ident,
Expand Down

0 comments on commit ce60193

Please sign in to comment.