Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions pkg/registry/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package registry

import (
"fmt"
"sync"

"github.com/stacklok/toolhive/pkg/config"
Expand All @@ -20,32 +21,33 @@ var (
defaultProviderMu sync.Mutex
)

// NewRegistryProvider creates a new registry provider based on the configuration
func NewRegistryProvider(cfg *config.Config) Provider {
// NewRegistryProvider creates a new registry provider based on the configuration.
// Returns an error if a custom registry is configured but cannot be reached.
func NewRegistryProvider(cfg *config.Config) (Provider, error) {
// Priority order:
// 1. API URL (if configured) - for live MCP Registry API queries
// 2. Remote URL (if configured) - for static JSON over HTTP
// 3. Local file path (if configured) - for local JSON file
// 4. Default - embedded registry data

if cfg != nil && len(cfg.RegistryApiUrl) > 0 {
// Use cached provider with persistent cache enabled by default
// This provides 1-hour TTL and works for both CLI and API server
provider, err := NewCachedAPIRegistryProvider(cfg.RegistryApiUrl, cfg.AllowPrivateRegistryIp, true)
if err != nil {
// Log error but fall back to default provider
// This prevents application from failing if API is temporarily unavailable
return NewLocalRegistryProvider()
return nil, fmt.Errorf("custom registry API at %s is not reachable: %w", cfg.RegistryApiUrl, err)
}
return provider
return provider, nil
}
if cfg != nil && len(cfg.RegistryUrl) > 0 {
return NewRemoteRegistryProvider(cfg.RegistryUrl, cfg.AllowPrivateRegistryIp)
provider, err := NewRemoteRegistryProvider(cfg.RegistryUrl, cfg.AllowPrivateRegistryIp)
if err != nil {
return nil, fmt.Errorf("custom registry at %s is not reachable: %w", cfg.RegistryUrl, err)
}
return provider, nil
}
if cfg != nil && len(cfg.LocalRegistryPath) > 0 {
return NewLocalRegistryProvider(cfg.LocalRegistryPath)
return NewLocalRegistryProvider(cfg.LocalRegistryPath), nil
}
return NewLocalRegistryProvider()
return NewLocalRegistryProvider(), nil
}

// GetDefaultProvider returns the default registry provider instance
Expand All @@ -63,7 +65,7 @@ func GetDefaultProviderWithConfig(configProvider config.Provider) (Provider, err
defaultProviderErr = err
return
}
defaultProvider = NewRegistryProvider(cfg)
defaultProvider, defaultProviderErr = NewRegistryProvider(cfg)
})

return defaultProvider, defaultProviderErr
Expand Down
12 changes: 9 additions & 3 deletions pkg/registry/provider_remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ type RemoteRegistryProvider struct {
allowPrivateIp bool
}

