diff --git a/gitlab/v4/objects/appearance.py b/gitlab/v4/objects/appearance.py index 0639c13fa..f6643f40d 100644 --- a/gitlab/v4/objects/appearance.py +++ b/gitlab/v4/objects/appearance.py @@ -61,4 +61,4 @@ def update( def get( self, id: Optional[Union[int, str]] = None, **kwargs: Any ) -> Optional[ApplicationAppearance]: - return cast(ApplicationAppearance, super().get(id=id, **kwargs)) + return cast(Optional[ApplicationAppearance], super().get(id=id, **kwargs)) diff --git a/gitlab/v4/objects/export_import.py b/gitlab/v4/objects/export_import.py index 7e01f47f9..6bba322a2 100644 --- a/gitlab/v4/objects/export_import.py +++ b/gitlab/v4/objects/export_import.py @@ -27,7 +27,7 @@ class GroupExportManager(GetWithoutIdMixin, CreateMixin, RESTManager): def get( self, id: Optional[Union[int, str]] = None, **kwargs: Any ) -> Optional[GroupExport]: - return cast(GroupExport, super().get(id=id, **kwargs)) + return cast(Optional[GroupExport], super().get(id=id, **kwargs)) class GroupImport(RESTObject): @@ -42,7 +42,7 @@ class GroupImportManager(GetWithoutIdMixin, RESTManager): def get( self, id: Optional[Union[int, str]] = None, **kwargs: Any ) -> Optional[GroupImport]: - return cast(GroupImport, super().get(id=id, **kwargs)) + return cast(Optional[GroupImport], super().get(id=id, **kwargs)) class ProjectExport(DownloadMixin, RefreshMixin, RESTObject): @@ -58,7 +58,7 @@ class ProjectExportManager(GetWithoutIdMixin, CreateMixin, RESTManager): def get( self, id: Optional[Union[int, str]] = None, **kwargs: Any ) -> Optional[ProjectExport]: - return cast(ProjectExport, super().get(id=id, **kwargs)) + return cast(Optional[ProjectExport], super().get(id=id, **kwargs)) class ProjectImport(RefreshMixin, RESTObject): @@ -73,4 +73,4 @@ class ProjectImportManager(GetWithoutIdMixin, RESTManager): def get( self, id: Optional[Union[int, str]] = None, **kwargs: Any ) -> Optional[ProjectImport]: - return cast(ProjectImport, super().get(id=id, **kwargs)) + return cast(Optional[ProjectImport], super().get(id=id, **kwargs)) diff --git a/gitlab/v4/objects/merge_request_approvals.py b/gitlab/v4/objects/merge_request_approvals.py index f05b9778e..2bbd39926 100644 --- a/gitlab/v4/objects/merge_request_approvals.py +++ b/gitlab/v4/objects/merge_request_approvals.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import Any, cast, Dict, List, Optional, TYPE_CHECKING, Union from gitlab import exceptions as exc from gitlab.base import RequiredOptional, RESTManager, RESTObject @@ -45,6 +45,11 @@ class ProjectApprovalManager(GetWithoutIdMixin, UpdateMixin, RESTManager): ) _update_uses_post = True + def get( + self, id: Optional[Union[int, str]] = None, **kwargs: Any + ) -> Optional[ProjectApproval]: + return cast(Optional[ProjectApproval], super().get(id=id, **kwargs)) + @exc.on_http_error(exc.GitlabUpdateError) def set_approvers( self, @@ -105,6 +110,11 @@ class ProjectMergeRequestApprovalManager(GetWithoutIdMixin, UpdateMixin, RESTMan _update_attrs = RequiredOptional(required=("approvals_required",)) _update_uses_post = True + def get( + self, id: Optional[Union[int, str]] = None, **kwargs: Any + ) -> Optional[ProjectMergeRequestApproval]: + return cast(Optional[ProjectMergeRequestApproval], super().get(id=id, **kwargs)) + @exc.on_http_error(exc.GitlabUpdateError) def set_approvers( self, @@ -241,3 +251,10 @@ class ProjectMergeRequestApprovalStateManager(GetWithoutIdMixin, RESTManager): _path = "/projects/{project_id}/merge_requests/{mr_iid}/approval_state" _obj_cls = ProjectMergeRequestApprovalState _from_parent_attrs = {"project_id": "project_id", "mr_iid": "iid"} + + def get( + self, id: Optional[Union[int, str]] = None, **kwargs: Any + ) -> Optional[ProjectMergeRequestApprovalState]: + return cast( + Optional[ProjectMergeRequestApprovalState], super().get(id=id, **kwargs) + ) diff --git a/gitlab/v4/objects/notification_settings.py b/gitlab/v4/objects/notification_settings.py index f1f7cce87..b5a37971e 100644 --- a/gitlab/v4/objects/notification_settings.py +++ b/gitlab/v4/objects/notification_settings.py @@ -1,3 +1,5 @@ +from typing import Any, cast, Optional, Union + from gitlab.base import RequiredOptional, RESTManager, RESTObject from gitlab.mixins import GetWithoutIdMixin, SaveMixin, UpdateMixin @@ -36,6 +38,11 @@ class NotificationSettingsManager(GetWithoutIdMixin, UpdateMixin, RESTManager): ), ) + def get( + self, id: Optional[Union[int, str]] = None, **kwargs: Any + ) -> Optional[NotificationSettings]: + return cast(Optional[NotificationSettings], super().get(id=id, **kwargs)) + class GroupNotificationSettings(NotificationSettings): pass @@ -46,6 +53,11 @@ class GroupNotificationSettingsManager(NotificationSettingsManager): _obj_cls = GroupNotificationSettings _from_parent_attrs = {"group_id": "id"} + def get( + self, id: Optional[Union[int, str]] = None, **kwargs: Any + ) -> Optional[GroupNotificationSettings]: + return cast(Optional[GroupNotificationSettings], super().get(id=id, **kwargs)) + class ProjectNotificationSettings(NotificationSettings): pass @@ -55,3 +67,8 @@ class ProjectNotificationSettingsManager(NotificationSettingsManager): _path = "/projects/{project_id}/notification_settings" _obj_cls = ProjectNotificationSettings _from_parent_attrs = {"project_id": "id"} + + def get( + self, id: Optional[Union[int, str]] = None, **kwargs: Any + ) -> Optional[ProjectNotificationSettings]: + return cast(Optional[ProjectNotificationSettings], super().get(id=id, **kwargs)) diff --git a/gitlab/v4/objects/pipelines.py b/gitlab/v4/objects/pipelines.py index fd597dad8..ac4290f25 100644 --- a/gitlab/v4/objects/pipelines.py +++ b/gitlab/v4/objects/pipelines.py @@ -246,3 +246,8 @@ class ProjectPipelineTestReportManager(GetWithoutIdMixin, RESTManager): _path = "/projects/{project_id}/pipelines/{pipeline_id}/test_report" _obj_cls = ProjectPipelineTestReport _from_parent_attrs = {"project_id": "project_id", "pipeline_id": "id"} + + def get( + self, id: Optional[Union[int, str]] = None, **kwargs: Any + ) -> Optional[ProjectPipelineTestReport]: + return cast(Optional[ProjectPipelineTestReport], super().get(id=id, **kwargs)) diff --git a/gitlab/v4/objects/push_rules.py b/gitlab/v4/objects/push_rules.py index 89c3e644a..b948a01fb 100644 --- a/gitlab/v4/objects/push_rules.py +++ b/gitlab/v4/objects/push_rules.py @@ -54,4 +54,4 @@ class ProjectPushRulesManager( def get( self, id: Optional[Union[int, str]] = None, **kwargs: Any ) -> Optional[ProjectPushRules]: - return cast(ProjectPushRules, super().get(id=id, **kwargs)) + return cast(Optional[ProjectPushRules], super().get(id=id, **kwargs)) diff --git a/gitlab/v4/objects/settings.py b/gitlab/v4/objects/settings.py index 0fb7f8a40..96f253939 100644 --- a/gitlab/v4/objects/settings.py +++ b/gitlab/v4/objects/settings.py @@ -118,4 +118,4 @@ def update( def get( self, id: Optional[Union[int, str]] = None, **kwargs: Any ) -> Optional[ApplicationSettings]: - return cast(ApplicationSettings, super().get(id=id, **kwargs)) + return cast(Optional[ApplicationSettings], super().get(id=id, **kwargs)) diff --git a/gitlab/v4/objects/statistics.py b/gitlab/v4/objects/statistics.py index 18b2be8c7..2941f9143 100644 --- a/gitlab/v4/objects/statistics.py +++ b/gitlab/v4/objects/statistics.py @@ -1,3 +1,5 @@ +from typing import Any, cast, Optional, Union + from gitlab.base import RESTManager, RESTObject from gitlab.mixins import GetWithoutIdMixin, RefreshMixin @@ -22,6 +24,11 @@ class ProjectAdditionalStatisticsManager(GetWithoutIdMixin, RESTManager): _obj_cls = ProjectAdditionalStatistics _from_parent_attrs = {"project_id": "id"} + def get( + self, id: Optional[Union[int, str]] = None, **kwargs: Any + ) -> Optional[ProjectAdditionalStatistics]: + return cast(Optional[ProjectAdditionalStatistics], super().get(id=id, **kwargs)) + class IssuesStatistics(RefreshMixin, RESTObject): _id_attr = None @@ -31,6 +38,11 @@ class IssuesStatisticsManager(GetWithoutIdMixin, RESTManager): _path = "/issues_statistics" _obj_cls = IssuesStatistics + def get( + self, id: Optional[Union[int, str]] = None, **kwargs: Any + ) -> Optional[IssuesStatistics]: + return cast(Optional[IssuesStatistics], super().get(id=id, **kwargs)) + class GroupIssuesStatistics(RefreshMixin, RESTObject): _id_attr = None @@ -41,6 +53,11 @@ class GroupIssuesStatisticsManager(GetWithoutIdMixin, RESTManager): _obj_cls = GroupIssuesStatistics _from_parent_attrs = {"group_id": "id"} + def get( + self, id: Optional[Union[int, str]] = None, **kwargs: Any + ) -> Optional[GroupIssuesStatistics]: + return cast(Optional[GroupIssuesStatistics], super().get(id=id, **kwargs)) + class ProjectIssuesStatistics(RefreshMixin, RESTObject): _id_attr = None @@ -50,3 +67,8 @@ class ProjectIssuesStatisticsManager(GetWithoutIdMixin, RESTManager): _path = "/projects/{project_id}/issues_statistics" _obj_cls = ProjectIssuesStatistics _from_parent_attrs = {"project_id": "id"} + + def get( + self, id: Optional[Union[int, str]] = None, **kwargs: Any + ) -> Optional[ProjectIssuesStatistics]: + return cast(Optional[ProjectIssuesStatistics], super().get(id=id, **kwargs)) diff --git a/gitlab/v4/objects/users.py b/gitlab/v4/objects/users.py index fac448aff..568e019da 100644 --- a/gitlab/v4/objects/users.py +++ b/gitlab/v4/objects/users.py @@ -3,7 +3,7 @@ https://docs.gitlab.com/ee/api/users.html https://docs.gitlab.com/ee/api/projects.html#list-projects-starred-by-a-user """ -from typing import Any, cast, Dict, List, Union +from typing import Any, cast, Dict, List, Optional, Union import requests @@ -120,6 +120,11 @@ class CurrentUserStatusManager(GetWithoutIdMixin, UpdateMixin, RESTManager): _obj_cls = CurrentUserStatus _update_attrs = RequiredOptional(optional=("emoji", "message")) + def get( + self, id: Optional[Union[int, str]] = None, **kwargs: Any + ) -> Optional[CurrentUserStatus]: + return cast(Optional[CurrentUserStatus], super().get(id=id, **kwargs)) + class CurrentUser(RESTObject): _id_attr = None @@ -135,6 +140,11 @@ class CurrentUserManager(GetWithoutIdMixin, RESTManager): _path = "/user" _obj_cls = CurrentUser + def get( + self, id: Optional[Union[int, str]] = None, **kwargs: Any + ) -> Optional[CurrentUser]: + return cast(Optional[CurrentUser], super().get(id=id, **kwargs)) + class User(SaveMixin, ObjectDeleteMixin, RESTObject): _short_print_attr = "username" @@ -390,6 +400,11 @@ class UserStatusManager(GetWithoutIdMixin, RESTManager): _obj_cls = UserStatus _from_parent_attrs = {"user_id": "id"} + def get( + self, id: Optional[Union[int, str]] = None, **kwargs: Any + ) -> Optional[UserStatus]: + return cast(Optional[UserStatus], super().get(id=id, **kwargs)) + class UserActivitiesManager(ListMixin, RESTManager): _path = "/user/activities" diff --git a/tests/meta/test_ensure_type_hints.py b/tests/meta/test_ensure_type_hints.py index a770afba3..2449324b3 100644 --- a/tests/meta/test_ensure_type_hints.py +++ b/tests/meta/test_ensure_type_hints.py @@ -4,8 +4,10 @@ Original notes by John L. Villalovos """ +import dataclasses +import functools import inspect -from typing import Tuple, Type +from typing import Optional, Type import _pytest @@ -13,6 +15,23 @@ import gitlab.v4.objects +@functools.total_ordering +@dataclasses.dataclass(frozen=True) +class ClassInfo: + name: str + type: Type + + def __lt__(self, other: object) -> bool: + if not isinstance(other, ClassInfo): + return NotImplemented + return (self.type.__module__, self.name) < (other.type.__module__, other.name) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ClassInfo): + return NotImplemented + return (self.type.__module__, self.name) == (other.type.__module__, other.name) + + def pytest_generate_tests(metafunc: _pytest.python.Metafunc) -> None: """Find all of the classes in gitlab.v4.objects and pass them to our test function""" @@ -35,38 +54,84 @@ def pytest_generate_tests(metafunc: _pytest.python.Metafunc) -> None: if not class_name.endswith("Manager"): continue - class_info_set.add((class_name, class_value)) + class_info_set.add(ClassInfo(name=class_name, type=class_value)) + + metafunc.parametrize("class_info", sorted(class_info_set)) - metafunc.parametrize("class_info", class_info_set) + +GET_ID_METHOD_TEMPLATE = """ +def get( + self, id: Union[str, int], lazy: bool = False, **kwargs: Any +) -> {obj_cls.__name__}: + return cast({obj_cls.__name__}, super().get(id=id, lazy=lazy, **kwargs)) + +You may also need to add the following imports: +from typing import Any, cast, Union" +""" + +GET_WITHOUT_ID_METHOD_TEMPLATE = """ +def get( + self, id: Optional[Union[int, str]] = None, **kwargs: Any +) -> Optional[{obj_cls.__name__}]: + return cast(Optional[{obj_cls.__name__}], super().get(id=id, **kwargs)) + +You may also need to add the following imports: +from typing import Any, cast, Optional, Union" +""" class TestTypeHints: - def test_check_get_function_type_hints(self, class_info: Tuple[str, Type]) -> None: + def test_check_get_function_type_hints(self, class_info: ClassInfo) -> None: """Ensure classes derived from GetMixin have defined a 'get()' method with correct type-hints. """ - class_name, class_value = class_info - if not class_name.endswith("Manager"): - return + self.get_check_helper( + base_type=gitlab.mixins.GetMixin, + class_info=class_info, + method_template=GET_ID_METHOD_TEMPLATE, + optional_return=False, + ) - mro = class_value.mro() + def test_check_get_without_id_function_type_hints( + self, class_info: ClassInfo + ) -> None: + """Ensure classes derived from GetMixin have defined a 'get()' method with + correct type-hints. + """ + self.get_check_helper( + base_type=gitlab.mixins.GetWithoutIdMixin, + class_info=class_info, + method_template=GET_WITHOUT_ID_METHOD_TEMPLATE, + optional_return=True, + ) + + def get_check_helper( + self, + *, + base_type: Type, + class_info: ClassInfo, + method_template: str, + optional_return: bool, + ) -> None: + if not class_info.name.endswith("Manager"): + return + mro = class_info.type.mro() # The class needs to be derived from GetMixin or we ignore it - if gitlab.mixins.GetMixin not in mro: + if base_type not in mro: return - obj_cls = class_value._obj_cls - signature = inspect.signature(class_value.get) - filename = inspect.getfile(class_value) + obj_cls = class_info.type._obj_cls + signature = inspect.signature(class_info.type.get) + filename = inspect.getfile(class_info.type) fail_message = ( - f"class definition for {class_name!r} in file {filename!r} " + f"class definition for {class_info.name!r} in file {filename!r} " f"must have defined a 'get' method with a return annotation of " f"{obj_cls} but found {signature.return_annotation}\n" f"Recommend adding the followinng method:\n" - f"def get(\n" - f" self, id: Union[str, int], lazy: bool = False, **kwargs: Any\n" - f" ) -> {obj_cls.__name__}:\n" - f" return cast({obj_cls.__name__}, super().get(id=id, lazy=lazy, " - f"**kwargs))\n" ) - assert obj_cls == signature.return_annotation, fail_message + fail_message += method_template.format(obj_cls=obj_cls) + check_type = obj_cls + if optional_return: + check_type = Optional[obj_cls] + assert check_type == signature.return_annotation, fail_message