Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 91 additions & 32 deletions patchwork/common/client/scm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,26 @@
import hashlib
import itertools
import time
from difflib import unified_diff
from enum import Enum
from itertools import chain
from pathlib import Path
from urllib.parse import urlparse

import git
import gitlab.const
from attrs import define
from azure.devops.connection import Connection
from azure.devops.released.client_factory import ClientFactory
from azure.devops.released.core.core_client import CoreClient
from azure.devops.released.git.git_client import GitClient
from azure.devops.v7_1.git.models import GitPullRequest, GitPullRequestSearchCriteria, TeamProjectReference, GitRepository
from azure.devops.v7_1.git.models import GitPullRequest, GitPullRequestSearchCriteria, TeamProjectReference, \
GitRepository, Comment, GitPullRequestCommentThread, GitTargetVersionDescriptor, GitBaseVersionDescriptor
from github import Auth, Consts, Github, GithubException, PullRequest
from github.GithubException import UnknownObjectException
from gitlab import Gitlab, GitlabAuthenticationError, GitlabError
from gitlab.v4.objects import ProjectMergeRequest
from giturlparse import GitUrlParsed, parse
from msrest.authentication import BasicAuthentication
from typing_extensions import Protocol, TypedDict
from typing_extensions import Protocol, TypedDict, Iterator

from patchwork.logger import logger

Expand All @@ -35,14 +35,6 @@ def get_slug_from_remote_url(remote_url: str) -> str:
return slug


@define
class Comment:
path: str
body: str
start_line: int | None
end_line: int


class IssueText(TypedDict):
title: str
body: str
Expand Down Expand Up @@ -386,18 +378,67 @@ def url(self) -> str:
def set_pr_description(self, body: str) -> None:
final_body = PullRequestProtocol._apply_pr_template(self, body)
body = GitPullRequest(description=final_body)
self.git_client.update_pull_request(body, repository_id=self._pr.repository.id, pull_request_id=self._pr.pull_request_id, project=self._pr.repository.project.id)
self._pr = self.git_client.update_pull_request(body, repository_id=self._pr.repository.id, pull_request_id=self._pr.pull_request_id, project=self._pr.repository.project.id)

def create_comment(
self, body: str, path: str | None = None, start_line: int | None = None, end_line: int | None = None
) -> str | None:
...
try:
comment_body = Comment(content=body)
comment_thread_body = GitPullRequestCommentThread(comments=[comment_body])
comment_thread = self.git_client.create_thread(comment_thread_body, repository_id=self._pr.repository.id, pull_request_id=self.id, project=self._pr.repository.project.id)
return body
except Exception as e:
logger.error(e)
return None

def __iter_comments(self) -> Iterator[tuple[GitPullRequestCommentThread, list[Comment]]]:
threads = self.git_client.get_threads(repository_id=self._pr.repository.id, pull_request_id=self.id, project=self._pr.repository.project.id)
for thread in threads:
comments = self.git_client.get_comments(repository_id=self._pr.repository.id, pull_request_id=self.id, thread_id=thread.id, project=self._pr.repository.project.id)
yield thread, comments

def reset_comments(self) -> None:
...
for thread, comments in self.__iter_comments():
comment_ids_to_delete = []
for comment in comments:
if comment.content.startswith(_COMMENT_MARKER):
comment_ids_to_delete.append(comment.id)
if len(comment_ids_to_delete) == len(comments):
for comment_id in comment_ids_to_delete:
self.git_client.delete_comment(repository_id=self._pr.repository.id, pull_request_id=self.id, thread_id=thread.id, comment_id=comment_id, project=self._pr.repository.project.id)

def texts(self) -> PullRequestTexts:
...
target_branch = self._pr.last_merge_source_commit.commit_id
feature_branch = self._pr.last_merge_target_commit.commit_id

repo = git.Repo(path=Path.cwd(), search_parent_directories=True)
for remote in repo.remotes:
remote.fetch()
target_commit = repo.commit(target_branch)
feature_commit = repo.commit(feature_branch)

diff_index = feature_commit.diff(target_commit)
diffs = dict()
for diff in diff_index:
a_path = diff.a_path
b_path = diff.b_path
a_blob = diff.a_blob.data_stream.read().decode("utf-8")
b_blob = diff.b_blob.data_stream.read().decode("utf-8")
diff_lines = unified_diff(a_blob.splitlines(keepends=True), b_blob.splitlines(keepends=True), a_path, b_path)
diff_content = "".join(diff_lines)
diffs[a_path] = diff_content

comments: list[PullRequestComment] = []
for _, raw_comments in self.__iter_comments():
for raw_comment in raw_comments:
comments.append(dict(user=raw_comment.author.display_name, body=raw_comment.content))
return dict(
title=self._pr.title or "",
body=self._pr.description or "",
comments=comments,
diffs=diffs,
)

