Skip to content

Commit

Permalink
wip: Add more types
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahmonod committed Oct 2, 2023
1 parent aaaa42d commit bc83ae7
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 51 deletions.
97 changes: 69 additions & 28 deletions src/github3/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,13 @@ class GitHub(models.GitHubCore):
"""

def __init__(
self, username="", password="", token="", session=None, api_version=""
):
self,
username: str = "",
password: str = "",
token: str = "",
session: t.Optional[session.GitHubSession] = None,
api_version: str = "",
) -> None:
"""Create a new GitHub instance to talk to the API.
:param str api_version:
Expand All @@ -79,13 +84,15 @@ def __init__(
elif username and password:
self.login(username, password)

def _repr(self):
def _repr(self) -> str:
if self.session.auth:
return f"<GitHub [{self.session.auth!r}]>"
return f"<Anonymous GitHub at 0x{id(self):x}>"

@requires_auth
def activate_membership(self, organization):
def activate_membership(
self, organization: t.Union[str, orgs.Organization]
) -> t.Optional[orgs.Membership]:
"""Activate the membership to an organization.
:param organization:
Expand All @@ -109,7 +116,9 @@ def activate_membership(self, organization):
return self._instance_or_null(orgs.Membership, _json)

@requires_auth
def add_email_addresses(self, addresses=[]):
def add_email_addresses(
self, addresses: t.Optional[t.List[str]] = None
) -> t.Sequence[users.Email]:
"""Add the addresses to the authenticated user's account.
:param list addresses:
Expand All @@ -123,9 +132,15 @@ def add_email_addresses(self, addresses=[]):
if addresses:
url = self._build_url("user", "emails")
json = self._json(self._post(url, data=addresses), 201)
return [users.Email(email, self) for email in json] if json else []
return (
[users.Email(email, self.session) for email in json]
if json
else []
)

def all_events(self, number=-1, etag=None):
def all_events(
self, number: int = -1, etag: t.Optional[str] = None
) -> t.Iterable[events.Event]:
"""Iterate over public events.
:param int number:
Expand All @@ -142,8 +157,12 @@ def all_events(self, number=-1, etag=None):
return self._iter(int(number), url, events.Event, etag=etag)

def all_organizations(
self, number=-1, since=None, etag=None, per_page=None
):
self,
number: int = -1,
since: t.Optional[int] = None,
etag: t.Optional[str] = None,
per_page: t.Optional[int] = None,
) -> t.Iterable[orgs.ShortOrganization]:
"""Iterate over every organization in the order they were created.
:param int number:
Expand Down Expand Up @@ -171,8 +190,12 @@ def all_organizations(
)

def all_repositories(
self, number=-1, since=None, etag=None, per_page=None
):
self,
number: int = -1,
since: t.Optional[int] = None,
etag: t.Optional[str] = None,
per_page: t.Optional[int] = None,
) -> t.Iterable[repo.Repository]:
"""Iterate over every repository in the order they were created.
:param int number:
Expand All @@ -199,7 +222,13 @@ def all_repositories(
etag=etag,
)

def all_users(self, number=-1, etag=None, per_page=None, since=None):
def all_users(
self,
number: int = -1,
etag: t.Optional[str] = None,
per_page: t.Optional[int] = None,
since: t.Optional[int] = None,
) -> t.Iterable[users.ShortUser]:
"""Iterate over every user in the order they signed up for GitHub.
.. versionchanged:: 1.0.0
Expand Down Expand Up @@ -229,7 +258,7 @@ def all_users(self, number=-1, etag=None, per_page=None, since=None):
params={"per_page": per_page, "since": since},
)

def app(self, app_slug):
def app(self, app_slug: t.Any) -> t.Optional[apps.App]:
"""Retrieve information about a specific app using its "slug".
.. versionadded:: 1.2.0
Expand All @@ -256,7 +285,9 @@ def app(self, app_slug):
return self._instance_or_null(apps.App, json)

@decorators.requires_app_bearer_auth
def app_installation(self, installation_id):
def app_installation(
self, installation_id: int
) -> t.Optional[apps.Installation]:
"""Retrieve a specific App installation by its ID.
.. versionadded: 1.2.0
Expand All @@ -283,7 +314,9 @@ def app_installation(self, installation_id):
return self._instance_or_null(apps.Installation, json)

