From 4df2d3d8035bae8af77bc371c1e9517e6a9e611f Mon Sep 17 00:00:00 2001 From: William Martin <williammartin@github.com> Date: Fri, 11 Apr 2025 22:20:59 +0200 Subject: [PATCH] WIP: schema generation --- cmd/tools2md/main.go | 28 + go.mod | 2 +- go.sum | 4 +- .../testdata/TestFullSchema.golden.md | 235 +++ ...tMultipleCategoriesMultipleTools.golden.md | 14 + .../testdata/TestOneCategoryOneTool.golden.md | 6 + .../TestToolsWithProperties.golden.md | 7 + internal/tools2md/tools2md.go | 174 +++ internal/tools2md/tools2md_test.go | 253 ++++ pkg/github/code_scanning.go | 170 ++- pkg/github/code_scanning_test.go | 44 +- pkg/github/issues.go | 825 ++++++----- pkg/github/issues_test.go | 187 ++- pkg/github/pullrequests.go | 1293 +++++++++-------- pkg/github/pullrequests_test.go | 267 ++-- pkg/github/repositories.go | 927 ++++++------ pkg/github/repositories_test.go | 180 ++- pkg/github/search.go | 283 ++-- pkg/github/search_test.go | 69 +- pkg/github/server.go | 200 ++- pkg/github/server_test.go | 16 +- schema.md | 235 +++ 22 files changed, 3310 insertions(+), 2109 deletions(-) create mode 100644 cmd/tools2md/main.go create mode 100644 internal/tools2md/testdata/TestFullSchema.golden.md create mode 100644 internal/tools2md/testdata/TestMultipleCategoriesMultipleTools.golden.md create mode 100644 internal/tools2md/testdata/TestOneCategoryOneTool.golden.md create mode 100644 internal/tools2md/testdata/TestToolsWithProperties.golden.md create mode 100644 internal/tools2md/tools2md.go create mode 100644 internal/tools2md/tools2md_test.go create mode 100644 schema.md diff --git a/cmd/tools2md/main.go b/cmd/tools2md/main.go new file mode 100644 index 000000000..4c5cf3c36 --- /dev/null +++ b/cmd/tools2md/main.go @@ -0,0 +1,28 @@ +package main + +import ( + "flag" + "fmt" + "os" + + "github.com/github/github-mcp-server/internal/tools2md" + "github.com/github/github-mcp-server/pkg/github" + "github.com/github/github-mcp-server/pkg/translations" +) + +var filepath = flag.String("filepath", "schema.md", "filepath to schema file") + +func main() { + tools := github.DefaultTools(translations.NullTranslationHelper) + md := tools2md.Convert(tools) + if *filepath == "" { + panic("filepath cannot be empty") + } + + err := os.WriteFile(*filepath, []byte(md), 0600) + if err != nil { + panic(err) + } + + fmt.Println("Schema file generated successfully at", *filepath) +} diff --git a/go.mod b/go.mod index 858690cde..2ea9d7ae3 100644 --- a/go.mod +++ b/go.mod @@ -57,7 +57,7 @@ require ( go.opentelemetry.io/otel/trace v1.35.0 // indirect go.opentelemetry.io/proto/otlp v1.5.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/sys v0.31.0 // indirect + golang.org/x/sys v0.32.0 // indirect golang.org/x/text v0.23.0 // indirect golang.org/x/time v0.5.0 // indirect google.golang.org/protobuf v1.36.5 // indirect diff --git a/go.sum b/go.sum index 19d368ded..96c331dda 100644 --- a/go.sum +++ b/go.sum @@ -145,8 +145,8 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= +golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= diff --git a/internal/tools2md/testdata/TestFullSchema.golden.md b/internal/tools2md/testdata/TestFullSchema.golden.md new file mode 100644 index 000000000..0f1fb55e6 --- /dev/null +++ b/internal/tools2md/testdata/TestFullSchema.golden.md @@ -0,0 +1,235 @@ +## Tools + +### Users + +- **get_me** - Get details of the authenticated GitHub user. Use this when a request include "me", "my"... + - `reason`: Optional: reason the session was created (string, optional) + +### Issues + +- **get_issue** - Get details of a specific issue in a GitHub repository + - `issue_number`: The number of the issue (number, required) + - `owner`: The owner of the repository (string, required) + - `repo`: The name of the repository (string, required) + +- **search_issues** - Search for issues and pull requests across GitHub repositories + - `order`: Sort order ('asc' or 'desc') (string, optional) + - `page`: Page number for pagination (min 1) (number, optional) + - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) + - `q`: Search query using GitHub issues search syntax (string, required) + - `sort`: Sort field (comments, reactions, created, etc.) (string, optional) + +- **list_issues** - List issues in a GitHub repository with filtering options + - `direction`: Sort direction ('asc', 'desc') (string, optional) + - `labels`: Filter by labels (array, optional) + - `owner`: Repository owner (string, required) + - `page`: Page number for pagination (min 1) (number, optional) + - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) + - `repo`: Repository name (string, required) + - `since`: Filter by date (ISO 8601 timestamp) (string, optional) + - `sort`: Sort by ('created', 'updated', 'comments') (string, optional) + - `state`: Filter by state ('open', 'closed', 'all') (string, optional) + +- **get_issue_comments** - Get comments for a GitHub issue + - `issue_number`: Issue number (number, required) + - `owner`: Repository owner (string, required) + - `page`: Page number (number, optional) + - `per_page`: Number of records per page (number, optional) + - `repo`: Repository name (string, required) + +- **create_issue** - Create a new issue in a GitHub repository + - `assignees`: Usernames to assign to this issue (array, optional) + - `body`: Issue body content (string, optional) + - `labels`: Labels to apply to this issue (array, optional) + - `milestone`: Milestone number (number, optional) + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + - `title`: Issue title (string, required) + +- **add_issue_comment** - Add a comment to an existing issue + - `body`: Comment text (string, required) + - `issue_number`: Issue number to comment on (number, required) + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + +- **update_issue** - Update an existing issue in a GitHub repository + - `assignees`: New assignees (array, optional) + - `body`: New description (string, optional) + - `issue_number`: Issue number to update (number, required) + - `labels`: New labels (array, optional) + - `milestone`: New milestone number (number, optional) + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + - `state`: New state ('open' or 'closed') (string, optional) + - `title`: New title (string, optional) + +### Pull Requests + +- **get_pull_request** - Get details of a specific pull request + - `owner`: Repository owner (string, required) + - `pullNumber`: Pull request number (number, required) + - `repo`: Repository name (string, required) + +- **list_pull_requests** - List and filter repository pull requests + - `base`: Filter by base branch (string, optional) + - `direction`: Sort direction ('asc', 'desc') (string, optional) + - `head`: Filter by head user/org and branch (string, optional) + - `owner`: Repository owner (string, required) + - `page`: Page number for pagination (min 1) (number, optional) + - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) + - `repo`: Repository name (string, required) + - `sort`: Sort by ('created', 'updated', 'popularity', 'long-running') (string, optional) + - `state`: Filter by state ('open', 'closed', 'all') (string, optional) + +- **get_pull_request_files** - Get the list of files changed in a pull request + - `owner`: Repository owner (string, required) + - `pullNumber`: Pull request number (number, required) + - `repo`: Repository name (string, required) + +- **get_pull_request_status** - Get the combined status of all status checks for a pull request + - `owner`: Repository owner (string, required) + - `pullNumber`: Pull request number (number, required) + - `repo`: Repository name (string, required) + +- **get_pull_request_comments** - Get the review comments on a pull request + - `owner`: Repository owner (string, required) + - `pullNumber`: Pull request number (number, required) + - `repo`: Repository name (string, required) + +- **get_pull_request_reviews** - Get the reviews on a pull request + - `owner`: Repository owner (string, required) + - `pullNumber`: Pull request number (number, required) + - `repo`: Repository name (string, required) + +- **merge_pull_request** - Merge a pull request + - `commit_message`: Extra detail for merge commit (string, optional) + - `commit_title`: Title for merge commit (string, optional) + - `merge_method`: Merge method ('merge', 'squash', 'rebase') (string, optional) + - `owner`: Repository owner (string, required) + - `pullNumber`: Pull request number (number, required) + - `repo`: Repository name (string, required) + +- **update_pull_request_branch** - Update a pull request branch with the latest changes from the base branch + - `expectedHeadSha`: The expected SHA of the pull request's HEAD ref (string, optional) + - `owner`: Repository owner (string, required) + - `pullNumber`: Pull request number (number, required) + - `repo`: Repository name (string, required) + +- **create_pull_request_review** - Create a review on a pull request + - `body`: Review comment text (string, optional) + - `comments`: Line-specific comments array of objects to place comments on pull request changes. Requires path and body. For line comments use line or position. For multi-line comments use start_line and line with optional side parameters. (array, optional) + - `commitId`: SHA of commit to review (string, optional) + - `event`: Review action ('APPROVE', 'REQUEST_CHANGES', 'COMMENT') (string, required) + - `owner`: Repository owner (string, required) + - `pullNumber`: Pull request number (number, required) + - `repo`: Repository name (string, required) + +- **create_pull_request** - Create a new pull request in a GitHub repository + - `base`: Branch to merge into (string, required) + - `body`: PR description (string, optional) + - `draft`: Create as draft PR (boolean, optional) + - `head`: Branch containing changes (string, required) + - `maintainer_can_modify`: Allow maintainer edits (boolean, optional) + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + - `title`: PR title (string, required) + +- **update_pull_request** - Update an existing pull request in a GitHub repository + - `base`: New base branch name (string, optional) + - `body`: New description (string, optional) + - `maintainer_can_modify`: Allow maintainer edits (boolean, optional) + - `owner`: Repository owner (string, required) + - `pullNumber`: Pull request number to update (number, required) + - `repo`: Repository name (string, required) + - `state`: New state ('open' or 'closed') (string, optional) + - `title`: New title (string, optional) + +### Repositories + +- **get_file_contents** - Get the contents of a file or directory from a GitHub repository + - `branch`: Branch to get contents from (string, optional) + - `owner`: Repository owner (username or organization) (string, required) + - `path`: Path to file/directory (string, required) + - `repo`: Repository name (string, required) + +- **get_commit** - Get details for a commit from a GitHub repository + - `owner`: Repository owner (string, required) + - `page`: Page number for pagination (min 1) (number, optional) + - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) + - `repo`: Repository name (string, required) + - `sha`: Commit SHA, branch name, or tag name (string, required) + +- **list_commits** - Get list of commits of a branch in a GitHub repository + - `owner`: Repository owner (string, required) + - `page`: Page number for pagination (min 1) (number, optional) + - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) + - `repo`: Repository name (string, required) + - `sha`: Branch name (string, optional) + +- **create_or_update_file** - Create or update a single file in a GitHub repository + - `branch`: Branch to create/update the file in (string, required) + - `content`: Content of the file (string, required) + - `message`: Commit message (string, required) + - `owner`: Repository owner (username or organization) (string, required) + - `path`: Path where to create/update the file (string, required) + - `repo`: Repository name (string, required) + - `sha`: SHA of file being replaced (for updates) (string, optional) + +- **create_repository** - Create a new GitHub repository in your account + - `autoInit`: Initialize with README (boolean, optional) + - `description`: Repository description (string, optional) + - `name`: Repository name (string, required) + - `private`: Whether repo should be private (boolean, optional) + +- **fork_repository** - Fork a GitHub repository to your account or specified organization + - `organization`: Organization to fork to (string, optional) + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + +- **create_branch** - Create a new branch in a GitHub repository + - `branch`: Name for new branch (string, required) + - `from_branch`: Source branch (defaults to repo default) (string, optional) + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + +- **push_files** - Push multiple files to a GitHub repository in a single commit + - `branch`: Branch to push to (string, required) + - `files`: Array of file objects to push, each object with path (string) and content (string) (array, required) + - `message`: Commit message (string, required) + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + +### Search + +- **search_repositories** - Search for GitHub repositories + - `page`: Page number for pagination (min 1) (number, optional) + - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) + - `query`: Search query (string, required) + +- **search_code** - Search for code across GitHub repositories + - `order`: Sort order ('asc' or 'desc') (string, optional) + - `page`: Page number for pagination (min 1) (number, optional) + - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) + - `q`: Search query using GitHub code search syntax (string, required) + - `sort`: Sort field ('indexed' only) (string, optional) + +- **search_users** - Search for GitHub users + - `order`: Sort order ('asc' or 'desc') (string, optional) + - `page`: Page number for pagination (min 1) (number, optional) + - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) + - `q`: Search query using GitHub users search syntax (string, required) + - `sort`: Sort field (followers, repositories, joined) (string, optional) + +### Code Scanning + +- **get_code_scanning_alert** - Get details of a specific code scanning alert in a GitHub repository. + - `alertNumber`: The number of the alert. (number, required) + - `owner`: The owner of the repository. (string, required) + - `repo`: The name of the repository. (string, required) + +- **list_code_scanning_alerts** - List code scanning alerts in a GitHub repository. + - `owner`: The owner of the repository. (string, required) + - `ref`: The Git reference for the results you want to list. (string, optional) + - `repo`: The name of the repository. (string, required) + - `severity`: Only code scanning alerts with this severity will be returned. Possible values are: critical, high, medium, low, warning, note, error. (string, optional) + - `state`: State of the code scanning alerts to list. Set to closed to list only closed code scanning alerts. Default: open (string, optional) diff --git a/internal/tools2md/testdata/TestMultipleCategoriesMultipleTools.golden.md b/internal/tools2md/testdata/TestMultipleCategoriesMultipleTools.golden.md new file mode 100644 index 000000000..40fcf3436 --- /dev/null +++ b/internal/tools2md/testdata/TestMultipleCategoriesMultipleTools.golden.md @@ -0,0 +1,14 @@ +## Tools + +### Test Category 1 + +- **test_tool_1** - A tool for testing + - No parameters required + +- **test_tool_3** - Yet another tool for testing + - No parameters required + +### Test Category 2 + +- **test_tool_2** - Another tool for testing + - No parameters required diff --git a/internal/tools2md/testdata/TestOneCategoryOneTool.golden.md b/internal/tools2md/testdata/TestOneCategoryOneTool.golden.md new file mode 100644 index 000000000..0d2bb2a9a --- /dev/null +++ b/internal/tools2md/testdata/TestOneCategoryOneTool.golden.md @@ -0,0 +1,6 @@ +## Tools + +### Test Category + +- **test_tool** - A tool for testing + - No parameters required diff --git a/internal/tools2md/testdata/TestToolsWithProperties.golden.md b/internal/tools2md/testdata/TestToolsWithProperties.golden.md new file mode 100644 index 000000000..95c89806d --- /dev/null +++ b/internal/tools2md/testdata/TestToolsWithProperties.golden.md @@ -0,0 +1,7 @@ +## Tools + +### Test Category + +- **test_tool** - A tool for testing + - `prop_1`: A test property (string, required) + - `prop_2`: Another test property (number, optional) diff --git a/internal/tools2md/tools2md.go b/internal/tools2md/tools2md.go new file mode 100644 index 000000000..420ed4d30 --- /dev/null +++ b/internal/tools2md/tools2md.go @@ -0,0 +1,174 @@ +package tools2md + +import ( + "fmt" + "slices" + "strings" + + "github.com/github/github-mcp-server/pkg/github" +) + +func byCategoryPriority(a, b categorisedTools) int { + var priorityMap = map[github.Category]int{ + github.CategoryUsers: 0, + github.CategoryIssues: 1, + github.CategoryPullRequests: 2, + github.CategoryRepositories: 3, + github.CategorySearch: 4, + github.CategoryCodeScanning: 5, + } + + pa, oka := priorityMap[a.category] + pb, okb := priorityMap[b.category] + + // if both categories are not in the map, sort by our priorities + if oka && okb { + return pa - pb + } + + // If either one was in the map then priortise that one. + if oka { + return -1 + } + + if okb { + return 1 + } + + // if neither were in the map, sort alphabetically, which helps with test ordering. + return strings.Compare(string(a.category), string(b.category)) +} + +type categorisedToolMap map[github.Category][]github.Tool + +func (m categorisedToolMap) add(tool github.Tool) { + m[tool.Category] = append(m[tool.Category], tool) +} + +type categorisedTools struct { + category github.Category + tools []github.Tool +} + +type sortedCategorisedTools []categorisedTools + +func (m categorisedToolMap) sorted() sortedCategorisedTools { + var out sortedCategorisedTools + for category, tools := range m { + out = append(out, categorisedTools{ + category: category, + tools: tools, + }) + } + + slices.SortStableFunc(out, byCategoryPriority) + return out +} + +// Replace a lot of this with a Go template. +// Much TDD. +func Convert(tools github.Tools) string { + if len(tools) == 0 { + return "" + } + + toolMap := categorisedToolMap{} + for _, tool := range tools { + toolMap.add(tool) + } + + var md markdownBuilder + + md.h2("Tools") + md.newline() + + sortedToolMap := toolMap.sorted() + for i, categorisedTools := range sortedToolMap { + md.h3(string(categorisedTools.category)) + md.newline() + + for j, tool := range categorisedTools.tools { + md.textf("- %s - %s", bold(tool.Definition.Name), tool.Definition.Description) + md.newline() + + if len(tool.Definition.InputSchema.Properties) == 0 { + md.text(" - No parameters required") + md.newline() + } else { + // order the properties alphabetically to maintain a consistent order + // maybe in future do some kind of grouping like pagination together. + var propNames []string + for propName := range tool.Definition.InputSchema.Properties { + propNames = append(propNames, propName) + } + slices.Sort(propNames) + + for _, propName := range propNames { + prop := tool.Definition.InputSchema.Properties[propName] + propSchema := prop.(map[string]any) + required := func() string { + if slices.Contains(tool.Definition.InputSchema.Required, propName) { + return "required" + } + return "optional" + }() + + md.textf(" - %s: %s (%s, %s)", code(propName), propSchema["description"], propSchema["type"], required) + md.newline() + } + } + + // if not the last tool in the category, add a newline + if j < len(categorisedTools.tools)-1 { + md.newline() + } + } + + // if not the last category, add a newline + if i < len(sortedToolMap)-1 { + md.newline() + } + } + + return md.String() +} + +type markdownBuilder struct { + content strings.Builder +} + +func (b *markdownBuilder) h2(text string) { + b.content.WriteString(fmt.Sprintf("## %s\n", text)) +} + +func (b *markdownBuilder) h3(text string) { + b.content.WriteString(fmt.Sprintf("### %s\n", text)) +} + +func (b *markdownBuilder) bold(text string) { + b.content.WriteString(fmt.Sprintf("**%s**", text)) +} + +func (b *markdownBuilder) newline() { + b.content.WriteString("\n") +} + +func (b *markdownBuilder) text(text string) { + b.content.WriteString(text) +} + +func (b *markdownBuilder) textf(format string, args ...any) { + b.content.WriteString(fmt.Sprintf(format, args...)) +} + +func (b *markdownBuilder) String() string { + return b.content.String() +} + +func bold(text string) string { + return fmt.Sprintf("**%s**", text) +} + +func code(text string) string { + return fmt.Sprintf("`%s`", text) +} diff --git a/internal/tools2md/tools2md_test.go b/internal/tools2md/tools2md_test.go new file mode 100644 index 000000000..a47a539bc --- /dev/null +++ b/internal/tools2md/tools2md_test.go @@ -0,0 +1,253 @@ +package tools2md_test + +import ( + "flag" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/github/github-mcp-server/internal/tools2md" + "github.com/github/github-mcp-server/pkg/github" + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-cmp/cmp" + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/require" +) + +// func TestSimple(t *testing.T) { +// server := github.NewServer( +// func(_ context.Context) (*gogithub.Client, error) { +// panic("not implemented") +// }, +// "0.0.1", +// false, +// translations.NullTranslationHelper, +// ) + +// // TODO: handle pagination +// // as per https://github.com/mark3labs/mcp-go/blob/cc777fcbf3176d0e76634f58047707d1f666cae8/client/stdio.go#L464 +// request := mcp.JSONRPCRequest{ +// JSONRPC: mcp.JSONRPC_VERSION, +// ID: "1", +// Request: mcp.Request{ +// Method: "tools/list", +// }, +// Params: mcp.PaginatedRequest{ +// Request: mcp.Request{}, +// }, +// } + +// raw, err := json.Marshal(request) +// if err != nil { +// t.Fatalf("failed to marshal request: %v", err) +// } + +// message := server.HandleMessage( +// context.Background(), +// raw, +// ) + +// response, ok := message.(mcp.JSONRPCResponse) +// require.True(t, ok, "expected JSONRPCResponse, got %T", message) + +// listToolsResult, ok := response.Result.(mcp.ListToolsResult) +// require.True(t, ok, "expected ListToolsResult, got %T", response.Result) + +// listToolsResult.Tools[0] + +// } + +var update = flag.Bool("update", false, "update .golden files") +var diffMd = flag.Bool("diffmd", false, "on failure create a .golden.diff.md file for external comparison") + +func TestNoToolsReturnsEmptyString(t *testing.T) { + tools := github.Tools{} + + result := tools2md.Convert(tools) + require.Empty(t, result, "expected empty string when converting no tools") +} + +func TestOneCategoryOneTool(t *testing.T) { + goldenFilePath := goldenFilePath(t.Name()) + + tools := github.Tools{ + { + Definition: mcp.Tool{ + Name: "test_tool", + Description: "A tool for testing", + }, + Category: "Test Category", + }, + } + + md := tools2md.Convert(tools) + if *update { + require.NoError( + t, + os.WriteFile(goldenFilePath, []byte(md), 0600), + "failed to update golden file", + ) + } + + golden, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "failed to read golden file") + + if diff := cmp.Diff(string(golden), md); diff != "" { + if *diffMd { + diffFilePath := strings.ReplaceAll(goldenFilePath, ".golden.md", ".golden.diff.md") + require.NoError( + t, + os.WriteFile(diffFilePath, []byte(md), 0600), + "failed to update diff file", + ) + } + + t.Errorf("golden file mismatch\n%s", diff) + } +} + +func TestMultipleCategoriesMultipleTools(t *testing.T) { + goldenFilePath := goldenFilePath(t.Name()) + + tools := github.Tools{ + { + Definition: mcp.Tool{ + Name: "test_tool_1", + Description: "A tool for testing", + }, + Category: "Test Category 1", + }, + { + Definition: mcp.Tool{ + Name: "test_tool_2", + Description: "Another tool for testing", + }, + Category: "Test Category 2", + }, + { + Definition: mcp.Tool{ + Name: "test_tool_3", + Description: "Yet another tool for testing", + }, + Category: "Test Category 1", + }, + } + + md := tools2md.Convert(tools) + if *update { + require.NoError( + t, + os.WriteFile(goldenFilePath, []byte(md), 0600), + "failed to update golden file", + ) + } + + golden, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "failed to read golden file") + + if diff := cmp.Diff(string(golden), md); diff != "" { + if *diffMd { + diffFilePath := strings.ReplaceAll(goldenFilePath, ".golden.md", ".golden.diff.md") + require.NoError( + t, + os.WriteFile(diffFilePath, []byte(md), 0600), + "failed to update diff file", + ) + } + + t.Errorf("golden file mismatch\n%s", diff) + } +} + +func TestToolsWithProperties(t *testing.T) { + goldenFilePath := goldenFilePath(t.Name()) + + tools := github.Tools{ + { + Definition: mcp.Tool{ + Name: "test_tool", + Description: "A tool for testing", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "prop_1": map[string]any{ + "description": "A test property", + "type": "string", + }, + "prop_2": map[string]any{ + "description": "Another test property", + "type": "number", + }, + }, + Required: []string{"prop_1"}, + }, + }, + Category: "Test Category", + }, + } + + md := tools2md.Convert(tools) + if *update { + require.NoError( + t, + os.WriteFile(goldenFilePath, []byte(md), 0600), + "failed to update golden file", + ) + } + + golden, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "failed to read golden file") + + if diff := cmp.Diff(string(golden), md); diff != "" { + if *diffMd { + diffFilePath := strings.ReplaceAll(goldenFilePath, ".golden.md", ".golden.diff.md") + require.NoError( + t, + os.WriteFile(diffFilePath, []byte(md), 0600), + "failed to update diff file", + ) + } + + t.Errorf("golden file mismatch\n%s", diff) + } +} + +func TestFullSchema(t *testing.T) { + goldenFilePath := goldenFilePath(t.Name()) + + tools := github.DefaultTools(translations.NullTranslationHelper) + + md := tools2md.Convert(tools) + if *update { + require.NoError( + t, + os.WriteFile(goldenFilePath, []byte(md), 0600), + "failed to update golden file", + ) + } + + golden, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "failed to read golden file") + + if diff := cmp.Diff(string(golden), md); diff != "" { + if *diffMd { + diffFilePath := strings.ReplaceAll(goldenFilePath, ".golden.md", ".golden.diff.md") + require.NoError( + t, + os.WriteFile(diffFilePath, []byte(md), 0600), + "failed to update diff file", + ) + } + + t.Errorf("golden file mismatch\n%s", diff) + } +} + +// In case we use subtests +func goldenFilePath(testName string) string { + return filepath.Join( + "testdata", + strings.ReplaceAll(testName+".golden.md", "/", "_"), + ) +} diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index 4fc029bf6..3474cff98 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -13,8 +13,10 @@ import ( "github.com/mark3labs/mcp-go/server" ) -func GetCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("get_code_scanning_alert", +func GetCodeScanningAlert(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "get_code_scanning_alert", mcp.WithDescription(t("TOOL_GET_CODE_SCANNING_ALERT_DESCRIPTION", "Get details of a specific code scanning alert in a GitHub repository.")), mcp.WithString("owner", mcp.Required(), @@ -29,50 +31,57 @@ func GetCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelpe mcp.Description("The number of the alert."), ), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - alertNumber, err := RequiredInt(request, "alertNumber") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + alertNumber, err := RequiredInt(request, "alertNumber") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - alert, resp, err := client.CodeScanning.GetAlert(ctx, owner, repo, int64(alertNumber)) - if err != nil { - return nil, fmt.Errorf("failed to get alert: %w", err) - } - defer func() { _ = resp.Body.Close() }() + alert, resp, err := client.CodeScanning.GetAlert(ctx, owner, repo, int64(alertNumber)) + if err != nil { + return nil, fmt.Errorf("failed to get alert: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get alert: %s", string(body))), nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(alert) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal alert: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to get alert: %s", string(body))), nil - } - r, err := json.Marshal(alert) - if err != nil { - return nil, fmt.Errorf("failed to marshal alert: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: ReadOnly, + Category: CategoryCodeScanning, + } } -func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("list_code_scanning_alerts", +func ListCodeScanningAlerts(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "list_code_scanning_alerts", mcp.WithDescription(t("TOOL_LIST_CODE_SCANNING_ALERTS_DESCRIPTION", "List code scanning alerts in a GitHub repository.")), mcp.WithString("owner", mcp.Required(), @@ -93,51 +102,56 @@ func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHel mcp.Description("Only code scanning alerts with this severity will be returned. Possible values are: critical, high, medium, low, warning, note, error."), ), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - ref, err := OptionalParam[string](request, "ref") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - state, err := OptionalParam[string](request, "state") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - severity, err := OptionalParam[string](request, "severity") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + ref, err := OptionalParam[string](request, "ref") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + state, err := OptionalParam[string](request, "state") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + severity, err := OptionalParam[string](request, "severity") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - alerts, resp, err := client.CodeScanning.ListAlertsForRepo(ctx, owner, repo, &github.AlertListOptions{Ref: ref, State: state, Severity: severity}) - if err != nil { - return nil, fmt.Errorf("failed to list alerts: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + alerts, resp, err := client.CodeScanning.ListAlertsForRepo(ctx, owner, repo, &github.AlertListOptions{Ref: ref, State: state, Severity: severity}) + if err != nil { + return nil, fmt.Errorf("failed to list alerts: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to list alerts: %s", string(body))), nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(alerts) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal alerts: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to list alerts: %s", string(body))), nil - } - r, err := json.Marshal(alerts) - if err != nil { - return nil, fmt.Errorf("failed to marshal alerts: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: ReadOnly, + Category: CategoryCodeScanning, + } } diff --git a/pkg/github/code_scanning_test.go b/pkg/github/code_scanning_test.go index c9895e269..1a6831320 100644 --- a/pkg/github/code_scanning_test.go +++ b/pkg/github/code_scanning_test.go @@ -15,15 +15,14 @@ import ( func Test_GetCodeScanningAlert(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetCodeScanningAlert(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool := GetCodeScanningAlert(translations.NullTranslationHelper) - assert.Equal(t, "get_code_scanning_alert", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "alertNumber") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "alertNumber"}) + assert.Equal(t, "get_code_scanning_alert", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "alertNumber") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo", "alertNumber"}) // Setup mock alert for success case mockAlert := &github.Alert{ @@ -82,13 +81,13 @@ func Test_GetCodeScanningAlert(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetCodeScanningAlert(stubGetClientFn(client), translations.NullTranslationHelper) + tool := GetCodeScanningAlert(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -117,17 +116,16 @@ func Test_GetCodeScanningAlert(t *testing.T) { func Test_ListCodeScanningAlerts(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListCodeScanningAlerts(stubGetClientFn(mockClient), translations.NullTranslationHelper) - - assert.Equal(t, "list_code_scanning_alerts", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "ref") - assert.Contains(t, tool.InputSchema.Properties, "state") - assert.Contains(t, tool.InputSchema.Properties, "severity") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"}) + tool := ListCodeScanningAlerts(translations.NullTranslationHelper) + + assert.Equal(t, "list_code_scanning_alerts", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "ref") + assert.Contains(t, tool.Definition.InputSchema.Properties, "state") + assert.Contains(t, tool.Definition.InputSchema.Properties, "severity") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo"}) // Setup mock alerts for success case mockAlerts := []*github.Alert{ @@ -201,13 +199,13 @@ func Test_ListCodeScanningAlerts(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListCodeScanningAlerts(stubGetClientFn(client), translations.NullTranslationHelper) + tool := ListCodeScanningAlerts(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 16c34141c..5969e92a0 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -15,8 +15,10 @@ import ( ) // GetIssue creates a tool to get details of a specific issue in a GitHub repository. -func GetIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("get_issue", +func GetIssue(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "get_issue", mcp.WithDescription(t("TOOL_GET_ISSUE_DESCRIPTION", "Get details of a specific issue in a GitHub repository")), mcp.WithString("owner", mcp.Required(), @@ -31,50 +33,57 @@ func GetIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Description("The number of the issue"), ), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - issueNumber, err := RequiredInt(request, "issue_number") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + issueNumber, err := RequiredInt(request, "issue_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - issue, resp, err := client.Issues.Get(ctx, owner, repo, issueNumber) - if err != nil { - return nil, fmt.Errorf("failed to get issue: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + issue, resp, err := client.Issues.Get(ctx, owner, repo, issueNumber) + if err != nil { + return nil, fmt.Errorf("failed to get issue: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get issue: %s", string(body))), nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(issue) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal issue: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to get issue: %s", string(body))), nil - } - r, err := json.Marshal(issue) - if err != nil { - return nil, fmt.Errorf("failed to marshal issue: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: ReadOnly, + Category: CategoryIssues, + } } // AddIssueComment creates a tool to add a comment to an issue. -func AddIssueComment(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("add_issue_comment", +func AddIssueComment(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "add_issue_comment", mcp.WithDescription(t("TOOL_ADD_ISSUE_COMMENT_DESCRIPTION", "Add a comment to an existing issue")), mcp.WithString("owner", mcp.Required(), @@ -93,58 +102,65 @@ func AddIssueComment(getClient GetClientFn, t translations.TranslationHelperFunc mcp.Description("Comment text"), ), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - issueNumber, err := RequiredInt(request, "issue_number") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - body, err := requiredParam[string](request, "body") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + issueNumber, err := RequiredInt(request, "issue_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + body, err := requiredParam[string](request, "body") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - comment := &github.IssueComment{ - Body: github.Ptr(body), - } + comment := &github.IssueComment{ + Body: github.Ptr(body), + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - createdComment, resp, err := client.Issues.CreateComment(ctx, owner, repo, issueNumber, comment) - if err != nil { - return nil, fmt.Errorf("failed to create comment: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + createdComment, resp, err := client.Issues.CreateComment(ctx, owner, repo, issueNumber, comment) + if err != nil { + return nil, fmt.Errorf("failed to create comment: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to create comment: %s", string(body))), nil + } - if resp.StatusCode != http.StatusCreated { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(createdComment) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to create comment: %s", string(body))), nil - } - r, err := json.Marshal(createdComment) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: Write, + Category: CategoryIssues, + } } // SearchIssues creates a tool to search for issues and pull requests. -func SearchIssues(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("search_issues", +func SearchIssues(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "search_issues", mcp.WithDescription(t("TOOL_SEARCH_ISSUES_DESCRIPTION", "Search for issues and pull requests across GitHub repositories")), mcp.WithString("q", mcp.Required(), @@ -172,63 +188,70 @@ func SearchIssues(getClient GetClientFn, t translations.TranslationHelperFunc) ( ), WithPagination(), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query, err := requiredParam[string](request, "q") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - sort, err := OptionalParam[string](request, "sort") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - order, err := OptionalParam[string](request, "order") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - pagination, err := OptionalPaginationParams(request) - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + query, err := requiredParam[string](request, "q") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + sort, err := OptionalParam[string](request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + order, err := OptionalParam[string](request, "order") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pagination, err := OptionalPaginationParams(request) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - opts := &github.SearchOptions{ - Sort: sort, - Order: order, - ListOptions: github.ListOptions{ - PerPage: pagination.perPage, - Page: pagination.page, - }, - } + opts := &github.SearchOptions{ + Sort: sort, + Order: order, + ListOptions: github.ListOptions{ + PerPage: pagination.perPage, + Page: pagination.page, + }, + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - result, resp, err := client.Search.Issues(ctx, query, opts) - if err != nil { - return nil, fmt.Errorf("failed to search issues: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + result, resp, err := client.Search.Issues(ctx, query, opts) + if err != nil { + return nil, fmt.Errorf("failed to search issues: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to search issues: %s", string(body))), nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(result) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to search issues: %s", string(body))), nil - } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: ReadOnly, + Category: CategoryIssues, + } } // CreateIssue creates a tool to create a new issue in a GitHub repository. -func CreateIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("create_issue", +func CreateIssue(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "create_issue", mcp.WithDescription(t("TOOL_CREATE_ISSUE_DESCRIPTION", "Create a new issue in a GitHub repository")), mcp.WithString("owner", mcp.Required(), @@ -265,88 +288,95 @@ func CreateIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (t mcp.Description("Milestone number"), ), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - title, err := requiredParam[string](request, "title") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + title, err := requiredParam[string](request, "title") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - // Optional parameters - body, err := OptionalParam[string](request, "body") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + // Optional parameters + body, err := OptionalParam[string](request, "body") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - // Get assignees - assignees, err := OptionalStringArrayParam(request, "assignees") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + // Get assignees + assignees, err := OptionalStringArrayParam(request, "assignees") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - // Get labels - labels, err := OptionalStringArrayParam(request, "labels") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + // Get labels + labels, err := OptionalStringArrayParam(request, "labels") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - // Get optional milestone - milestone, err := OptionalIntParam(request, "milestone") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + // Get optional milestone + milestone, err := OptionalIntParam(request, "milestone") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - var milestoneNum *int - if milestone != 0 { - milestoneNum = &milestone - } + var milestoneNum *int + if milestone != 0 { + milestoneNum = &milestone + } - // Create the issue request - issueRequest := &github.IssueRequest{ - Title: github.Ptr(title), - Body: github.Ptr(body), - Assignees: &assignees, - Labels: &labels, - Milestone: milestoneNum, - } + // Create the issue request + issueRequest := &github.IssueRequest{ + Title: github.Ptr(title), + Body: github.Ptr(body), + Assignees: &assignees, + Labels: &labels, + Milestone: milestoneNum, + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - issue, resp, err := client.Issues.Create(ctx, owner, repo, issueRequest) - if err != nil { - return nil, fmt.Errorf("failed to create issue: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + issue, resp, err := client.Issues.Create(ctx, owner, repo, issueRequest) + if err != nil { + return nil, fmt.Errorf("failed to create issue: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to create issue: %s", string(body))), nil + } - if resp.StatusCode != http.StatusCreated { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(issue) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to create issue: %s", string(body))), nil - } - r, err := json.Marshal(issue) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: Write, + Category: CategoryIssues, + } } // ListIssues creates a tool to list and filter repository issues -func ListIssues(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("list_issues", +func ListIssues(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "list_issues", mcp.WithDescription(t("TOOL_LIST_ISSUES_DESCRIPTION", "List issues in a GitHub repository with filtering options")), mcp.WithString("owner", mcp.Required(), @@ -381,90 +411,97 @@ func ListIssues(getClient GetClientFn, t translations.TranslationHelperFunc) (to ), WithPagination(), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - opts := &github.IssueListByRepoOptions{} + opts := &github.IssueListByRepoOptions{} - // Set optional parameters if provided - opts.State, err = OptionalParam[string](request, "state") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + // Set optional parameters if provided + opts.State, err = OptionalParam[string](request, "state") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - // Get labels - opts.Labels, err = OptionalStringArrayParam(request, "labels") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + // Get labels + opts.Labels, err = OptionalStringArrayParam(request, "labels") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - opts.Sort, err = OptionalParam[string](request, "sort") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + opts.Sort, err = OptionalParam[string](request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - opts.Direction, err = OptionalParam[string](request, "direction") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + opts.Direction, err = OptionalParam[string](request, "direction") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - since, err := OptionalParam[string](request, "since") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - if since != "" { - timestamp, err := parseISOTimestamp(since) + since, err := OptionalParam[string](request, "since") if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to list issues: %s", err.Error())), nil + return mcp.NewToolResultError(err.Error()), nil + } + if since != "" { + timestamp, err := parseISOTimestamp(since) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to list issues: %s", err.Error())), nil + } + opts.Since = timestamp } - opts.Since = timestamp - } - if page, ok := request.Params.Arguments["page"].(float64); ok { - opts.Page = int(page) - } + if page, ok := request.Params.Arguments["page"].(float64); ok { + opts.Page = int(page) + } - if perPage, ok := request.Params.Arguments["perPage"].(float64); ok { - opts.PerPage = int(perPage) - } + if perPage, ok := request.Params.Arguments["perPage"].(float64); ok { + opts.PerPage = int(perPage) + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - issues, resp, err := client.Issues.ListByRepo(ctx, owner, repo, opts) - if err != nil { - return nil, fmt.Errorf("failed to list issues: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + issues, resp, err := client.Issues.ListByRepo(ctx, owner, repo, opts) + if err != nil { + return nil, fmt.Errorf("failed to list issues: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to list issues: %s", string(body))), nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(issues) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal issues: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to list issues: %s", string(body))), nil - } - r, err := json.Marshal(issues) - if err != nil { - return nil, fmt.Errorf("failed to marshal issues: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: ReadOnly, + Category: CategoryIssues, + } } // UpdateIssue creates a tool to update an existing issue in a GitHub repository. -func UpdateIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("update_issue", +func UpdateIssue(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "update_issue", mcp.WithDescription(t("TOOL_UPDATE_ISSUE_DESCRIPTION", "Update an existing issue in a GitHub repository")), mcp.WithString("owner", mcp.Required(), @@ -508,105 +545,112 @@ func UpdateIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (t mcp.Description("New milestone number"), ), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - issueNumber, err := RequiredInt(request, "issue_number") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + issueNumber, err := RequiredInt(request, "issue_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - // Create the issue request with only provided fields - issueRequest := &github.IssueRequest{} + // Create the issue request with only provided fields + issueRequest := &github.IssueRequest{} - // Set optional parameters if provided - title, err := OptionalParam[string](request, "title") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - if title != "" { - issueRequest.Title = github.Ptr(title) - } + // Set optional parameters if provided + title, err := OptionalParam[string](request, "title") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if title != "" { + issueRequest.Title = github.Ptr(title) + } - body, err := OptionalParam[string](request, "body") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - if body != "" { - issueRequest.Body = github.Ptr(body) - } + body, err := OptionalParam[string](request, "body") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if body != "" { + issueRequest.Body = github.Ptr(body) + } - state, err := OptionalParam[string](request, "state") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - if state != "" { - issueRequest.State = github.Ptr(state) - } + state, err := OptionalParam[string](request, "state") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if state != "" { + issueRequest.State = github.Ptr(state) + } - // Get labels - labels, err := OptionalStringArrayParam(request, "labels") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - if len(labels) > 0 { - issueRequest.Labels = &labels - } + // Get labels + labels, err := OptionalStringArrayParam(request, "labels") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if len(labels) > 0 { + issueRequest.Labels = &labels + } - // Get assignees - assignees, err := OptionalStringArrayParam(request, "assignees") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - if len(assignees) > 0 { - issueRequest.Assignees = &assignees - } + // Get assignees + assignees, err := OptionalStringArrayParam(request, "assignees") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if len(assignees) > 0 { + issueRequest.Assignees = &assignees + } - milestone, err := OptionalIntParam(request, "milestone") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - if milestone != 0 { - milestoneNum := milestone - issueRequest.Milestone = &milestoneNum - } + milestone, err := OptionalIntParam(request, "milestone") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if milestone != 0 { + milestoneNum := milestone + issueRequest.Milestone = &milestoneNum + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - updatedIssue, resp, err := client.Issues.Edit(ctx, owner, repo, issueNumber, issueRequest) - if err != nil { - return nil, fmt.Errorf("failed to update issue: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + updatedIssue, resp, err := client.Issues.Edit(ctx, owner, repo, issueNumber, issueRequest) + if err != nil { + return nil, fmt.Errorf("failed to update issue: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to update issue: %s", string(body))), nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(updatedIssue) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to update issue: %s", string(body))), nil - } - r, err := json.Marshal(updatedIssue) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: Write, + Category: CategoryIssues, + } } // GetIssueComments creates a tool to get comments for a GitHub issue. -func GetIssueComments(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("get_issue_comments", +func GetIssueComments(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "get_issue_comments", mcp.WithDescription(t("TOOL_GET_ISSUE_COMMENTS_DESCRIPTION", "Get comments for a GitHub issue")), mcp.WithString("owner", mcp.Required(), @@ -627,60 +671,65 @@ func GetIssueComments(getClient GetClientFn, t translations.TranslationHelperFun mcp.Description("Number of records per page"), ), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - issueNumber, err := RequiredInt(request, "issue_number") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - page, err := OptionalIntParamWithDefault(request, "page", 1) - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - perPage, err := OptionalIntParamWithDefault(request, "per_page", 30) - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + issueNumber, err := RequiredInt(request, "issue_number") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + page, err := OptionalIntParamWithDefault(request, "page", 1) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + perPage, err := OptionalIntParamWithDefault(request, "per_page", 30) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - opts := &github.IssueListCommentsOptions{ - ListOptions: github.ListOptions{ - Page: page, - PerPage: perPage, - }, - } + opts := &github.IssueListCommentsOptions{ + ListOptions: github.ListOptions{ + Page: page, + PerPage: perPage, + }, + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - comments, resp, err := client.Issues.ListComments(ctx, owner, repo, issueNumber, opts) - if err != nil { - return nil, fmt.Errorf("failed to get issue comments: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + comments, resp, err := client.Issues.ListComments(ctx, owner, repo, issueNumber, opts) + if err != nil { + return nil, fmt.Errorf("failed to get issue comments: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get issue comments: %s", string(body))), nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(comments) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to get issue comments: %s", string(body))), nil - } - r, err := json.Marshal(comments) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: ReadOnly, + Category: CategoryIssues, + } } // parseISOTimestamp parses an ISO 8601 timestamp string into a time.Time object. diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index 61ca0ae7a..8f1799755 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -17,15 +17,14 @@ import ( func Test_GetIssue(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetIssue(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool := GetIssue(translations.NullTranslationHelper) - assert.Equal(t, "get_issue", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "issue_number") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "issue_number"}) + assert.Equal(t, "get_issue", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "issue_number") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo", "issue_number"}) // Setup mock issue for success case mockIssue := &github.Issue{ @@ -82,13 +81,13 @@ func Test_GetIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetIssue(stubGetClientFn(client), translations.NullTranslationHelper) + tool := GetIssue(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -113,16 +112,15 @@ func Test_GetIssue(t *testing.T) { func Test_AddIssueComment(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := AddIssueComment(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool := AddIssueComment(translations.NullTranslationHelper) - assert.Equal(t, "add_issue_comment", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "issue_number") - assert.Contains(t, tool.InputSchema.Properties, "body") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "issue_number", "body"}) + assert.Equal(t, "add_issue_comment", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "issue_number") + assert.Contains(t, tool.Definition.InputSchema.Properties, "body") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo", "issue_number", "body"}) // Setup mock comment for success case mockComment := &github.IssueComment{ @@ -185,7 +183,7 @@ func Test_AddIssueComment(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := AddIssueComment(stubGetClientFn(client), translations.NullTranslationHelper) + tool := AddIssueComment(translations.NullTranslationHelper) // Create call request request := mcp.CallToolRequest{ @@ -201,7 +199,7 @@ func Test_AddIssueComment(t *testing.T) { } // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -236,17 +234,16 @@ func Test_AddIssueComment(t *testing.T) { func Test_SearchIssues(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SearchIssues(stubGetClientFn(mockClient), translations.NullTranslationHelper) - - assert.Equal(t, "search_issues", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "q") - assert.Contains(t, tool.InputSchema.Properties, "sort") - assert.Contains(t, tool.InputSchema.Properties, "order") - assert.Contains(t, tool.InputSchema.Properties, "perPage") - assert.Contains(t, tool.InputSchema.Properties, "page") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"q"}) + tool := SearchIssues(translations.NullTranslationHelper) + + assert.Equal(t, "search_issues", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "q") + assert.Contains(t, tool.Definition.InputSchema.Properties, "sort") + assert.Contains(t, tool.Definition.InputSchema.Properties, "order") + assert.Contains(t, tool.Definition.InputSchema.Properties, "perPage") + assert.Contains(t, tool.Definition.InputSchema.Properties, "page") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"q"}) // Setup mock search results mockSearchResult := &github.IssuesSearchResult{ @@ -352,13 +349,13 @@ func Test_SearchIssues(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchIssues(stubGetClientFn(client), translations.NullTranslationHelper) + tool := SearchIssues(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -392,19 +389,18 @@ func Test_SearchIssues(t *testing.T) { func Test_CreateIssue(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := CreateIssue(stubGetClientFn(mockClient), translations.NullTranslationHelper) - - assert.Equal(t, "create_issue", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "title") - assert.Contains(t, tool.InputSchema.Properties, "body") - assert.Contains(t, tool.InputSchema.Properties, "assignees") - assert.Contains(t, tool.InputSchema.Properties, "labels") - assert.Contains(t, tool.InputSchema.Properties, "milestone") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "title"}) + tool := CreateIssue(translations.NullTranslationHelper) + + assert.Equal(t, "create_issue", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "title") + assert.Contains(t, tool.Definition.InputSchema.Properties, "body") + assert.Contains(t, tool.Definition.InputSchema.Properties, "assignees") + assert.Contains(t, tool.Definition.InputSchema.Properties, "labels") + assert.Contains(t, tool.Definition.InputSchema.Properties, "milestone") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo", "title"}) // Setup mock issue for success case mockIssue := &github.Issue{ @@ -506,13 +502,13 @@ func Test_CreateIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := CreateIssue(stubGetClientFn(client), translations.NullTranslationHelper) + tool := CreateIssue(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -566,21 +562,21 @@ func Test_CreateIssue(t *testing.T) { func Test_ListIssues(t *testing.T) { // Verify tool definition - mockClient := github.NewClient(nil) - tool, _ := ListIssues(stubGetClientFn(mockClient), translations.NullTranslationHelper) - - assert.Equal(t, "list_issues", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "state") - assert.Contains(t, tool.InputSchema.Properties, "labels") - assert.Contains(t, tool.InputSchema.Properties, "sort") - assert.Contains(t, tool.InputSchema.Properties, "direction") - assert.Contains(t, tool.InputSchema.Properties, "since") - assert.Contains(t, tool.InputSchema.Properties, "page") - assert.Contains(t, tool.InputSchema.Properties, "perPage") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"}) + + tool := ListIssues(translations.NullTranslationHelper) + + assert.Equal(t, "list_issues", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "state") + assert.Contains(t, tool.Definition.InputSchema.Properties, "labels") + assert.Contains(t, tool.Definition.InputSchema.Properties, "sort") + assert.Contains(t, tool.Definition.InputSchema.Properties, "direction") + assert.Contains(t, tool.Definition.InputSchema.Properties, "since") + assert.Contains(t, tool.Definition.InputSchema.Properties, "page") + assert.Contains(t, tool.Definition.InputSchema.Properties, "perPage") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo"}) // Setup mock issues for success case mockIssues := []*github.Issue{ @@ -698,13 +694,13 @@ func Test_ListIssues(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListIssues(stubGetClientFn(client), translations.NullTranslationHelper) + tool := ListIssues(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -742,21 +738,21 @@ func Test_ListIssues(t *testing.T) { func Test_UpdateIssue(t *testing.T) { // Verify tool definition - mockClient := github.NewClient(nil) - tool, _ := UpdateIssue(stubGetClientFn(mockClient), translations.NullTranslationHelper) - - assert.Equal(t, "update_issue", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "issue_number") - assert.Contains(t, tool.InputSchema.Properties, "title") - assert.Contains(t, tool.InputSchema.Properties, "body") - assert.Contains(t, tool.InputSchema.Properties, "state") - assert.Contains(t, tool.InputSchema.Properties, "labels") - assert.Contains(t, tool.InputSchema.Properties, "assignees") - assert.Contains(t, tool.InputSchema.Properties, "milestone") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "issue_number"}) + + tool := UpdateIssue(translations.NullTranslationHelper) + + assert.Equal(t, "update_issue", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "issue_number") + assert.Contains(t, tool.Definition.InputSchema.Properties, "title") + assert.Contains(t, tool.Definition.InputSchema.Properties, "body") + assert.Contains(t, tool.Definition.InputSchema.Properties, "state") + assert.Contains(t, tool.Definition.InputSchema.Properties, "labels") + assert.Contains(t, tool.Definition.InputSchema.Properties, "assignees") + assert.Contains(t, tool.Definition.InputSchema.Properties, "milestone") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo", "issue_number"}) // Setup mock issue for success case mockIssue := &github.Issue{ @@ -882,13 +878,13 @@ func Test_UpdateIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := UpdateIssue(stubGetClientFn(client), translations.NullTranslationHelper) + tool := UpdateIssue(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -999,17 +995,16 @@ func Test_ParseISOTimestamp(t *testing.T) { func Test_GetIssueComments(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetIssueComments(stubGetClientFn(mockClient), translations.NullTranslationHelper) - - assert.Equal(t, "get_issue_comments", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "issue_number") - assert.Contains(t, tool.InputSchema.Properties, "page") - assert.Contains(t, tool.InputSchema.Properties, "per_page") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "issue_number"}) + tool := GetIssueComments(translations.NullTranslationHelper) + + assert.Equal(t, "get_issue_comments", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "issue_number") + assert.Contains(t, tool.Definition.InputSchema.Properties, "page") + assert.Contains(t, tool.Definition.InputSchema.Properties, "per_page") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo", "issue_number"}) // Setup mock comments for success case mockComments := []*github.IssueComment{ @@ -1100,13 +1095,13 @@ func Test_GetIssueComments(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetIssueComments(stubGetClientFn(client), translations.NullTranslationHelper) + tool := GetIssueComments(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 14aeb9187..b162b26c6 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -14,8 +14,10 @@ import ( ) // GetPullRequest creates a tool to get details of a specific pull request. -func GetPullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("get_pull_request", +func GetPullRequest(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "get_pull_request", mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_DESCRIPTION", "Get details of a specific pull request")), mcp.WithString("owner", mcp.Required(), @@ -30,50 +32,57 @@ func GetPullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) mcp.Description("Pull request number"), ), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - pullNumber, err := RequiredInt(request, "pullNumber") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := RequiredInt(request, "pullNumber") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) - if err != nil { - return nil, fmt.Errorf("failed to get pull request: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) + if err != nil { + return nil, fmt.Errorf("failed to get pull request: %w", err) + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request: %s", string(body))), nil + } + + r, err := json.Marshal(pr) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request: %s", string(body))), nil - } - r, err := json.Marshal(pr) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: ReadOnly, + Category: CategoryPullRequests, + } } // UpdatePullRequest creates a tool to update an existing pull request. -func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("update_pull_request", +func UpdatePullRequest(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "update_pull_request", mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository")), mcp.WithString("owner", mcp.Required(), @@ -104,93 +113,100 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu mcp.Description("Allow maintainer edits"), ), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - pullNumber, err := RequiredInt(request, "pullNumber") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := RequiredInt(request, "pullNumber") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - // Build the update struct only with provided fields - update := &github.PullRequest{} - updateNeeded := false + // Build the update struct only with provided fields + update := &github.PullRequest{} + updateNeeded := false - if title, ok, err := OptionalParamOK[string](request, "title"); err != nil { - return mcp.NewToolResultError(err.Error()), nil - } else if ok { - update.Title = github.Ptr(title) - updateNeeded = true - } + if title, ok, err := OptionalParamOK[string](request, "title"); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } else if ok { + update.Title = github.Ptr(title) + updateNeeded = true + } - if body, ok, err := OptionalParamOK[string](request, "body"); err != nil { - return mcp.NewToolResultError(err.Error()), nil - } else if ok { - update.Body = github.Ptr(body) - updateNeeded = true - } + if body, ok, err := OptionalParamOK[string](request, "body"); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } else if ok { + update.Body = github.Ptr(body) + updateNeeded = true + } - if state, ok, err := OptionalParamOK[string](request, "state"); err != nil { - return mcp.NewToolResultError(err.Error()), nil - } else if ok { - update.State = github.Ptr(state) - updateNeeded = true - } + if state, ok, err := OptionalParamOK[string](request, "state"); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } else if ok { + update.State = github.Ptr(state) + updateNeeded = true + } - if base, ok, err := OptionalParamOK[string](request, "base"); err != nil { - return mcp.NewToolResultError(err.Error()), nil - } else if ok { - update.Base = &github.PullRequestBranch{Ref: github.Ptr(base)} - updateNeeded = true - } + if base, ok, err := OptionalParamOK[string](request, "base"); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } else if ok { + update.Base = &github.PullRequestBranch{Ref: github.Ptr(base)} + updateNeeded = true + } - if maintainerCanModify, ok, err := OptionalParamOK[bool](request, "maintainer_can_modify"); err != nil { - return mcp.NewToolResultError(err.Error()), nil - } else if ok { - update.MaintainerCanModify = github.Ptr(maintainerCanModify) - updateNeeded = true - } + if maintainerCanModify, ok, err := OptionalParamOK[bool](request, "maintainer_can_modify"); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } else if ok { + update.MaintainerCanModify = github.Ptr(maintainerCanModify) + updateNeeded = true + } - if !updateNeeded { - return mcp.NewToolResultError("No update parameters provided."), nil - } + if !updateNeeded { + return mcp.NewToolResultError("No update parameters provided."), nil + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) - if err != nil { - return nil, fmt.Errorf("failed to update pull request: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) + if err != nil { + return nil, fmt.Errorf("failed to update pull request: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(pr) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil - } - r, err := json.Marshal(pr) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: Write, + Category: CategoryPullRequests, + } } // ListPullRequests creates a tool to list and filter repository pull requests. -func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("list_pull_requests", +func ListPullRequests(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "list_pull_requests", mcp.WithDescription(t("TOOL_LIST_PULL_REQUESTS_DESCRIPTION", "List and filter repository pull requests")), mcp.WithString("owner", mcp.Required(), @@ -217,82 +233,89 @@ func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFun ), WithPagination(), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - state, err := OptionalParam[string](request, "state") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - head, err := OptionalParam[string](request, "head") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - base, err := OptionalParam[string](request, "base") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - sort, err := OptionalParam[string](request, "sort") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - direction, err := OptionalParam[string](request, "direction") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - pagination, err := OptionalPaginationParams(request) - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + state, err := OptionalParam[string](request, "state") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + head, err := OptionalParam[string](request, "head") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + base, err := OptionalParam[string](request, "base") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + sort, err := OptionalParam[string](request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + direction, err := OptionalParam[string](request, "direction") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pagination, err := OptionalPaginationParams(request) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - opts := &github.PullRequestListOptions{ - State: state, - Head: head, - Base: base, - Sort: sort, - Direction: direction, - ListOptions: github.ListOptions{ - PerPage: pagination.perPage, - Page: pagination.page, - }, - } + opts := &github.PullRequestListOptions{ + State: state, + Head: head, + Base: base, + Sort: sort, + Direction: direction, + ListOptions: github.ListOptions{ + PerPage: pagination.perPage, + Page: pagination.page, + }, + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - prs, resp, err := client.PullRequests.List(ctx, owner, repo, opts) - if err != nil { - return nil, fmt.Errorf("failed to list pull requests: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + prs, resp, err := client.PullRequests.List(ctx, owner, repo, opts) + if err != nil { + return nil, fmt.Errorf("failed to list pull requests: %w", err) + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to list pull requests: %s", string(body))), nil + } + + r, err := json.Marshal(prs) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to list pull requests: %s", string(body))), nil - } - r, err := json.Marshal(prs) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: ReadOnly, + Category: CategoryPullRequests, + } } // MergePullRequest creates a tool to merge a pull request. -func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("merge_pull_request", +func MergePullRequest(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "merge_pull_request", mcp.WithDescription(t("TOOL_MERGE_PULL_REQUEST_DESCRIPTION", "Merge a pull request")), mcp.WithString("owner", mcp.Required(), @@ -316,67 +339,74 @@ func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFun mcp.Description("Merge method ('merge', 'squash', 'rebase')"), ), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - pullNumber, err := RequiredInt(request, "pullNumber") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - commitTitle, err := OptionalParam[string](request, "commit_title") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - commitMessage, err := OptionalParam[string](request, "commit_message") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - mergeMethod, err := OptionalParam[string](request, "merge_method") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := RequiredInt(request, "pullNumber") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + commitTitle, err := OptionalParam[string](request, "commit_title") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + commitMessage, err := OptionalParam[string](request, "commit_message") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + mergeMethod, err := OptionalParam[string](request, "merge_method") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - options := &github.PullRequestOptions{ - CommitTitle: commitTitle, - MergeMethod: mergeMethod, - } + options := &github.PullRequestOptions{ + CommitTitle: commitTitle, + MergeMethod: mergeMethod, + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - result, resp, err := client.PullRequests.Merge(ctx, owner, repo, pullNumber, commitMessage, options) - if err != nil { - return nil, fmt.Errorf("failed to merge pull request: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + result, resp, err := client.PullRequests.Merge(ctx, owner, repo, pullNumber, commitMessage, options) + if err != nil { + return nil, fmt.Errorf("failed to merge pull request: %w", err) + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to merge pull request: %s", string(body))), nil + } + + r, err := json.Marshal(result) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to merge pull request: %s", string(body))), nil - } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: Write, + Category: CategoryPullRequests, + } } // GetPullRequestFiles creates a tool to get the list of files changed in a pull request. -func GetPullRequestFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("get_pull_request_files", +func GetPullRequestFiles(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "get_pull_request_files", mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_FILES_DESCRIPTION", "Get the list of files changed in a pull request")), mcp.WithString("owner", mcp.Required(), @@ -391,51 +421,58 @@ func GetPullRequestFiles(getClient GetClientFn, t translations.TranslationHelper mcp.Description("Pull request number"), ), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - pullNumber, err := RequiredInt(request, "pullNumber") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := RequiredInt(request, "pullNumber") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - opts := &github.ListOptions{} - files, resp, err := client.PullRequests.ListFiles(ctx, owner, repo, pullNumber, opts) - if err != nil { - return nil, fmt.Errorf("failed to get pull request files: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + opts := &github.ListOptions{} + files, resp, err := client.PullRequests.ListFiles(ctx, owner, repo, pullNumber, opts) + if err != nil { + return nil, fmt.Errorf("failed to get pull request files: %w", err) + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request files: %s", string(body))), nil + } + + r, err := json.Marshal(files) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request files: %s", string(body))), nil - } - r, err := json.Marshal(files) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: ReadOnly, + Category: CategoryPullRequests, + } } // GetPullRequestStatus creates a tool to get the combined status of all status checks for a pull request. -func GetPullRequestStatus(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("get_pull_request_status", +func GetPullRequestStatus(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "get_pull_request_status", mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_STATUS_DESCRIPTION", "Get the combined status of all status checks for a pull request")), mcp.WithString("owner", mcp.Required(), @@ -450,65 +487,72 @@ func GetPullRequestStatus(getClient GetClientFn, t translations.TranslationHelpe mcp.Description("Pull request number"), ), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - pullNumber, err := RequiredInt(request, "pullNumber") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - // First get the PR to find the head SHA - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) - if err != nil { - return nil, fmt.Errorf("failed to get pull request: %w", err) - } - defer func() { _ = resp.Body.Close() }() + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := RequiredInt(request, "pullNumber") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + // First get the PR to find the head SHA + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) + if err != nil { + return nil, fmt.Errorf("failed to get pull request: %w", err) + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request: %s", string(body))), nil + } + + // Get combined status for the head SHA + status, resp, err := client.Repositories.GetCombinedStatus(ctx, owner, repo, *pr.Head.SHA, nil) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to get combined status: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request: %s", string(body))), nil - } + defer func() { _ = resp.Body.Close() }() - // Get combined status for the head SHA - status, resp, err := client.Repositories.GetCombinedStatus(ctx, owner, repo, *pr.Head.SHA, nil) - if err != nil { - return nil, fmt.Errorf("failed to get combined status: %w", err) - } - defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get combined status: %s", string(body))), nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(status) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to get combined status: %s", string(body))), nil - } - r, err := json.Marshal(status) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: ReadOnly, + Category: CategoryPullRequests, + } } // UpdatePullRequestBranch creates a tool to update a pull request branch with the latest changes from the base branch. -func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("update_pull_request_branch", +func UpdatePullRequestBranch(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "update_pull_request_branch", mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_BRANCH_DESCRIPTION", "Update a pull request branch with the latest changes from the base branch")), mcp.WithString("owner", mcp.Required(), @@ -526,63 +570,70 @@ func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHe mcp.Description("The expected SHA of the pull request's HEAD ref"), ), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - pullNumber, err := RequiredInt(request, "pullNumber") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - expectedHeadSHA, err := OptionalParam[string](request, "expectedHeadSha") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - opts := &github.PullRequestBranchUpdateOptions{} - if expectedHeadSHA != "" { - opts.ExpectedHeadSHA = github.Ptr(expectedHeadSHA) - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := RequiredInt(request, "pullNumber") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + expectedHeadSHA, err := OptionalParam[string](request, "expectedHeadSha") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + opts := &github.PullRequestBranchUpdateOptions{} + if expectedHeadSHA != "" { + opts.ExpectedHeadSHA = github.Ptr(expectedHeadSHA) + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - result, resp, err := client.PullRequests.UpdateBranch(ctx, owner, repo, pullNumber, opts) - if err != nil { - // Check if it's an acceptedError. An acceptedError indicates that the update is in progress, - // and it's not a real error. - if resp != nil && resp.StatusCode == http.StatusAccepted && isAcceptedError(err) { - return mcp.NewToolResultText("Pull request branch update is in progress"), nil - } - return nil, fmt.Errorf("failed to update pull request branch: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + result, resp, err := client.PullRequests.UpdateBranch(ctx, owner, repo, pullNumber, opts) + if err != nil { + // Check if it's an acceptedError. An acceptedError indicates that the update is in progress, + // and it's not a real error. + if resp != nil && resp.StatusCode == http.StatusAccepted && isAcceptedError(err) { + return mcp.NewToolResultText("Pull request branch update is in progress"), nil + } + return nil, fmt.Errorf("failed to update pull request branch: %w", err) + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusAccepted { - body, err := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusAccepted { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request branch: %s", string(body))), nil + } + + r, err := json.Marshal(result) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request branch: %s", string(body))), nil - } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: Write, + Category: CategoryPullRequests, + } } // GetPullRequestComments creates a tool to get the review comments on a pull request. -func GetPullRequestComments(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("get_pull_request_comments", +func GetPullRequestComments(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "get_pull_request_comments", mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_COMMENTS_DESCRIPTION", "Get the review comments on a pull request")), mcp.WithString("owner", mcp.Required(), @@ -597,56 +648,63 @@ func GetPullRequestComments(getClient GetClientFn, t translations.TranslationHel mcp.Description("Pull request number"), ), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - pullNumber, err := RequiredInt(request, "pullNumber") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := RequiredInt(request, "pullNumber") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - opts := &github.PullRequestListCommentsOptions{ - ListOptions: github.ListOptions{ - PerPage: 100, - }, - } + opts := &github.PullRequestListCommentsOptions{ + ListOptions: github.ListOptions{ + PerPage: 100, + }, + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - comments, resp, err := client.PullRequests.ListComments(ctx, owner, repo, pullNumber, opts) - if err != nil { - return nil, fmt.Errorf("failed to get pull request comments: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + comments, resp, err := client.PullRequests.ListComments(ctx, owner, repo, pullNumber, opts) + if err != nil { + return nil, fmt.Errorf("failed to get pull request comments: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request comments: %s", string(body))), nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(comments) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request comments: %s", string(body))), nil - } - r, err := json.Marshal(comments) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: ReadOnly, + Category: CategoryPullRequests, + } } // GetPullRequestReviews creates a tool to get the reviews on a pull request. -func GetPullRequestReviews(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("get_pull_request_reviews", +func GetPullRequestReviews(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "get_pull_request_reviews", mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_REVIEWS_DESCRIPTION", "Get the reviews on a pull request")), mcp.WithString("owner", mcp.Required(), @@ -661,50 +719,57 @@ func GetPullRequestReviews(getClient GetClientFn, t translations.TranslationHelp mcp.Description("Pull request number"), ), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - pullNumber, err := RequiredInt(request, "pullNumber") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := RequiredInt(request, "pullNumber") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - reviews, resp, err := client.PullRequests.ListReviews(ctx, owner, repo, pullNumber, nil) - if err != nil { - return nil, fmt.Errorf("failed to get pull request reviews: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + reviews, resp, err := client.PullRequests.ListReviews(ctx, owner, repo, pullNumber, nil) + if err != nil { + return nil, fmt.Errorf("failed to get pull request reviews: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request reviews: %s", string(body))), nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(reviews) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request reviews: %s", string(body))), nil - } - r, err := json.Marshal(reviews) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: ReadOnly, + Category: CategoryPullRequests, + } } // CreatePullRequestReview creates a tool to submit a review on a pull request. -func CreatePullRequestReview(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("create_pull_request_review", +func CreatePullRequestReview(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "create_pull_request_review", mcp.WithDescription(t("TOOL_CREATE_PULL_REQUEST_REVIEW_DESCRIPTION", "Create a review on a pull request")), mcp.WithString("owner", mcp.Required(), @@ -769,138 +834,145 @@ func CreatePullRequestReview(getClient GetClientFn, t translations.TranslationHe mcp.Description("Line-specific comments array of objects to place comments on pull request changes. Requires path and body. For line comments use line or position. For multi-line comments use start_line and line with optional side parameters."), ), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - pullNumber, err := RequiredInt(request, "pullNumber") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - event, err := requiredParam[string](request, "event") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - - // Create review request - reviewRequest := &github.PullRequestReviewRequest{ - Event: github.Ptr(event), - } - - // Add body if provided - body, err := OptionalParam[string](request, "body") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - if body != "" { - reviewRequest.Body = github.Ptr(body) - } - - // Add commit ID if provided - commitID, err := OptionalParam[string](request, "commitId") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - if commitID != "" { - reviewRequest.CommitID = github.Ptr(commitID) - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := RequiredInt(request, "pullNumber") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + event, err := requiredParam[string](request, "event") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - // Add comments if provided - if commentsObj, ok := request.Params.Arguments["comments"].([]interface{}); ok && len(commentsObj) > 0 { - comments := []*github.DraftReviewComment{} + // Create review request + reviewRequest := &github.PullRequestReviewRequest{ + Event: github.Ptr(event), + } - for _, c := range commentsObj { - commentMap, ok := c.(map[string]interface{}) - if !ok { - return mcp.NewToolResultError("each comment must be an object with path and body"), nil - } + // Add body if provided + body, err := OptionalParam[string](request, "body") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if body != "" { + reviewRequest.Body = github.Ptr(body) + } - path, ok := commentMap["path"].(string) - if !ok || path == "" { - return mcp.NewToolResultError("each comment must have a path"), nil - } + // Add commit ID if provided + commitID, err := OptionalParam[string](request, "commitId") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if commitID != "" { + reviewRequest.CommitID = github.Ptr(commitID) + } - body, ok := commentMap["body"].(string) - if !ok || body == "" { - return mcp.NewToolResultError("each comment must have a body"), nil + // Add comments if provided + if commentsObj, ok := request.Params.Arguments["comments"].([]interface{}); ok && len(commentsObj) > 0 { + comments := []*github.DraftReviewComment{} + + for _, c := range commentsObj { + commentMap, ok := c.(map[string]interface{}) + if !ok { + return mcp.NewToolResultError("each comment must be an object with path and body"), nil + } + + path, ok := commentMap["path"].(string) + if !ok || path == "" { + return mcp.NewToolResultError("each comment must have a path"), nil + } + + body, ok := commentMap["body"].(string) + if !ok || body == "" { + return mcp.NewToolResultError("each comment must have a body"), nil + } + + _, hasPosition := commentMap["position"].(float64) + _, hasLine := commentMap["line"].(float64) + _, hasSide := commentMap["side"].(string) + _, hasStartLine := commentMap["start_line"].(float64) + _, hasStartSide := commentMap["start_side"].(string) + + switch { + case !hasPosition && !hasLine: + return mcp.NewToolResultError("each comment must have either position or line"), nil + case hasPosition && (hasLine || hasSide || hasStartLine || hasStartSide): + return mcp.NewToolResultError("position cannot be combined with line, side, start_line, or start_side"), nil + case hasStartSide && !hasSide: + return mcp.NewToolResultError("if start_side is provided, side must also be provided"), nil + } + + comment := &github.DraftReviewComment{ + Path: github.Ptr(path), + Body: github.Ptr(body), + } + + if positionFloat, ok := commentMap["position"].(float64); ok { + comment.Position = github.Ptr(int(positionFloat)) + } else if lineFloat, ok := commentMap["line"].(float64); ok { + comment.Line = github.Ptr(int(lineFloat)) + } + if side, ok := commentMap["side"].(string); ok { + comment.Side = github.Ptr(side) + } + if startLineFloat, ok := commentMap["start_line"].(float64); ok { + comment.StartLine = github.Ptr(int(startLineFloat)) + } + if startSide, ok := commentMap["start_side"].(string); ok { + comment.StartSide = github.Ptr(startSide) + } + + comments = append(comments, comment) } - _, hasPosition := commentMap["position"].(float64) - _, hasLine := commentMap["line"].(float64) - _, hasSide := commentMap["side"].(string) - _, hasStartLine := commentMap["start_line"].(float64) - _, hasStartSide := commentMap["start_side"].(string) - - switch { - case !hasPosition && !hasLine: - return mcp.NewToolResultError("each comment must have either position or line"), nil - case hasPosition && (hasLine || hasSide || hasStartLine || hasStartSide): - return mcp.NewToolResultError("position cannot be combined with line, side, start_line, or start_side"), nil - case hasStartSide && !hasSide: - return mcp.NewToolResultError("if start_side is provided, side must also be provided"), nil - } + reviewRequest.Comments = comments + } - comment := &github.DraftReviewComment{ - Path: github.Ptr(path), - Body: github.Ptr(body), - } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + review, resp, err := client.PullRequests.CreateReview(ctx, owner, repo, pullNumber, reviewRequest) + if err != nil { + return nil, fmt.Errorf("failed to create pull request review: %w", err) + } + defer func() { _ = resp.Body.Close() }() - if positionFloat, ok := commentMap["position"].(float64); ok { - comment.Position = github.Ptr(int(positionFloat)) - } else if lineFloat, ok := commentMap["line"].(float64); ok { - comment.Line = github.Ptr(int(lineFloat)) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) } - if side, ok := commentMap["side"].(string); ok { - comment.Side = github.Ptr(side) - } - if startLineFloat, ok := commentMap["start_line"].(float64); ok { - comment.StartLine = github.Ptr(int(startLineFloat)) - } - if startSide, ok := commentMap["start_side"].(string); ok { - comment.StartSide = github.Ptr(startSide) - } - - comments = append(comments, comment) + return mcp.NewToolResultError(fmt.Sprintf("failed to create pull request review: %s", string(body))), nil } - reviewRequest.Comments = comments - } - - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - review, resp, err := client.PullRequests.CreateReview(ctx, owner, repo, pullNumber, reviewRequest) - if err != nil { - return nil, fmt.Errorf("failed to create pull request review: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(review) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to create pull request review: %s", string(body))), nil - } - r, err := json.Marshal(review) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: Write, + Category: CategoryPullRequests, + } } // CreatePullRequest creates a tool to create a new pull request. -func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("create_pull_request", +func CreatePullRequest(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "create_pull_request", mcp.WithDescription(t("TOOL_CREATE_PULL_REQUEST_DESCRIPTION", "Create a new pull request in a GitHub repository")), mcp.WithString("owner", mcp.Required(), @@ -932,79 +1004,84 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu mcp.Description("Allow maintainer edits"), ), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - title, err := requiredParam[string](request, "title") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - head, err := requiredParam[string](request, "head") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - base, err := requiredParam[string](request, "base") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + title, err := requiredParam[string](request, "title") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + head, err := requiredParam[string](request, "head") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + base, err := requiredParam[string](request, "base") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - body, err := OptionalParam[string](request, "body") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + body, err := OptionalParam[string](request, "body") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - draft, err := OptionalParam[bool](request, "draft") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + draft, err := OptionalParam[bool](request, "draft") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - maintainerCanModify, err := OptionalParam[bool](request, "maintainer_can_modify") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + maintainerCanModify, err := OptionalParam[bool](request, "maintainer_can_modify") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - newPR := &github.NewPullRequest{ - Title: github.Ptr(title), - Head: github.Ptr(head), - Base: github.Ptr(base), - } + newPR := &github.NewPullRequest{ + Title: github.Ptr(title), + Head: github.Ptr(head), + Base: github.Ptr(base), + } - if body != "" { - newPR.Body = github.Ptr(body) - } + if body != "" { + newPR.Body = github.Ptr(body) + } - newPR.Draft = github.Ptr(draft) - newPR.MaintainerCanModify = github.Ptr(maintainerCanModify) + newPR.Draft = github.Ptr(draft) + newPR.MaintainerCanModify = github.Ptr(maintainerCanModify) - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - pr, resp, err := client.PullRequests.Create(ctx, owner, repo, newPR) - if err != nil { - return nil, fmt.Errorf("failed to create pull request: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + pr, resp, err := client.PullRequests.Create(ctx, owner, repo, newPR) + if err != nil { + return nil, fmt.Errorf("failed to create pull request: %w", err) + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusCreated { - body, err := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to create pull request: %s", string(body))), nil + } + + r, err := json.Marshal(pr) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to create pull request: %s", string(body))), nil - } - r, err := json.Marshal(pr) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: Write, + Category: CategoryPullRequests, + } } diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 3c20dfc2c..2383148db 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -16,15 +16,14 @@ import ( func Test_GetPullRequest(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetPullRequest(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool := GetPullRequest(translations.NullTranslationHelper) - assert.Equal(t, "get_pull_request", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "pullNumber") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber"}) + assert.Equal(t, "get_pull_request", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "pullNumber") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo", "pullNumber"}) // Setup mock PR for success case mockPR := &github.PullRequest{ @@ -94,13 +93,13 @@ func Test_GetPullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetPullRequest(stubGetClientFn(client), translations.NullTranslationHelper) + tool := GetPullRequest(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -128,20 +127,19 @@ func Test_GetPullRequest(t *testing.T) { func Test_UpdatePullRequest(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := UpdatePullRequest(stubGetClientFn(mockClient), translations.NullTranslationHelper) - - assert.Equal(t, "update_pull_request", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "pullNumber") - assert.Contains(t, tool.InputSchema.Properties, "title") - assert.Contains(t, tool.InputSchema.Properties, "body") - assert.Contains(t, tool.InputSchema.Properties, "state") - assert.Contains(t, tool.InputSchema.Properties, "base") - assert.Contains(t, tool.InputSchema.Properties, "maintainer_can_modify") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber"}) + tool := UpdatePullRequest(translations.NullTranslationHelper) + + assert.Equal(t, "update_pull_request", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "pullNumber") + assert.Contains(t, tool.Definition.InputSchema.Properties, "title") + assert.Contains(t, tool.Definition.InputSchema.Properties, "body") + assert.Contains(t, tool.Definition.InputSchema.Properties, "state") + assert.Contains(t, tool.Definition.InputSchema.Properties, "base") + assert.Contains(t, tool.Definition.InputSchema.Properties, "maintainer_can_modify") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo", "pullNumber"}) // Setup mock PR for success case mockUpdatedPR := &github.PullRequest{ @@ -257,13 +255,13 @@ func Test_UpdatePullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := UpdatePullRequest(stubGetClientFn(client), translations.NullTranslationHelper) + tool := UpdatePullRequest(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -310,21 +308,20 @@ func Test_UpdatePullRequest(t *testing.T) { func Test_ListPullRequests(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListPullRequests(stubGetClientFn(mockClient), translations.NullTranslationHelper) - - assert.Equal(t, "list_pull_requests", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "state") - assert.Contains(t, tool.InputSchema.Properties, "head") - assert.Contains(t, tool.InputSchema.Properties, "base") - assert.Contains(t, tool.InputSchema.Properties, "sort") - assert.Contains(t, tool.InputSchema.Properties, "direction") - assert.Contains(t, tool.InputSchema.Properties, "perPage") - assert.Contains(t, tool.InputSchema.Properties, "page") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"}) + tool := ListPullRequests(translations.NullTranslationHelper) + + assert.Equal(t, "list_pull_requests", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "state") + assert.Contains(t, tool.Definition.InputSchema.Properties, "head") + assert.Contains(t, tool.Definition.InputSchema.Properties, "base") + assert.Contains(t, tool.Definition.InputSchema.Properties, "sort") + assert.Contains(t, tool.Definition.InputSchema.Properties, "direction") + assert.Contains(t, tool.Definition.InputSchema.Properties, "perPage") + assert.Contains(t, tool.Definition.InputSchema.Properties, "page") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo"}) // Setup mock PRs for success case mockPRs := []*github.PullRequest{ @@ -403,13 +400,13 @@ func Test_ListPullRequests(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListPullRequests(stubGetClientFn(client), translations.NullTranslationHelper) + tool := ListPullRequests(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -440,18 +437,17 @@ func Test_ListPullRequests(t *testing.T) { func Test_MergePullRequest(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := MergePullRequest(stubGetClientFn(mockClient), translations.NullTranslationHelper) - - assert.Equal(t, "merge_pull_request", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "pullNumber") - assert.Contains(t, tool.InputSchema.Properties, "commit_title") - assert.Contains(t, tool.InputSchema.Properties, "commit_message") - assert.Contains(t, tool.InputSchema.Properties, "merge_method") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber"}) + tool := MergePullRequest(translations.NullTranslationHelper) + + assert.Equal(t, "merge_pull_request", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "pullNumber") + assert.Contains(t, tool.Definition.InputSchema.Properties, "commit_title") + assert.Contains(t, tool.Definition.InputSchema.Properties, "commit_message") + assert.Contains(t, tool.Definition.InputSchema.Properties, "merge_method") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo", "pullNumber"}) // Setup mock merge result for success case mockMergeResult := &github.PullRequestMergeResult{ @@ -518,13 +514,13 @@ func Test_MergePullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := MergePullRequest(stubGetClientFn(client), translations.NullTranslationHelper) + tool := MergePullRequest(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -551,15 +547,14 @@ func Test_MergePullRequest(t *testing.T) { func Test_GetPullRequestFiles(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetPullRequestFiles(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool := GetPullRequestFiles(translations.NullTranslationHelper) - assert.Equal(t, "get_pull_request_files", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "pullNumber") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber"}) + assert.Equal(t, "get_pull_request_files", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "pullNumber") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo", "pullNumber"}) // Setup mock PR files for success case mockFiles := []*github.CommitFile{ @@ -630,13 +625,13 @@ func Test_GetPullRequestFiles(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetPullRequestFiles(stubGetClientFn(client), translations.NullTranslationHelper) + tool := GetPullRequestFiles(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -667,15 +662,14 @@ func Test_GetPullRequestFiles(t *testing.T) { func Test_GetPullRequestStatus(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetPullRequestStatus(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool := GetPullRequestStatus(translations.NullTranslationHelper) - assert.Equal(t, "get_pull_request_status", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "pullNumber") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber"}) + assert.Equal(t, "get_pull_request_status", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "pullNumber") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo", "pullNumber"}) // Setup mock PR for successful PR fetch mockPR := &github.PullRequest{ @@ -790,13 +784,13 @@ func Test_GetPullRequestStatus(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetPullRequestStatus(stubGetClientFn(client), translations.NullTranslationHelper) + tool := GetPullRequestStatus(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -828,16 +822,15 @@ func Test_GetPullRequestStatus(t *testing.T) { func Test_UpdatePullRequestBranch(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := UpdatePullRequestBranch(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool := UpdatePullRequestBranch(translations.NullTranslationHelper) - assert.Equal(t, "update_pull_request_branch", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "pullNumber") - assert.Contains(t, tool.InputSchema.Properties, "expectedHeadSha") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber"}) + assert.Equal(t, "update_pull_request_branch", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "pullNumber") + assert.Contains(t, tool.Definition.InputSchema.Properties, "expectedHeadSha") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo", "pullNumber"}) // Setup mock update result for success case mockUpdateResult := &github.PullRequestBranchUpdateResponse{ @@ -917,13 +910,13 @@ func Test_UpdatePullRequestBranch(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := UpdatePullRequestBranch(stubGetClientFn(client), translations.NullTranslationHelper) + tool := UpdatePullRequestBranch(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -944,15 +937,14 @@ func Test_UpdatePullRequestBranch(t *testing.T) { func Test_GetPullRequestComments(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetPullRequestComments(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool := GetPullRequestComments(translations.NullTranslationHelper) - assert.Equal(t, "get_pull_request_comments", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "pullNumber") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber"}) + assert.Equal(t, "get_pull_request_comments", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "pullNumber") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo", "pullNumber"}) // Setup mock PR comments for success case mockComments := []*github.PullRequestComment{ @@ -1033,13 +1025,13 @@ func Test_GetPullRequestComments(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetPullRequestComments(stubGetClientFn(client), translations.NullTranslationHelper) + tool := GetPullRequestComments(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -1071,15 +1063,14 @@ func Test_GetPullRequestComments(t *testing.T) { func Test_GetPullRequestReviews(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetPullRequestReviews(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool := GetPullRequestReviews(translations.NullTranslationHelper) - assert.Equal(t, "get_pull_request_reviews", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "pullNumber") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber"}) + assert.Equal(t, "get_pull_request_reviews", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "pullNumber") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo", "pullNumber"}) // Setup mock PR reviews for success case mockReviews := []*github.PullRequestReview{ @@ -1156,13 +1147,13 @@ func Test_GetPullRequestReviews(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetPullRequestReviews(stubGetClientFn(client), translations.NullTranslationHelper) + tool := GetPullRequestReviews(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -1194,19 +1185,18 @@ func Test_GetPullRequestReviews(t *testing.T) { func Test_CreatePullRequestReview(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := CreatePullRequestReview(stubGetClientFn(mockClient), translations.NullTranslationHelper) - - assert.Equal(t, "create_pull_request_review", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "pullNumber") - assert.Contains(t, tool.InputSchema.Properties, "body") - assert.Contains(t, tool.InputSchema.Properties, "event") - assert.Contains(t, tool.InputSchema.Properties, "commitId") - assert.Contains(t, tool.InputSchema.Properties, "comments") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber", "event"}) + tool := CreatePullRequestReview(translations.NullTranslationHelper) + + assert.Equal(t, "create_pull_request_review", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "pullNumber") + assert.Contains(t, tool.Definition.InputSchema.Properties, "body") + assert.Contains(t, tool.Definition.InputSchema.Properties, "event") + assert.Contains(t, tool.Definition.InputSchema.Properties, "commitId") + assert.Contains(t, tool.Definition.InputSchema.Properties, "comments") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo", "pullNumber", "event"}) // Setup mock review for success case mockReview := &github.PullRequestReview{ @@ -1523,13 +1513,13 @@ func Test_CreatePullRequestReview(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := CreatePullRequestReview(stubGetClientFn(client), translations.NullTranslationHelper) + tool := CreatePullRequestReview(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -1565,20 +1555,19 @@ func Test_CreatePullRequestReview(t *testing.T) { func Test_CreatePullRequest(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := CreatePullRequest(stubGetClientFn(mockClient), translations.NullTranslationHelper) - - assert.Equal(t, "create_pull_request", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "title") - assert.Contains(t, tool.InputSchema.Properties, "body") - assert.Contains(t, tool.InputSchema.Properties, "head") - assert.Contains(t, tool.InputSchema.Properties, "base") - assert.Contains(t, tool.InputSchema.Properties, "draft") - assert.Contains(t, tool.InputSchema.Properties, "maintainer_can_modify") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "title", "head", "base"}) + tool := CreatePullRequest(translations.NullTranslationHelper) + + assert.Equal(t, "create_pull_request", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "title") + assert.Contains(t, tool.Definition.InputSchema.Properties, "body") + assert.Contains(t, tool.Definition.InputSchema.Properties, "head") + assert.Contains(t, tool.Definition.InputSchema.Properties, "base") + assert.Contains(t, tool.Definition.InputSchema.Properties, "draft") + assert.Contains(t, tool.Definition.InputSchema.Properties, "maintainer_can_modify") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo", "title", "head", "base"}) // Setup mock PR for success case mockPR := &github.PullRequest{ @@ -1678,13 +1667,13 @@ func Test_CreatePullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := CreatePullRequest(stubGetClientFn(client), translations.NullTranslationHelper) + tool := CreatePullRequest(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index 56500eafe..9877261aa 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -13,8 +13,10 @@ import ( "github.com/mark3labs/mcp-go/server" ) -func GetCommit(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("get_commit", +func GetCommit(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "get_commit", mcp.WithDescription(t("TOOL_GET_COMMITS_DESCRIPTION", "Get details for a commit from a GitHub repository")), mcp.WithString("owner", mcp.Required(), @@ -30,59 +32,66 @@ func GetCommit(getClient GetClientFn, t translations.TranslationHelperFunc) (too ), WithPagination(), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - sha, err := requiredParam[string](request, "sha") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - pagination, err := OptionalPaginationParams(request) - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + sha, err := requiredParam[string](request, "sha") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pagination, err := OptionalPaginationParams(request) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - opts := &github.ListOptions{ - Page: pagination.page, - PerPage: pagination.perPage, - } + opts := &github.ListOptions{ + Page: pagination.page, + PerPage: pagination.perPage, + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - commit, resp, err := client.Repositories.GetCommit(ctx, owner, repo, sha, opts) - if err != nil { - return nil, fmt.Errorf("failed to get commit: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + commit, resp, err := client.Repositories.GetCommit(ctx, owner, repo, sha, opts) + if err != nil { + return nil, fmt.Errorf("failed to get commit: %w", err) + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { - body, err := io.ReadAll(resp.Body) + if resp.StatusCode != 200 { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get commit: %s", string(body))), nil + } + + r, err := json.Marshal(commit) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to get commit: %s", string(body))), nil - } - r, err := json.Marshal(commit) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: ReadOnly, + Category: CategoryRepositories, + } } // ListCommits creates a tool to get commits of a branch in a repository. -func ListCommits(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("list_commits", +func ListCommits(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "list_commits", mcp.WithDescription(t("TOOL_LIST_COMMITS_DESCRIPTION", "Get list of commits of a branch in a GitHub repository")), mcp.WithString("owner", mcp.Required(), @@ -97,62 +106,69 @@ func ListCommits(getClient GetClientFn, t translations.TranslationHelperFunc) (t ), WithPagination(), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - sha, err := OptionalParam[string](request, "sha") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - pagination, err := OptionalPaginationParams(request) - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + sha, err := OptionalParam[string](request, "sha") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pagination, err := OptionalPaginationParams(request) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - opts := &github.CommitsListOptions{ - SHA: sha, - ListOptions: github.ListOptions{ - Page: pagination.page, - PerPage: pagination.perPage, - }, - } + opts := &github.CommitsListOptions{ + SHA: sha, + ListOptions: github.ListOptions{ + Page: pagination.page, + PerPage: pagination.perPage, + }, + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - commits, resp, err := client.Repositories.ListCommits(ctx, owner, repo, opts) - if err != nil { - return nil, fmt.Errorf("failed to list commits: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + commits, resp, err := client.Repositories.ListCommits(ctx, owner, repo, opts) + if err != nil { + return nil, fmt.Errorf("failed to list commits: %w", err) + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { - body, err := io.ReadAll(resp.Body) + if resp.StatusCode != 200 { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to list commits: %s", string(body))), nil + } + + r, err := json.Marshal(commits) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to list commits: %s", string(body))), nil - } - r, err := json.Marshal(commits) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: ReadOnly, + Category: CategoryRepositories, + } } // CreateOrUpdateFile creates a tool to create or update a file in a GitHub repository. -func CreateOrUpdateFile(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("create_or_update_file", +func CreateOrUpdateFile(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "create_or_update_file", mcp.WithDescription(t("TOOL_CREATE_OR_UPDATE_FILE_DESCRIPTION", "Create or update a single file in a GitHub repository")), mcp.WithString("owner", mcp.Required(), @@ -182,82 +198,89 @@ func CreateOrUpdateFile(getClient GetClientFn, t translations.TranslationHelperF mcp.Description("SHA of file being replaced (for updates)"), ), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - path, err := requiredParam[string](request, "path") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - content, err := requiredParam[string](request, "content") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - message, err := requiredParam[string](request, "message") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - branch, err := requiredParam[string](request, "branch") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + path, err := requiredParam[string](request, "path") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + content, err := requiredParam[string](request, "content") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + message, err := requiredParam[string](request, "message") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + branch, err := requiredParam[string](request, "branch") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - // Convert content to base64 - contentBytes := []byte(content) + // Convert content to base64 + contentBytes := []byte(content) - // Create the file options - opts := &github.RepositoryContentFileOptions{ - Message: github.Ptr(message), - Content: contentBytes, - Branch: github.Ptr(branch), - } + // Create the file options + opts := &github.RepositoryContentFileOptions{ + Message: github.Ptr(message), + Content: contentBytes, + Branch: github.Ptr(branch), + } - // If SHA is provided, set it (for updates) - sha, err := OptionalParam[string](request, "sha") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - if sha != "" { - opts.SHA = github.Ptr(sha) - } + // If SHA is provided, set it (for updates) + sha, err := OptionalParam[string](request, "sha") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if sha != "" { + opts.SHA = github.Ptr(sha) + } - // Create or update the file - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - fileContent, resp, err := client.Repositories.CreateFile(ctx, owner, repo, path, opts) - if err != nil { - return nil, fmt.Errorf("failed to create/update file: %w", err) - } - defer func() { _ = resp.Body.Close() }() + // Create or update the file + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + fileContent, resp, err := client.Repositories.CreateFile(ctx, owner, repo, path, opts) + if err != nil { + return nil, fmt.Errorf("failed to create/update file: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != 200 && resp.StatusCode != 201 { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to create/update file: %s", string(body))), nil + } - if resp.StatusCode != 200 && resp.StatusCode != 201 { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(fileContent) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to create/update file: %s", string(body))), nil - } - r, err := json.Marshal(fileContent) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: Write, + Category: CategoryRepositories, + } } // CreateRepository creates a tool to create a new GitHub repository. -func CreateRepository(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("create_repository", +func CreateRepository(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "create_repository", mcp.WithDescription(t("TOOL_CREATE_REPOSITORY_DESCRIPTION", "Create a new GitHub repository in your account")), mcp.WithString("name", mcp.Required(), @@ -273,61 +296,68 @@ func CreateRepository(getClient GetClientFn, t translations.TranslationHelperFun mcp.Description("Initialize with README"), ), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - name, err := requiredParam[string](request, "name") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - description, err := OptionalParam[string](request, "description") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - private, err := OptionalParam[bool](request, "private") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - autoInit, err := OptionalParam[bool](request, "autoInit") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + name, err := requiredParam[string](request, "name") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + description, err := OptionalParam[string](request, "description") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + private, err := OptionalParam[bool](request, "private") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + autoInit, err := OptionalParam[bool](request, "autoInit") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - repo := &github.Repository{ - Name: github.Ptr(name), - Description: github.Ptr(description), - Private: github.Ptr(private), - AutoInit: github.Ptr(autoInit), - } + repo := &github.Repository{ + Name: github.Ptr(name), + Description: github.Ptr(description), + Private: github.Ptr(private), + AutoInit: github.Ptr(autoInit), + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - createdRepo, resp, err := client.Repositories.Create(ctx, "", repo) - if err != nil { - return nil, fmt.Errorf("failed to create repository: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + createdRepo, resp, err := client.Repositories.Create(ctx, "", repo) + if err != nil { + return nil, fmt.Errorf("failed to create repository: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to create repository: %s", string(body))), nil + } - if resp.StatusCode != http.StatusCreated { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(createdRepo) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to create repository: %s", string(body))), nil - } - r, err := json.Marshal(createdRepo) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: Write, + Category: CategoryRepositories, + } } // GetFileContents creates a tool to get the contents of a file or directory from a GitHub repository. -func GetFileContents(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("get_file_contents", +func GetFileContents(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "get_file_contents", mcp.WithDescription(t("TOOL_GET_FILE_CONTENTS_DESCRIPTION", "Get the contents of a file or directory from a GitHub repository")), mcp.WithString("owner", mcp.Required(), @@ -345,62 +375,69 @@ func GetFileContents(getClient GetClientFn, t translations.TranslationHelperFunc mcp.Description("Branch to get contents from"), ), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - path, err := requiredParam[string](request, "path") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - branch, err := OptionalParam[string](request, "branch") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - opts := &github.RepositoryContentGetOptions{Ref: branch} - fileContent, dirContent, resp, err := client.Repositories.GetContents(ctx, owner, repo, path, opts) - if err != nil { - return nil, fmt.Errorf("failed to get file contents: %w", err) - } - defer func() { _ = resp.Body.Close() }() + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + path, err := requiredParam[string](request, "path") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + branch, err := OptionalParam[string](request, "branch") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - if resp.StatusCode != 200 { - body, err := io.ReadAll(resp.Body) + client, err := getClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to get GitHub client: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to get file contents: %s", string(body))), nil - } + opts := &github.RepositoryContentGetOptions{Ref: branch} + fileContent, dirContent, resp, err := client.Repositories.GetContents(ctx, owner, repo, path, opts) + if err != nil { + return nil, fmt.Errorf("failed to get file contents: %w", err) + } + defer func() { _ = resp.Body.Close() }() - var result interface{} - if fileContent != nil { - result = fileContent - } else { - result = dirContent - } + if resp.StatusCode != 200 { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get file contents: %s", string(body))), nil + } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } + var result interface{} + if fileContent != nil { + result = fileContent + } else { + result = dirContent + } + + r, err := json.Marshal(result) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } - return mcp.NewToolResultText(string(r)), nil - } + return mcp.NewToolResultText(string(r)), nil + } + }, + Access: ReadOnly, + Category: CategoryRepositories, + } } // ForkRepository creates a tool to fork a repository. -func ForkRepository(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("fork_repository", +func ForkRepository(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "fork_repository", mcp.WithDescription(t("TOOL_FORK_REPOSITORY_DESCRIPTION", "Fork a GitHub repository to your account or specified organization")), mcp.WithString("owner", mcp.Required(), @@ -414,60 +451,67 @@ func ForkRepository(getClient GetClientFn, t translations.TranslationHelperFunc) mcp.Description("Organization to fork to"), ), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - org, err := OptionalParam[string](request, "organization") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + org, err := OptionalParam[string](request, "organization") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - opts := &github.RepositoryCreateForkOptions{} - if org != "" { - opts.Organization = org - } + opts := &github.RepositoryCreateForkOptions{} + if org != "" { + opts.Organization = org + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - forkedRepo, resp, err := client.Repositories.CreateFork(ctx, owner, repo, opts) - if err != nil { - // Check if it's an acceptedError. An acceptedError indicates that the update is in progress, - // and it's not a real error. - if resp != nil && resp.StatusCode == http.StatusAccepted && isAcceptedError(err) { - return mcp.NewToolResultText("Fork is in progress"), nil - } - return nil, fmt.Errorf("failed to fork repository: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + forkedRepo, resp, err := client.Repositories.CreateFork(ctx, owner, repo, opts) + if err != nil { + // Check if it's an acceptedError. An acceptedError indicates that the update is in progress, + // and it's not a real error. + if resp != nil && resp.StatusCode == http.StatusAccepted && isAcceptedError(err) { + return mcp.NewToolResultText("Fork is in progress"), nil + } + return nil, fmt.Errorf("failed to fork repository: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusAccepted { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to fork repository: %s", string(body))), nil + } - if resp.StatusCode != http.StatusAccepted { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(forkedRepo) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to fork repository: %s", string(body))), nil - } - r, err := json.Marshal(forkedRepo) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: Write, + Category: CategoryRepositories, + } } // CreateBranch creates a tool to create a new branch. -func CreateBranch(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("create_branch", +func CreateBranch(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "create_branch", mcp.WithDescription(t("TOOL_CREATE_BRANCH_DESCRIPTION", "Create a new branch in a GitHub repository")), mcp.WithString("owner", mcp.Required(), @@ -485,74 +529,80 @@ func CreateBranch(getClient GetClientFn, t translations.TranslationHelperFunc) ( mcp.Description("Source branch (defaults to repo default)"), ), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - branch, err := requiredParam[string](request, "branch") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - fromBranch, err := OptionalParam[string](request, "from_branch") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + branch, err := requiredParam[string](request, "branch") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + fromBranch, err := OptionalParam[string](request, "from_branch") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + // Get the source branch SHA + var ref *github.Reference + + if fromBranch == "" { + // Get default branch if from_branch not specified + repository, resp, err := client.Repositories.Get(ctx, owner, repo) + if err != nil { + return nil, fmt.Errorf("failed to get repository: %w", err) + } + defer func() { _ = resp.Body.Close() }() - // Get the source branch SHA - var ref *github.Reference + fromBranch = *repository.DefaultBranch + } - if fromBranch == "" { - // Get default branch if from_branch not specified - repository, resp, err := client.Repositories.Get(ctx, owner, repo) + // Get SHA of source branch + ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+fromBranch) if err != nil { - return nil, fmt.Errorf("failed to get repository: %w", err) + return nil, fmt.Errorf("failed to get reference: %w", err) } defer func() { _ = resp.Body.Close() }() - fromBranch = *repository.DefaultBranch - } - - // Get SHA of source branch - ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+fromBranch) - if err != nil { - return nil, fmt.Errorf("failed to get reference: %w", err) - } - defer func() { _ = resp.Body.Close() }() + // Create new branch + newRef := &github.Reference{ + Ref: github.Ptr("refs/heads/" + branch), + Object: &github.GitObject{SHA: ref.Object.SHA}, + } - // Create new branch - newRef := &github.Reference{ - Ref: github.Ptr("refs/heads/" + branch), - Object: &github.GitObject{SHA: ref.Object.SHA}, - } + createdRef, resp, err := client.Git.CreateRef(ctx, owner, repo, newRef) + if err != nil { + return nil, fmt.Errorf("failed to create branch: %w", err) + } + defer func() { _ = resp.Body.Close() }() - createdRef, resp, err := client.Git.CreateRef(ctx, owner, repo, newRef) - if err != nil { - return nil, fmt.Errorf("failed to create branch: %w", err) - } - defer func() { _ = resp.Body.Close() }() + r, err := json.Marshal(createdRef) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } - r, err := json.Marshal(createdRef) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: Write, + Category: CategoryRepositories, + } } // PushFiles creates a tool to push multiple files in a single commit to a GitHub repository. -func PushFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("push_files", +func PushFiles(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool("push_files", mcp.WithDescription(t("TOOL_PUSH_FILES_DESCRIPTION", "Push multiple files to a GitHub repository in a single commit")), mcp.WithString("owner", mcp.Required(), @@ -591,109 +641,114 @@ func PushFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (too mcp.Description("Commit message"), ), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - branch, err := requiredParam[string](request, "branch") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - message, err := requiredParam[string](request, "message") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - - // Parse files parameter - this should be an array of objects with path and content - filesObj, ok := request.Params.Arguments["files"].([]interface{}) - if !ok { - return mcp.NewToolResultError("files parameter must be an array of objects with path and content"), nil - } - - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - // Get the reference for the branch - ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+branch) - if err != nil { - return nil, fmt.Errorf("failed to get branch reference: %w", err) - } - defer func() { _ = resp.Body.Close() }() + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + branch, err := requiredParam[string](request, "branch") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + message, err := requiredParam[string](request, "message") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - // Get the commit object that the branch points to - baseCommit, resp, err := client.Git.GetCommit(ctx, owner, repo, *ref.Object.SHA) - if err != nil { - return nil, fmt.Errorf("failed to get base commit: %w", err) - } - defer func() { _ = resp.Body.Close() }() + // Parse files parameter - this should be an array of objects with path and content + filesObj, ok := request.Params.Arguments["files"].([]interface{}) + if !ok { + return mcp.NewToolResultError("files parameter must be an array of objects with path and content"), nil + } - // Create tree entries for all files - var entries []*github.TreeEntry + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - for _, file := range filesObj { - fileMap, ok := file.(map[string]interface{}) - if !ok { - return mcp.NewToolResultError("each file must be an object with path and content"), nil + // Get the reference for the branch + ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+branch) + if err != nil { + return nil, fmt.Errorf("failed to get branch reference: %w", err) } + defer func() { _ = resp.Body.Close() }() - path, ok := fileMap["path"].(string) - if !ok || path == "" { - return mcp.NewToolResultError("each file must have a path"), nil + // Get the commit object that the branch points to + baseCommit, resp, err := client.Git.GetCommit(ctx, owner, repo, *ref.Object.SHA) + if err != nil { + return nil, fmt.Errorf("failed to get base commit: %w", err) } + defer func() { _ = resp.Body.Close() }() - content, ok := fileMap["content"].(string) - if !ok { - return mcp.NewToolResultError("each file must have content"), nil + // Create tree entries for all files + var entries []*github.TreeEntry + + for _, file := range filesObj { + fileMap, ok := file.(map[string]interface{}) + if !ok { + return mcp.NewToolResultError("each file must be an object with path and content"), nil + } + + path, ok := fileMap["path"].(string) + if !ok || path == "" { + return mcp.NewToolResultError("each file must have a path"), nil + } + + content, ok := fileMap["content"].(string) + if !ok { + return mcp.NewToolResultError("each file must have content"), nil + } + + // Create a tree entry for the file + entries = append(entries, &github.TreeEntry{ + Path: github.Ptr(path), + Mode: github.Ptr("100644"), // Regular file mode + Type: github.Ptr("blob"), + Content: github.Ptr(content), + }) } - // Create a tree entry for the file - entries = append(entries, &github.TreeEntry{ - Path: github.Ptr(path), - Mode: github.Ptr("100644"), // Regular file mode - Type: github.Ptr("blob"), - Content: github.Ptr(content), - }) - } + // Create a new tree with the file entries + newTree, resp, err := client.Git.CreateTree(ctx, owner, repo, *baseCommit.Tree.SHA, entries) + if err != nil { + return nil, fmt.Errorf("failed to create tree: %w", err) + } + defer func() { _ = resp.Body.Close() }() - // Create a new tree with the file entries - newTree, resp, err := client.Git.CreateTree(ctx, owner, repo, *baseCommit.Tree.SHA, entries) - if err != nil { - return nil, fmt.Errorf("failed to create tree: %w", err) - } - defer func() { _ = resp.Body.Close() }() + // Create a new commit + commit := &github.Commit{ + Message: github.Ptr(message), + Tree: newTree, + Parents: []*github.Commit{{SHA: baseCommit.SHA}}, + } + newCommit, resp, err := client.Git.CreateCommit(ctx, owner, repo, commit, nil) + if err != nil { + return nil, fmt.Errorf("failed to create commit: %w", err) + } + defer func() { _ = resp.Body.Close() }() - // Create a new commit - commit := &github.Commit{ - Message: github.Ptr(message), - Tree: newTree, - Parents: []*github.Commit{{SHA: baseCommit.SHA}}, - } - newCommit, resp, err := client.Git.CreateCommit(ctx, owner, repo, commit, nil) - if err != nil { - return nil, fmt.Errorf("failed to create commit: %w", err) - } - defer func() { _ = resp.Body.Close() }() + // Update the reference to point to the new commit + ref.Object.SHA = newCommit.SHA + updatedRef, resp, err := client.Git.UpdateRef(ctx, owner, repo, ref, false) + if err != nil { + return nil, fmt.Errorf("failed to update reference: %w", err) + } + defer func() { _ = resp.Body.Close() }() - // Update the reference to point to the new commit - ref.Object.SHA = newCommit.SHA - updatedRef, resp, err := client.Git.UpdateRef(ctx, owner, repo, ref, false) - if err != nil { - return nil, fmt.Errorf("failed to update reference: %w", err) - } - defer func() { _ = resp.Body.Close() }() + r, err := json.Marshal(updatedRef) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } - r, err := json.Marshal(updatedRef) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: Write, + Category: CategoryRepositories, + } } diff --git a/pkg/github/repositories_test.go b/pkg/github/repositories_test.go index 20f96dde5..cad1ff6f5 100644 --- a/pkg/github/repositories_test.go +++ b/pkg/github/repositories_test.go @@ -17,16 +17,15 @@ import ( func Test_GetFileContents(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetFileContents(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool := GetFileContents(translations.NullTranslationHelper) - assert.Equal(t, "get_file_contents", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "path") - assert.Contains(t, tool.InputSchema.Properties, "branch") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "path"}) + assert.Equal(t, "get_file_contents", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "path") + assert.Contains(t, tool.Definition.InputSchema.Properties, "branch") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo", "path"}) // Setup mock file content for success case mockFileContent := &github.RepositoryContent{ @@ -132,7 +131,7 @@ func Test_GetFileContents(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetFileContents(stubGetClientFn(client), translations.NullTranslationHelper) + tool := GetFileContents(translations.NullTranslationHelper) // Create call request request := mcp.CallToolRequest{ @@ -148,7 +147,7 @@ func Test_GetFileContents(t *testing.T) { } // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -188,15 +187,14 @@ func Test_GetFileContents(t *testing.T) { func Test_ForkRepository(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ForkRepository(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool := ForkRepository(translations.NullTranslationHelper) - assert.Equal(t, "fork_repository", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "organization") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"}) + assert.Equal(t, "fork_repository", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "organization") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo"}) // Setup mock forked repo for success case mockForkedRepo := &github.Repository{ @@ -259,13 +257,13 @@ func Test_ForkRepository(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ForkRepository(stubGetClientFn(client), translations.NullTranslationHelper) + tool := ForkRepository(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -286,16 +284,15 @@ func Test_ForkRepository(t *testing.T) { func Test_CreateBranch(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := CreateBranch(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool := CreateBranch(translations.NullTranslationHelper) - assert.Equal(t, "create_branch", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "branch") - assert.Contains(t, tool.InputSchema.Properties, "from_branch") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "branch"}) + assert.Equal(t, "create_branch", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "branch") + assert.Contains(t, tool.Definition.InputSchema.Properties, "from_branch") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo", "branch"}) // Setup mock repository for default branch test mockRepo := &github.Repository{ @@ -445,13 +442,13 @@ func Test_CreateBranch(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := CreateBranch(stubGetClientFn(client), translations.NullTranslationHelper) + tool := CreateBranch(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -477,15 +474,14 @@ func Test_CreateBranch(t *testing.T) { func Test_GetCommit(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetCommit(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool := GetCommit(translations.NullTranslationHelper) - assert.Equal(t, "get_commit", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "sha") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "sha"}) + assert.Equal(t, "get_commit", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "sha") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo", "sha"}) mockCommit := &github.RepositoryCommit{ SHA: github.Ptr("abc123def456"), @@ -572,13 +568,13 @@ func Test_GetCommit(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetCommit(stubGetClientFn(client), translations.NullTranslationHelper) + tool := GetCommit(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -607,17 +603,16 @@ func Test_GetCommit(t *testing.T) { func Test_ListCommits(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListCommits(stubGetClientFn(mockClient), translations.NullTranslationHelper) - - assert.Equal(t, "list_commits", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "sha") - assert.Contains(t, tool.InputSchema.Properties, "page") - assert.Contains(t, tool.InputSchema.Properties, "perPage") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"}) + tool := ListCommits(translations.NullTranslationHelper) + + assert.Equal(t, "list_commits", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "sha") + assert.Contains(t, tool.Definition.InputSchema.Properties, "page") + assert.Contains(t, tool.Definition.InputSchema.Properties, "perPage") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo"}) // Setup mock commits for success case mockCommits := []*github.RepositoryCommit{ @@ -744,13 +739,13 @@ func Test_ListCommits(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListCommits(stubGetClientFn(client), translations.NullTranslationHelper) + tool := ListCommits(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -781,19 +776,18 @@ func Test_ListCommits(t *testing.T) { func Test_CreateOrUpdateFile(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := CreateOrUpdateFile(stubGetClientFn(mockClient), translations.NullTranslationHelper) - - assert.Equal(t, "create_or_update_file", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "path") - assert.Contains(t, tool.InputSchema.Properties, "content") - assert.Contains(t, tool.InputSchema.Properties, "message") - assert.Contains(t, tool.InputSchema.Properties, "branch") - assert.Contains(t, tool.InputSchema.Properties, "sha") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "path", "content", "message", "branch"}) + tool := CreateOrUpdateFile(translations.NullTranslationHelper) + + assert.Equal(t, "create_or_update_file", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "path") + assert.Contains(t, tool.Definition.InputSchema.Properties, "content") + assert.Contains(t, tool.Definition.InputSchema.Properties, "message") + assert.Contains(t, tool.Definition.InputSchema.Properties, "branch") + assert.Contains(t, tool.Definition.InputSchema.Properties, "sha") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo", "path", "content", "message", "branch"}) // Setup mock file content response mockFileResponse := &github.RepositoryContentResponse{ @@ -905,13 +899,13 @@ func Test_CreateOrUpdateFile(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := CreateOrUpdateFile(stubGetClientFn(client), translations.NullTranslationHelper) + tool := CreateOrUpdateFile(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -944,16 +938,15 @@ func Test_CreateOrUpdateFile(t *testing.T) { func Test_CreateRepository(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := CreateRepository(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool := CreateRepository(translations.NullTranslationHelper) - assert.Equal(t, "create_repository", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "name") - assert.Contains(t, tool.InputSchema.Properties, "description") - assert.Contains(t, tool.InputSchema.Properties, "private") - assert.Contains(t, tool.InputSchema.Properties, "autoInit") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"name"}) + assert.Equal(t, "create_repository", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "name") + assert.Contains(t, tool.Definition.InputSchema.Properties, "description") + assert.Contains(t, tool.Definition.InputSchema.Properties, "private") + assert.Contains(t, tool.Definition.InputSchema.Properties, "autoInit") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"name"}) // Setup mock repository response mockRepo := &github.Repository{ @@ -1053,13 +1046,13 @@ func Test_CreateRepository(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := CreateRepository(stubGetClientFn(client), translations.NullTranslationHelper) + tool := CreateRepository(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -1090,17 +1083,16 @@ func Test_CreateRepository(t *testing.T) { func Test_PushFiles(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := PushFiles(stubGetClientFn(mockClient), translations.NullTranslationHelper) - - assert.Equal(t, "push_files", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "branch") - assert.Contains(t, tool.InputSchema.Properties, "files") - assert.Contains(t, tool.InputSchema.Properties, "message") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "branch", "files", "message"}) + tool := PushFiles(translations.NullTranslationHelper) + + assert.Equal(t, "push_files", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "owner") + assert.Contains(t, tool.Definition.InputSchema.Properties, "repo") + assert.Contains(t, tool.Definition.InputSchema.Properties, "branch") + assert.Contains(t, tool.Definition.InputSchema.Properties, "files") + assert.Contains(t, tool.Definition.InputSchema.Properties, "message") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"owner", "repo", "branch", "files", "message"}) // Setup mock objects mockRef := &github.Reference{ @@ -1386,13 +1378,13 @@ func Test_PushFiles(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PushFiles(stubGetClientFn(client), translations.NullTranslationHelper) + tool := PushFiles(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { diff --git a/pkg/github/search.go b/pkg/github/search.go index 75810e245..3c7645851 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -13,8 +13,10 @@ import ( ) // SearchRepositories creates a tool to search for GitHub repositories. -func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("search_repositories", +func SearchRepositories(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "search_repositories", mcp.WithDescription(t("TOOL_SEARCH_REPOSITORIES_DESCRIPTION", "Search for GitHub repositories")), mcp.WithString("query", mcp.Required(), @@ -22,53 +24,60 @@ func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperF ), WithPagination(), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query, err := requiredParam[string](request, "query") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - pagination, err := OptionalPaginationParams(request) - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + query, err := requiredParam[string](request, "query") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pagination, err := OptionalPaginationParams(request) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - opts := &github.SearchOptions{ - ListOptions: github.ListOptions{ - Page: pagination.page, - PerPage: pagination.perPage, - }, - } + opts := &github.SearchOptions{ + ListOptions: github.ListOptions{ + Page: pagination.page, + PerPage: pagination.perPage, + }, + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - result, resp, err := client.Search.Repositories(ctx, query, opts) - if err != nil { - return nil, fmt.Errorf("failed to search repositories: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + result, resp, err := client.Search.Repositories(ctx, query, opts) + if err != nil { + return nil, fmt.Errorf("failed to search repositories: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != 200 { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to search repositories: %s", string(body))), nil + } - if resp.StatusCode != 200 { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(result) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to search repositories: %s", string(body))), nil - } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: ReadOnly, + Category: CategorySearch, + } } // SearchCode creates a tool to search for code across GitHub repositories. -func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("search_code", +func SearchCode(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "search_code", mcp.WithDescription(t("TOOL_SEARCH_CODE_DESCRIPTION", "Search for code across GitHub repositories")), mcp.WithString("q", mcp.Required(), @@ -83,64 +92,71 @@ func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (to ), WithPagination(), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query, err := requiredParam[string](request, "q") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - sort, err := OptionalParam[string](request, "sort") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - order, err := OptionalParam[string](request, "order") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - pagination, err := OptionalPaginationParams(request) - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + query, err := requiredParam[string](request, "q") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + sort, err := OptionalParam[string](request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + order, err := OptionalParam[string](request, "order") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pagination, err := OptionalPaginationParams(request) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - opts := &github.SearchOptions{ - Sort: sort, - Order: order, - ListOptions: github.ListOptions{ - PerPage: pagination.perPage, - Page: pagination.page, - }, - } + opts := &github.SearchOptions{ + Sort: sort, + Order: order, + ListOptions: github.ListOptions{ + PerPage: pagination.perPage, + Page: pagination.page, + }, + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - result, resp, err := client.Search.Code(ctx, query, opts) - if err != nil { - return nil, fmt.Errorf("failed to search code: %w", err) - } - defer func() { _ = resp.Body.Close() }() + result, resp, err := client.Search.Code(ctx, query, opts) + if err != nil { + return nil, fmt.Errorf("failed to search code: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != 200 { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to search code: %s", string(body))), nil + } - if resp.StatusCode != 200 { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(result) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to search code: %s", string(body))), nil - } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: ReadOnly, + Category: CategorySearch, + } } // SearchUsers creates a tool to search for GitHub users. -func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("search_users", +func SearchUsers(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool( + "search_users", mcp.WithDescription(t("TOOL_SEARCH_USERS_DESCRIPTION", "Search for GitHub users")), mcp.WithString("q", mcp.Required(), @@ -156,57 +172,62 @@ func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (t ), WithPagination(), ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query, err := requiredParam[string](request, "q") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - sort, err := OptionalParam[string](request, "sort") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - order, err := OptionalParam[string](request, "order") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - pagination, err := OptionalPaginationParams(request) - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + query, err := requiredParam[string](request, "q") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + sort, err := OptionalParam[string](request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + order, err := OptionalParam[string](request, "order") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pagination, err := OptionalPaginationParams(request) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - opts := &github.SearchOptions{ - Sort: sort, - Order: order, - ListOptions: github.ListOptions{ - PerPage: pagination.perPage, - Page: pagination.page, - }, - } + opts := &github.SearchOptions{ + Sort: sort, + Order: order, + ListOptions: github.ListOptions{ + PerPage: pagination.perPage, + Page: pagination.page, + }, + } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - result, resp, err := client.Search.Users(ctx, query, opts) - if err != nil { - return nil, fmt.Errorf("failed to search users: %w", err) - } - defer func() { _ = resp.Body.Close() }() + result, resp, err := client.Search.Users(ctx, query, opts) + if err != nil { + return nil, fmt.Errorf("failed to search users: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != 200 { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to search users: %s", string(body))), nil + } - if resp.StatusCode != 200 { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(result) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal response: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to search users: %s", string(body))), nil - } - r, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: ReadOnly, + Category: CategorySearch, + } } diff --git a/pkg/github/search_test.go b/pkg/github/search_test.go index b61518e47..5057d2ac1 100644 --- a/pkg/github/search_test.go +++ b/pkg/github/search_test.go @@ -15,15 +15,14 @@ import ( func Test_SearchRepositories(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SearchRepositories(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool := SearchRepositories(translations.NullTranslationHelper) - assert.Equal(t, "search_repositories", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "query") - assert.Contains(t, tool.InputSchema.Properties, "page") - assert.Contains(t, tool.InputSchema.Properties, "perPage") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"query"}) + assert.Equal(t, "search_repositories", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "query") + assert.Contains(t, tool.Definition.InputSchema.Properties, "page") + assert.Contains(t, tool.Definition.InputSchema.Properties, "perPage") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"query"}) // Setup mock search results mockSearchResult := &github.RepositoriesSearchResult{ @@ -122,13 +121,13 @@ func Test_SearchRepositories(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchRepositories(stubGetClientFn(client), translations.NullTranslationHelper) + tool := SearchRepositories(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -162,17 +161,16 @@ func Test_SearchRepositories(t *testing.T) { func Test_SearchCode(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SearchCode(stubGetClientFn(mockClient), translations.NullTranslationHelper) - - assert.Equal(t, "search_code", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "q") - assert.Contains(t, tool.InputSchema.Properties, "sort") - assert.Contains(t, tool.InputSchema.Properties, "order") - assert.Contains(t, tool.InputSchema.Properties, "perPage") - assert.Contains(t, tool.InputSchema.Properties, "page") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"q"}) + tool := SearchCode(translations.NullTranslationHelper) + + assert.Equal(t, "search_code", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "q") + assert.Contains(t, tool.Definition.InputSchema.Properties, "sort") + assert.Contains(t, tool.Definition.InputSchema.Properties, "order") + assert.Contains(t, tool.Definition.InputSchema.Properties, "perPage") + assert.Contains(t, tool.Definition.InputSchema.Properties, "page") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"q"}) // Setup mock search results mockSearchResult := &github.CodeSearchResult{ @@ -273,13 +271,13 @@ func Test_SearchCode(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchCode(stubGetClientFn(client), translations.NullTranslationHelper) + tool := SearchCode(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { @@ -313,17 +311,16 @@ func Test_SearchCode(t *testing.T) { func Test_SearchUsers(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SearchUsers(stubGetClientFn(mockClient), translations.NullTranslationHelper) - - assert.Equal(t, "search_users", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "q") - assert.Contains(t, tool.InputSchema.Properties, "sort") - assert.Contains(t, tool.InputSchema.Properties, "order") - assert.Contains(t, tool.InputSchema.Properties, "perPage") - assert.Contains(t, tool.InputSchema.Properties, "page") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"q"}) + tool := SearchUsers(translations.NullTranslationHelper) + + assert.Equal(t, "search_users", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "q") + assert.Contains(t, tool.Definition.InputSchema.Properties, "sort") + assert.Contains(t, tool.Definition.InputSchema.Properties, "order") + assert.Contains(t, tool.Definition.InputSchema.Properties, "perPage") + assert.Contains(t, tool.Definition.InputSchema.Properties, "page") + assert.ElementsMatch(t, tool.Definition.InputSchema.Required, []string{"q"}) // Setup mock search results mockSearchResult := &github.UsersSearchResult{ @@ -428,13 +425,13 @@ func Test_SearchUsers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchUsers(stubGetClientFn(client), translations.NullTranslationHelper) + tool := SearchUsers(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { diff --git a/pkg/github/server.go b/pkg/github/server.go index 490a81051..47f6b91a1 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -16,6 +16,101 @@ import ( type GetClientFn func(context.Context) (*github.Client, error) +type Access int + +const ( + // Zero value is writer, that way if forgotten, it won't be included + // in read only configuration. + Write Access = iota + ReadOnly +) + +type Handler func(getClient GetClientFn) server.ToolHandlerFunc + +type Tool struct { + Definition mcp.Tool + Handler Handler + Access Access + Category Category +} + +type Category string + +const ( + // CategoryUsers is the category for user-related tools. + CategoryUsers Category = "Users" + // CategoryIssues is the category for issue-related tools. + CategoryIssues Category = "Issues" + // CategoryPullRequests is the category for pull request-related tools. + CategoryPullRequests Category = "Pull Requests" + // CategoryRepositories is the category for repository-related tools. + CategoryRepositories Category = "Repositories" + // CategorySearch is the category for search-related tools. + CategorySearch Category = "Search" + // CategoryCodeScanning is the category for code scanning-related tools. + CategoryCodeScanning Category = "Code Scanning" +) + +type Tools []Tool + +func (t Tools) ReadOnly() []Tool { + var readOnlyTools []Tool + for _, tool := range t { + if tool.Access == ReadOnly { + readOnlyTools = append(readOnlyTools, tool) + } + } + return readOnlyTools +} + +func DefaultTools(t translations.TranslationHelperFunc) Tools { + return []Tool{ + // Users + GetMe(t), + + // Issues + GetIssue(t), + SearchIssues(t), + ListIssues(t), + GetIssueComments(t), + CreateIssue(t), + AddIssueComment(t), + UpdateIssue(t), + + // Pull Requests + GetPullRequest(t), + ListPullRequests(t), + GetPullRequestFiles(t), + GetPullRequestStatus(t), + GetPullRequestComments(t), + GetPullRequestReviews(t), + MergePullRequest(t), + UpdatePullRequestBranch(t), + CreatePullRequestReview(t), + CreatePullRequest(t), + UpdatePullRequest(t), + + // Repositories + SearchRepositories(t), + GetFileContents(t), + GetCommit(t), + ListCommits(t), + CreateOrUpdateFile(t), + CreateRepository(t), + ForkRepository(t), + CreateBranch(t), + PushFiles(t), + + // Search + SearchCode(t), + SearchUsers(t), + + // Code Scanning + GetCodeScanningAlert(t), + ListCodeScanningAlerts(t), + } +} + // NewServer creates a new GitHub MCP server with the specified GH client and logger. func NewServer(getClient GetClientFn, version string, readOnly bool, t translations.TranslationHelperFunc, opts ...server.ServerOption) *server.MCPServer { // Add default options @@ -32,99 +127,66 @@ func NewServer(getClient GetClientFn, version string, readOnly bool, t translati opts..., ) - // Add GitHub Resources + // // Add GitHub Resources s.AddResourceTemplate(GetRepositoryResourceContent(getClient, t)) s.AddResourceTemplate(GetRepositoryResourceBranchContent(getClient, t)) s.AddResourceTemplate(GetRepositoryResourceCommitContent(getClient, t)) s.AddResourceTemplate(GetRepositoryResourceTagContent(getClient, t)) s.AddResourceTemplate(GetRepositoryResourcePrContent(getClient, t)) - // Add GitHub tools - Issues - s.AddTool(GetIssue(getClient, t)) - s.AddTool(SearchIssues(getClient, t)) - s.AddTool(ListIssues(getClient, t)) - s.AddTool(GetIssueComments(getClient, t)) - if !readOnly { - s.AddTool(CreateIssue(getClient, t)) - s.AddTool(AddIssueComment(getClient, t)) - s.AddTool(UpdateIssue(getClient, t)) + // Add GitHub Tools + tools := DefaultTools(t) + if readOnly { + tools = tools.ReadOnly() } - // Add GitHub tools - Pull Requests - s.AddTool(GetPullRequest(getClient, t)) - s.AddTool(ListPullRequests(getClient, t)) - s.AddTool(GetPullRequestFiles(getClient, t)) - s.AddTool(GetPullRequestStatus(getClient, t)) - s.AddTool(GetPullRequestComments(getClient, t)) - s.AddTool(GetPullRequestReviews(getClient, t)) - if !readOnly { - s.AddTool(MergePullRequest(getClient, t)) - s.AddTool(UpdatePullRequestBranch(getClient, t)) - s.AddTool(CreatePullRequestReview(getClient, t)) - s.AddTool(CreatePullRequest(getClient, t)) - s.AddTool(UpdatePullRequest(getClient, t)) + for _, tool := range tools { + s.AddTool(tool.Definition, tool.Handler(getClient)) } - // Add GitHub tools - Repositories - s.AddTool(SearchRepositories(getClient, t)) - s.AddTool(GetFileContents(getClient, t)) - s.AddTool(GetCommit(getClient, t)) - s.AddTool(ListCommits(getClient, t)) - if !readOnly { - s.AddTool(CreateOrUpdateFile(getClient, t)) - s.AddTool(CreateRepository(getClient, t)) - s.AddTool(ForkRepository(getClient, t)) - s.AddTool(CreateBranch(getClient, t)) - s.AddTool(PushFiles(getClient, t)) - } - - // Add GitHub tools - Search - s.AddTool(SearchCode(getClient, t)) - s.AddTool(SearchUsers(getClient, t)) - - // Add GitHub tools - Users - s.AddTool(GetMe(getClient, t)) - - // Add GitHub tools - Code Scanning - s.AddTool(GetCodeScanningAlert(getClient, t)) - s.AddTool(ListCodeScanningAlerts(getClient, t)) return s } // GetMe creates a tool to get details of the authenticated user. -func GetMe(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("get_me", +func GetMe(t translations.TranslationHelperFunc) Tool { + return Tool{ + Definition: mcp.NewTool("get_me", mcp.WithDescription(t("TOOL_GET_ME_DESCRIPTION", "Get details of the authenticated GitHub user. Use this when a request include \"me\", \"my\"...")), mcp.WithString("reason", mcp.Description("Optional: reason the session was created"), ), ), - func(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) { - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - user, resp, err := client.Users.Get(ctx, "") - if err != nil { - return nil, fmt.Errorf("failed to get user: %w", err) - } - defer func() { _ = resp.Body.Close() }() + Handler: func(getClient GetClientFn) server.ToolHandlerFunc { + return func(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + user, resp, err := client.Users.Get(ctx, "") + if err != nil { + return nil, fmt.Errorf("failed to get user: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get user: %s", string(body))), nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(user) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to marshal user: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to get user: %s", string(body))), nil - } - r, err := json.Marshal(user) - if err != nil { - return nil, fmt.Errorf("failed to marshal user: %w", err) + return mcp.NewToolResultText(string(r)), nil } - - return mcp.NewToolResultText(string(r)), nil - } + }, + Access: ReadOnly, + Category: CategoryUsers, + } } // OptionalParamOK is a helper function that can be used to fetch a requested parameter from the request. diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index 3ee9851af..070e8df48 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -23,13 +23,13 @@ func stubGetClientFn(client *github.Client) GetClientFn { func Test_GetMe(t *testing.T) { // Verify tool definition - mockClient := github.NewClient(nil) - tool, _ := GetMe(stubGetClientFn(mockClient), translations.NullTranslationHelper) - assert.Equal(t, "get_me", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "reason") - assert.Empty(t, tool.InputSchema.Required) // No required parameters + tool := GetMe(translations.NullTranslationHelper) + + assert.Equal(t, "get_me", tool.Definition.Name) + assert.NotEmpty(t, tool.Definition.Description) + assert.Contains(t, tool.Definition.InputSchema.Properties, "reason") + assert.Empty(t, tool.Definition.InputSchema.Required) // No required parameters // Setup mock user response mockUser := &github.User{ @@ -102,13 +102,13 @@ func Test_GetMe(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetMe(stubGetClientFn(client), translations.NullTranslationHelper) + tool := GetMe(translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, err := handler(context.Background(), request) + result, err := tool.Handler(stubGetClientFn(client))(context.Background(), request) // Verify results if tc.expectError { diff --git a/schema.md b/schema.md new file mode 100644 index 000000000..0f1fb55e6 --- /dev/null +++ b/schema.md @@ -0,0 +1,235 @@ +## Tools + +### Users + +- **get_me** - Get details of the authenticated GitHub user. Use this when a request include "me", "my"... + - `reason`: Optional: reason the session was created (string, optional) + +### Issues + +- **get_issue** - Get details of a specific issue in a GitHub repository + - `issue_number`: The number of the issue (number, required) + - `owner`: The owner of the repository (string, required) + - `repo`: The name of the repository (string, required) + +- **search_issues** - Search for issues and pull requests across GitHub repositories + - `order`: Sort order ('asc' or 'desc') (string, optional) + - `page`: Page number for pagination (min 1) (number, optional) + - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) + - `q`: Search query using GitHub issues search syntax (string, required) + - `sort`: Sort field (comments, reactions, created, etc.) (string, optional) + +- **list_issues** - List issues in a GitHub repository with filtering options + - `direction`: Sort direction ('asc', 'desc') (string, optional) + - `labels`: Filter by labels (array, optional) + - `owner`: Repository owner (string, required) + - `page`: Page number for pagination (min 1) (number, optional) + - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) + - `repo`: Repository name (string, required) + - `since`: Filter by date (ISO 8601 timestamp) (string, optional) + - `sort`: Sort by ('created', 'updated', 'comments') (string, optional) + - `state`: Filter by state ('open', 'closed', 'all') (string, optional) + +- **get_issue_comments** - Get comments for a GitHub issue + - `issue_number`: Issue number (number, required) + - `owner`: Repository owner (string, required) + - `page`: Page number (number, optional) + - `per_page`: Number of records per page (number, optional) + - `repo`: Repository name (string, required) + +- **create_issue** - Create a new issue in a GitHub repository + - `assignees`: Usernames to assign to this issue (array, optional) + - `body`: Issue body content (string, optional) + - `labels`: Labels to apply to this issue (array, optional) + - `milestone`: Milestone number (number, optional) + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + - `title`: Issue title (string, required) + +- **add_issue_comment** - Add a comment to an existing issue + - `body`: Comment text (string, required) + - `issue_number`: Issue number to comment on (number, required) + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + +- **update_issue** - Update an existing issue in a GitHub repository + - `assignees`: New assignees (array, optional) + - `body`: New description (string, optional) + - `issue_number`: Issue number to update (number, required) + - `labels`: New labels (array, optional) + - `milestone`: New milestone number (number, optional) + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + - `state`: New state ('open' or 'closed') (string, optional) + - `title`: New title (string, optional) + +### Pull Requests + +- **get_pull_request** - Get details of a specific pull request + - `owner`: Repository owner (string, required) + - `pullNumber`: Pull request number (number, required) + - `repo`: Repository name (string, required) + +- **list_pull_requests** - List and filter repository pull requests + - `base`: Filter by base branch (string, optional) + - `direction`: Sort direction ('asc', 'desc') (string, optional) + - `head`: Filter by head user/org and branch (string, optional) + - `owner`: Repository owner (string, required) + - `page`: Page number for pagination (min 1) (number, optional) + - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) + - `repo`: Repository name (string, required) + - `sort`: Sort by ('created', 'updated', 'popularity', 'long-running') (string, optional) + - `state`: Filter by state ('open', 'closed', 'all') (string, optional) + +- **get_pull_request_files** - Get the list of files changed in a pull request + - `owner`: Repository owner (string, required) + - `pullNumber`: Pull request number (number, required) + - `repo`: Repository name (string, required) + +- **get_pull_request_status** - Get the combined status of all status checks for a pull request + - `owner`: Repository owner (string, required) + - `pullNumber`: Pull request number (number, required) + - `repo`: Repository name (string, required) + +- **get_pull_request_comments** - Get the review comments on a pull request + - `owner`: Repository owner (string, required) + - `pullNumber`: Pull request number (number, required) + - `repo`: Repository name (string, required) + +- **get_pull_request_reviews** - Get the reviews on a pull request + - `owner`: Repository owner (string, required) + - `pullNumber`: Pull request number (number, required) + - `repo`: Repository name (string, required) + +- **merge_pull_request** - Merge a pull request + - `commit_message`: Extra detail for merge commit (string, optional) + - `commit_title`: Title for merge commit (string, optional) + - `merge_method`: Merge method ('merge', 'squash', 'rebase') (string, optional) + - `owner`: Repository owner (string, required) + - `pullNumber`: Pull request number (number, required) + - `repo`: Repository name (string, required) + +- **update_pull_request_branch** - Update a pull request branch with the latest changes from the base branch + - `expectedHeadSha`: The expected SHA of the pull request's HEAD ref (string, optional) + - `owner`: Repository owner (string, required) + - `pullNumber`: Pull request number (number, required) + - `repo`: Repository name (string, required) + +- **create_pull_request_review** - Create a review on a pull request + - `body`: Review comment text (string, optional) + - `comments`: Line-specific comments array of objects to place comments on pull request changes. Requires path and body. For line comments use line or position. For multi-line comments use start_line and line with optional side parameters. (array, optional) + - `commitId`: SHA of commit to review (string, optional) + - `event`: Review action ('APPROVE', 'REQUEST_CHANGES', 'COMMENT') (string, required) + - `owner`: Repository owner (string, required) + - `pullNumber`: Pull request number (number, required) + - `repo`: Repository name (string, required) + +- **create_pull_request** - Create a new pull request in a GitHub repository + - `base`: Branch to merge into (string, required) + - `body`: PR description (string, optional) + - `draft`: Create as draft PR (boolean, optional) + - `head`: Branch containing changes (string, required) + - `maintainer_can_modify`: Allow maintainer edits (boolean, optional) + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + - `title`: PR title (string, required) + +- **update_pull_request** - Update an existing pull request in a GitHub repository + - `base`: New base branch name (string, optional) + - `body`: New description (string, optional) + - `maintainer_can_modify`: Allow maintainer edits (boolean, optional) + - `owner`: Repository owner (string, required) + - `pullNumber`: Pull request number to update (number, required) + - `repo`: Repository name (string, required) + - `state`: New state ('open' or 'closed') (string, optional) + - `title`: New title (string, optional) + +### Repositories + +- **get_file_contents** - Get the contents of a file or directory from a GitHub repository + - `branch`: Branch to get contents from (string, optional) + - `owner`: Repository owner (username or organization) (string, required) + - `path`: Path to file/directory (string, required) + - `repo`: Repository name (string, required) + +- **get_commit** - Get details for a commit from a GitHub repository + - `owner`: Repository owner (string, required) + - `page`: Page number for pagination (min 1) (number, optional) + - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) + - `repo`: Repository name (string, required) + - `sha`: Commit SHA, branch name, or tag name (string, required) + +- **list_commits** - Get list of commits of a branch in a GitHub repository + - `owner`: Repository owner (string, required) + - `page`: Page number for pagination (min 1) (number, optional) + - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) + - `repo`: Repository name (string, required) + - `sha`: Branch name (string, optional) + +- **create_or_update_file** - Create or update a single file in a GitHub repository + - `branch`: Branch to create/update the file in (string, required) + - `content`: Content of the file (string, required) + - `message`: Commit message (string, required) + - `owner`: Repository owner (username or organization) (string, required) + - `path`: Path where to create/update the file (string, required) + - `repo`: Repository name (string, required) + - `sha`: SHA of file being replaced (for updates) (string, optional) + +- **create_repository** - Create a new GitHub repository in your account + - `autoInit`: Initialize with README (boolean, optional) + - `description`: Repository description (string, optional) + - `name`: Repository name (string, required) + - `private`: Whether repo should be private (boolean, optional) + +- **fork_repository** - Fork a GitHub repository to your account or specified organization + - `organization`: Organization to fork to (string, optional) + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + +- **create_branch** - Create a new branch in a GitHub repository + - `branch`: Name for new branch (string, required) + - `from_branch`: Source branch (defaults to repo default) (string, optional) + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + +- **push_files** - Push multiple files to a GitHub repository in a single commit + - `branch`: Branch to push to (string, required) + - `files`: Array of file objects to push, each object with path (string) and content (string) (array, required) + - `message`: Commit message (string, required) + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + +### Search + +- **search_repositories** - Search for GitHub repositories + - `page`: Page number for pagination (min 1) (number, optional) + - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) + - `query`: Search query (string, required) + +- **search_code** - Search for code across GitHub repositories + - `order`: Sort order ('asc' or 'desc') (string, optional) + - `page`: Page number for pagination (min 1) (number, optional) + - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) + - `q`: Search query using GitHub code search syntax (string, required) + - `sort`: Sort field ('indexed' only) (string, optional) + +- **search_users** - Search for GitHub users + - `order`: Sort order ('asc' or 'desc') (string, optional) + - `page`: Page number for pagination (min 1) (number, optional) + - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) + - `q`: Search query using GitHub users search syntax (string, required) + - `sort`: Sort field (followers, repositories, joined) (string, optional) + +### Code Scanning + +- **get_code_scanning_alert** - Get details of a specific code scanning alert in a GitHub repository. + - `alertNumber`: The number of the alert. (number, required) + - `owner`: The owner of the repository. (string, required) + - `repo`: The name of the repository. (string, required) + +- **list_code_scanning_alerts** - List code scanning alerts in a GitHub repository. + - `owner`: The owner of the repository. (string, required) + - `ref`: The Git reference for the results you want to list. (string, optional) + - `repo`: The name of the repository. (string, required) + - `severity`: Only code scanning alerts with this severity will be returned. Possible values are: critical, high, medium, low, warning, note, error. (string, optional) + - `state`: State of the code scanning alerts to list. Set to closed to list only closed code scanning alerts. Default: open (string, optional)