class GithubClient(ScmPlatformClientProtocol):
DEFAULT_URL = Consts.DEFAULT_BASE_URL
Expand Down Expand Up @@ -450,11 +491,11 @@ def find_issue_by_id(self, slug: str, issue_id: int) -> IssueText | None:
logger.warn(f"Failed to get issue: {e}")
return None

def get_pr_by_url(self, url: str) -> PullRequestProtocol | None:
def get_pr_by_url(self, url: str) -> GithubPullRequest | None:
slug, pr_id = self.get_slug_and_id_from_url(url)
return self.find_pr_by_id(slug, pr_id)

def find_pr_by_id(self, slug: str, pr_id: int) -> PullRequestProtocol | None:
def find_pr_by_id(self, slug: str, pr_id: int) -> GithubPullRequest | None:
repo = self.github.get_repo(slug)
try:
pr = repo.get_pull(pr_id)
Expand Down Expand Up @@ -508,7 +549,7 @@ def create_pr(
body: str,
original_branch: str,
feature_branch: str,
) -> PullRequestProtocol:
) -> GithubPullRequest:
# before creating a PR, check if one already exists
repo = self.github.get_repo(slug)
gh_pr = repo.create_pull(title=title, body=body, base=original_branch, head=feature_branch)
Expand Down Expand Up @@ -579,11 +620,11 @@ def find_issue_by_id(self, slug: str, issue_id: int) -> IssueText | None:
logger.warn(f"Failed to get issue: {e}")
return None

def get_pr_by_url(self, url: str) -> PullRequestProtocol | None:
def get_pr_by_url(self, url: str) -> GitlabMergeRequest | None:
slug, pr_id = self.get_slug_and_id_from_url(url)
return self.find_pr_by_id(slug, pr_id)

def find_pr_by_id(self, slug: str, pr_id: int) -> PullRequestProtocol | None:
def find_pr_by_id(self, slug: str, pr_id: int) -> GitlabMergeRequest | None:
project = self.gitlab.projects.get(slug)
try:
mr = project.mergerequests.get(pr_id)
Expand All @@ -599,7 +640,7 @@ def find_prs(
original_branch: str | None = None,
feature_branch: str | None = None,
limit: int | None = None,
) -> list[PullRequestProtocol]:
) -> list[GitlabMergeRequest]:
project = self.gitlab.projects.get(slug)
kwargs_list = dict(iterator=[True], state=[None], target_branch=[None], source_branch=[None])