// NewRemoteRegistryProvider creates a new remote registry provider
func NewRemoteRegistryProvider(registryURL string, allowPrivateIp bool) *RemoteRegistryProvider {
// NewRemoteRegistryProvider creates a new remote registry provider.
// Validates the registry is reachable before returning.
func NewRemoteRegistryProvider(registryURL string, allowPrivateIp bool) (*RemoteRegistryProvider, error) {
p := &RemoteRegistryProvider{
registryURL: registryURL,
allowPrivateIp: allowPrivateIp,
Expand All @@ -27,7 +28,12 @@ func NewRemoteRegistryProvider(registryURL string, allowPrivateIp bool) *RemoteR
// Initialize the base provider with the GetRegistry function
p.BaseProvider = NewBaseProvider(p.GetRegistry)

return p
// Validate the registry is reachable
if _, err := p.GetRegistry(); err != nil {
return nil, err
}

return p, nil
}

// GetRegistry returns the remote registry data
Expand Down
63 changes: 28 additions & 35 deletions pkg/registry/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,48 +18,52 @@ func TestNewRegistryProvider(t *testing.T) {
name string
config *config.Config
expectedType string
expectError bool
}{
{
name: "nil config returns embedded provider",
config: nil,
expectedType: "*registry.LocalRegistryProvider",
expectError: false,
},
{
name: "empty registry URL returns embedded provider",
config: &config.Config{
RegistryUrl: "",
},
expectedType: "*registry.LocalRegistryProvider",
expectError: false,
},
{
name: "registry URL returns remote provider",
name: "unreachable registry URL returns error",
config: &config.Config{
RegistryUrl: "https://example.com/registry.json",
RegistryUrl: "https://non-existent-host-12345.com/registry.json",
},
expectedType: "*registry.RemoteRegistryProvider",
expectedType: "",
expectError: true,
},
{
name: "local registry path returns embedded provider with file path",
config: &config.Config{
LocalRegistryPath: "/path/to/registry.json",
},
expectedType: "*registry.LocalRegistryProvider",
},
{
name: "registry URL takes precedence over local path",
config: &config.Config{
RegistryUrl: "https://example.com/registry.json",
LocalRegistryPath: "/path/to/registry.json",
},
expectedType: "*registry.RemoteRegistryProvider",
expectError: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
provider := NewRegistryProvider(tt.config)
provider, err := NewRegistryProvider(tt.config)

if tt.expectError {
assert.Error(t, err)
assert.Nil(t, provider)
return
}

assert.NoError(t, err)
// Check the type of the provider
providerType := getTypeName(provider)
if providerType != tt.expectedType {
Expand Down Expand Up @@ -126,21 +130,7 @@ func TestLocalRegistryProvider(t *testing.T) {
}
}

func TestRemoteRegistryProvider(t *testing.T) {
t.Parallel()
// Note: This test would require a mock HTTP server for full testing
// For now, we just test the creation
provider := NewRemoteRegistryProvider("https://example.com/registry.json", false)

if provider == nil {
t.Fatal("NewRemoteRegistryProvider() returned nil")
}

// Test that it implements the interface
var _ Provider = provider
}

func TestRemoteRegistryProvider_GetRegistry_Error(t *testing.T) {
func TestRemoteRegistryProvider_CreationError(t *testing.T) {
t.Parallel()

tests := []struct {
Expand All @@ -164,16 +154,16 @@ func TestRemoteRegistryProvider_GetRegistry_Error(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

provider := NewRemoteRegistryProvider(tt.url, false)
registry, err := provider.GetRegistry()
provider, err := NewRemoteRegistryProvider(tt.url, false)

if tt.expectError {
assert.Error(t, err)
assert.Nil(t, registry)
assert.Nil(t, provider)
} else {
// This case would require a working HTTP server
assert.NoError(t, err)
assert.NotNil(t, registry)
assert.NotNil(t, provider)
// Test that it implements the interface
var _ Provider = provider
}
})
}
Expand Down Expand Up @@ -276,7 +266,8 @@ func TestGetRegistry(t *testing.T) {
require.NoError(t, err)

// Create provider with test config
provider := NewRegistryProvider(cfg)
provider, err := NewRegistryProvider(cfg)
require.NoError(t, err)
reg, err := provider.GetRegistry()
if err != nil {
t.Fatalf("Failed to get registry: %v", err)
Expand Down Expand Up @@ -319,7 +310,8 @@ func TestGetServer(t *testing.T) {
require.NoError(t, err)

// Create provider with test config
provider := NewRegistryProvider(cfg)
provider, err := NewRegistryProvider(cfg)
require.NoError(t, err)

// Test getting an existing server
server, err := provider.GetServer("osv")
Expand Down Expand Up @@ -371,7 +363,8 @@ func TestSearchServers(t *testing.T) {
require.NoError(t, err)

// Create provider with test config
provider := NewRegistryProvider(cfg)
provider, err := NewRegistryProvider(cfg)
require.NoError(t, err)

// Test searching for servers
servers, err := provider.SearchServers("search")
Expand Down
Loading