diff --git a/patchwork/common/client/scm.py b/patchwork/common/client/scm.py index 22704f9a4..c0b6f47b4 100644 --- a/patchwork/common/client/scm.py +++ b/patchwork/common/client/scm.py @@ -4,26 +4,26 @@ import hashlib import itertools import time +from difflib import unified_diff from enum import Enum from itertools import chain from pathlib import Path -from urllib.parse import urlparse import git import gitlab.const -from attrs import define from azure.devops.connection import Connection from azure.devops.released.client_factory import ClientFactory from azure.devops.released.core.core_client import CoreClient from azure.devops.released.git.git_client import GitClient -from azure.devops.v7_1.git.models import GitPullRequest, GitPullRequestSearchCriteria, TeamProjectReference, GitRepository +from azure.devops.v7_1.git.models import GitPullRequest, GitPullRequestSearchCriteria, TeamProjectReference, \ + GitRepository, Comment, GitPullRequestCommentThread, GitTargetVersionDescriptor, GitBaseVersionDescriptor from github import Auth, Consts, Github, GithubException, PullRequest from github.GithubException import UnknownObjectException from gitlab import Gitlab, GitlabAuthenticationError, GitlabError from gitlab.v4.objects import ProjectMergeRequest from giturlparse import GitUrlParsed, parse from msrest.authentication import BasicAuthentication -from typing_extensions import Protocol, TypedDict +from typing_extensions import Protocol, TypedDict, Iterator from patchwork.logger import logger @@ -35,14 +35,6 @@ def get_slug_from_remote_url(remote_url: str) -> str: return slug -@define -class Comment: - path: str - body: str - start_line: int | None - end_line: int - - class IssueText(TypedDict): title: str body: str @@ -386,18 +378,67 @@ def url(self) -> str: def set_pr_description(self, body: str) -> None: final_body = PullRequestProtocol._apply_pr_template(self, body) body = GitPullRequest(description=final_body) - self.git_client.update_pull_request(body, repository_id=self._pr.repository.id, pull_request_id=self._pr.pull_request_id, project=self._pr.repository.project.id) + self._pr = self.git_client.update_pull_request(body, repository_id=self._pr.repository.id, pull_request_id=self._pr.pull_request_id, project=self._pr.repository.project.id) def create_comment( self, body: str, path: str | None = None, start_line: int | None = None, end_line: int | None = None ) -> str | None: - ... + try: + comment_body = Comment(content=body) + comment_thread_body = GitPullRequestCommentThread(comments=[comment_body]) + comment_thread = self.git_client.create_thread(comment_thread_body, repository_id=self._pr.repository.id, pull_request_id=self.id, project=self._pr.repository.project.id) + return body + except Exception as e: + logger.error(e) + return None + + def __iter_comments(self) -> Iterator[tuple[GitPullRequestCommentThread, list[Comment]]]: + threads = self.git_client.get_threads(repository_id=self._pr.repository.id, pull_request_id=self.id, project=self._pr.repository.project.id) + for thread in threads: + comments = self.git_client.get_comments(repository_id=self._pr.repository.id, pull_request_id=self.id, thread_id=thread.id, project=self._pr.repository.project.id) + yield thread, comments def reset_comments(self) -> None: - ... + for thread, comments in self.__iter_comments(): + comment_ids_to_delete = [] + for comment in comments: + if comment.content.startswith(_COMMENT_MARKER): + comment_ids_to_delete.append(comment.id) + if len(comment_ids_to_delete) == len(comments): + for comment_id in comment_ids_to_delete: + self.git_client.delete_comment(repository_id=self._pr.repository.id, pull_request_id=self.id, thread_id=thread.id, comment_id=comment_id, project=self._pr.repository.project.id) def texts(self) -> PullRequestTexts: - ... + target_branch = self._pr.last_merge_source_commit.commit_id + feature_branch = self._pr.last_merge_target_commit.commit_id + + repo = git.Repo(path=Path.cwd(), search_parent_directories=True) + for remote in repo.remotes: + remote.fetch() + target_commit = repo.commit(target_branch) + feature_commit = repo.commit(feature_branch) + + diff_index = feature_commit.diff(target_commit) + diffs = dict() + for diff in diff_index: + a_path = diff.a_path + b_path = diff.b_path + a_blob = diff.a_blob.data_stream.read().decode("utf-8") + b_blob = diff.b_blob.data_stream.read().decode("utf-8") + diff_lines = unified_diff(a_blob.splitlines(keepends=True), b_blob.splitlines(keepends=True), a_path, b_path) + diff_content = "".join(diff_lines) + diffs[a_path] = diff_content + + comments: list[PullRequestComment] = [] + for _, raw_comments in self.__iter_comments(): + for raw_comment in raw_comments: + comments.append(dict(user=raw_comment.author.display_name, body=raw_comment.content)) + return dict( + title=self._pr.title or "", + body=self._pr.description or "", + comments=comments, + diffs=diffs, + ) class GithubClient(ScmPlatformClientProtocol): DEFAULT_URL = Consts.DEFAULT_BASE_URL @@ -450,11 +491,11 @@ def find_issue_by_id(self, slug: str, issue_id: int) -> IssueText | None: logger.warn(f"Failed to get issue: {e}") return None - def get_pr_by_url(self, url: str) -> PullRequestProtocol | None: + def get_pr_by_url(self, url: str) -> GithubPullRequest | None: slug, pr_id = self.get_slug_and_id_from_url(url) return self.find_pr_by_id(slug, pr_id) - def find_pr_by_id(self, slug: str, pr_id: int) -> PullRequestProtocol | None: + def find_pr_by_id(self, slug: str, pr_id: int) -> GithubPullRequest | None: repo = self.github.get_repo(slug) try: pr = repo.get_pull(pr_id) @@ -508,7 +549,7 @@ def create_pr( body: str, original_branch: str, feature_branch: str, - ) -> PullRequestProtocol: + ) -> GithubPullRequest: # before creating a PR, check if one already exists repo = self.github.get_repo(slug) gh_pr = repo.create_pull(title=title, body=body, base=original_branch, head=feature_branch) @@ -579,11 +620,11 @@ def find_issue_by_id(self, slug: str, issue_id: int) -> IssueText | None: logger.warn(f"Failed to get issue: {e}") return None - def get_pr_by_url(self, url: str) -> PullRequestProtocol | None: + def get_pr_by_url(self, url: str) -> GitlabMergeRequest | None: slug, pr_id = self.get_slug_and_id_from_url(url) return self.find_pr_by_id(slug, pr_id) - def find_pr_by_id(self, slug: str, pr_id: int) -> PullRequestProtocol | None: + def find_pr_by_id(self, slug: str, pr_id: int) -> GitlabMergeRequest | None: project = self.gitlab.projects.get(slug) try: mr = project.mergerequests.get(pr_id) @@ -599,7 +640,7 @@ def find_prs( original_branch: str | None = None, feature_branch: str | None = None, limit: int | None = None, - ) -> list[PullRequestProtocol]: + ) -> list[GitlabMergeRequest]: project = self.gitlab.projects.get(slug) kwargs_list = dict(iterator=[True], state=[None], target_branch=[None], source_branch=[None]) @@ -630,7 +671,7 @@ def create_pr( body: str, original_branch: str, feature_branch: str, - ) -> PullRequestProtocol: + ) -> GitlabMergeRequest: # before creating a PR, check if one already exists project = self.gitlab.projects.get(slug) gl_mr = project.mergerequests.create( @@ -714,11 +755,27 @@ def set_url(self, url: str) -> None: self.__url = url def test(self) -> bool: - response = self.core_client.get_projects() - return next(iter(response), None) is not None + try: + proj = self.project + return True + except ValueError: + return False def get_slug_and_id_from_url(self, url: str) -> tuple[str, int] | None: - ... + url_parts = url.split("/") + if len(url_parts) == 1: + logger.error(f"Invalid URL: {url}") + return None + + try: + resource_id = int(url_parts[-1]) + except ValueError: + logger.error(f"Invalid URL: {url}") + return None + + slug = "/".join(url_parts[-6:-3]) + + return slug, resource_id def find_issue_by_url(self, url: str) -> IssueText | None: ... @@ -726,11 +783,13 @@ def find_issue_by_url(self, url: str) -> IssueText | None: def find_issue_by_id(self, slug: str, issue_id: int) -> IssueText | None: ... - def get_pr_by_url(self, url: str) -> PullRequestProtocol | None: - ... + def get_pr_by_url(self, url: str) -> AzureDevopsPullRequest | None: + slug, resource_id = self.get_slug_and_id_from_url(url) + return self.find_pr_by_id(slug, resource_id) - def find_pr_by_id(self, slug: str, pr_id: int) -> PullRequestProtocol | None: - ... + def find_pr_by_id(self, slug: str, pr_id: int) -> AzureDevopsPullRequest | None: + pr = self.git_client.get_pull_request(repository_id=self.repo.id, pull_request_id=pr_id, project=self.project.id) + return AzureDevopsPullRequest(pr, self.git_client, self.__pr_resource_html_url()) def find_prs( self, @@ -739,7 +798,7 @@ def find_prs( original_branch: str | None = None, feature_branch: str | None = None, limit: int | None = None, - ) -> list[PullRequestProtocol]: + ) -> list[AzureDevopsPullRequest]: kwargs_list = dict(status=[None], target_ref_name=[None], source_ref_name=[None]) if state is not None: @@ -777,7 +836,7 @@ def create_pr( body: str, original_branch: str, feature_branch: str, - ) -> PullRequestProtocol: + ) -> AzureDevopsPullRequest: # before creating a PR, check if one already exists pr_body = GitPullRequest( source_ref_name=f"refs/heads/{feature_branch}", diff --git a/patchwork/steps/CreatePRComment/CreatePRComment.py b/patchwork/steps/CreatePRComment/CreatePRComment.py index 5ca55b593..1af225b79 100644 --- a/patchwork/steps/CreatePRComment/CreatePRComment.py +++ b/patchwork/steps/CreatePRComment/CreatePRComment.py @@ -1,4 +1,4 @@ -from patchwork.common.client.scm import GithubClient, GitlabClient +from patchwork.common.client.scm import GithubClient, GitlabClient, AzureDevopsClient from patchwork.logger import logger from patchwork.step import Step, StepStatus @@ -15,6 +15,8 @@ def __init__(self, inputs: dict): self.scm_client = GithubClient(inputs["github_api_key"]) elif "gitlab_api_key" in inputs.keys(): self.scm_client = GitlabClient(inputs["gitlab_api_key"]) + elif "azuredevops_api_key" in inputs.keys(): + self.scm_client = AzureDevopsClient(inputs["azuredevops_api_key"]) else: raise ValueError(f'Missing required input data: "github_api_key" or "gitlab_api_key"') diff --git a/patchwork/steps/CreatePRComment/typed.py b/patchwork/steps/CreatePRComment/typed.py index b63d8b6bd..49a24a970 100644 --- a/patchwork/steps/CreatePRComment/typed.py +++ b/patchwork/steps/CreatePRComment/typed.py @@ -11,8 +11,9 @@ class __CreatePRCommentRequiredInputs(TypedDict): class CreatePRCommentInputs(__CreatePRCommentRequiredInputs, total=False): noisy_comments: Annotated[bool, StepTypeConfig(is_config=True)] scm_url: Annotated[str, StepTypeConfig(is_config=True)] - gitlab_api_key: Annotated[str, StepTypeConfig(is_config=True, or_op=["github_api_key"])] - github_api_key: Annotated[str, StepTypeConfig(is_config=True, or_op=["gitlab_api_key"])] + gitlab_api_key: Annotated[str, StepTypeConfig(is_config=True, or_op=["github_api_key", "azuredevops_api_key"])] + github_api_key: Annotated[str, StepTypeConfig(is_config=True, or_op=["gitlab_api_key", "azuredevops_api_key"])] + azuredevops_api_key: Annotated[str, StepTypeConfig(is_config=True, or_op=["gitlab_api_key", "github_api_key"])] class CreatePRCommentOutputs(TypedDict): diff --git a/patchwork/steps/ReadPRDiffs/ReadPRDiffs.py b/patchwork/steps/ReadPRDiffs/ReadPRDiffs.py index 45b733fee..554b31419 100644 --- a/patchwork/steps/ReadPRDiffs/ReadPRDiffs.py +++ b/patchwork/steps/ReadPRDiffs/ReadPRDiffs.py @@ -1,6 +1,6 @@ from typing_extensions import List -from patchwork.common.client.scm import GithubClient, GitlabClient +from patchwork.common.client.scm import GithubClient, GitlabClient, AzureDevopsClient from patchwork.step import Step from patchwork.steps.ReadPRDiffs.typed import ReadPRDiffsInputs, ReadPRDiffsOutputs @@ -26,17 +26,16 @@ def filter_by_extension(file, extensions): class ReadPRDiffs(Step, input_class=ReadPRDiffsInputs, output_class=ReadPRDiffsOutputs): - required_keys = {"pr_url"} def __init__(self, inputs: dict): super().__init__(inputs) - if not all(key in inputs.keys() for key in self.required_keys): - raise ValueError(f'Missing required data: "{self.required_keys}"') if "github_api_key" in inputs.keys(): self.scm_client = GithubClient(inputs["github_api_key"]) elif "gitlab_api_key" in inputs.keys(): self.scm_client = GitlabClient(inputs["gitlab_api_key"]) + elif "azuredevops_api_key" in inputs.keys(): + self.scm_client = AzureDevopsClient(inputs["azuredevops_api_key"]) else: raise ValueError(f'Missing required input data: "github_api_key" or "gitlab_api_key"') diff --git a/patchwork/steps/ReadPRDiffs/typed.py b/patchwork/steps/ReadPRDiffs/typed.py index 13f732259..93322fd95 100644 --- a/patchwork/steps/ReadPRDiffs/typed.py +++ b/patchwork/steps/ReadPRDiffs/typed.py @@ -10,8 +10,9 @@ class __ReadPRDiffsRequiredInputs(TypedDict): class ReadPRDiffsInputs(__ReadPRDiffsRequiredInputs, total=False): scm_url: Annotated[str, StepTypeConfig(is_config=True)] - gitlab_api_key: Annotated[str, StepTypeConfig(is_config=True, or_op=["github_api_key"])] - github_api_key: Annotated[str, StepTypeConfig(is_config=True, or_op=["gitlab_api_key"])] + gitlab_api_key: Annotated[str, StepTypeConfig(is_config=True, or_op=["github_api_key", "azuredevops_api_key"])] + github_api_key: Annotated[str, StepTypeConfig(is_config=True, or_op=["gitlab_api_key", "azuredevops_api_key"])] + azuredevops_api_key: Annotated[str, StepTypeConfig(is_config=True, or_op=["gitlab_api_key", "github_api_key"])] class ReadPRDiffsOutputs(TypedDict): diff --git a/pyproject.toml b/pyproject.toml index 3cd563b39..4c9353170 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "patchwork-cli" -version = "0.0.85" +version = "0.0.86" description = "" authors = ["patched.codes"] license = "AGPL" diff --git a/tests/common/test_app.py b/tests/common/test_app.py index 3e4ea8bb3..f36e75fd8 100644 --- a/tests/common/test_app.py +++ b/tests/common/test_app.py @@ -48,6 +48,7 @@ def test_default_list_option_callback(runner): == """\ AutoFix DependencyUpgrade +GenerateCodeUsageExample GenerateDiagram GenerateDocstring GenerateREADME @@ -68,6 +69,7 @@ def test_config_list_option_callback(runner, config_dir, patchflow_file): == f"""\ AutoFix DependencyUpgrade +GenerateCodeUsageExample GenerateDiagram GenerateDocstring GenerateREADME