Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
387 changes: 386 additions & 1 deletion cmd/thv/app/group.go

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion cmd/thv/app/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ func runCmdFunc(cmd *cobra.Command, args []string) error {
// Get debug mode flag
debugMode, _ := cmd.Flags().GetBool("debug")

return runSingleServer(ctx, &runFlags, serverOrImage, cmdArgs, debugMode, cmd, "")
}

// runSingleServer handles the core logic for running a single MCP server
func runSingleServer(ctx context.Context, runFlags *RunFlags, serverOrImage string, cmdArgs []string, debugMode bool, cmd *cobra.Command, groupName string) error { //nolint:lll
// Create container runtime
rt, err := container.NewFactory().Create(ctx)
if err != nil {
Expand Down Expand Up @@ -200,7 +205,7 @@ func runCmdFunc(cmd *cobra.Command, args []string) error {
}

// Build the run configuration
runnerConfig, err := BuildRunnerConfig(ctx, &runFlags, serverOrImage, cmdArgs, debugMode, cmd)
runnerConfig, err := BuildRunnerConfig(ctx, runFlags, serverOrImage, cmdArgs, debugMode, cmd, groupName)
if err != nil {
return err
}
Expand Down
6 changes: 4 additions & 2 deletions cmd/thv/app/run_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ func BuildRunnerConfig(
cmdArgs []string,
debugMode bool,
cmd *cobra.Command,
groupName string,
) (*runner.RunConfig, error) {
// Validate and setup basic configuration
validatedHost, err := ValidateAndNormaliseHostFlag(runFlags.Host)
Expand Down Expand Up @@ -258,7 +259,7 @@ func BuildRunnerConfig(
}

// Handle image retrieval
imageURL, serverMetadata, err := handleImageRetrieval(ctx, serverOrImage, runFlags)
imageURL, serverMetadata, err := handleImageRetrieval(ctx, serverOrImage, runFlags, groupName)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -335,6 +336,7 @@ func handleImageRetrieval(
ctx context.Context,
serverOrImage string,
runFlags *RunFlags,
groupName string,
) (
string,
registry.ServerMetadata,
Expand All @@ -343,7 +345,7 @@ func handleImageRetrieval(

// Try to get server from registry (container or remote) or direct URL
imageURL, serverMetadata, err := retriever.GetMCPServer(
ctx, serverOrImage, runFlags.CACertPath, runFlags.VerifyImage)
ctx, serverOrImage, runFlags.CACertPath, runFlags.VerifyImage, groupName)
if err != nil {
return "", nil, fmt.Errorf("failed to find or create the MCP server %s: %v", serverOrImage, err)
}
Expand Down
1 change: 1 addition & 0 deletions docs/cli/thv_group.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 42 additions & 0 deletions docs/cli/thv_group_run.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pkg/api/v1/workload_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ func (s *WorkloadService) BuildFullRunConfig(ctx context.Context, req *createReq
req.Image,
"", // We do not let the user specify a CA cert path here.
retriever.VerifyImageWarn,
"", // TODO Add support for registry groups lookups for APi
)
if err != nil {
// Check if the error is due to context timeout
Expand Down
2 changes: 1 addition & 1 deletion pkg/api/v1/workloads_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ func makeMockRetriever(
) retriever.Retriever {
t.Helper()

return func(_ context.Context, serverOrImage string, _ string, verificationType string) (string, registry.ServerMetadata, error) {
return func(_ context.Context, serverOrImage string, _ string, verificationType string, _ string) (string, registry.ServerMetadata, error) {
assert.Equal(t, expectedServerOrImage, serverOrImage)
assert.Equal(t, retriever.VerifyImageWarn, verificationType)
return returnedImage, returnedServerMetadata, returnedError
Expand Down
58 changes: 58 additions & 0 deletions pkg/mcp/server/handler_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"context"
"testing"

"github.com/mark3labs/mcp-go/mcp"
Expand Down Expand Up @@ -248,3 +249,60 @@ func TestPrepareEnvironmentVariables(t *testing.T) {
})
}
}

func TestBuildServerConfig(t *testing.T) {
t.Parallel()

ctx := context.Background()
args := &runServerArgs{
Server: "test-server",
Name: "test-name",
Host: "127.0.0.1",
Env: map[string]string{"TEST_VAR": "test_value"},
}

tests := []struct {
name string
imageURL string
imageMetadata *registry.ImageMetadata
expectError bool
}{
{
name: "valid config with nil metadata",
imageURL: "test/image:latest",
imageMetadata: nil,
expectError: false, // Actually succeeds because container runtime creation works
},
{
name: "valid config with metadata",
imageURL: "test/image:latest",
imageMetadata: &registry.ImageMetadata{
BaseServerMetadata: registry.BaseServerMetadata{
Transport: "stdio",
},
Image: "test/image:latest",
Args: []string{"--test"},
EnvVars: []*registry.EnvVar{
{Name: "DEFAULT_VAR", Default: "default_value"},
},
},
expectError: false, // Actually succeeds and tests the type assertion line
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

runConfig, err := buildServerConfig(ctx, args, tt.imageURL, tt.imageMetadata)

if tt.expectError {
assert.Error(t, err)
assert.Nil(t, runConfig)
} else {
assert.NoError(t, err)
assert.NotNil(t, runConfig)
}
})
}
}
2 changes: 1 addition & 1 deletion pkg/mcp/server/run_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (h *Handler) RunServer(ctx context.Context, request mcp.CallToolRequest) (*

// Use retriever to properly fetch and prepare the MCP server
// TODO: make this configurable so we could warn or even fail
imageURL, serverMetadata, err := retriever.GetMCPServer(ctx, args.Server, "", "disabled")
imageURL, serverMetadata, err := retriever.GetMCPServer(ctx, args.Server, "", "disabled", "")
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("Failed to get MCP server: %v", err)), nil
}
Expand Down
12 changes: 12 additions & 0 deletions pkg/registry/provider_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ func (p *LocalRegistryProvider) GetRegistry() (*Registry, error) {
server.Name = name
}

// Set name field on servers within groups
for _, group := range registry.Groups {
if group != nil {
for name, server := range group.Servers {
server.Name = name
}
for name, server := range group.RemoteServers {
server.Name = name
}
}
}

return registry, nil
}

Expand Down
12 changes: 12 additions & 0 deletions pkg/registry/provider_remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,17 @@ func (p *RemoteRegistryProvider) GetRegistry() (*Registry, error) {
server.Name = name
}

// Set name field on servers within groups
for _, group := range registry.Groups {
if group != nil {
for name, server := range group.Servers {
server.Name = name
}
for name, server := range group.RemoteServers {
server.Name = name
}
}
}

return registry, nil
}
104 changes: 104 additions & 0 deletions pkg/registry/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/stacklok/toolhive/pkg/config"
Expand Down Expand Up @@ -138,6 +139,45 @@ func TestRemoteRegistryProvider(t *testing.T) {
var _ Provider = provider
}

func TestRemoteRegistryProvider_GetRegistry_Error(t *testing.T) {
t.Parallel()

tests := []struct {
name string
url string
expectError bool
}{
{
name: "invalid URL scheme",
url: "invalid://url",
expectError: true,
},
{
name: "non-existent host",
url: "https://non-existent-host-12345.com/registry.json",
expectError: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

provider := NewRemoteRegistryProvider(tt.url, false)
registry, err := provider.GetRegistry()

if tt.expectError {
assert.Error(t, err)
assert.Nil(t, registry)
} else {
// This case would require a working HTTP server
assert.NoError(t, err)
assert.NotNil(t, registry)
}
})
}
}

