Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Draft PR status in responses #927

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/github3/checks.py
Expand Up @@ -58,7 +58,10 @@ def to_pull(self):
"""
from . import pulls

json = self._json(self._get(self.url), 200)
json = self._json(
self._get(self.url, headers=pulls.PULLS_PREVIEW_HEADERS),
200,
)
return self._instance_or_null(pulls.PullRequest, json)

refresh = to_pull
Expand Down
5 changes: 4 additions & 1 deletion src/github3/events.py
Expand Up @@ -161,7 +161,10 @@ def to_pull(self):
"""
from . import pulls

json = self._json(self._get(self.url), 200)
json = self._json(
self._get(self.url, headers=pulls.PULLS_PREVIEW_HEADERS),
200,
)
return self._instance_or_null(pulls.PullRequest, json)

refresh = to_pull
Expand Down
5 changes: 4 additions & 1 deletion src/github3/github.py
Expand Up @@ -1837,7 +1837,10 @@ def pull_request(self, owner, repository, number):
url = self._build_url(
"repos", owner, repository, "pulls", str(number)
)
json = self._json(self._get(url), 200)
json = self._json(
self._get(url, headers=pulls.PULLS_PREVIEW_HEADERS),
200,
)
return self._instance_or_null(pulls.PullRequest, json)

def rate_limit(self):
Expand Down
8 changes: 7 additions & 1 deletion src/github3/issues/issue.py
Expand Up @@ -344,7 +344,13 @@ def pull_request(self):
if self.pull_request_urls is not None:
pull_request_url = self.pull_request_urls.get("url")
if pull_request_url:
json = self._json(self._get(pull_request_url), 200)
json = self._json(
self._get(
pull_request_url,
headers=pulls.PULLS_PREVIEW_HEADERS,
),
200,
)
return self._instance_or_null(pulls.PullRequest, json)

@requires_auth
Expand Down
3 changes: 2 additions & 1 deletion src/github3/projects.py
Expand Up @@ -558,7 +558,8 @@ def retrieve_pull_request_from_content(self):
parsed = self._uri_parse(self.content_url)
_, owner, repository, _, number = parsed.path[1:].split("/", 5)
resp = self._get(
self._build_url("repos", owner, repository, "pulls", number)
self._build_url("repos", owner, repository, "pulls", number),
headers=pulls.PULLS_PREVIEW_HEADERS,
)
json = self._json(resp, 200)
return self._instance_or_null(pulls.PullRequest, json)
38 changes: 36 additions & 2 deletions src/github3/pulls.py
Expand Up @@ -14,6 +14,10 @@
from .issues import Issue
from .issues.comment import IssueComment

PULLS_PREVIEW_HEADERS = {
"Accept": "application/vnd.github.shadow-cat-preview"
}


