Skip to content

feat: implement PR combine #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 7, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -8,13 +8,13 @@ require github.com/briandowns/spinner v1.23.2

require (
github.com/cli/go-gh/v2 v2.12.0
github.com/cli/shurcooL-graphql v0.0.4
github.com/spf13/cobra v1.9.1
)

require (
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/cli/safeexec v1.0.0 // indirect
github.com/cli/shurcooL-graphql v0.0.4 // indirect
github.com/fatih/color v1.7.0 // indirect
github.com/henvic/httpretty v0.0.6 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
203 changes: 203 additions & 0 deletions internal/cmd/combine_prs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
package cmd

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"

"github.com/cli/go-gh/v2/pkg/api"
)

func CombinePRs(ctx context.Context, graphQlClient *api.GraphQLClient, restClient *api.RESTClient, owner, repo string, matchedPRs []struct {
Number int
Title string
Branch string
Base string
BaseSHA string
}) error {
// Define the combined branch name
workingBranchName := combineBranchName + workingBranchSuffix

baseBranchSHA, err := getBranchSHA(ctx, restClient, owner, repo, baseBranch)
if err != nil {
return fmt.Errorf("failed to get SHA of main branch: %w", err)
}

// Delete any pre-existing working branch
err = deleteBranch(ctx, restClient, owner, repo, workingBranchName)
if err != nil {
Logger.Debug("Working branch not found, continuing", "branch", workingBranchName)
}

// Delete any pre-existing combined branch
err = deleteBranch(ctx, restClient, owner, repo, combineBranchName)
if err != nil {
Logger.Debug("Combined branch not found, continuing", "branch", combineBranchName)
}

// Create the combined branch
err = createBranch(ctx, restClient, owner, repo, combineBranchName, baseBranchSHA)
if err != nil {
return fmt.Errorf("failed to create combined branch: %w", err)
}

// Create the working branch
err = createBranch(ctx, restClient, owner, repo, workingBranchName, baseBranchSHA)
if err != nil {
return fmt.Errorf("failed to create working branch: %w", err)
}

// Merge all PR branches into the working branch
var combinedPRs []string
var mergeFailedPRs []string
for _, pr := range matchedPRs {
err := mergeBranch(ctx, restClient, owner, repo, workingBranchName, pr.Branch)
if err != nil {
Logger.Warn("Failed to merge branch", "branch", pr.Branch, "error", err)
mergeFailedPRs = append(mergeFailedPRs, fmt.Sprintf("#%d", pr.Number))
} else {
Logger.Info("Merged branch", "branch", pr.Branch)
combinedPRs = append(combinedPRs, fmt.Sprintf("#%d - %s", pr.Number, pr.Title))
}
}

// Update the combined branch to the latest commit of the working branch
err = updateRef(ctx, restClient, owner, repo, combineBranchName, workingBranchName)
if err != nil {
return fmt.Errorf("failed to update combined branch: %w", err)
}

// Delete the temporary working branch
err = deleteBranch(ctx, restClient, owner, repo, workingBranchName)
if err != nil {
Logger.Warn("Failed to delete working branch", "branch", workingBranchName, "error", err)
}

// Create the combined PR
prBody := generatePRBody(combinedPRs, mergeFailedPRs)
prTitle := "Combined PRs"
err = createPullRequest(ctx, restClient, owner, repo, prTitle, combineBranchName, baseBranch, prBody)
if err != nil {
return fmt.Errorf("failed to create combined PR: %w", err)
}

return nil
}

// Get the SHA of a given branch
func getBranchSHA(ctx context.Context, client *api.RESTClient, owner, repo, branch string) (string, error) {
var ref struct {
Object struct {
SHA string `json:"sha"`
} `json:"object"`
}
endpoint := fmt.Sprintf("repos/%s/%s/git/ref/heads/%s", owner, repo, branch)
err := client.Get(endpoint, &ref)
if err != nil {
return "", fmt.Errorf("failed to get SHA of branch %s: %w", branch, err)
}
return ref.Object.SHA, nil
}