func TestLocalRegistryProviderWithLocalFile(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -397,3 +437,67 @@ func TestListServers(t *testing.T) {
t.Errorf("ListServers() returned %d servers, want %d", len(servers), totalServers)
}
}

func TestParseRegistryData(t *testing.T) {
t.Parallel()

tests := []struct {
name string
data []byte
expectError bool
}{
{
name: "valid registry data",
data: []byte(`{
"version": "1.0.0",
"last_updated": "2023-01-01T00:00:00Z",
"servers": {
"test-server": {
"image": "test/image:latest",
"description": "Test server"
}
}
}`),
expectError: false,
},
{
name: "invalid JSON",
data: []byte(`invalid json`),
expectError: true,
},
{
name: "empty data",
data: []byte(``),
expectError: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

registry, err := parseRegistryData(tt.data)

if tt.expectError {
assert.Error(t, err)
assert.Nil(t, registry)
} else {
assert.NoError(t, err)
assert.NotNil(t, registry)
}
})
}
}

func TestLocalRegistryProvider_FileReadError(t *testing.T) {
t.Parallel()

// Test with non-existent file path
provider := NewLocalRegistryProvider("/non/existent/path/registry.json")

registry, err := provider.GetRegistry()

assert.Error(t, err)
assert.Nil(t, registry)
assert.Contains(t, err.Error(), "failed to read local registry file")
}
Loading
Loading