class PullDestination(models.GitHubCore):
"""The object that represents a pull request destination.
Expand Down Expand Up @@ -211,6 +215,7 @@ def _update_attributes(self, pull):
self.commits_url = pull["commits_url"]
self.created_at = self._strptime(pull["created_at"])
self.diff_url = pull["diff_url"]
self.draft = pull.get("draft", False)
self.head = Head(pull["head"], self)
self.html_url = pull["html_url"]
self.id = pull["id"]
Expand Down Expand Up @@ -322,7 +327,14 @@ def create_review_requests(self, reviewers=None, team_reviewers=None):
data["team_reviewers"] = [
getattr(t, "slug", t) for t in team_reviewers
]
json = self._json(self._post(url, data=data), 201)
json = self._json(
self._post(
url,
data=data,
headers=PULLS_PREVIEW_HEADERS,
),
201,
)
return self._instance_or_null(ShortPullRequest, json)

@requires_auth
Expand Down Expand Up @@ -641,7 +653,14 @@ def update(
self._remove_none(data)

if data:
json = self._json(self._patch(self._api, data=dumps(data)), 200)
json = self._json(
self._patch(
self._api,
data=dumps(data),
headers=PULLS_PREVIEW_HEADERS,
),
200,
)

if json:
self._update_attributes(json)
Expand Down Expand Up @@ -796,6 +815,21 @@ class ShortPullRequest(_PullRequest):

The URL to retrieve the diff for this pull request via the API.

.. attribute:: draft

.. versionadded:: 1.3.1

A boolean attribute indicating whether the pull request is a draft
or not.

.. note::

Draft status is only available for repositories with GitHub Free
and GitHubPro, and in public and private repositories with
GitHub Team, GitHub Enterprise CLoud, and GitHub Enterprise
omgjlk marked this conversation as resolved.
Show resolved Hide resolved
Server. Unless specific Draft state is provided by GitHub API
we will set the attribute to to ``False``.

.. attribute:: head

A :class:`~github3.pulls.Head` object representing the head pull
Expand Down
21 changes: 18 additions & 3 deletions src/github3/repos/repo.py
Expand Up @@ -117,7 +117,14 @@ def _create_pull(self, data):
json = None
if data:
url = self._build_url("pulls", base_url=self._api)
json = self._json(self._post(url, data=data), 201)
json = self._json(
self._post(
url,
data=data,
headers=pulls.PULLS_PREVIEW_HEADERS,
),
201,
)
return self._instance_or_null(pulls.ShortPullRequest, json)

@decorators.requires_auth
Expand Down Expand Up @@ -2349,7 +2356,10 @@ def pull_request(self, number):
json = None
if int(number) > 0:
url = self._build_url("pulls", str(number), base_url=self._api)
json = self._json(self._get(url), 200)
json = self._json(
self._get(url, headers=pulls.PULLS_PREVIEW_HEADERS),
200,
)
return self._instance_or_null(pulls.PullRequest, json)

def pull_requests(
Expand Down Expand Up @@ -2408,7 +2418,12 @@ def pull_requests(
params.update(head=head, base=base, sort=sort, direction=direction)
self._remove_none(params)
return self._iter(
int(number), url, pulls.ShortPullRequest, params, etag
int(number),
url,
pulls.ShortPullRequest,
params,
etag,
headers=pulls.PULLS_PREVIEW_HEADERS,
)

def readme(self):
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/test_github.py
@@ -1,6 +1,6 @@
import pytest

from github3 import GitHubEnterprise, GitHubError
from github3 import GitHubEnterprise, GitHubError, pulls
from github3.github import GitHub, GitHubStatus
from github3.projects import Project

Expand Down Expand Up @@ -540,7 +540,8 @@ def test_pull_request(self):
)

self.session.get.assert_called_once_with(
url_for("repos/octocat/hello-world/pulls/1")
url_for("repos/octocat/hello-world/pulls/1"),
headers=pulls.PULLS_PREVIEW_HEADERS,
)

def test_pull_request_negative_id(self):
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/test_issues_issue.py
Expand Up @@ -4,6 +4,7 @@
import dateutil.parser
import mock

from github3 import pulls
from github3.issues.label import Label
from github3.issues import Issue
from . import helper
Expand Down Expand Up @@ -270,7 +271,8 @@ def test_pull_request(self):
self.instance.pull_request()

self.session.get.assert_called_once_with(
self.instance.pull_request_urls["url"]
self.instance.pull_request_urls["url"],
headers=pulls.PULLS_PREVIEW_HEADERS,
)

def test_pull_request_without_urls(self):
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/test_projects.py
Expand Up @@ -8,6 +8,7 @@
from github3 import exceptions
from github3 import issues
from github3 import projects
from github3 import pulls


get_project_example_data = helper.create_example_data_helper(
Expand Down Expand Up @@ -297,7 +298,8 @@ def test_retrieve_pull_request_from_content(self):

self.session.get.assert_called_once_with(
"https://api.github.com/repos/api-playground/projects-test/"
"pulls/3"
"pulls/3",
headers=pulls.PULLS_PREVIEW_HEADERS,
)


Expand Down
7 changes: 6 additions & 1 deletion tests/unit/test_pulls.py
Expand Up @@ -40,6 +40,7 @@ def test_close(self):
"body": self.instance.body,
"state": "closed",
},
headers=pulls.PULLS_PREVIEW_HEADERS,
)

def test_create_comment(self):
Expand Down Expand Up @@ -70,7 +71,9 @@ def test_create_review_requests(self):
self.instance.create_review_requests(reviewers=["sigmavirus24"])

self.session.post.assert_called_once_with(
url_for("requested_reviewers"), '{"reviewers": ["sigmavirus24"]}'
url_for("requested_reviewers"),
'{"reviewers": ["sigmavirus24"]}',
headers=pulls.PULLS_PREVIEW_HEADERS,
)

def test_create_review(self):
Expand Down Expand Up @@ -165,6 +168,7 @@ def test_reopen(self):
"body": self.instance.body,
"state": "open",
},
headers=pulls.PULLS_PREVIEW_HEADERS,
)

def test_review_requests(self):
Expand All @@ -186,6 +190,7 @@ def test_update(self):
"body": "my new body",
"state": "open",
},
headers=pulls.PULLS_PREVIEW_HEADERS,
)

def test_attributes(self):
Expand Down
17 changes: 12 additions & 5 deletions tests/unit/test_repos_repo.py
Expand Up @@ -4,7 +4,7 @@
import pytest

from base64 import b64encode
from github3 import GitHubError
from github3 import GitHubError, pulls
from github3.exceptions import GitHubException
from github3.repos.comment import RepoComment
from github3.repos.commit import RepoCommit
Expand Down Expand Up @@ -324,7 +324,11 @@ def test_create_pull_private(self):
"""Verify the request for creating a pull request."""
data = {"title": "foo", "base": "master", "head": "feature_branch"}
self.instance._create_pull(data)
self.post_called_with(url_for("pulls"), data=data)
self.post_called_with(
url_for("pulls"),
data=data,
headers=pulls.PULLS_PREVIEW_HEADERS,
)

def test_create_pull(self):
"""Verify the request for creating a pull request."""
Expand Down Expand Up @@ -804,7 +808,10 @@ def test_project(self):
def test_pull_request(self):
"""Verify the request for retrieving a pull request."""
self.instance.pull_request(1)
self.session.get.assert_called_once_with(url_for("pulls/1"))
self.session.get.assert_called_once_with(
url_for("pulls/1"),
headers=pulls.PULLS_PREVIEW_HEADERS,
)

def test_pull_request_required_number(self):
"""Verify the request for retrieving a pull request."""
Expand Down Expand Up @@ -1273,7 +1280,7 @@ def test_pull_requests(self):
self.session.get.assert_called_once_with(
url_for("pulls"),
params={"per_page": 100, "sort": "created", "direction": "desc"},
headers={},
headers=pulls.PULLS_PREVIEW_HEADERS,
)

def test_pull_requests_ignore_invalid_state(self):
Expand All @@ -1284,7 +1291,7 @@ def test_pull_requests_ignore_invalid_state(self):
self.session.get.assert_called_once_with(
url_for("pulls"),
params={"per_page": 100, "sort": "created", "direction": "desc"},
headers={},
headers=pulls.PULLS_PREVIEW_HEADERS,
)

def test_refs(self):
Expand Down