Expand Down Expand Up @@ -630,7 +671,7 @@ def create_pr(
body: str,
original_branch: str,
feature_branch: str,
) -> PullRequestProtocol:
) -> GitlabMergeRequest:
# before creating a PR, check if one already exists
project = self.gitlab.projects.get(slug)
gl_mr = project.mergerequests.create(
Expand Down Expand Up @@ -714,23 +755,41 @@ def set_url(self, url: str) -> None:
self.__url = url

def test(self) -> bool:
response = self.core_client.get_projects()
return next(iter(response), None) is not None
try:
proj = self.project
return True
except ValueError:
return False

def get_slug_and_id_from_url(self, url: str) -> tuple[str, int] | None:
...
url_parts = url.split("/")
if len(url_parts) == 1:
logger.error(f"Invalid URL: {url}")
return None

try:
resource_id = int(url_parts[-1])
except ValueError:
logger.error(f"Invalid URL: {url}")
return None

slug = "/".join(url_parts[-6:-3])

return slug, resource_id

def find_issue_by_url(self, url: str) -> IssueText | None:
...

def find_issue_by_id(self, slug: str, issue_id: int) -> IssueText | None:
...

def get_pr_by_url(self, url: str) -> PullRequestProtocol | None:
...
def get_pr_by_url(self, url: str) -> AzureDevopsPullRequest | None:
slug, resource_id = self.get_slug_and_id_from_url(url)
return self.find_pr_by_id(slug, resource_id)

def find_pr_by_id(self, slug: str, pr_id: int) -> PullRequestProtocol | None:
...
def find_pr_by_id(self, slug: str, pr_id: int) -> AzureDevopsPullRequest | None:
pr = self.git_client.get_pull_request(repository_id=self.repo.id, pull_request_id=pr_id, project=self.project.id)
return AzureDevopsPullRequest(pr, self.git_client, self.__pr_resource_html_url())

def find_prs(
self,
Expand All @@ -739,7 +798,7 @@ def find_prs(
original_branch: str | None = None,
feature_branch: str | None = None,
limit: int | None = None,
) -> list[PullRequestProtocol]:
) -> list[AzureDevopsPullRequest]:
kwargs_list = dict(status=[None], target_ref_name=[None], source_ref_name=[None])

if state is not None:
Expand Down Expand Up @@ -777,7 +836,7 @@ def create_pr(
body: str,
original_branch: str,
feature_branch: str,
) -> PullRequestProtocol:
) -> AzureDevopsPullRequest:
# before creating a PR, check if one already exists
pr_body = GitPullRequest(
source_ref_name=f"refs/heads/{feature_branch}",
Expand Down
4 changes: 3 additions & 1 deletion patchwork/steps/CreatePRComment/CreatePRComment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from patchwork.common.client.scm import GithubClient, GitlabClient
from patchwork.common.client.scm import GithubClient, GitlabClient, AzureDevopsClient
from patchwork.logger import logger
from patchwork.step import Step, StepStatus

Expand All @@ -15,6 +15,8 @@ def __init__(self, inputs: dict):
self.scm_client = GithubClient(inputs["github_api_key"])
elif "gitlab_api_key" in inputs.keys():
self.scm_client = GitlabClient(inputs["gitlab_api_key"])
elif "azuredevops_api_key" in inputs.keys():
self.scm_client = AzureDevopsClient(inputs["azuredevops_api_key"])
else:
raise ValueError(f'Missing required input data: "github_api_key" or "gitlab_api_key"')

Expand Down
5 changes: 3 additions & 2 deletions patchwork/steps/CreatePRComment/typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ class __CreatePRCommentRequiredInputs(TypedDict):
class CreatePRCommentInputs(__CreatePRCommentRequiredInputs, total=False):
noisy_comments: Annotated[bool, StepTypeConfig(is_config=True)]
scm_url: Annotated[str, StepTypeConfig(is_config=True)]
gitlab_api_key: Annotated[str, StepTypeConfig(is_config=True, or_op=["github_api_key"])]
github_api_key: Annotated[str, StepTypeConfig(is_config=True, or_op=["gitlab_api_key"])]
gitlab_api_key: Annotated[str, StepTypeConfig(is_config=True, or_op=["github_api_key", "azuredevops_api_key"])]
github_api_key: Annotated[str, StepTypeConfig(is_config=True, or_op=["gitlab_api_key", "azuredevops_api_key"])]
azuredevops_api_key: Annotated[str, StepTypeConfig(is_config=True, or_op=["gitlab_api_key", "github_api_key"])]


class CreatePRCommentOutputs(TypedDict):
Expand Down
7 changes: 3 additions & 4 deletions patchwork/steps/ReadPRDiffs/ReadPRDiffs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing_extensions import List

from patchwork.common.client.scm import GithubClient, GitlabClient
from patchwork.common.client.scm import GithubClient, GitlabClient, AzureDevopsClient
from patchwork.step import Step
from patchwork.steps.ReadPRDiffs.typed import ReadPRDiffsInputs, ReadPRDiffsOutputs

Expand All @@ -26,17 +26,16 @@ def filter_by_extension(file, extensions):


class ReadPRDiffs(Step, input_class=ReadPRDiffsInputs, output_class=ReadPRDiffsOutputs):
required_keys = {"pr_url"}

def __init__(self, inputs: dict):
super().__init__(inputs)
if not all(key in inputs.keys() for key in self.required_keys):
raise ValueError(f'Missing required data: "{self.required_keys}"')

if "github_api_key" in inputs.keys():
self.scm_client = GithubClient(inputs["github_api_key"])
elif "gitlab_api_key" in inputs.keys():
self.scm_client = GitlabClient(inputs["gitlab_api_key"])
elif "azuredevops_api_key" in inputs.keys():
self.scm_client = AzureDevopsClient(inputs["azuredevops_api_key"])
else:
raise ValueError(f'Missing required input data: "github_api_key" or "gitlab_api_key"')

Expand Down
5 changes: 3 additions & 2 deletions patchwork/steps/ReadPRDiffs/typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ class __ReadPRDiffsRequiredInputs(TypedDict):

class ReadPRDiffsInputs(__ReadPRDiffsRequiredInputs, total=False):
scm_url: Annotated[str, StepTypeConfig(is_config=True)]
gitlab_api_key: Annotated[str, StepTypeConfig(is_config=True, or_op=["github_api_key"])]
github_api_key: Annotated[str, StepTypeConfig(is_config=True, or_op=["gitlab_api_key"])]
gitlab_api_key: Annotated[str, StepTypeConfig(is_config=True, or_op=["github_api_key", "azuredevops_api_key"])]
github_api_key: Annotated[str, StepTypeConfig(is_config=True, or_op=["gitlab_api_key", "azuredevops_api_key"])]
azuredevops_api_key: Annotated[str, StepTypeConfig(is_config=True, or_op=["gitlab_api_key", "github_api_key"])]


class ReadPRDiffsOutputs(TypedDict):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "patchwork-cli"
version = "0.0.85"
version = "0.0.86"
description = ""
authors = ["patched.codes"]
license = "AGPL"
Expand Down
2 changes: 2 additions & 0 deletions tests/common/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_default_list_option_callback(runner):
== """\
AutoFix
DependencyUpgrade
GenerateCodeUsageExample
GenerateDiagram
GenerateDocstring
GenerateREADME
Expand All @@ -68,6 +69,7 @@ def test_config_list_option_callback(runner, config_dir, patchflow_file):
== f"""\
AutoFix
DependencyUpgrade
GenerateCodeUsageExample
GenerateDiagram
GenerateDocstring
GenerateREADME
Expand Down
Loading