diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go index 13c216fb2..9ef8bcad8 100644 --- a/cmd/vmcp/app/commands.go +++ b/cmd/vmcp/app/commands.go @@ -8,6 +8,7 @@ import ( "github.com/spf13/viper" "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/vmcp/config" ) var rootCmd = &cobra.Command{ @@ -120,10 +121,47 @@ This command checks: } logger.Infof("Validating configuration: %s", configPath) - // TODO: Validate configuration - // This will be implemented in a future PR when pkg/vmcp is added - return fmt.Errorf("validate command not yet implemented") + // Load configuration from YAML + loader := config.NewYAMLLoader(configPath) + cfg, err := loader.Load() + if err != nil { + logger.Errorf("Failed to load configuration: %v", err) + return fmt.Errorf("configuration loading failed: %w", err) + } + + logger.Debugf("Configuration loaded successfully, performing validation...") + + // Validate configuration + validator := config.NewValidator() + if err := validator.Validate(cfg); err != nil { + logger.Errorf("Configuration validation failed: %v", err) + return fmt.Errorf("validation failed: %w", err) + } + + logger.Infof("✓ Configuration is valid") + logger.Infof(" Name: %s", cfg.Name) + logger.Infof(" Group: %s", cfg.GroupRef) + logger.Infof(" Incoming Auth: %s", cfg.IncomingAuth.Type) + logger.Infof(" Outgoing Auth: %s (source: %s)", + func() string { + if len(cfg.OutgoingAuth.Backends) > 0 { + return fmt.Sprintf("%d backends configured", len(cfg.OutgoingAuth.Backends)) + } + return "default only" + }(), + cfg.OutgoingAuth.Source) + logger.Infof(" Conflict Resolution: %s", cfg.Aggregation.ConflictResolution) + + if cfg.TokenCache != nil { + logger.Infof(" Token Cache: %s", cfg.TokenCache.Provider) + } + + if len(cfg.CompositeTools) > 0 { + logger.Infof(" Composite Tools: %d defined", len(cfg.CompositeTools)) + } + + return nil }, } } diff --git a/examples/vmcp-config-invalid.yaml b/examples/vmcp-config-invalid.yaml new file mode 100644 index 000000000..da0233211 --- /dev/null +++ b/examples/vmcp-config-invalid.yaml @@ -0,0 +1,28 @@ +# Invalid Virtual MCP Server Configuration Example +# This file has intentional errors to test validation + +# Missing required field: name +# name: "test-vmcp" +group: "test-group" + +# Invalid auth type +incoming_auth: + type: invalid_type + oidc: + issuer: "https://keycloak.example.com" + client_id: "test-client" + client_secret_env: "TEST_SECRET" + audience: "vmcp" + scopes: ["openid"] + +# Invalid source +outgoing_auth: + source: invalid_source + default: + type: pass_through + +# Invalid conflict resolution strategy +aggregation: + conflict_resolution: invalid_strategy + conflict_resolution_config: + prefix_format: "{workload}_" diff --git a/pkg/vmcp/config/config.go b/pkg/vmcp/config/config.go index e944eb825..8dba220fd 100644 --- a/pkg/vmcp/config/config.go +++ b/pkg/vmcp/config/config.go @@ -11,6 +11,14 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp" ) +// Token cache provider types +const ( + // CacheProviderMemory represents in-memory token cache provider + CacheProviderMemory = "memory" + // CacheProviderRedis represents Redis token cache provider + CacheProviderRedis = "redis" +) + // Config is the unified configuration model for Virtual MCP Server. // This is platform-agnostic and used by both CLI and Kubernetes deployments. // diff --git a/pkg/vmcp/config/validator.go b/pkg/vmcp/config/validator.go new file mode 100644 index 000000000..dc863cef6 --- /dev/null +++ b/pkg/vmcp/config/validator.go @@ -0,0 +1,528 @@ +package config + +import ( + "fmt" + "strings" + + "github.com/stacklok/toolhive/pkg/vmcp" +) + +// DefaultValidator implements comprehensive configuration validation. +type DefaultValidator struct{} + +// NewValidator creates a new configuration validator. +func NewValidator() *DefaultValidator { + return &DefaultValidator{} +} + +// Validate performs comprehensive validation of the configuration. +func (v *DefaultValidator) Validate(cfg *Config) error { + if cfg == nil { + return fmt.Errorf("%w: configuration is nil", vmcp.ErrInvalidConfig) + } + + var errors []string + + // Validate basic fields + if err := v.validateBasicFields(cfg); err != nil { + errors = append(errors, err.Error()) + } + + // Validate incoming authentication + if err := v.validateIncomingAuth(cfg.IncomingAuth); err != nil { + errors = append(errors, err.Error()) + } + + // Validate outgoing authentication + if err := v.validateOutgoingAuth(cfg.OutgoingAuth); err != nil { + errors = append(errors, err.Error()) + } + + // Validate aggregation configuration + if err := v.validateAggregation(cfg.Aggregation); err != nil { + errors = append(errors, err.Error()) + } + + // Validate token cache configuration + if err := v.validateTokenCache(cfg.TokenCache); err != nil { + errors = append(errors, err.Error()) + } + + // Validate operational configuration + if err := v.validateOperational(cfg.Operational); err != nil { + errors = append(errors, err.Error()) + } + + // Validate composite tools + if err := v.validateCompositeTools(cfg.CompositeTools); err != nil { + errors = append(errors, err.Error()) + } + + if len(errors) > 0 { + return fmt.Errorf("%w:\n - %s", vmcp.ErrInvalidConfig, strings.Join(errors, "\n - ")) + } + + return nil +} + +func (*DefaultValidator) validateBasicFields(cfg *Config) error { + if cfg.Name == "" { + return fmt.Errorf("name is required") + } + + if cfg.GroupRef == "" { + return fmt.Errorf("group reference is required") + } + + return nil +} + +func (v *DefaultValidator) validateIncomingAuth(auth *IncomingAuthConfig) error { + if auth == nil { + return fmt.Errorf("incoming_auth is required") + } + + // Validate auth type + validTypes := []string{"oidc", "local", "anonymous"} + if !contains(validTypes, auth.Type) { + return fmt.Errorf("incoming_auth.type must be one of: %s", strings.Join(validTypes, ", ")) + } + + // Validate OIDC configuration + if auth.Type == "oidc" { + if auth.OIDC == nil { + return fmt.Errorf("incoming_auth.oidc is required when type is 'oidc'") + } + + if auth.OIDC.Issuer == "" { + return fmt.Errorf("incoming_auth.oidc.issuer is required") + } + + if auth.OIDC.ClientID == "" { + return fmt.Errorf("incoming_auth.oidc.client_id is required") + } + + if auth.OIDC.Audience == "" { + return fmt.Errorf("incoming_auth.oidc.audience is required") + } + + // Client secret should be set (either directly or via env var reference) + if auth.OIDC.ClientSecret == "" { + return fmt.Errorf("incoming_auth.oidc.client_secret is required") + } + } + + // Validate authorization configuration + if auth.Authz != nil { + if err := v.validateAuthz(auth.Authz); err != nil { + return fmt.Errorf("incoming_auth.authz: %w", err) + } + } + + return nil +} + +func (*DefaultValidator) validateAuthz(authz *AuthzConfig) error { + validTypes := []string{"cedar", "none"} + if !contains(validTypes, authz.Type) { + return fmt.Errorf("type must be one of: %s", strings.Join(validTypes, ", ")) + } + + if authz.Type == "cedar" && len(authz.Policies) == 0 { + return fmt.Errorf("policies are required when type is 'cedar'") + } + + return nil +} + +func (v *DefaultValidator) validateOutgoingAuth(auth *OutgoingAuthConfig) error { + if auth == nil { + return fmt.Errorf("outgoing_auth is required") + } + + // Validate source + validSources := []string{"inline", "discovered", "mixed"} + if !contains(validSources, auth.Source) { + return fmt.Errorf("outgoing_auth.source must be one of: %s", strings.Join(validSources, ", ")) + } + + // Validate default strategy + if auth.Default != nil { + if err := v.validateBackendAuthStrategy("default", auth.Default); err != nil { + return fmt.Errorf("outgoing_auth.default: %w", err) + } + } + + // Validate per-backend strategies + for backendName, strategy := range auth.Backends { + if err := v.validateBackendAuthStrategy(backendName, strategy); err != nil { + return fmt.Errorf("outgoing_auth.backends.%s: %w", backendName, err) + } + } + + return nil +} + +func (*DefaultValidator) validateBackendAuthStrategy(_ string, strategy *BackendAuthStrategy) error { + if strategy == nil { + return fmt.Errorf("strategy is nil") + } + + validTypes := []string{ + "pass_through", "token_exchange", "client_credentials", + "service_account", "header_injection", "oauth_proxy", + } + if !contains(validTypes, strategy.Type) { + return fmt.Errorf("type must be one of: %s", strings.Join(validTypes, ", ")) + } + + // Validate type-specific requirements + switch strategy.Type { + case "token_exchange": + // Token exchange requires specific metadata + required := []string{"token_url", "client_id", "audience"} + for _, field := range required { + if _, ok := strategy.Metadata[field]; !ok { + return fmt.Errorf("token_exchange requires metadata field: %s", field) + } + } + + case "service_account": + // Service account requires credentials + if _, ok := strategy.Metadata["credentials_env"]; !ok { + return fmt.Errorf("service_account requires metadata field: credentials_env") + } + + case "header_injection": + // Header injection requires header name and value/format + if _, ok := strategy.Metadata["header_name"]; !ok { + return fmt.Errorf("header_injection requires metadata field: header_name") + } + } + + return nil +} + +func (v *DefaultValidator) validateAggregation(agg *AggregationConfig) error { + if agg == nil { + return fmt.Errorf("aggregation is required") + } + + // Validate conflict resolution strategy + validStrategies := []vmcp.ConflictResolutionStrategy{ + vmcp.ConflictStrategyPrefix, + vmcp.ConflictStrategyPriority, + vmcp.ConflictStrategyManual, + } + if !containsStrategy(validStrategies, agg.ConflictResolution) { + return fmt.Errorf("conflict_resolution must be one of: prefix, priority, manual") + } + + // Validate strategy-specific configuration + if agg.ConflictResolutionConfig == nil { + return fmt.Errorf("conflict_resolution_config is required") + } + + if err := v.validateConflictStrategy(agg); err != nil { + return err + } + + return v.validateToolConfigurations(agg.Tools) +} + +// validateConflictStrategy validates strategy-specific configuration +func (*DefaultValidator) validateConflictStrategy(agg *AggregationConfig) error { + switch agg.ConflictResolution { + case vmcp.ConflictStrategyPrefix: + if agg.ConflictResolutionConfig.PrefixFormat == "" { + return fmt.Errorf("prefix_format is required for prefix strategy") + } + + case vmcp.ConflictStrategyPriority: + if len(agg.ConflictResolutionConfig.PriorityOrder) == 0 { + return fmt.Errorf("priority_order is required for priority strategy") + } + + case vmcp.ConflictStrategyManual: + // Manual strategy requires explicit overrides + if len(agg.Tools) == 0 { + return fmt.Errorf("tool overrides are required for manual strategy") + } + } + + return nil +} + +// validateToolConfigurations validates tool override configurations +func (v *DefaultValidator) validateToolConfigurations(tools []*WorkloadToolConfig) error { + workloadNames := make(map[string]bool) + for i, tool := range tools { + if tool.Workload == "" { + return fmt.Errorf("tools[%d].workload is required", i) + } + + if workloadNames[tool.Workload] { + return fmt.Errorf("duplicate workload configuration: %s", tool.Workload) + } + workloadNames[tool.Workload] = true + + if err := v.validateToolOverrides(tool.Overrides, i); err != nil { + return err + } + } + + return nil +} + +// validateToolOverrides validates individual tool overrides +func (*DefaultValidator) validateToolOverrides(overrides map[string]*ToolOverride, toolIndex int) error { + for toolName, override := range overrides { + if override.Name == "" && override.Description == "" { + return fmt.Errorf("tools[%d].overrides.%s: at least one of name or description must be specified", toolIndex, toolName) + } + } + return nil +} + +func (*DefaultValidator) validateTokenCache(cache *TokenCacheConfig) error { + if cache == nil { + return nil // Token cache is optional + } + + validProviders := []string{CacheProviderMemory, CacheProviderRedis} + if !contains(validProviders, cache.Provider) { + return fmt.Errorf("token_cache.provider must be one of: %s", strings.Join(validProviders, ", ")) + } + + switch cache.Provider { + case CacheProviderMemory: + if cache.Memory == nil { + return fmt.Errorf("token_cache.memory is required when provider is 'memory'") + } + if cache.Memory.MaxEntries <= 0 { + return fmt.Errorf("token_cache.memory.max_entries must be positive") + } + if cache.Memory.TTLOffset < 0 { + return fmt.Errorf("token_cache.memory.ttl_offset cannot be negative") + } + + case CacheProviderRedis: + if cache.Redis == nil { + return fmt.Errorf("token_cache.redis is required when provider is 'redis'") + } + if cache.Redis.Address == "" { + return fmt.Errorf("token_cache.redis.address is required") + } + if cache.Redis.TTLOffset < 0 { + return fmt.Errorf("token_cache.redis.ttl_offset cannot be negative") + } + } + + return nil +} + +func (v *DefaultValidator) validateOperational(ops *OperationalConfig) error { + if ops == nil { + return nil // Operational config is optional (defaults apply) + } + + // Validate timeouts + if ops.Timeouts != nil { + if ops.Timeouts.Default <= 0 { + return fmt.Errorf("operational.timeouts.default must be positive") + } + + for workload, timeout := range ops.Timeouts.PerWorkload { + if timeout <= 0 { + return fmt.Errorf("operational.timeouts.per_workload.%s must be positive", workload) + } + } + } + + // Validate failure handling + if ops.FailureHandling != nil { + if err := v.validateFailureHandling(ops.FailureHandling); err != nil { + return fmt.Errorf("operational.failure_handling: %w", err) + } + } + + return nil +} + +func (*DefaultValidator) validateFailureHandling(fh *FailureHandlingConfig) error { + if fh.HealthCheckInterval <= 0 { + return fmt.Errorf("health_check_interval must be positive") + } + + if fh.UnhealthyThreshold <= 0 { + return fmt.Errorf("unhealthy_threshold must be positive") + } + + validModes := []string{"fail", "best_effort"} + if !contains(validModes, fh.PartialFailureMode) { + return fmt.Errorf("partial_failure_mode must be one of: %s", strings.Join(validModes, ", ")) + } + + // Validate circuit breaker + if fh.CircuitBreaker != nil && fh.CircuitBreaker.Enabled { + if fh.CircuitBreaker.FailureThreshold <= 0 { + return fmt.Errorf("circuit_breaker.failure_threshold must be positive") + } + if fh.CircuitBreaker.Timeout <= 0 { + return fmt.Errorf("circuit_breaker.timeout must be positive") + } + } + + return nil +} + +func (v *DefaultValidator) validateCompositeTools(tools []*CompositeToolConfig) error { + if len(tools) == 0 { + return nil // Composite tools are optional + } + + toolNames := make(map[string]bool) + + for i, tool := range tools { + // Validate basic fields + if tool.Name == "" { + return fmt.Errorf("composite_tools[%d].name is required", i) + } + + if toolNames[tool.Name] { + return fmt.Errorf("duplicate composite tool name: %s", tool.Name) + } + toolNames[tool.Name] = true + + if tool.Description == "" { + return fmt.Errorf("composite_tools[%d].description is required", i) + } + + if tool.Timeout <= 0 { + return fmt.Errorf("composite_tools[%d].timeout must be positive", i) + } + + // Validate steps + if len(tool.Steps) == 0 { + return fmt.Errorf("composite_tools[%d] must have at least one step", i) + } + + if err := v.validateWorkflowSteps(tool.Name, tool.Steps); err != nil { + return fmt.Errorf("composite_tools[%d]: %w", i, err) + } + } + + return nil +} + +func (v *DefaultValidator) validateWorkflowSteps(_ string, steps []*WorkflowStepConfig) error { + stepIDs := make(map[string]bool) + + for i, step := range steps { + if err := v.validateStepBasics(step, i, stepIDs); err != nil { + return err + } + + if err := v.validateStepType(step, i); err != nil { + return err + } + + if err := v.validateStepDependencies(step, i, stepIDs); err != nil { + return err + } + + if err := v.validateStepErrorHandling(step, i); err != nil { + return err + } + } + + return nil +} + +// validateStepBasics validates basic step requirements (ID uniqueness) +func (*DefaultValidator) validateStepBasics(step *WorkflowStepConfig, index int, stepIDs map[string]bool) error { + if step.ID == "" { + return fmt.Errorf("step[%d].id is required", index) + } + + if stepIDs[step.ID] { + return fmt.Errorf("duplicate step ID: %s", step.ID) + } + stepIDs[step.ID] = true + + return nil +} + +// validateStepType validates step type and type-specific requirements +func (*DefaultValidator) validateStepType(step *WorkflowStepConfig, index int) error { + validTypes := []string{"tool", "elicitation"} + if !contains(validTypes, step.Type) { + return fmt.Errorf("step[%d].type must be one of: %s", index, strings.Join(validTypes, ", ")) + } + + switch step.Type { + case "tool": + if step.Tool == "" { + return fmt.Errorf("step[%d].tool is required for tool steps", index) + } + + case "elicitation": + if step.Message == "" { + return fmt.Errorf("step[%d].message is required for elicitation steps", index) + } + if len(step.Schema) == 0 { + return fmt.Errorf("step[%d].schema is required for elicitation steps", index) + } + // Note: timeout validation is optional - defaults are set during loading + } + + return nil +} + +// validateStepDependencies validates step dependency references +func (*DefaultValidator) validateStepDependencies(step *WorkflowStepConfig, index int, stepIDs map[string]bool) error { + for _, depID := range step.DependsOn { + if !stepIDs[depID] { + return fmt.Errorf("step[%d].depends_on references non-existent step: %s", index, depID) + } + } + return nil +} + +// validateStepErrorHandling validates step error handling configuration +func (*DefaultValidator) validateStepErrorHandling(step *WorkflowStepConfig, index int) error { + if step.OnError == nil { + return nil + } + + validActions := []string{"abort", "continue", "retry"} + if !contains(validActions, step.OnError.Action) { + return fmt.Errorf("step[%d].on_error.action must be one of: %s", index, strings.Join(validActions, ", ")) + } + + if step.OnError.Action == "retry" && step.OnError.RetryCount <= 0 { + return fmt.Errorf("step[%d].on_error.retry_count must be positive for retry action", index) + } + + return nil +} + +// Helper functions + +func contains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} + +func containsStrategy(slice []vmcp.ConflictResolutionStrategy, item vmcp.ConflictResolutionStrategy) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} diff --git a/pkg/vmcp/config/validator_test.go b/pkg/vmcp/config/validator_test.go new file mode 100644 index 000000000..c7bb64678 --- /dev/null +++ b/pkg/vmcp/config/validator_test.go @@ -0,0 +1,547 @@ +package config + +import ( + "strings" + "testing" + "time" + + "github.com/stacklok/toolhive/pkg/vmcp" +) + +func TestValidator_ValidateBasicFields(t *testing.T) { + t.Parallel() + tests := []struct { + name string + cfg *Config + wantErr bool + errMsg string + }{ + { + name: "valid configuration", + cfg: &Config{ + Name: "test-vmcp", + GroupRef: "test-group", + IncomingAuth: &IncomingAuthConfig{ + Type: "anonymous", + }, + OutgoingAuth: &OutgoingAuthConfig{ + Source: "inline", + }, + Aggregation: &AggregationConfig{ + ConflictResolution: vmcp.ConflictStrategyPrefix, + ConflictResolutionConfig: &ConflictResolutionConfig{ + PrefixFormat: "{workload}_", + }, + }, + }, + wantErr: false, + }, + { + name: "missing name", + cfg: &Config{ + GroupRef: "test-group", + IncomingAuth: &IncomingAuthConfig{ + Type: "anonymous", + }, + OutgoingAuth: &OutgoingAuthConfig{ + Source: "inline", + }, + Aggregation: &AggregationConfig{ + ConflictResolution: vmcp.ConflictStrategyPrefix, + ConflictResolutionConfig: &ConflictResolutionConfig{ + PrefixFormat: "{workload}_", + }, + }, + }, + wantErr: true, + errMsg: "name is required", + }, + { + name: "missing group reference", + cfg: &Config{ + Name: "test-vmcp", + IncomingAuth: &IncomingAuthConfig{ + Type: "anonymous", + }, + OutgoingAuth: &OutgoingAuthConfig{ + Source: "inline", + }, + Aggregation: &AggregationConfig{ + ConflictResolution: vmcp.ConflictStrategyPrefix, + ConflictResolutionConfig: &ConflictResolutionConfig{ + PrefixFormat: "{workload}_", + }, + }, + }, + wantErr: true, + errMsg: "group reference is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + v := NewValidator() + err := v.Validate(tt.cfg) + + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr && err != nil { + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("Validate() error message = %v, want to contain %v", err.Error(), tt.errMsg) + } + } + }) + } +} + +func TestValidator_ValidateIncomingAuth(t *testing.T) { + t.Parallel() + tests := []struct { + name string + auth *IncomingAuthConfig + wantErr bool + errMsg string + }{ + { + name: "valid anonymous auth", + auth: &IncomingAuthConfig{ + Type: "anonymous", + }, + wantErr: false, + }, + { + name: "valid OIDC auth", + auth: &IncomingAuthConfig{ + Type: "oidc", + OIDC: &OIDCConfig{ + Issuer: "https://example.com", + ClientID: "test-client", + ClientSecret: "test-secret", + Audience: "vmcp", + Scopes: []string{"openid"}, + }, + }, + wantErr: false, + }, + { + name: "invalid auth type", + auth: &IncomingAuthConfig{ + Type: "invalid", + }, + wantErr: true, + errMsg: "incoming_auth.type must be one of", + }, + { + name: "OIDC without config", + auth: &IncomingAuthConfig{ + Type: "oidc", + }, + wantErr: true, + errMsg: "incoming_auth.oidc is required", + }, + { + name: "OIDC missing issuer", + auth: &IncomingAuthConfig{ + Type: "oidc", + OIDC: &OIDCConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + Audience: "vmcp", + }, + }, + wantErr: true, + errMsg: "issuer is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + v := NewValidator() + err := v.validateIncomingAuth(tt.auth) + + if (err != nil) != tt.wantErr { + t.Errorf("validateIncomingAuth() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr && err != nil && tt.errMsg != "" { + if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateIncomingAuth() error message = %v, want to contain %v", err.Error(), tt.errMsg) + } + } + }) + } +} + +func TestValidator_ValidateOutgoingAuth(t *testing.T) { + t.Parallel() + tests := []struct { + name string + auth *OutgoingAuthConfig + wantErr bool + errMsg string + }{ + { + name: "valid inline source with pass_through default", + auth: &OutgoingAuthConfig{ + Source: "inline", + Default: &BackendAuthStrategy{ + Type: "pass_through", + }, + }, + wantErr: false, + }, + { + name: "valid token_exchange backend", + auth: &OutgoingAuthConfig{ + Source: "inline", + Backends: map[string]*BackendAuthStrategy{ + "github": { + Type: "token_exchange", + Metadata: map[string]any{ + "token_url": "https://example.com/token", + "client_id": "test-client", + "audience": "github-api", + }, + }, + }, + }, + wantErr: false, + }, + { + name: "invalid source", + auth: &OutgoingAuthConfig{ + Source: "invalid", + }, + wantErr: true, + errMsg: "outgoing_auth.source must be one of", + }, + { + name: "invalid backend auth type", + auth: &OutgoingAuthConfig{ + Source: "inline", + Backends: map[string]*BackendAuthStrategy{ + "test": { + Type: "invalid", + }, + }, + }, + wantErr: true, + errMsg: "type must be one of", + }, + { + name: "token_exchange missing required metadata", + auth: &OutgoingAuthConfig{ + Source: "inline", + Backends: map[string]*BackendAuthStrategy{ + "github": { + Type: "token_exchange", + Metadata: map[string]any{ + "client_id": "test-client", + // Missing token_url and audience + }, + }, + }, + }, + wantErr: true, + errMsg: "token_exchange requires metadata field", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + v := NewValidator() + err := v.validateOutgoingAuth(tt.auth) + + if (err != nil) != tt.wantErr { + t.Errorf("validateOutgoingAuth() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr && err != nil && tt.errMsg != "" { + if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateOutgoingAuth() error message = %v, want to contain %v", err.Error(), tt.errMsg) + } + } + }) + } +} + +func TestValidator_ValidateAggregation(t *testing.T) { + t.Parallel() + tests := []struct { + name string + agg *AggregationConfig + wantErr bool + errMsg string + }{ + { + name: "valid prefix strategy", + agg: &AggregationConfig{ + ConflictResolution: vmcp.ConflictStrategyPrefix, + ConflictResolutionConfig: &ConflictResolutionConfig{ + PrefixFormat: "{workload}_", + }, + }, + wantErr: false, + }, + { + name: "valid priority strategy", + agg: &AggregationConfig{ + ConflictResolution: vmcp.ConflictStrategyPriority, + ConflictResolutionConfig: &ConflictResolutionConfig{ + PriorityOrder: []string{"github", "jira"}, + }, + }, + wantErr: false, + }, + { + name: "valid manual strategy", + agg: &AggregationConfig{ + ConflictResolution: vmcp.ConflictStrategyManual, + ConflictResolutionConfig: &ConflictResolutionConfig{}, + Tools: []*WorkloadToolConfig{ + { + Workload: "github", + Overrides: map[string]*ToolOverride{ + "create_issue": { + Name: "gh_create_issue", + }, + }, + }, + }, + }, + wantErr: false, + }, + { + name: "prefix strategy missing format", + agg: &AggregationConfig{ + ConflictResolution: vmcp.ConflictStrategyPrefix, + ConflictResolutionConfig: &ConflictResolutionConfig{}, + }, + wantErr: true, + errMsg: "prefix_format is required", + }, + { + name: "priority strategy missing order", + agg: &AggregationConfig{ + ConflictResolution: vmcp.ConflictStrategyPriority, + ConflictResolutionConfig: &ConflictResolutionConfig{}, + }, + wantErr: true, + errMsg: "priority_order is required", + }, + { + name: "manual strategy missing overrides", + agg: &AggregationConfig{ + ConflictResolution: vmcp.ConflictStrategyManual, + ConflictResolutionConfig: &ConflictResolutionConfig{}, + }, + wantErr: true, + errMsg: "tool overrides are required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + v := NewValidator() + err := v.validateAggregation(tt.agg) + + if (err != nil) != tt.wantErr { + t.Errorf("validateAggregation() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr && err != nil && tt.errMsg != "" { + if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateAggregation() error message = %v, want to contain %v", err.Error(), tt.errMsg) + } + } + }) + } +} + +func TestValidator_ValidateTokenCache(t *testing.T) { + t.Parallel() + tests := []struct { + name string + cache *TokenCacheConfig + wantErr bool + errMsg string + }{ + { + name: "nil cache (optional)", + cache: nil, + wantErr: false, + }, + { + name: "valid memory cache", + cache: &TokenCacheConfig{ + Provider: CacheProviderMemory, + Memory: &MemoryCacheConfig{ + MaxEntries: 1000, + TTLOffset: 5 * time.Minute, + }, + }, + wantErr: false, + }, + { + name: "valid redis cache", + cache: &TokenCacheConfig{ + Provider: "redis", + Redis: &RedisCacheConfig{ + Address: "localhost:6379", + TTLOffset: 5 * time.Minute, + }, + }, + wantErr: false, + }, + { + name: "invalid provider", + cache: &TokenCacheConfig{ + Provider: "invalid", + }, + wantErr: true, + errMsg: "token_cache.provider must be one of", + }, + { + name: "memory cache with negative max entries", + cache: &TokenCacheConfig{ + Provider: CacheProviderMemory, + Memory: &MemoryCacheConfig{ + MaxEntries: -1, + }, + }, + wantErr: true, + errMsg: "max_entries must be positive", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + v := NewValidator() + err := v.validateTokenCache(tt.cache) + + if (err != nil) != tt.wantErr { + t.Errorf("validateTokenCache() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr && err != nil && tt.errMsg != "" { + if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateTokenCache() error message = %v, want to contain %v", err.Error(), tt.errMsg) + } + } + }) + } +} + +func TestValidator_ValidateCompositeTools(t *testing.T) { + t.Parallel() + tests := []struct { + name string + tools []*CompositeToolConfig + wantErr bool + errMsg string + }{ + { + name: "nil tools (optional)", + tools: nil, + wantErr: false, + }, + { + name: "valid composite tool", + tools: []*CompositeToolConfig{ + { + Name: "deploy_workflow", + Description: "Deploy workflow", + Timeout: 30 * time.Minute, + Steps: []*WorkflowStepConfig{ + { + ID: "merge", + Type: "tool", + Tool: "github.merge_pr", + }, + }, + }, + }, + wantErr: false, + }, + { + name: "missing tool name", + tools: []*CompositeToolConfig{ + { + Description: "Deploy workflow", + Timeout: 30 * time.Minute, + Steps: []*WorkflowStepConfig{ + { + ID: "merge", + Type: "tool", + Tool: "github.merge_pr", + }, + }, + }, + }, + wantErr: true, + errMsg: "name is required", + }, + { + name: "duplicate tool name", + tools: []*CompositeToolConfig{ + { + Name: "deploy", + Description: "Deploy workflow", + Timeout: 30 * time.Minute, + Steps: []*WorkflowStepConfig{ + { + ID: "merge", + Type: "tool", + Tool: "github.merge_pr", + }, + }, + }, + { + Name: "deploy", + Description: "Another deploy workflow", + Timeout: 30 * time.Minute, + Steps: []*WorkflowStepConfig{ + { + ID: "merge", + Type: "tool", + Tool: "jira.create_issue", + }, + }, + }, + }, + wantErr: true, + errMsg: "duplicate composite tool name", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + v := NewValidator() + err := v.validateCompositeTools(tt.tools) + + if (err != nil) != tt.wantErr { + t.Errorf("validateCompositeTools() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr && err != nil && tt.errMsg != "" { + if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateCompositeTools() error message = %v, want to contain %v", err.Error(), tt.errMsg) + } + } + }) + } +} diff --git a/pkg/vmcp/config/yaml_loader.go b/pkg/vmcp/config/yaml_loader.go new file mode 100644 index 000000000..7a4229b75 --- /dev/null +++ b/pkg/vmcp/config/yaml_loader.go @@ -0,0 +1,572 @@ +package config + +import ( + "fmt" + "os" + "time" + + "gopkg.in/yaml.v3" + + "github.com/stacklok/toolhive/pkg/vmcp" +) + +// YAMLLoader loads configuration from a YAML file. +// This is the CLI-specific loader that parses the YAML format defined in the proposal. +type YAMLLoader struct { + filePath string +} + +// NewYAMLLoader creates a new YAML configuration loader. +func NewYAMLLoader(filePath string) *YAMLLoader { + return &YAMLLoader{filePath: filePath} +} + +// Load reads and parses the YAML configuration file. +func (l *YAMLLoader) Load() (*Config, error) { + data, err := os.ReadFile(l.filePath) + if err != nil { + return nil, fmt.Errorf("failed to read config file: %w", err) + } + + var raw rawConfig + if err := yaml.Unmarshal(data, &raw); err != nil { + return nil, fmt.Errorf("failed to parse YAML: %w", err) + } + + cfg, err := l.transformToConfig(&raw) + if err != nil { + return nil, fmt.Errorf("failed to transform config: %w", err) + } + + return cfg, nil +} + +// rawConfig represents the YAML structure as defined in the proposal. +type rawConfig struct { + Name string `yaml:"name"` + Group string `yaml:"group"` + + IncomingAuth rawIncomingAuth `yaml:"incoming_auth"` + OutgoingAuth rawOutgoingAuth `yaml:"outgoing_auth"` + Aggregation rawAggregation `yaml:"aggregation"` + TokenCache *rawTokenCache `yaml:"token_cache"` + Operational *rawOperational `yaml:"operational"` + + CompositeTools []*rawCompositeTool `yaml:"composite_tools"` +} + +type rawIncomingAuth struct { + Type string `yaml:"type"` + OIDC *struct { + Issuer string `yaml:"issuer"` + ClientID string `yaml:"client_id"` + ClientSecretEnv string `yaml:"client_secret_env"` + Audience string `yaml:"audience"` + Scopes []string `yaml:"scopes"` + } `yaml:"oidc"` + Authz *struct { + Type string `yaml:"type"` + Policies []string `yaml:"policies"` + } `yaml:"authz"` +} + +type rawOutgoingAuth struct { + Source string `yaml:"source"` + Default *rawBackendAuthStrategy `yaml:"default"` + Backends map[string]*rawBackendAuthStrategy `yaml:"backends"` +} + +type rawBackendAuthStrategy struct { + Type string `yaml:"type"` + TokenExchange *rawTokenExchangeAuth `yaml:"token_exchange"` + ServiceAccount *rawServiceAccountAuth `yaml:"service_account"` +} + +type rawTokenExchangeAuth struct { + TokenURL string `yaml:"token_url"` + ClientID string `yaml:"client_id"` + ClientSecretEnv string `yaml:"client_secret_env"` + Audience string `yaml:"audience"` + Scopes []string `yaml:"scopes"` + SubjectTokenType string `yaml:"subject_token_type"` +} + +type rawServiceAccountAuth struct { + CredentialsEnv string `yaml:"credentials_env"` + HeaderName string `yaml:"header_name"` + HeaderFormat string `yaml:"header_format"` +} + +type rawAggregation struct { + ConflictResolution string `yaml:"conflict_resolution"` + ConflictResolutionConfig *rawConflictResolutionConfig `yaml:"conflict_resolution_config"` + Tools []*rawWorkloadToolConfig `yaml:"tools"` +} + +type rawConflictResolutionConfig struct { + PrefixFormat string `yaml:"prefix_format"` + PriorityOrder []string `yaml:"priority_order"` +} + +type rawWorkloadToolConfig struct { + Workload string `yaml:"workload"` + Filter []string `yaml:"filter"` + Overrides map[string]*rawToolOverride `yaml:"overrides"` +} + +type rawToolOverride struct { + Name string `yaml:"name"` + Description string `yaml:"description"` +} + +type rawTokenCache struct { + Provider string `yaml:"provider"` + Config struct { + MaxEntries int `yaml:"max_entries"` + TTLOffset string `yaml:"ttl_offset"` + Address string `yaml:"address"` + DB int `yaml:"db"` + KeyPrefix string `yaml:"key_prefix"` + Password string `yaml:"password"` + } `yaml:"config"` +} + +type rawOperational struct { + Timeouts struct { + Default string `yaml:"default"` + PerWorkload map[string]string `yaml:"per_workload"` + } `yaml:"timeouts"` + FailureHandling struct { + HealthCheckInterval string `yaml:"health_check_interval"` + UnhealthyThreshold int `yaml:"unhealthy_threshold"` + PartialFailureMode string `yaml:"partial_failure_mode"` + CircuitBreaker struct { + Enabled bool `yaml:"enabled"` + FailureThreshold int `yaml:"failure_threshold"` + Timeout string `yaml:"timeout"` + } `yaml:"circuit_breaker"` + } `yaml:"failure_handling"` +} + +type rawCompositeTool struct { + Name string `yaml:"name"` + Description string `yaml:"description"` + Parameters map[string]map[string]any `yaml:"parameters"` + Timeout string `yaml:"timeout"` + Steps []*rawWorkflowStep `yaml:"steps"` +} + +type rawWorkflowStep struct { + ID string `yaml:"id"` + Type string `yaml:"type"` + Tool string `yaml:"tool"` + Arguments map[string]any `yaml:"arguments"` + Condition string `yaml:"condition"` + DependsOn []string `yaml:"depends_on"` + OnError *rawStepErrorHandling `yaml:"on_error"` + Message string `yaml:"message"` + Schema map[string]any `yaml:"schema"` + Timeout string `yaml:"timeout"` + OnDecline *rawElicitationResponse `yaml:"on_decline"` + OnCancel *rawElicitationResponse `yaml:"on_cancel"` +} + +type rawStepErrorHandling struct { + Action string `yaml:"action"` + RetryCount int `yaml:"retry_count"` + RetryDelay string `yaml:"retry_delay"` +} + +type rawElicitationResponse struct { + Action string `yaml:"action"` +} + +// transformToConfig converts the raw YAML structure to the unified Config model. +func (l *YAMLLoader) transformToConfig(raw *rawConfig) (*Config, error) { + cfg := &Config{ + Name: raw.Name, + GroupRef: raw.Group, + } + + // Transform incoming auth + incomingAuth, err := l.transformIncomingAuth(&raw.IncomingAuth) + if err != nil { + return nil, fmt.Errorf("incoming_auth: %w", err) + } + cfg.IncomingAuth = incomingAuth + + // Transform outgoing auth + outgoingAuth, err := l.transformOutgoingAuth(&raw.OutgoingAuth) + if err != nil { + return nil, fmt.Errorf("outgoing_auth: %w", err) + } + cfg.OutgoingAuth = outgoingAuth + + // Transform aggregation + aggregation, err := l.transformAggregation(&raw.Aggregation) + if err != nil { + return nil, fmt.Errorf("aggregation: %w", err) + } + cfg.Aggregation = aggregation + + // Transform token cache + if raw.TokenCache != nil { + tokenCache, err := l.transformTokenCache(raw.TokenCache) + if err != nil { + return nil, fmt.Errorf("token_cache: %w", err) + } + cfg.TokenCache = tokenCache + } + + // Transform operational + if raw.Operational != nil { + operational, err := l.transformOperational(raw.Operational) + if err != nil { + return nil, fmt.Errorf("operational: %w", err) + } + cfg.Operational = operational + } + + // Transform composite tools + if len(raw.CompositeTools) > 0 { + compositeTools, err := l.transformCompositeTools(raw.CompositeTools) + if err != nil { + return nil, fmt.Errorf("composite_tools: %w", err) + } + cfg.CompositeTools = compositeTools + } + + return cfg, nil +} + +func (*YAMLLoader) transformIncomingAuth(raw *rawIncomingAuth) (*IncomingAuthConfig, error) { + cfg := &IncomingAuthConfig{ + Type: raw.Type, + } + + if raw.OIDC != nil { + // Resolve environment variable for client secret + clientSecret := os.Getenv(raw.OIDC.ClientSecretEnv) + if clientSecret == "" && raw.OIDC.ClientSecretEnv != "" { + return nil, fmt.Errorf("environment variable %s not set for client_secret", raw.OIDC.ClientSecretEnv) + } + + cfg.OIDC = &OIDCConfig{ + Issuer: raw.OIDC.Issuer, + ClientID: raw.OIDC.ClientID, + ClientSecret: clientSecret, + Audience: raw.OIDC.Audience, + Scopes: raw.OIDC.Scopes, + } + } + + if raw.Authz != nil { + cfg.Authz = &AuthzConfig{ + Type: raw.Authz.Type, + Policies: raw.Authz.Policies, + } + } + + return cfg, nil +} + +func (l *YAMLLoader) transformOutgoingAuth(raw *rawOutgoingAuth) (*OutgoingAuthConfig, error) { + cfg := &OutgoingAuthConfig{ + Source: raw.Source, + Backends: make(map[string]*BackendAuthStrategy), + } + + if raw.Default != nil { + strategy, err := l.transformBackendAuthStrategy(raw.Default) + if err != nil { + return nil, fmt.Errorf("default: %w", err) + } + cfg.Default = strategy + } + + for name, rawStrategy := range raw.Backends { + strategy, err := l.transformBackendAuthStrategy(rawStrategy) + if err != nil { + return nil, fmt.Errorf("backend %s: %w", name, err) + } + cfg.Backends[name] = strategy + } + + return cfg, nil +} + +func (*YAMLLoader) transformBackendAuthStrategy(raw *rawBackendAuthStrategy) (*BackendAuthStrategy, error) { + strategy := &BackendAuthStrategy{ + Type: raw.Type, + Metadata: make(map[string]any), + } + + switch raw.Type { + case "token_exchange": + if raw.TokenExchange == nil { + return nil, fmt.Errorf("token_exchange configuration is required") + } + + // Resolve client secret from environment + clientSecret := os.Getenv(raw.TokenExchange.ClientSecretEnv) + if clientSecret == "" && raw.TokenExchange.ClientSecretEnv != "" { + return nil, fmt.Errorf("environment variable %s not set", raw.TokenExchange.ClientSecretEnv) + } + + strategy.Metadata = map[string]any{ + "token_url": raw.TokenExchange.TokenURL, + "client_id": raw.TokenExchange.ClientID, + "client_secret": clientSecret, + "audience": raw.TokenExchange.Audience, + "scopes": raw.TokenExchange.Scopes, + "subject_token_type": raw.TokenExchange.SubjectTokenType, + } + + case "service_account": + if raw.ServiceAccount == nil { + return nil, fmt.Errorf("service_account configuration is required") + } + + // Resolve credentials from environment + credentials := os.Getenv(raw.ServiceAccount.CredentialsEnv) + if credentials == "" { + return nil, fmt.Errorf("environment variable %s not set", raw.ServiceAccount.CredentialsEnv) + } + + strategy.Metadata = map[string]any{ + "credentials": credentials, + "credentials_env": raw.ServiceAccount.CredentialsEnv, + "header_name": raw.ServiceAccount.HeaderName, + "header_format": raw.ServiceAccount.HeaderFormat, + } + } + + return strategy, nil +} + +// transformAggregation transforms raw aggregation configuration. +// Error return is maintained for consistency with other transform methods and future validation. +// +//nolint:unparam // error return kept for interface consistency +func (*YAMLLoader) transformAggregation(raw *rawAggregation) (*AggregationConfig, error) { + strategy := vmcp.ConflictResolutionStrategy(raw.ConflictResolution) + + cfg := &AggregationConfig{ + ConflictResolution: strategy, + ConflictResolutionConfig: &ConflictResolutionConfig{}, + } + + if raw.ConflictResolutionConfig != nil { + cfg.ConflictResolutionConfig.PrefixFormat = raw.ConflictResolutionConfig.PrefixFormat + cfg.ConflictResolutionConfig.PriorityOrder = raw.ConflictResolutionConfig.PriorityOrder + } + + for _, rawTool := range raw.Tools { + tool := &WorkloadToolConfig{ + Workload: rawTool.Workload, + Filter: rawTool.Filter, + Overrides: make(map[string]*ToolOverride), + } + + for name, override := range rawTool.Overrides { + tool.Overrides[name] = &ToolOverride{ + Name: override.Name, + Description: override.Description, + } + } + + cfg.Tools = append(cfg.Tools, tool) + } + + return cfg, nil +} + +func (*YAMLLoader) transformTokenCache(raw *rawTokenCache) (*TokenCacheConfig, error) { + cfg := &TokenCacheConfig{ + Provider: raw.Provider, + } + + switch raw.Provider { + case CacheProviderMemory: + ttlOffset, err := time.ParseDuration(raw.Config.TTLOffset) + if err != nil { + return nil, fmt.Errorf("invalid ttl_offset: %w", err) + } + + cfg.Memory = &MemoryCacheConfig{ + MaxEntries: raw.Config.MaxEntries, + TTLOffset: ttlOffset, + } + + case CacheProviderRedis: + ttlOffset, err := time.ParseDuration(raw.Config.TTLOffset) + if err != nil { + return nil, fmt.Errorf("invalid ttl_offset: %w", err) + } + + cfg.Redis = &RedisCacheConfig{ + Address: raw.Config.Address, + DB: raw.Config.DB, + KeyPrefix: raw.Config.KeyPrefix, + Password: raw.Config.Password, + TTLOffset: ttlOffset, + } + } + + return cfg, nil +} + +func (*YAMLLoader) transformOperational(raw *rawOperational) (*OperationalConfig, error) { + cfg := &OperationalConfig{} + + // Transform timeouts + if raw.Timeouts.Default != "" { + defaultTimeout, err := time.ParseDuration(raw.Timeouts.Default) + if err != nil { + return nil, fmt.Errorf("invalid default timeout: %w", err) + } + + cfg.Timeouts = &TimeoutConfig{ + Default: defaultTimeout, + PerWorkload: make(map[string]time.Duration), + } + + for workload, timeoutStr := range raw.Timeouts.PerWorkload { + timeout, err := time.ParseDuration(timeoutStr) + if err != nil { + return nil, fmt.Errorf("invalid timeout for workload %s: %w", workload, err) + } + cfg.Timeouts.PerWorkload[workload] = timeout + } + } + + // Transform failure handling + healthCheckInterval, err := time.ParseDuration(raw.FailureHandling.HealthCheckInterval) + if err != nil { + return nil, fmt.Errorf("invalid health_check_interval: %w", err) + } + + cfg.FailureHandling = &FailureHandlingConfig{ + HealthCheckInterval: healthCheckInterval, + UnhealthyThreshold: raw.FailureHandling.UnhealthyThreshold, + PartialFailureMode: raw.FailureHandling.PartialFailureMode, + } + + // Transform circuit breaker + if raw.FailureHandling.CircuitBreaker.Enabled { + cbTimeout, err := time.ParseDuration(raw.FailureHandling.CircuitBreaker.Timeout) + if err != nil { + return nil, fmt.Errorf("invalid circuit_breaker timeout: %w", err) + } + + cfg.FailureHandling.CircuitBreaker = &CircuitBreakerConfig{ + Enabled: true, + FailureThreshold: raw.FailureHandling.CircuitBreaker.FailureThreshold, + Timeout: cbTimeout, + } + } + + return cfg, nil +} + +func (l *YAMLLoader) transformCompositeTools(raw []*rawCompositeTool) ([]*CompositeToolConfig, error) { + var tools []*CompositeToolConfig + + for _, rawTool := range raw { + timeout, err := time.ParseDuration(rawTool.Timeout) + if err != nil { + return nil, fmt.Errorf("tool %s: invalid timeout: %w", rawTool.Name, err) + } + + tool := &CompositeToolConfig{ + Name: rawTool.Name, + Description: rawTool.Description, + Parameters: make(map[string]ParameterSchema), + Timeout: timeout, + } + + // Transform parameters + for name, paramMap := range rawTool.Parameters { + typeVal, ok := paramMap["type"] + if !ok { + return nil, fmt.Errorf("tool %s, parameter %s: missing 'type' field", rawTool.Name, name) + } + typeStr, ok := typeVal.(string) + if !ok { + return nil, fmt.Errorf("tool %s, parameter %s: 'type' field must be a string", rawTool.Name, name) + } + param := ParameterSchema{ + Type: typeStr, + } + if def, ok := paramMap["default"]; ok { + param.Default = def + } + tool.Parameters[name] = param + } + + // Transform steps + for _, rawStep := range rawTool.Steps { + step, err := l.transformWorkflowStep(rawStep) + if err != nil { + return nil, fmt.Errorf("tool %s, step %s: %w", rawTool.Name, rawStep.ID, err) + } + tool.Steps = append(tool.Steps, step) + } + + tools = append(tools, tool) + } + + return tools, nil +} + +func (*YAMLLoader) transformWorkflowStep(raw *rawWorkflowStep) (*WorkflowStepConfig, error) { + step := &WorkflowStepConfig{ + ID: raw.ID, + Type: raw.Type, + Tool: raw.Tool, + Arguments: raw.Arguments, + Condition: raw.Condition, + DependsOn: raw.DependsOn, + Message: raw.Message, + Schema: raw.Schema, + } + + if raw.Timeout != "" { + timeout, err := time.ParseDuration(raw.Timeout) + if err != nil { + return nil, fmt.Errorf("invalid timeout: %w", err) + } + step.Timeout = timeout + } else if raw.Type == "elicitation" { + // Set default timeout for elicitation steps + step.Timeout = 5 * time.Minute + } + + if raw.OnError != nil { + step.OnError = &StepErrorHandling{ + Action: raw.OnError.Action, + RetryCount: raw.OnError.RetryCount, + } + if raw.OnError.RetryDelay != "" { + retryDelay, err := time.ParseDuration(raw.OnError.RetryDelay) + if err != nil { + return nil, fmt.Errorf("invalid retry_delay: %w", err) + } + step.OnError.RetryDelay = retryDelay + } + } + + if raw.OnDecline != nil { + step.OnDecline = &ElicitationResponseConfig{ + Action: raw.OnDecline.Action, + } + } + + if raw.OnCancel != nil { + step.OnCancel = &ElicitationResponseConfig{ + Action: raw.OnCancel.Action, + } + } + + return step, nil +} diff --git a/pkg/vmcp/config/yaml_loader_test.go b/pkg/vmcp/config/yaml_loader_test.go new file mode 100644 index 000000000..acd96df34 --- /dev/null +++ b/pkg/vmcp/config/yaml_loader_test.go @@ -0,0 +1,488 @@ +package config + +import ( + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/stacklok/toolhive/pkg/vmcp" +) + +func TestYAMLLoader_Load(t *testing.T) { + t.Parallel() + tests := []struct { + name string + yaml string + envVars map[string]string + want func(*testing.T, *Config) + wantErr bool + errMsg string + }{ + { + name: "valid minimal configuration", + yaml: ` +name: test-vmcp +group: test-group + +incoming_auth: + type: anonymous + +outgoing_auth: + source: inline + default: + type: pass_through + +aggregation: + conflict_resolution: prefix + conflict_resolution_config: + prefix_format: "{workload}_" +`, + want: func(t *testing.T, cfg *Config) { + t.Helper() + if cfg.Name != "test-vmcp" { + t.Errorf("Name = %v, want test-vmcp", cfg.Name) + } + if cfg.GroupRef != "test-group" { + t.Errorf("GroupRef = %v, want test-group", cfg.GroupRef) + } + if cfg.IncomingAuth.Type != "anonymous" { + t.Errorf("IncomingAuth.Type = %v, want anonymous", cfg.IncomingAuth.Type) + } + if cfg.OutgoingAuth.Source != "inline" { + t.Errorf("OutgoingAuth.Source = %v, want inline", cfg.OutgoingAuth.Source) + } + if cfg.Aggregation.ConflictResolution != vmcp.ConflictStrategyPrefix { + t.Errorf("ConflictResolution = %v, want prefix", cfg.Aggregation.ConflictResolution) + } + }, + wantErr: false, + }, + { + name: "valid OIDC configuration with env vars", + yaml: ` +name: test-vmcp +group: test-group + +incoming_auth: + type: oidc + oidc: + issuer: https://auth.example.com + client_id: test-client + client_secret_env: TEST_SECRET + audience: vmcp + scopes: + - openid + - profile + +outgoing_auth: + source: inline + default: + type: pass_through + +aggregation: + conflict_resolution: prefix + conflict_resolution_config: + prefix_format: "{workload}_" +`, + envVars: map[string]string{ + "TEST_SECRET": "my-secret-value", + }, + want: func(t *testing.T, cfg *Config) { + t.Helper() + if cfg.IncomingAuth.Type != "oidc" { + t.Errorf("IncomingAuth.Type = %v, want oidc", cfg.IncomingAuth.Type) + } + if cfg.IncomingAuth.OIDC == nil { + t.Fatal("IncomingAuth.OIDC is nil") + } + if cfg.IncomingAuth.OIDC.Issuer != "https://auth.example.com" { + t.Errorf("OIDC.Issuer = %v, want https://auth.example.com", cfg.IncomingAuth.OIDC.Issuer) + } + if cfg.IncomingAuth.OIDC.ClientSecret != "my-secret-value" { + t.Errorf("OIDC.ClientSecret = %v, want my-secret-value", cfg.IncomingAuth.OIDC.ClientSecret) + } + }, + wantErr: false, + }, + { + name: "valid configuration with token cache", + yaml: ` +name: test-vmcp +group: test-group + +incoming_auth: + type: anonymous + +outgoing_auth: + source: inline + default: + type: pass_through + +aggregation: + conflict_resolution: prefix + conflict_resolution_config: + prefix_format: "{workload}_" + +token_cache: + provider: memory + config: + max_entries: 1000 + ttl_offset: 5m +`, + want: func(t *testing.T, cfg *Config) { + t.Helper() + if cfg.TokenCache == nil { + t.Fatal("TokenCache is nil") + } + if cfg.TokenCache.Provider != CacheProviderMemory { + t.Errorf("TokenCache.Provider = %v, want memory", cfg.TokenCache.Provider) + } + if cfg.TokenCache.Memory == nil { + t.Fatal("TokenCache.Memory is nil") + } + if cfg.TokenCache.Memory.MaxEntries != 1000 { + t.Errorf("Memory.MaxEntries = %v, want 1000", cfg.TokenCache.Memory.MaxEntries) + } + if cfg.TokenCache.Memory.TTLOffset != 5*time.Minute { + t.Errorf("Memory.TTLOffset = %v, want 5m", cfg.TokenCache.Memory.TTLOffset) + } + }, + wantErr: false, + }, + { + name: "valid configuration with composite tools", + yaml: ` +name: test-vmcp +group: test-group + +incoming_auth: + type: anonymous + +outgoing_auth: + source: inline + default: + type: pass_through + +aggregation: + conflict_resolution: prefix + conflict_resolution_config: + prefix_format: "{workload}_" + +composite_tools: + - name: deploy_workflow + description: Deploy and notify + parameters: + pr_number: + type: integer + timeout: 30m + steps: + - id: merge + type: tool + tool: github.merge_pr + arguments: + pr: "{{.params.pr_number}}" + - id: notify + type: tool + tool: slack.post_message + arguments: + message: "Deployed PR {{.params.pr_number}}" + depends_on: + - merge +`, + want: func(t *testing.T, cfg *Config) { + t.Helper() + if len(cfg.CompositeTools) != 1 { + t.Fatalf("CompositeTools length = %v, want 1", len(cfg.CompositeTools)) + } + tool := cfg.CompositeTools[0] + if tool.Name != "deploy_workflow" { + t.Errorf("Tool.Name = %v, want deploy_workflow", tool.Name) + } + if tool.Timeout != 30*time.Minute { + t.Errorf("Tool.Timeout = %v, want 30m", tool.Timeout) + } + if len(tool.Steps) != 2 { + t.Errorf("Tool.Steps length = %v, want 2", len(tool.Steps)) + } + }, + wantErr: false, + }, + { + name: "invalid YAML syntax", + yaml: ` +name: test-vmcp +group: test-group +incoming_auth + type: anonymous +`, + wantErr: true, + errMsg: "failed to parse YAML", + }, + { + name: "missing environment variable", + yaml: ` +name: test-vmcp +group: test-group + +incoming_auth: + type: oidc + oidc: + issuer: https://auth.example.com + client_id: test-client + client_secret_env: MISSING_VAR + audience: vmcp + +outgoing_auth: + source: inline + default: + type: pass_through + +aggregation: + conflict_resolution: prefix + conflict_resolution_config: + prefix_format: "{workload}_" +`, + wantErr: true, + errMsg: "environment variable MISSING_VAR not set", + }, + { + name: "invalid duration format", + yaml: ` +name: test-vmcp +group: test-group + +incoming_auth: + type: anonymous + +outgoing_auth: + source: inline + default: + type: pass_through + +aggregation: + conflict_resolution: prefix + conflict_resolution_config: + prefix_format: "{workload}_" + +token_cache: + provider: memory + config: + max_entries: 1000 + ttl_offset: invalid-duration +`, + wantErr: true, + errMsg: "invalid ttl_offset", + }, + { + name: "composite tool with missing parameter type", + yaml: ` +name: test-vmcp +group: test-group + +incoming_auth: + type: anonymous + +outgoing_auth: + source: inline + default: + type: pass_through + +aggregation: + conflict_resolution: prefix + conflict_resolution_config: + prefix_format: "{workload}_" + +composite_tools: + - name: test_tool + description: Test tool + timeout: 5m + parameters: + param1: + default: "value" + steps: + - id: step1 + type: tool + tool: some.tool +`, + wantErr: true, + errMsg: "missing 'type' field", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + // Set up environment variables + for k, v := range tt.envVars { + os.Setenv(k, v) + defer os.Unsetenv(k) + } + + // Create temporary file with YAML content + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(tmpFile, []byte(tt.yaml), 0644); err != nil { + t.Fatalf("Failed to write temp file: %v", err) + } + + // Load configuration + loader := NewYAMLLoader(tmpFile) + cfg, err := loader.Load() + + // Check error expectation + if (err != nil) != tt.wantErr { + t.Errorf("Load() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr && err != nil && tt.errMsg != "" { + if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("Load() error message = %v, want to contain %v", err.Error(), tt.errMsg) + } + return + } + + // Verify configuration + if tt.want != nil && cfg != nil { + tt.want(t, cfg) + } + }) + } +} + +func TestYAMLLoader_LoadFileNotFound(t *testing.T) { + t.Parallel() + loader := NewYAMLLoader("/non/existent/file.yaml") + _, err := loader.Load() + + if err == nil { + t.Error("Load() expected error for non-existent file, got nil") + } + + if !strings.Contains(err.Error(), "failed to read config file") { + t.Errorf("Load() error = %v, want to contain 'failed to read config file'", err) + } +} + +func TestYAMLLoader_IntegrationWithValidator(t *testing.T) { + t.Parallel() + tests := []struct { + name string + yaml string + envVars map[string]string + shouldPass bool + errMsg string + }{ + { + name: "valid configuration passes validation", + yaml: ` +name: test-vmcp +group: test-group + +incoming_auth: + type: anonymous + +outgoing_auth: + source: inline + default: + type: pass_through + +aggregation: + conflict_resolution: prefix + conflict_resolution_config: + prefix_format: "{workload}_" +`, + shouldPass: true, + }, + { + name: "configuration with missing name fails validation", + yaml: ` +group: test-group + +incoming_auth: + type: anonymous + +outgoing_auth: + source: inline + default: + type: pass_through + +aggregation: + conflict_resolution: prefix + conflict_resolution_config: + prefix_format: "{workload}_" +`, + shouldPass: false, + errMsg: "name is required", + }, + { + name: "configuration with invalid auth type fails validation", + yaml: ` +name: test-vmcp +group: test-group + +incoming_auth: + type: invalid_type + +outgoing_auth: + source: inline + default: + type: pass_through + +aggregation: + conflict_resolution: prefix + conflict_resolution_config: + prefix_format: "{workload}_" +`, + shouldPass: false, + errMsg: "incoming_auth.type must be one of", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + // Set up environment variables + for k, v := range tt.envVars { + os.Setenv(k, v) + defer os.Unsetenv(k) + } + + // Create temporary file + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(tmpFile, []byte(tt.yaml), 0644); err != nil { + t.Fatalf("Failed to write temp file: %v", err) + } + + // Load and validate + loader := NewYAMLLoader(tmpFile) + cfg, err := loader.Load() + if err != nil { + if tt.shouldPass { + t.Fatalf("Load() unexpected error = %v", err) + } + return + } + + validator := NewValidator() + err = validator.Validate(cfg) + + if tt.shouldPass && err != nil { + t.Errorf("Validate() unexpected error = %v", err) + } + + if !tt.shouldPass && err == nil { + t.Error("Validate() expected error, got nil") + } + + if !tt.shouldPass && err != nil && tt.errMsg != "" { + if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("Validate() error = %v, want to contain %v", err.Error(), tt.errMsg) + } + } + }) + } +}