// generatePRBody generates the body for the combined PR
func generatePRBody(combinedPRs, mergeFailedPRs []string) string {
body := "✅ The following pull requests have been successfully combined:\n"
for _, pr := range combinedPRs {
body += "- " + pr + "\n"
}
if len(mergeFailedPRs) > 0 {
body += "\n⚠️ The following pull requests could not be merged due to conflicts:\n"
for _, pr := range mergeFailedPRs {
body += "- " + pr + "\n"
}
}
return body
}

// deleteBranch deletes a branch in the repository
func deleteBranch(ctx context.Context, client *api.RESTClient, owner, repo, branch string) error {
endpoint := fmt.Sprintf("repos/%s/%s/git/refs/heads/%s", owner, repo, branch)
return client.Delete(endpoint, nil)
}

// createBranch creates a new branch in the repository
func createBranch(ctx context.Context, client *api.RESTClient, owner, repo, branch, sha string) error {
endpoint := fmt.Sprintf("repos/%s/%s/git/refs", owner, repo)
payload := map[string]string{
"ref": "refs/heads/" + branch,
"sha": sha,
}
body, err := encodePayload(payload)
if err != nil {
return fmt.Errorf("failed to encode payload: %w", err)
}
return client.Post(endpoint, body, nil)
}

// mergeBranch merges a branch into the base branch
func mergeBranch(ctx context.Context, client *api.RESTClient, owner, repo, base, head string) error {
endpoint := fmt.Sprintf("repos/%s/%s/merges", owner, repo)
payload := map[string]string{
"base": base,
"head": head,
}
body, err := encodePayload(payload)
if err != nil {
return fmt.Errorf("failed to encode payload: %w", err)
}
return client.Post(endpoint, body, nil)
}

// updateRef updates a branch to point to the latest commit of another branch
func updateRef(ctx context.Context, client *api.RESTClient, owner, repo, branch, sourceBranch string) error {
// Get the SHA of the source branch
var ref struct {
Object struct {
SHA string `json:"sha"`
} `json:"object"`
}
endpoint := fmt.Sprintf("repos/%s/%s/git/ref/heads/%s", owner, repo, sourceBranch)
err := client.Get(endpoint, &ref)
if err != nil {
return fmt.Errorf("failed to get SHA of source branch: %w", err)
}

// Update the branch to point to the new SHA
endpoint = fmt.Sprintf("repos/%s/%s/git/refs/heads/%s", owner, repo, branch)
payload := map[string]interface{}{
"sha": ref.Object.SHA,
"force": true,
}
body, err := encodePayload(payload)
if err != nil {
return fmt.Errorf("failed to encode payload: %w", err)
}
return client.Patch(endpoint, body, nil)
}

// createPullRequest creates a new pull request
func createPullRequest(ctx context.Context, client *api.RESTClient, owner, repo, title, head, base, body string) error {
endpoint := fmt.Sprintf("repos/%s/%s/pulls", owner, repo)
payload := map[string]string{
"title": title,
"head": head,
"base": base,
"body": body,
}
requestBody, err := encodePayload(payload)
if err != nil {
return fmt.Errorf("failed to encode payload: %w", err)
}
return client.Post(endpoint, requestBody, nil)
}

// encodePayload encodes a payload as JSON and returns an io.Reader
func encodePayload(payload interface{}) (io.Reader, error) {
data, err := json.Marshal(payload)
if err != nil {
return nil, err
}
return bytes.NewReader(data), nil
}
167 changes: 167 additions & 0 deletions internal/cmd/match_criteria.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
package cmd

import (
"context"
"fmt"
"regexp"
"strings"

"github.com/cli/go-gh/v2/pkg/api"
graphql "github.com/cli/shurcooL-graphql"
)

// checks if a PR matches all filtering criteria
@@ -113,3 +118,165 @@ func labelsMatchCriteria(prLabels []struct{ Name string }) bool {

return true
}

// GraphQL response structure for PR status info
type prStatusResponse struct {
Data struct {
Repository struct {
PullRequest struct {
ReviewDecision string `json:"reviewDecision"`
Commits struct {
Nodes []struct {
Commit struct {
StatusCheckRollup *struct {
State string `json:"state"`
} `json:"statusCheckRollup"`
} `json:"commit"`
} `json:"nodes"`
} `json:"commits"`
} `json:"pullRequest"`
} `json:"repository"`
} `json:"data"`
Errors []struct {
Message string `json:"message"`
} `json:"errors,omitempty"`
}

