From baf444b95224f8786003cb9f45898c0b96f50abd Mon Sep 17 00:00:00 2001 From: Dustin Spicuzza Date: Fri, 20 Jun 2025 12:30:13 -0400 Subject: [PATCH] Add content tool --- README.md | 1 + .../snapshots/go/content/test_function.snap | 7 +++ .../tests/go/content/content_test.go | 56 +++++++++++++++++++ internal/tools/content.go | 54 ++++++++++++++++++ internal/tools/definition.go | 2 +- internal/tools/lsp-utilities.go | 22 ++++---- tools.go | 42 ++++++++++++++ 7 files changed, 173 insertions(+), 11 deletions(-) create mode 100644 integrationtests/snapshots/go/content/test_function.snap create mode 100644 integrationtests/tests/go/content/content_test.go create mode 100644 internal/tools/content.go diff --git a/README.md b/README.md index 5bd3194..a91ebb2 100644 --- a/README.md +++ b/README.md @@ -173,6 +173,7 @@ This is an [MCP](https://modelcontextprotocol.io/introduction) server that runs ## Tools - `definition`: Retrieves the complete source code definition of any symbol (function, type, constant, etc.) from your codebase. +- `content`: Retrieves the complete source code definition (function, type, constant, etc.) from your codebase at a specific location. - `references`: Locates all usages and references of a symbol throughout the codebase. - `diagnostics`: Provides diagnostic information for a specific file, including warnings and errors. - `hover`: Display documentation, type hints, or other hover information for a given location. diff --git a/integrationtests/snapshots/go/content/test_function.snap b/integrationtests/snapshots/go/content/test_function.snap new file mode 100644 index 0000000..0782543 --- /dev/null +++ b/integrationtests/snapshots/go/content/test_function.snap @@ -0,0 +1,7 @@ +Symbol: TestFunction +/TEST_OUTPUT/workspace/clean.go +Range: L31:C1 - L33:C2 + +31|func TestFunction() { +32| fmt.Println("This is a test function") +33|} diff --git a/integrationtests/tests/go/content/content_test.go b/integrationtests/tests/go/content/content_test.go new file mode 100644 index 0000000..07c7a4d --- /dev/null +++ b/integrationtests/tests/go/content/content_test.go @@ -0,0 +1,56 @@ +package content_test + +import ( + "context" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/isaacphi/mcp-language-server/integrationtests/tests/common" + "github.com/isaacphi/mcp-language-server/integrationtests/tests/go/internal" + "github.com/isaacphi/mcp-language-server/internal/tools" +) + +func TestContent(t *testing.T) { + suite := internal.GetTestSuite(t) + + ctx, cancel := context.WithTimeout(suite.Context, 10*time.Second) + defer cancel() + + tests := []struct { + name string + file string + line int + column int + expectedText string + snapshotName string + }{ + { + name: "Function", + file: filepath.Join(suite.WorkspaceDir, "clean.go"), + line: 32, + column: 1, + expectedText: "func TestFunction()", + snapshotName: "test_function", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Call the ReadDefinition tool + result, err := tools.GetContentInfo(ctx, suite.Client, tc.file, tc.line, tc.column) + if err != nil { + t.Fatalf("Failed to read content: %v", err) + } + + // Check that the result contains relevant information + if !strings.Contains(result, tc.expectedText) { + t.Errorf("Content does not contain expected text: %s", tc.expectedText) + } + + // Use snapshot testing to verify exact output + common.SnapshotTest(t, "go", "content", tc.snapshotName, result) + }) + } +} diff --git a/internal/tools/content.go b/internal/tools/content.go new file mode 100644 index 0000000..fdde358 --- /dev/null +++ b/internal/tools/content.go @@ -0,0 +1,54 @@ +package tools + +import ( + "context" + "fmt" + "strings" + + "github.com/isaacphi/mcp-language-server/internal/lsp" + "github.com/isaacphi/mcp-language-server/internal/protocol" +) + +// GetContentInfo reads the source code definition of a symbol (function, type, constant, etc.) at the specified position +func GetContentInfo(ctx context.Context, client *lsp.Client, filePath string, line, column int) (string, error) { + // Open the file if not already open + err := client.OpenFile(ctx, filePath) + if err != nil { + return "", fmt.Errorf("could not open file: %v", err) + } + + // Convert 1-indexed line/column to 0-indexed for LSP protocol + position := protocol.Position{ + Line: uint32(line - 1), + Character: uint32(column - 1), + } + + location := protocol.Location{ + URI: protocol.DocumentUri("file://" + filePath), + Range: protocol.Range{ + Start: position, + End: position, + }, + } + + definition, loc, symbol, err := GetFullDefinition(ctx, client, location) + locationInfo := fmt.Sprintf( + "Symbol: %s\n"+ + "File: %s\n"+ + "Range: L%d:C%d - L%d:C%d\n\n", + symbol.GetName(), + strings.TrimPrefix(string(loc.URI), "file://"), + loc.Range.Start.Line+1, + loc.Range.Start.Character+1, + loc.Range.End.Line+1, + loc.Range.End.Character+1, + ) + + if err != nil { + return "", err + } + + definition = addLineNumbers(definition, int(loc.Range.Start.Line)+1) + + return locationInfo + definition, nil +} diff --git a/internal/tools/definition.go b/internal/tools/definition.go index 0af14a7..049033b 100644 --- a/internal/tools/definition.go +++ b/internal/tools/definition.go @@ -64,7 +64,7 @@ func ReadDefinition(ctx context.Context, client *lsp.Client, symbolName string) } banner := "---\n\n" - definition, loc, err := GetFullDefinition(ctx, client, loc) + definition, loc, _, err := GetFullDefinition(ctx, client, loc) locationInfo := fmt.Sprintf( "Symbol: %s\n"+ "File: %s\n"+ diff --git a/internal/tools/lsp-utilities.go b/internal/tools/lsp-utilities.go index ae7d70b..b41cbbe 100644 --- a/internal/tools/lsp-utilities.go +++ b/internal/tools/lsp-utilities.go @@ -12,7 +12,7 @@ import ( ) // Gets the full code block surrounding the start of the input location -func GetFullDefinition(ctx context.Context, client *lsp.Client, startLocation protocol.Location) (string, protocol.Location, error) { +func GetFullDefinition(ctx context.Context, client *lsp.Client, startLocation protocol.Location) (string, protocol.Location, protocol.DocumentSymbolResult, error) { symParams := protocol.DocumentSymbolParams{ TextDocument: protocol.TextDocumentIdentifier{ URI: startLocation.URI, @@ -22,15 +22,16 @@ func GetFullDefinition(ctx context.Context, client *lsp.Client, startLocation pr // Get all symbols in document symResult, err := client.DocumentSymbol(ctx, symParams) if err != nil { - return "", protocol.Location{}, fmt.Errorf("failed to get document symbols: %w", err) + return "", protocol.Location{}, nil, fmt.Errorf("failed to get document symbols: %w", err) } symbols, err := symResult.Results() if err != nil { - return "", protocol.Location{}, fmt.Errorf("failed to process document symbols: %w", err) + return "", protocol.Location{}, nil, fmt.Errorf("failed to process document symbols: %w", err) } var symbolRange protocol.Range + var symbol protocol.DocumentSymbolResult found := false // Search for symbol at startLocation @@ -38,6 +39,7 @@ func GetFullDefinition(ctx context.Context, client *lsp.Client, startLocation pr searchSymbols = func(symbols []protocol.DocumentSymbolResult) bool { for _, sym := range symbols { if containsPosition(sym.GetRange(), startLocation.Range.Start) { + symbol = sym symbolRange = sym.GetRange() found = true return true @@ -62,14 +64,14 @@ func GetFullDefinition(ctx context.Context, client *lsp.Client, startLocation pr // Convert URI to filesystem path filePath, err := url.PathUnescape(strings.TrimPrefix(string(startLocation.URI), "file://")) if err != nil { - return "", protocol.Location{}, fmt.Errorf("failed to unescape URI: %w", err) + return "", protocol.Location{}, nil, fmt.Errorf("failed to unescape URI: %w", err) } // Read the file to get the full lines of the definition // because we may have a start and end column content, err := os.ReadFile(filePath) if err != nil { - return "", protocol.Location{}, fmt.Errorf("failed to read file: %w", err) + return "", protocol.Location{}, nil, fmt.Errorf("failed to read file: %w", err) } lines := strings.Split(string(content), "\n") @@ -79,7 +81,7 @@ func GetFullDefinition(ctx context.Context, client *lsp.Client, startLocation pr // Get the line at the end of the range if int(symbolRange.End.Line) >= len(lines) { - return "", protocol.Location{}, fmt.Errorf("line number out of range") + return "", protocol.Location{}, nil, fmt.Errorf("line number out of range") } line := lines[symbolRange.End.Line] @@ -128,14 +130,14 @@ func GetFullDefinition(ctx context.Context, client *lsp.Client, startLocation pr // Return the text within the range if int(symbolRange.End.Line) >= len(lines) { - return "", protocol.Location{}, fmt.Errorf("end line out of range") + return "", protocol.Location{}, nil, fmt.Errorf("end line out of range") } selectedLines := lines[symbolRange.Start.Line : symbolRange.End.Line+1] - return strings.Join(selectedLines, "\n"), startLocation, nil + return strings.Join(selectedLines, "\n"), startLocation, symbol, nil } - return "", protocol.Location{}, fmt.Errorf("symbol not found") + return "", protocol.Location{}, nil, fmt.Errorf("symbol not found") } // GetLineRangesToDisplay determines which lines should be displayed for a set of locations @@ -146,7 +148,7 @@ func GetLineRangesToDisplay(ctx context.Context, client *lsp.Client, locations [ // For each location, get its container and add relevant lines for _, loc := range locations { // Use GetFullDefinition to find container - _, containerLoc, err := GetFullDefinition(ctx, client, loc) + _, containerLoc, _, err := GetFullDefinition(ctx, client, loc) if err != nil { // If container not found, just use the location's line refLine := int(loc.Range.Start.Line) diff --git a/tools.go b/tools.go index f7de723..db38fd5 100644 --- a/tools.go +++ b/tools.go @@ -382,6 +382,48 @@ func (s *mcpServer) registerTools() error { return mcp.NewToolResultText(text), nil }) + contentTool := mcp.NewTool("content", + mcp.WithDescription("Read the source code definition of a symbol (function, type, constant, etc.) at the specified location."), + mcp.WithString("filePath", + mcp.Required(), + mcp.Description("The path to the file"), + ), + mcp.WithNumber("line", + mcp.Required(), + mcp.Description("The line number where the content is requested (1-indexed)"), + ), + mcp.WithNumber("column", + mcp.Required(), + mcp.Description("The column number where the content is requested (1-indexed)"), + ), + ) + + s.mcpServer.AddTool(contentTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract arguments + filePath, err := request.RequireString("filePath") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + line, err := request.RequireInt("line") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + column, err := request.RequireInt("column") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + coreLogger.Debug("Executing content for file: %s line: %d column: %d", filePath, line, column) + text, err := tools.GetContentInfo(s.ctx, s.lspClient, filePath, line, column) + if err != nil { + coreLogger.Error("Failed to get content information: %v", err) + return mcp.NewToolResultError(fmt.Sprintf("failed to get content: %v", err)), nil + } + return mcp.NewToolResultText(text), nil + }) + coreLogger.Info("Successfully registered all MCP tools") return nil }