Skip to content

feat: add labels and assignees to the resulting combined-pr #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -10,11 +10,13 @@ require (
github.com/cli/go-gh/v2 v2.12.0
github.com/cli/shurcooL-graphql v0.0.4
github.com/spf13/cobra v1.9.1
github.com/stretchr/testify v1.10.0
)

require (
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/cli/safeexec v1.0.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fatih/color v1.7.0 // indirect
github.com/henvic/httpretty v0.0.6 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
@@ -23,6 +25,7 @@ require (
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/spf13/pflag v1.0.6 // indirect
github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e // indirect
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -42,8 +42,8 @@ github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo=
github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0=
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e h1:BuzhfgfWQbX0dWzYzT1zsORLnHRv3bcRcsaUk0VmXA8=
github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e/go.mod h1:/Tnicc6m/lsJE0irFMA0LfIwTBo4QP7A8IfyIv4zZKI=
golang.org/x/sys v0.0.0-20210831042530-f4d43177bf5e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
68 changes: 56 additions & 12 deletions internal/cmd/combine_prs.go
Original file line number Diff line number Diff line change
@@ -12,7 +12,15 @@ import (
"github.com/github/gh-combine/internal/github"
)

func CombinePRs(ctx context.Context, graphQlClient *api.GraphQLClient, restClient *api.RESTClient, repo github.Repo, pulls github.Pulls) error {
// Updated RESTClientInterface to match the method signatures of api.RESTClient
type RESTClientInterface interface {
Post(endpoint string, body io.Reader, response interface{}) error
Get(endpoint string, response interface{}) error
Delete(endpoint string, response interface{}) error
Patch(endpoint string, body io.Reader, response interface{}) error
}

func CombinePRs(ctx context.Context, graphQlClient *api.GraphQLClient, restClient RESTClientInterface, repo github.Repo, pulls github.Pulls) error {
// Define the combined branch name
workingBranchName := combineBranchName + workingBranchSuffix

@@ -87,7 +95,7 @@ func CombinePRs(ctx context.Context, graphQlClient *api.GraphQLClient, restClien
// Create the combined PR
prBody := generatePRBody(combinedPRs, mergeFailedPRs)
prTitle := "Combined PRs"
err = createPullRequest(ctx, restClient, repo, prTitle, combineBranchName, repoDefaultBranch, prBody)
err = createPullRequest(ctx, restClient, repo, prTitle, combineBranchName, repoDefaultBranch, prBody, addLabels, addAssignees)
if err != nil {
return fmt.Errorf("failed to create combined PR: %w", err)
}
@@ -102,7 +110,7 @@ func isMergeConflictError(err error) bool {
}

// Find the default branch of a repository
func getDefaultBranch(ctx context.Context, client *api.RESTClient, repo github.Repo) (string, error) {
func getDefaultBranch(ctx context.Context, client RESTClientInterface, repo github.Repo) (string, error) {
var repoInfo struct {
DefaultBranch string `json:"default_branch"`
}
@@ -115,7 +123,7 @@ func getDefaultBranch(ctx context.Context, client *api.RESTClient, repo github.R
}

// Get the SHA of a given branch
func getBranchSHA(ctx context.Context, client *api.RESTClient, repo github.Repo, branch string) (string, error) {
func getBranchSHA(ctx context.Context, client RESTClientInterface, repo github.Repo, branch string) (string, error) {
var ref struct {
Object struct {
SHA string `json:"sha"`
@@ -148,13 +156,13 @@ func generatePRBody(combinedPRs, mergeFailedPRs []string) string {
}

// deleteBranch deletes a branch in the repository
func deleteBranch(ctx context.Context, client *api.RESTClient, repo github.Repo, branch string) error {
func deleteBranch(ctx context.Context, client RESTClientInterface, repo github.Repo, branch string) error {
endpoint := fmt.Sprintf("repos/%s/%s/git/refs/heads/%s", repo.Owner, repo.Repo, branch)
return client.Delete(endpoint, nil)
}

// createBranch creates a new branch in the repository
func createBranch(ctx context.Context, client *api.RESTClient, repo github.Repo, branch, sha string) error {
func createBranch(ctx context.Context, client RESTClientInterface, repo github.Repo, branch, sha string) error {
endpoint := fmt.Sprintf("repos/%s/%s/git/refs", repo.Owner, repo.Repo)
payload := map[string]string{
"ref": "refs/heads/" + branch,
@@ -168,7 +176,7 @@ func createBranch(ctx context.Context, client *api.RESTClient, repo github.Repo,
}

// mergeBranch merges a branch into the base branch
func mergeBranch(ctx context.Context, client *api.RESTClient, repo github.Repo, base, head string) error {
func mergeBranch(ctx context.Context, client RESTClientInterface, repo github.Repo, base, head string) error {
endpoint := fmt.Sprintf("repos/%s/%s/merges", repo.Owner, repo.Repo)
payload := map[string]string{
"base": base,
@@ -182,7 +190,7 @@ func mergeBranch(ctx context.Context, client *api.RESTClient, repo github.Repo,
}

// updateRef updates a branch to point to the latest commit of another branch
func updateRef(ctx context.Context, client *api.RESTClient, repo github.Repo, branch, sourceBranch string) error {
func updateRef(ctx context.Context, client RESTClientInterface, repo github.Repo, branch, sourceBranch string) error {
// Get the SHA of the source branch
var ref struct {
Object struct {
@@ -208,20 +216,56 @@ func updateRef(ctx context.Context, client *api.RESTClient, repo github.Repo, br
return client.Patch(endpoint, body, nil)
}

// createPullRequest creates a new pull request
func createPullRequest(ctx context.Context, client *api.RESTClient, repo github.Repo, title, head, base, body string) error {
func createPullRequest(ctx context.Context, client RESTClientInterface, repo github.Repo, title, head, base, body string, labels, assignees []string) error {
endpoint := fmt.Sprintf("repos/%s/%s/pulls", repo.Owner, repo.Repo)
payload := map[string]string{
payload := map[string]interface{}{
"title": title,
"head": head,
"base": base,
"body": body,
}

requestBody, err := encodePayload(payload)
if err != nil {
return fmt.Errorf("failed to encode payload: %w", err)
}
return client.Post(endpoint, requestBody, nil)

// Create the pull request
var prResponse struct {
Number int `json:"number"`
}
err = client.Post(endpoint, requestBody, &prResponse)
if err != nil {
return fmt.Errorf("failed to create pull request: %w", err)
}

// Add labels if provided
if len(labels) > 0 {
labelsEndpoint := fmt.Sprintf("repos/%s/%s/issues/%d/labels", repo.Owner, repo.Repo, prResponse.Number)
labelsPayload, err := encodePayload(map[string][]string{"labels": labels})
if err != nil {
return fmt.Errorf("failed to encode labels payload: %w", err)
}
err = client.Post(labelsEndpoint, labelsPayload, nil)
if err != nil {
return fmt.Errorf("failed to add labels: %w", err)
}
}

// Add assignees if provided
if len(assignees) > 0 {
assigneesEndpoint := fmt.Sprintf("repos/%s/%s/issues/%d/assignees", repo.Owner, repo.Repo, prResponse.Number)
assigneesPayload, err := encodePayload(map[string][]string{"assignees": assignees})
if err != nil {
return fmt.Errorf("failed to encode assignees payload: %w", err)
}
err = client.Post(assigneesEndpoint, assigneesPayload, nil)
if err != nil {
return fmt.Errorf("failed to add assignees: %w", err)
}
}

return nil
}

// encodePayload encodes a payload as JSON and returns an io.Reader
39 changes: 39 additions & 0 deletions internal/cmd/combine_prs_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package cmd

import (
"context"
"strings"
"testing"

"github.com/github/gh-combine/internal/github"
"github.com/stretchr/testify/assert"
)

func TestCreatePullRequest(t *testing.T) {
client := &MockRESTClient{
PostFunc: func(endpoint string, body interface{}, response interface{}) error {
if strings.Contains(endpoint, "/pulls") {
if prResponse, ok := response.(*struct{ Number int }); ok {
prResponse.Number = 123 // Mock PR number
}
} else if strings.Contains(endpoint, "/labels") || strings.Contains(endpoint, "/assignees") {
// Mock successful label/assignee addition
return nil
}
return nil
},
}
repo := github.Repo{
Owner: "test-owner",
Repo: "test-repo",
}
title := "Test PR"
head := "test-branch"
base := "main"
body := "This is a test PR."
labels := []string{"bug", "enhancement"}
assignees := []string{"octocat", "hubot"}

err := createPullRequest(context.Background(), client, repo, title, head, base, body, labels, assignees)
assert.NoError(t, err)
}
64 changes: 64 additions & 0 deletions internal/cmd/mock_restclient.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package cmd

import (
"context"
"io"
"net/http"
)

type MockRESTClient struct {
PostFunc func(endpoint string, body interface{}, response interface{}) error
GetFunc func(endpoint string, response interface{}) error
DeleteFunc func(endpoint string, response interface{}) error
PatchFunc func(endpoint string, body io.Reader, response interface{}) error
}

// Updated the Post method to match the RESTClientInterface signature
func (m *MockRESTClient) Post(endpoint string, body io.Reader, response interface{}) error {
if m.PostFunc != nil {
return m.PostFunc(endpoint, body, response)
}
return nil
}

func (m *MockRESTClient) Get(endpoint string, response interface{}) error {
if m.GetFunc != nil {
return m.GetFunc(endpoint, response)
}
return nil
}

func (m *MockRESTClient) Delete(endpoint string, response interface{}) error {
if m.DeleteFunc != nil {
return m.DeleteFunc(endpoint, response)
}
return nil
}

// Updated the Patch method to match the RESTClientInterface signature
func (m *MockRESTClient) Patch(endpoint string, body io.Reader, response interface{}) error {
if m.PatchFunc != nil {
return m.PatchFunc(endpoint, body, response)
}
return nil
}

func (m *MockRESTClient) RequestWithContext(ctx context.Context, method string, path string, body io.Reader) (*http.Response, error) {
return nil, nil
}

func (m *MockRESTClient) Request(method string, path string, body io.Reader) (*http.Response, error) {
return nil, nil
}

func (m *MockRESTClient) DoWithContext(ctx context.Context, method string, path string, body io.Reader, response interface{}) error {
return nil
}

func (m *MockRESTClient) Do(method string, path string, body io.Reader, response interface{}) error {
return nil
}

func (m *MockRESTClient) Put(path string, body io.Reader, resp interface{}) error {
return nil
}
7 changes: 6 additions & 1 deletion internal/cmd/root.go
Original file line number Diff line number Diff line change
@@ -289,8 +289,13 @@ func processRepository(ctx context.Context, client *api.RESTClient, graphQlClien

Logger.Debug("Matched PRs", "repo", repo, "count", len(matchedPRs))

// Wrap the *api.RESTClient to implement RESTClientInterface
restClientWrapper := struct {
RESTClientInterface
}{client}

// Combine the PRs
err = CombinePRs(ctx, graphQlClient, client, repo, matchedPRs)
err = CombinePRs(ctx, graphQlClient, restClientWrapper, repo, matchedPRs)
if err != nil {
return fmt.Errorf("failed to combine PRs: %w", err)
}
15 changes: 15 additions & 0 deletions vendor/github.com/davecgh/go-spew/LICENSE

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Oops, something went wrong.
Loading
Oops, something went wrong.