@decorators.requires_app_bearer_auth
def app_installations(self, number=-1):
def app_installations(
self, number: int = -1
) -> t.Iterable[apps.Installation]:
"""Retrieve the list of installations for the authenticated app.
.. versionadded:: 1.2.0
Expand All @@ -310,7 +343,9 @@ def app_installations(self, number=-1):
)

@decorators.requires_app_bearer_auth
def app_installation_for_organization(self, organization):
def app_installation_for_organization(
self, organization: str
) -> t.Iterable[apps.Installation]:
"""Retrieve an App installation for a specific organization.
.. versionadded:: 1.2.0
Expand All @@ -337,7 +372,9 @@ def app_installation_for_organization(self, organization):
return self._instance_or_null(apps.Installation, json)

@decorators.requires_app_bearer_auth
def app_installation_for_repository(self, owner, repository):
def app_installation_for_repository(
self, owner: str, repository: str
) -> t.Optional[apps.Installation]:
"""Retrieve an App installation for a specific repository.
.. versionadded:: 1.2.0
Expand Down Expand Up @@ -1283,7 +1320,7 @@ def issues_on(
assignee=None,
mentioned=None,
labels=None,
sort=None,
sort: str | None = None,
direction=None,
since=None,
number=-1,
Expand Down Expand Up @@ -2037,13 +2074,13 @@ def repositories(

def repositories_by(
self,
username,
type=None,
sort=None,
direction=None,
number=-1,
etag=None,
):
username: str,
type: t.Optional[str] = None,
sort: t.Optional[str] = None,
direction: t.Optional[str] = None,
number: int = -1,
etag: t.Optional[str] = None,
) -> t.Iterable[repo.ShortRepository]:
"""List public repositories for the specified ``username``.
.. versionadded:: 0.6
Expand Down Expand Up @@ -2073,7 +2110,7 @@ def repositories_by(
"""
url = self._build_url("users", username, "repos")

params = {}
params: t.MutableMapping[str, t.Union[int, str, None]] = {}
if type in ("all", "owner", "member"):
params.update(type=type)
if sort in ("created", "updated", "pushed", "full_name"):
Expand All @@ -2085,7 +2122,9 @@ def repositories_by(
int(number), url, repo.ShortRepository, params, etag
)

def repository(self, owner, repository):
def repository(
self, owner: str, repository: str
) -> t.Optional[repo.Repository]:
"""Retrieve the desired repository.
:param str owner:
Expand All @@ -2104,7 +2143,9 @@ def repository(self, owner, repository):
return self._instance_or_null(repo.Repository, json)

@requires_auth
def repository_invitations(self, number=-1, etag=None):
def repository_invitations(
self, number: int = -1, etag: t.Optional[str] = None
) -> t.Iterator[invitation.Invitation]:
"""Iterate over the repository invitations for the current user.
:param int number:
Expand Down
2 changes: 1 addition & 1 deletion src/github3/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _repr(self):
return f"<github3-core at 0x{id(self):x}>"

@staticmethod
def _remove_none(data):
def _remove_none(data: t.Optional[t.MutableMapping[str, t.Any]]):
if not data:
return
for k, v in list(data.items()):
Expand Down
4 changes: 2 additions & 2 deletions src/github3/pulls.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def _update_attributes(self, pull):
self.created_at = self._strptime(pull["created_at"])
self.diff_url = pull["diff_url"]
self.head = Head(pull["head"], self)
self.html_url = pull["html_url"]
self.html_url: str = pull["html_url"]
self.id = pull["id"]
self.issue_url = pull["issue_url"]
self.links = pull["_links"]
Expand Down Expand Up @@ -992,7 +992,7 @@ def _update_attributes(self, review):
# PullReview.
self.commit_id = review.get("commit_id", None)
self.html_url = review["html_url"]
self.user = users.ShortUser(review["user"], self)
self.user = users.ShortUser(review["user"], self.session)
self.state = review["state"]
self.submitted_at = self._strptime(review.get("submitted_at"))
self.pull_request_url = review["pull_request_url"]
Expand Down
48 changes: 28 additions & 20 deletions src/github3/repos/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
import base64
import json as jsonlib
import typing as t

import uritemplate as urit # type: ignore

Expand Down Expand Up @@ -91,7 +92,7 @@ def _update_attributes(self, repo):
self.milestones_urlt = urit.URITemplate(repo["milestones_url"])
self.name = repo["name"]
self.notifications_urlt = urit.URITemplate(repo["notifications_url"])
self.owner = users.ShortUser(repo["owner"], self)
self.owner = users.ShortUser(repo["owner"], self.session)
self.private = repo["private"]
self.pulls_urlt = urit.URITemplate(repo["pulls_url"])
self.releases_urlt = urit.URITemplate(repo["releases_url"])
Expand All @@ -103,13 +104,13 @@ def _update_attributes(self, repo):
self.teams_url = repo["teams_url"]
self.trees_urlt = urit.URITemplate(repo["trees_url"])

def _repr(self):
def _repr(self) -> str:
return f"<{self.class_name} [{self}]>"

def __str__(self):
def __str__(self) -> str:
return self.full_name

def _create_pull(self, data):
def _create_pull(self, data: t.MutableMapping[str, t.Any] | None):
self._remove_none(data)
json = None
if data:
Expand Down Expand Up @@ -543,7 +544,9 @@ def contributor_statistics(self, number=-1, etag=None):
url = self._build_url("stats", "contributors", base_url=self._api)
return self._iter(int(number), url, stats.ContributorStats, etag=etag)

def contributors(self, anon=False, number=-1, etag=None):
def contributors(
self, anon: bool = False, number: int = -1, etag: str | None = None
) -> t.Iterable[users.Contributor]:
"""Iterate over the contributors to this repository.
:param bool anon:
Expand All @@ -559,7 +562,7 @@ def contributors(self, anon=False, number=-1, etag=None):
:class:`~github3.users.Contributor`
"""
url = self._build_url("contributors", base_url=self._api)
params = {}
params: t.MutableMapping[str, int | str | None] = {}
if anon:
params = {"anon": "true"}
return self._iter(int(number), url, users.Contributor, params, etag)
Expand Down Expand Up @@ -940,16 +943,21 @@ def create_file(
self._remove_none(data)
json = self._json(self._put(url, data=jsonlib.dumps(data)), 201)
if json and "content" in json and "commit" in json:
json["content"] = contents.Contents(json["content"], self)
json["commit"] = git.Commit(json["commit"], self)
json["content"] = contents.Contents(
json["content"], self.session
)
json["commit"] = git.Commit(json["commit"], self.session)
return json

@decorators.requires_auth
def create_fork(self, organization=None):
def create_fork(
self, organization: str | None = None
) -> t.Optional["Repository"]:
"""Create a fork of this repository.
:param str organization:
(required), login for organization to create the fork under
(optional), login for organization to create the fork under. If
omitted the fork will be a personal fork.
:returns:
the fork of this repository
:rtype:
Expand Down Expand Up @@ -991,7 +999,7 @@ def create_hook(self, name, config, events=["push"], active=True):
"active": active,
}
json = self._json(self._post(url, data=data), 201)
return hook.Hook(json, self) if json else None
return hook.Hook(json, self.session) if json else None

@decorators.requires_auth
def create_issue(
Expand Down Expand Up @@ -2422,14 +2430,14 @@ def pull_request(self, number):

def pull_requests(
self,
state=None,
head=None,
base=None,
sort="created",
direction="desc",
number=-1,
etag=None,
):
state: str | None = None,
head: str | None = None,
base: str | None = None,
sort: str = "created",
direction: str = "desc",
number: int = -1,
etag: str | None = None,
) -> t.Iterable["pulls.ShortPullRequest"]:
"""List pull requests on repository.
.. versionchanged:: 0.9.0
Expand Down Expand Up @@ -2466,7 +2474,7 @@ def pull_requests(
:class:`~github3.pulls.ShortPullRequest`
"""
url = self._build_url("pulls", base_url=self._api)
params = {}
params: t.MutableMapping[str, int | str | None] = {}

if state:
state = state.lower()
Expand Down

0 comments on commit bc83ae7

Please sign in to comment.