Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions test/e2e/thv-operator/virtualmcp/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"os"
"time"

mcpclient "github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/mcp"
"github.com/onsi/gomega"
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
Expand Down Expand Up @@ -45,6 +47,62 @@ func WaitForVirtualMCPServerReady(ctx context.Context, c client.Client, name, na
}, timeout, 5*time.Second).Should(gomega.Succeed())
}

// InitializedMCPClient holds an initialized MCP client with its associated context
type InitializedMCPClient struct {
Client *mcpclient.Client
Ctx context.Context
Cancel context.CancelFunc
}

// Close cleans up the MCP client resources
func (c *InitializedMCPClient) Close() {
if c.Cancel != nil {
c.Cancel()
}
if c.Client != nil {
_ = c.Client.Close()
}
}

// CreateInitializedMCPClient creates an MCP client, starts the transport, and initializes
// the connection with the given client name. Returns an InitializedMCPClient that should
// be closed when done using defer client.Close().
func CreateInitializedMCPClient(nodePort int32, clientName string, timeout time.Duration) (*InitializedMCPClient, error) {
serverURL := fmt.Sprintf("http://localhost:%d/mcp", nodePort)
mcpClient, err := mcpclient.NewStreamableHttpClient(serverURL)
if err != nil {
return nil, fmt.Errorf("failed to create MCP client: %w", err)
}

ctx, cancel := context.WithTimeout(context.Background(), timeout)

if err := mcpClient.Start(ctx); err != nil {
cancel()
_ = mcpClient.Close()
return nil, fmt.Errorf("failed to start MCP client: %w", err)
}

initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.Capabilities = mcp.ClientCapabilities{}
initRequest.Params.ClientInfo = mcp.Implementation{
Name: clientName,
Version: "1.0.0",
}

if _, err := mcpClient.Initialize(ctx, initRequest); err != nil {
cancel()
_ = mcpClient.Close()
return nil, fmt.Errorf("failed to initialize MCP client: %w", err)
}

return &InitializedMCPClient{
Client: mcpClient,
Ctx: ctx,
Cancel: cancel,
}, nil
}

