Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/config-api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ source:
# Source type: git, or api
type: api

# Data format: toolhive (native) or upstream (MCP registry format)
# Use 'upstream' for the official MCP registry format
# Data format: only 'upstream' is supported for API sources
format: upstream

# API endpoint configuration
Expand Down
32 changes: 22 additions & 10 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,25 @@ func (c *Config) validate() error {
return fmt.Errorf("source.type is required")
}

if err := c.validateSourceConfigByType(); err != nil {
return err
}

// Validate sync policy
if c.SyncPolicy == nil || c.SyncPolicy.Interval == "" {
return fmt.Errorf("syncPolicy.interval is required")
}

// Try to parse the interval to ensure it's valid
if _, err := time.ParseDuration(c.SyncPolicy.Interval); err != nil {
return fmt.Errorf("syncPolicy.interval must be a valid duration (e.g., '30m', '1h'): %w", err)
}

return nil
}

// validateSourceConfigByType validates the source configuration by the source type
func (c *Config) validateSourceConfigByType() error {
// Validate source-specific settings
switch c.Source.Type {
case SourceTypeGit:
Expand All @@ -307,6 +326,9 @@ func (c *Config) validate() error {
if c.Source.API.Endpoint == "" {
return fmt.Errorf("source.api.endpoint is required")
}
if c.Source.Format != "" && c.Source.Format != SourceFormatUpstream {
return fmt.Errorf("source.format must be either empty or %s when type is api, got %s", SourceFormatUpstream, c.Source.Format)
}

case SourceTypeFile:
if c.Source.File == nil {
Expand All @@ -320,15 +342,5 @@ func (c *Config) validate() error {
return fmt.Errorf("unsupported source type: %s", c.Source.Type)
}

// Validate sync policy
if c.SyncPolicy == nil || c.SyncPolicy.Interval == "" {
return fmt.Errorf("syncPolicy.interval is required")
}

// Try to parse the interval to ensure it's valid
if _, err := time.ParseDuration(c.SyncPolicy.Interval); err != nil {
return fmt.Errorf("syncPolicy.interval must be a valid duration (e.g., '30m', '1h'): %w", err)
}

return nil
}
14 changes: 14 additions & 0 deletions internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,20 @@ func TestConfigValidate(t *testing.T) {
wantErr: true,
errMsg: "source.file.path is required",
},
{
name: "invalid_format_when_type_is_api",
config: &Config{
Source: SourceConfig{
Type: "api",
Format: "toolhive",
API: &APIConfig{
Endpoint: "http://example.com",
},
},
},
wantErr: true,
errMsg: "source.format must be either empty or upstream",
},
{
name: "unsupported_source_type",
config: &Config{
Expand Down
51 changes: 19 additions & 32 deletions internal/sources/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ import (
)

// APISourceHandler handles registry data from API endpoints
// It detects the format (ToolHive vs Upstream) and delegates to the appropriate handler
// It validates the Upstream format and delegates to the appropriate handler
type APISourceHandler struct {
httpClient httpclient.Client
validator SourceDataValidator
toolhiveHandler *ToolHiveAPIHandler
upstreamHandler *UpstreamAPIHandler
}

Expand All @@ -26,7 +25,6 @@ func NewAPISourceHandler() *APISourceHandler {
return &APISourceHandler{
httpClient: httpClient,
validator: NewSourceDataValidator(),
toolhiveHandler: NewToolHiveAPIHandler(httpClient),
upstreamHandler: NewUpstreamAPIHandler(httpClient),
}
}
Expand All @@ -38,6 +36,11 @@ func (*APISourceHandler) Validate(source *config.SourceConfig) error {
config.SourceTypeAPI, source.Type)
}

if source.Format != "" && source.Format != config.SourceFormatUpstream {
return fmt.Errorf("unsupported format: expected %s or empty, got %s",
config.SourceFormatUpstream, source.Format)
}

if source.API == nil {
return fmt.Errorf("api configuration is required for source type %s",
config.SourceTypeAPI)
Expand All @@ -51,7 +54,7 @@ func (*APISourceHandler) Validate(source *config.SourceConfig) error {
}

// FetchRegistry retrieves registry data from the API endpoint
// It auto-detects the format and delegates to the appropriate handler
// It validates the Upstream format and delegates to the appropriate handler
func (h *APISourceHandler) FetchRegistry(ctx context.Context, cfg *config.Config) (*FetchResult, error) {
logger := log.FromContext(ctx)

Expand All @@ -60,14 +63,13 @@ func (h *APISourceHandler) FetchRegistry(ctx context.Context, cfg *config.Config
return nil, fmt.Errorf("source validation failed: %w", err)
}

// Detect format and get appropriate handler
handler, format, err := h.detectFormatAndGetHandler(ctx, cfg)
// Validate Upstream format and get appropriate handler
handler, err := h.validateUstreamFormat(ctx, cfg)
if err != nil {
return nil, fmt.Errorf("format detection failed: %w", err)
return nil, fmt.Errorf("upstream format validation failed: %w", err)
}

logger.Info("Detected API format, delegating to handler",
"format", format)
logger.Info("Validated Upstream format, delegating to handler")

// Delegate to the appropriate handler
return handler.FetchRegistry(ctx, cfg)
Expand All @@ -80,48 +82,33 @@ func (h *APISourceHandler) CurrentHash(ctx context.Context, cfg *config.Config)
return "", fmt.Errorf("source validation failed: %w", err)
}

// Detect format and get appropriate handler
handler, _, err := h.detectFormatAndGetHandler(ctx, cfg)
// Validate Upstream format and get appropriate handler
handler, err := h.validateUstreamFormat(ctx, cfg)
if err != nil {
return "", fmt.Errorf("format detection failed: %w", err)
return "", fmt.Errorf("upstream format validation failed: %w", err)
}

// Delegate to the appropriate handler
return handler.CurrentHash(ctx, cfg)
}

// apiFormatHandler is an internal interface for format-specific handlers
type apiFormatHandler interface {
Validate(ctx context.Context, endpoint string) error
FetchRegistry(ctx context.Context, cfg *config.Config) (*FetchResult, error)
CurrentHash(ctx context.Context, cfg *config.Config) (string, error)
}

// detectFormatAndGetHandler detects the API format and returns the appropriate handler
func (h *APISourceHandler) detectFormatAndGetHandler(
// validateUstreamFormat validates the Upstream format and returns the appropriate handler
func (h *APISourceHandler) validateUstreamFormat(
ctx context.Context,
cfg *config.Config,
) (apiFormatHandler, string, error) {
) (*UpstreamAPIHandler, error) {
logger := log.FromContext(ctx)
endpoint := h.getBaseURL(cfg)

// Try ToolHive format first (/v0/info)
toolhiveErr := h.toolhiveHandler.Validate(ctx, endpoint)
if toolhiveErr == nil {
logger.Info("Validated as ToolHive format")
return h.toolhiveHandler, "toolhive", nil
}
logger.V(1).Info("ToolHive format validation failed", "error", toolhiveErr.Error())

// Try upstream format (/openapi.yaml)
upstreamErr := h.upstreamHandler.Validate(ctx, endpoint)
if upstreamErr == nil {
logger.Info("Validated as upstream MCP Registry format")
return h.upstreamHandler, "upstream", nil
return h.upstreamHandler, nil
}
logger.V(1).Info("Upstream format validation failed", "error", upstreamErr.Error())

return nil, "", fmt.Errorf("unable to detect valid API format (tried toolhive and upstream)")
return nil, fmt.Errorf("unable to validate Upstream format")
}

// getBaseURL extracts and normalizes the base URL
Expand Down
57 changes: 14 additions & 43 deletions internal/sources/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ import (
)

const (
toolhiveInfoPath = "/v0/info"
toolhiveServersPath = "/v0/servers"
openapiPath = "/openapi.yaml"
openapiPath = "/openapi.yaml"
)

func TestAPISources(t *testing.T) {
Expand Down Expand Up @@ -54,6 +52,17 @@ var _ = Describe("APISourceHandler", func() {
Expect(err.Error()).To(ContainSubstring("invalid source type"))
})

It("should reject non-Upstream format", func() {
source := &config.SourceConfig{
Type: config.SourceTypeAPI,
Format: config.SourceFormatToolHive,
}

err := handler.Validate(source)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("unsupported format:"))
})

It("should reject missing API configuration", func() {
source := &config.SourceConfig{
Type: config.SourceTypeAPI,
Expand Down Expand Up @@ -91,49 +100,11 @@ var _ = Describe("APISourceHandler", func() {
})
})

Describe("Format Detection", func() {
Context("ToolHive Format", func() {
BeforeEach(func() {
mockServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case toolhiveInfoPath:
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"version":"1.0.0","last_updated":"2025-01-14T00:00:00Z","source":"file:/data/registry.json","total_servers":5}`))
case toolhiveServersPath:
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"servers":[],"total":0}`))
default:
w.WriteHeader(http.StatusNotFound)
}
}))
})

It("should detect and validate ToolHive format", func() {
registryConfig := &config.Config{
Source: config.SourceConfig{
Type: config.SourceTypeAPI,
API: &config.APIConfig{
Endpoint: mockServer.URL,
},
},
}
result, err := handler.FetchRegistry(ctx, registryConfig)
Expect(err).NotTo(HaveOccurred())
Expect(result).NotTo(BeNil())
Expect(result.Format).To(Equal(config.SourceFormatToolHive))
})
})

Describe("Upstream Format Validation", func() {
Context("Upstream Format", func() {
BeforeEach(func() {
mockServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case toolhiveInfoPath:
// Return 404 for /v0/info (upstream doesn't have this)
w.WriteHeader(http.StatusNotFound)
_, _ = w.Write([]byte(`{"detail":"Endpoint not found"}`))
case openapiPath:
w.Header().Set("Content-Type", "application/x-yaml")
w.WriteHeader(http.StatusOK)
Expand Down Expand Up @@ -189,7 +160,7 @@ openapi: 3.1.0

_, err := handler.FetchRegistry(ctx, registryConfig)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("format detection failed"))
Expect(err.Error()).To(ContainSubstring("upstream format validation failed"))
})
})
})
Expand Down
Loading
Loading