// GetPRStatusInfo fetches both CI status and approval status using GitHub's GraphQL API
func GetPRStatusInfo(ctx context.Context, graphQlClient *api.GraphQLClient, owner, repo string, prNumber int) (*prStatusResponse, error) {
// Check for context cancellation
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
// Continue processing
}

// Define a struct with embedded graphql query
var query struct {
Repository struct {
PullRequest struct {
ReviewDecision string
Commits struct {
Nodes []struct {
Commit struct {
StatusCheckRollup *struct {
State string
}
}
}
} `graphql:"commits(last: 1)"`
} `graphql:"pullRequest(number: $prNumber)"`
} `graphql:"repository(owner: $owner, name: $repo)"`
}

// Prepare GraphQL query variables
variables := map[string]interface{}{
"owner": graphql.String(owner),
"repo": graphql.String(repo),
"prNumber": graphql.Int(prNumber),
}

// Execute GraphQL query
err := graphQlClient.Query("PullRequestStatus", &query, variables)
if err != nil {
return nil, fmt.Errorf("GraphQL query failed: %w", err)
}

// Convert to our response format
response := &prStatusResponse{}
response.Data.Repository.PullRequest.ReviewDecision = query.Repository.PullRequest.ReviewDecision

if len(query.Repository.PullRequest.Commits.Nodes) > 0 {
response.Data.Repository.PullRequest.Commits.Nodes = make([]struct {
Commit struct {
StatusCheckRollup *struct {
State string `json:"state"`
} `json:"statusCheckRollup"`
} `json:"commit"`
}, len(query.Repository.PullRequest.Commits.Nodes))

for i, node := range query.Repository.PullRequest.Commits.Nodes {
if node.Commit.StatusCheckRollup != nil {
response.Data.Repository.PullRequest.Commits.Nodes[i].Commit.StatusCheckRollup = &struct {
State string `json:"state"`
}{
State: node.Commit.StatusCheckRollup.State,
}
}
}
}

return response, nil
}

// PrMeetsRequirements checks if a PR meets additional requirements beyond basic criteria
func PrMeetsRequirements(ctx context.Context, graphQlClient *api.GraphQLClient, owner, repo string, prNumber int) (bool, error) {
// If no additional requirements are specified, the PR meets requirements
if !requireCI && !mustBeApproved {
return true, nil
}

// Fetch PR status info once
response, err := GetPRStatusInfo(ctx, graphQlClient, owner, repo, prNumber)
if err != nil {
return false, err
}

// Check CI status if required
if requireCI {
passing := isCIPassing(response)
if !passing {
return false, nil
}
}

// Check approval status if required
if mustBeApproved {
approved := isPRApproved(response)
if !approved {
return false, nil
}
}

return true, nil
}

// isCIPassing checks if the CI status is passing based on the response
func isCIPassing(response *prStatusResponse) bool {
commits := response.Data.Repository.PullRequest.Commits.Nodes
if len(commits) == 0 {
Logger.Debug("No commits found for PR")
return false
}

statusCheckRollup := commits[0].Commit.StatusCheckRollup
if statusCheckRollup == nil {
Logger.Debug("No status checks found for PR")
return true // If no checks defined, consider it passing
}

if statusCheckRollup.State != "SUCCESS" {
Logger.Debug("PR failed CI check", "status", statusCheckRollup.State)
return false
}

return true
}

// isPRApproved checks if the PR is approved based on the response
func isPRApproved(response *prStatusResponse) bool {
reviewDecision := response.Data.Repository.PullRequest.ReviewDecision
Logger.Debug("PR review decision", "decision", reviewDecision)

switch reviewDecision {
case "APPROVED":
return true
case "": // When no reviews are required
Logger.Debug("PR has no required reviewers")
return true // If no reviews required, consider it approved
default:
Logger.Debug("PR not approved", "decision", reviewDecision)
return false
}
}
Loading
Oops, something went wrong.
Loading
Oops, something went wrong.