// getPodLogs retrieves logs from a specific pod container
func getPodLogs(ctx context.Context, namespace, podName, containerName string, previous bool) (string, error) {
// Get the rest config - try in-cluster first, then fall back to kubeconfig
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package virtualmcp

import (
"context"
"fmt"
"strings"
"time"

"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/mcp"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
Expand Down Expand Up @@ -209,32 +207,14 @@ var _ = Describe("VirtualMCPServer Aggregation Filtering", Ordered, func() {

Context("when tool filtering is configured", func() {
It("should only expose filtered tools from backend1", func() {
By("Creating MCP client for VirtualMCPServer")
serverURL := fmt.Sprintf("http://localhost:%d/mcp", vmcpNodePort)
mcpClient, err := client.NewStreamableHttpClient(serverURL)
By("Creating and initializing MCP client for VirtualMCPServer")
mcpClient, err := CreateInitializedMCPClient(vmcpNodePort, "toolhive-filtering-test", 30*time.Second)
Expect(err).ToNot(HaveOccurred())
defer mcpClient.Close()

By("Starting transport and initializing connection")
testCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

err = mcpClient.Start(testCtx)
Expect(err).ToNot(HaveOccurred())

initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{
Name: "toolhive-filtering-test",
Version: "1.0.0",
}

_, err = mcpClient.Initialize(testCtx, initRequest)
Expect(err).ToNot(HaveOccurred())

By("Listing tools from VirtualMCPServer")
listRequest := mcp.ListToolsRequest{}
tools, err := mcpClient.ListTools(testCtx, listRequest)
tools, err := mcpClient.Client.ListTools(mcpClient.Ctx, listRequest)
Expect(err).ToNot(HaveOccurred())

By(fmt.Sprintf("VirtualMCPServer exposes %d tools after filtering", len(tools.Tools)))
Expand Down Expand Up @@ -269,32 +249,14 @@ var _ = Describe("VirtualMCPServer Aggregation Filtering", Ordered, func() {
})

It("should still allow calling filtered tools", func() {
By("Creating MCP client for VirtualMCPServer")
serverURL := fmt.Sprintf("http://localhost:%d/mcp", vmcpNodePort)
mcpClient, err := client.NewStreamableHttpClient(serverURL)
By("Creating and initializing MCP client for VirtualMCPServer")
mcpClient, err := CreateInitializedMCPClient(vmcpNodePort, "toolhive-filtering-test", 30*time.Second)
Expect(err).ToNot(HaveOccurred())
defer mcpClient.Close()

By("Starting transport and initializing connection")
testCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

err = mcpClient.Start(testCtx)
Expect(err).ToNot(HaveOccurred())

initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{
Name: "toolhive-filtering-test",
Version: "1.0.0",
}

_, err = mcpClient.Initialize(testCtx, initRequest)
Expect(err).ToNot(HaveOccurred())

By("Listing available tools")
listRequest := mcp.ListToolsRequest{}
tools, err := mcpClient.ListTools(testCtx, listRequest)
tools, err := mcpClient.Client.ListTools(mcpClient.Ctx, listRequest)
Expect(err).ToNot(HaveOccurred())

// Find the backend1 echo tool
Expand All @@ -308,17 +270,14 @@ var _ = Describe("VirtualMCPServer Aggregation Filtering", Ordered, func() {
Expect(targetToolName).ToNot(BeEmpty(), "Should find echo tool from backend1")

By(fmt.Sprintf("Calling filtered echo tool: %s", targetToolName))
toolCallCtx, toolCallCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer toolCallCancel()

testInput := "filtered123"
callRequest := mcp.CallToolRequest{}
callRequest.Params.Name = targetToolName
callRequest.Params.Arguments = map[string]any{
"input": testInput,
}

result, err := mcpClient.CallTool(toolCallCtx, callRequest)
result, err := mcpClient.Client.CallTool(mcpClient.Ctx, callRequest)
Expect(err).ToNot(HaveOccurred(), "Should be able to call filtered tool")
Expect(result).ToNot(BeNil())
Expect(result.Content).ToNot(BeEmpty(), "Should have content in response")
Expand Down
55 changes: 7 additions & 48 deletions test/e2e/thv-operator/virtualmcp/virtualmcp_yardstick_base_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
package virtualmcp

import (
"context"
"fmt"
"net/http"
"strings"
"time"

"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/mcp"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
Expand Down Expand Up @@ -217,32 +215,14 @@ var _ = Describe("VirtualMCPServer Yardstick Base", Ordered, func() {
})

It("should aggregate echo tools from both yardstick backends", func() {
By("Creating MCP client for VirtualMCPServer")
serverURL := fmt.Sprintf("http://localhost:%d/mcp", vmcpNodePort)
mcpClient, err := client.NewStreamableHttpClient(serverURL)
By("Creating and initializing MCP client for VirtualMCPServer")
mcpClient, err := CreateInitializedMCPClient(vmcpNodePort, "toolhive-yardstick-test", 30*time.Second)
Expect(err).ToNot(HaveOccurred())
defer mcpClient.Close()

By("Starting transport and initializing connection")
testCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

err = mcpClient.Start(testCtx)
Expect(err).ToNot(HaveOccurred())

initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{
Name: "toolhive-yardstick-test",
Version: "1.0.0",
}

_, err = mcpClient.Initialize(testCtx, initRequest)
Expect(err).ToNot(HaveOccurred())

By("Listing tools from VirtualMCPServer")
listRequest := mcp.ListToolsRequest{}
tools, err := mcpClient.ListTools(testCtx, listRequest)
tools, err := mcpClient.Client.ListTools(mcpClient.Ctx, listRequest)
Expect(err).ToNot(HaveOccurred())
Expect(tools.Tools).ToNot(BeEmpty(), "VirtualMCPServer should aggregate tools from backends")

Expand Down Expand Up @@ -279,32 +259,14 @@ var _ = Describe("VirtualMCPServer Yardstick Base", Ordered, func() {
})

It("should successfully call echo tool through VirtualMCPServer", func() {
By("Creating MCP client for VirtualMCPServer")
serverURL := fmt.Sprintf("http://localhost:%d/mcp", vmcpNodePort)
mcpClient, err := client.NewStreamableHttpClient(serverURL)
By("Creating and initializing MCP client for VirtualMCPServer")
mcpClient, err := CreateInitializedMCPClient(vmcpNodePort, "toolhive-yardstick-test", 30*time.Second)
Expect(err).ToNot(HaveOccurred())
defer mcpClient.Close()

By("Starting transport and initializing connection")
testCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

err = mcpClient.Start(testCtx)
Expect(err).ToNot(HaveOccurred())

initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{
Name: "toolhive-yardstick-test",
Version: "1.0.0",
}

_, err = mcpClient.Initialize(testCtx, initRequest)
Expect(err).ToNot(HaveOccurred())

By("Listing available tools")
listRequest := mcp.ListToolsRequest{}
tools, err := mcpClient.ListTools(testCtx, listRequest)
tools, err := mcpClient.Client.ListTools(mcpClient.Ctx, listRequest)
Expect(err).ToNot(HaveOccurred())
Expect(tools.Tools).ToNot(BeEmpty())

Expand All @@ -320,9 +282,6 @@ var _ = Describe("VirtualMCPServer Yardstick Base", Ordered, func() {
Expect(targetToolName).ToNot(BeEmpty(), "Should find an echo tool")

By(fmt.Sprintf("Calling echo tool: %s", targetToolName))
toolCallCtx, toolCallCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer toolCallCancel()

// Yardstick echo tool requires alphanumeric input
testInput := "hello123"
callRequest := mcp.CallToolRequest{}
Expand All @@ -331,7 +290,7 @@ var _ = Describe("VirtualMCPServer Yardstick Base", Ordered, func() {
"input": testInput,
}

result, err := mcpClient.CallTool(toolCallCtx, callRequest)
result, err := mcpClient.Client.CallTool(mcpClient.Ctx, callRequest)
Expect(err).ToNot(HaveOccurred(),
fmt.Sprintf("Should be able to call tool '%s' through VirtualMCPServer", targetToolName))
Expect(result).ToNot(BeNil())
Expand Down
Loading