diff --git a/pkg/ratelimit/middleware.go b/pkg/ratelimit/middleware.go index 17137e3d81..c7ecfafd20 100644 --- a/pkg/ratelimit/middleware.go +++ b/pkg/ratelimit/middleware.go @@ -16,6 +16,7 @@ import ( "github.com/redis/go-redis/v9" v1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" + "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/mcp" "github.com/stacklok/toolhive/pkg/transport/types" ) @@ -119,7 +120,14 @@ func rateLimitHandler(limiter Limiter) types.MiddlewareFunction { return } - decision, err := limiter.Allow(r.Context(), parsed.ResourceID, "") + // When no identity is present (unauthenticated), userID stays empty + // and per-user buckets are skipped — only shared limits apply. CEL + // validation ensures perUser rate limits require auth to be enabled. + var userID string + if identity, ok := auth.IdentityFromContext(r.Context()); ok { + userID = identity.Subject + } + decision, err := limiter.Allow(r.Context(), parsed.ResourceID, userID) if err != nil { slog.Warn("rate limit check failed, allowing request", "error", err) next.ServeHTTP(w, r) diff --git a/pkg/ratelimit/middleware_test.go b/pkg/ratelimit/middleware_test.go index 41ce09bccf..ed76e72e0c 100644 --- a/pkg/ratelimit/middleware_test.go +++ b/pkg/ratelimit/middleware_test.go @@ -16,6 +16,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/mcp" ) @@ -29,6 +30,25 @@ func (d *dummyLimiter) Allow(context.Context, string, string) (*Decision, error) return d.decision, d.err } +// recordingLimiter captures the arguments passed to Allow. +type recordingLimiter struct { + toolName string + userID string +} + +func (r *recordingLimiter) Allow(_ context.Context, toolName, userID string) (*Decision, error) { + r.toolName = toolName + r.userID = userID + return &Decision{Allowed: true}, nil +} + +// withIdentity adds an auth.Identity with the given subject to the request context. +func withIdentity(r *http.Request, subject string) *http.Request { + identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: subject}} + ctx := auth.WithIdentity(r.Context(), identity) + return r.WithContext(ctx) +} + // withParsedMCPRequest adds a ParsedMCPRequest to the request context. func withParsedMCPRequest(r *http.Request, method, resourceID string, id any) *http.Request { parsed := &mcp.ParsedMCPRequest{ @@ -148,3 +168,43 @@ func TestRateLimitHandler_NonToolCallPassesThrough(t *testing.T) { assert.True(t, nextCalled, "non-tools/call should pass through regardless of limiter") assert.Equal(t, http.StatusOK, w.Code) } + +func TestRateLimitHandler_PassesUserID(t *testing.T) { + t.Parallel() + + recorder := &recordingLimiter{} + handler := rateLimitHandler(recorder)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + req = withParsedMCPRequest(req, "tools/call", "echo", 1) + req = withIdentity(req, "alice@example.com") + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "echo", recorder.toolName) + assert.Equal(t, "alice@example.com", recorder.userID) +} + +func TestRateLimitHandler_NoIdentityPassesEmptyUserID(t *testing.T) { + t.Parallel() + + recorder := &recordingLimiter{} + handler := rateLimitHandler(recorder)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + req = withParsedMCPRequest(req, "tools/call", "echo", 1) + // No identity in context — unauthenticated request. + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "echo", recorder.toolName) + assert.Empty(t, recorder.userID, "unauthenticated requests should pass empty userID") +} diff --git a/test/e2e/thv-operator/acceptance_tests/helpers.go b/test/e2e/thv-operator/acceptance_tests/helpers.go index c861e4e775..7e5ff06bfc 100644 --- a/test/e2e/thv-operator/acceptance_tests/helpers.go +++ b/test/e2e/thv-operator/acceptance_tests/helpers.go @@ -16,6 +16,7 @@ import ( "github.com/onsi/gomega" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/intstr" "k8s.io/utils/ptr" @@ -24,9 +25,9 @@ import ( "github.com/stacklok/toolhive/test/e2e/images" ) -// DeployRedis creates a Redis Deployment and Service in the given namespace. -// No password is configured — matches the default empty THV_SESSION_REDIS_PASSWORD. -func DeployRedis(ctx context.Context, c client.Client, namespace string, timeout, pollingInterval time.Duration) { +// EnsureRedis creates a Redis Deployment and Service if they don't already exist, +// then waits for Redis to be ready. Safe to call concurrently from multiple test blocks. +func EnsureRedis(ctx context.Context, c client.Client, namespace string, timeout, pollingInterval time.Duration) { labels := map[string]string{"app": "redis"} deployment := &appsv1.Deployment{ @@ -51,7 +52,9 @@ func DeployRedis(ctx context.Context, c client.Client, namespace string, timeout }, }, } - gomega.Expect(c.Create(ctx, deployment)).To(gomega.Succeed()) + if err := c.Create(ctx, deployment); err != nil && !apierrors.IsAlreadyExists(err) { + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + } service := &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ @@ -65,7 +68,9 @@ func DeployRedis(ctx context.Context, c client.Client, namespace string, timeout }, }, } - gomega.Expect(c.Create(ctx, service)).To(gomega.Succeed()) + if err := c.Create(ctx, service); err != nil && !apierrors.IsAlreadyExists(err) { + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + } ginkgo.By("Waiting for Redis to be ready") gomega.Eventually(func() error { @@ -99,7 +104,7 @@ func CleanupRedis(ctx context.Context, c client.Client, namespace string) { } // SendToolCall sends a JSON-RPC tools/call request and returns the HTTP status code and body. -func SendToolCall(httpClient *http.Client, port int32, toolName string, requestID int) (int, []byte) { +func SendToolCall(ctx context.Context, httpClient *http.Client, port int32, toolName string, requestID int) (int, []byte) { reqBody := map[string]any{ "jsonrpc": "2.0", "id": requestID, @@ -113,7 +118,7 @@ func SendToolCall(httpClient *http.Client, port int32, toolName string, requestI gomega.Expect(err).ToNot(gomega.HaveOccurred()) url := fmt.Sprintf("http://localhost:%d/mcp", port) - req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, url, bytes.NewReader(bodyBytes)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes)) gomega.Expect(err).ToNot(gomega.HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json, text/event-stream") @@ -127,3 +132,104 @@ func SendToolCall(httpClient *http.Client, port int32, toolName string, requestI return resp.StatusCode, respBody } + +// SendInitialize sends a JSON-RPC initialize request and returns the session ID +// from the Mcp-Session header. This must be called before tools/call when auth is enabled. +func SendInitialize( + ctx context.Context, httpClient *http.Client, port int32, bearerToken string, +) (sessionID string) { + reqBody := map[string]any{ + "jsonrpc": "2.0", + "id": 0, + "method": "initialize", + "params": map[string]any{ + "protocolVersion": "2025-03-26", + "capabilities": map[string]any{}, + "clientInfo": map[string]any{ + "name": "e2e-test", + "version": "1.0.0", + }, + }, + } + bodyBytes, err := json.Marshal(reqBody) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + + url := fmt.Sprintf("http://localhost:%d/mcp", port) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes)) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + if bearerToken != "" { + req.Header.Set("Authorization", "Bearer "+bearerToken) + } + + resp, err := httpClient.Do(req) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + defer func() { _ = resp.Body.Close() }() + + gomega.Expect(resp.StatusCode).To(gomega.Equal(http.StatusOK), + "initialize should succeed") + + sessionID = resp.Header.Get("Mcp-Session-Id") + gomega.Expect(sessionID).ToNot(gomega.BeEmpty(), + "initialize response should include Mcp-Session-Id header") + + return sessionID +} + +// SendAuthenticatedToolCallWithSession sends a JSON-RPC tools/call with Bearer token and session ID. +func SendAuthenticatedToolCallWithSession( + ctx context.Context, httpClient *http.Client, port int32, toolName string, requestID int, bearerToken, sessionID string, +) (int, []byte, string) { + reqBody := map[string]any{ + "jsonrpc": "2.0", + "id": requestID, + "method": "tools/call", + "params": map[string]any{ + "name": toolName, + "arguments": map[string]any{"input": "test"}, + }, + } + bodyBytes, err := json.Marshal(reqBody) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + + url := fmt.Sprintf("http://localhost:%d/mcp", port) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes)) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Authorization", "Bearer "+bearerToken) + if sessionID != "" { + req.Header.Set("Mcp-Session-Id", sessionID) + } + + resp, err := httpClient.Do(req) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + defer func() { _ = resp.Body.Close() }() + + retryAfter := resp.Header.Get("Retry-After") + + respBody, err := io.ReadAll(resp.Body) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + + return resp.StatusCode, respBody, retryAfter +} + +// GetOIDCToken fetches a JWT from the mock OIDC server for the given subject. +func GetOIDCToken(ctx context.Context, httpClient *http.Client, oidcNodePort int32, subject string) string { + url := fmt.Sprintf("http://localhost:%d/token?subject=%s", oidcNodePort, subject) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + + resp, err := httpClient.Do(req) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + defer func() { _ = resp.Body.Close() }() + + var tokenResp struct { + AccessToken string `json:"access_token"` + } + gomega.Expect(json.NewDecoder(resp.Body).Decode(&tokenResp)).To(gomega.Succeed()) + gomega.Expect(tokenResp.AccessToken).ToNot(gomega.BeEmpty(), "OIDC server should return a token") + + return tokenResp.AccessToken +} diff --git a/test/e2e/thv-operator/acceptance_tests/ratelimit_test.go b/test/e2e/thv-operator/acceptance_tests/ratelimit_test.go index 53a6ec78db..09a4284c67 100644 --- a/test/e2e/thv-operator/acceptance_tests/ratelimit_test.go +++ b/test/e2e/thv-operator/acceptance_tests/ratelimit_test.go @@ -22,10 +22,8 @@ import ( var _ = Describe("MCPServer Rate Limiting", Ordered, func() { var ( testNamespace = "default" - serverName = "ratelimit-test" timeout = 3 * time.Minute pollingInterval = 1 * time.Second - nodePort int32 httpClient *http.Client ) @@ -33,145 +31,294 @@ var _ = Describe("MCPServer Rate Limiting", Ordered, func() { httpClient = &http.Client{Timeout: 10 * time.Second} By("Deploying Redis for session storage and rate limiting") - DeployRedis(ctx, k8sClient, testNamespace, timeout, pollingInterval) - - By("Creating MCPServer with shared rate limit (maxTokens=3, refillPeriod=1m)") - server := &mcpv1alpha1.MCPServer{ - ObjectMeta: metav1.ObjectMeta{ - Name: serverName, - Namespace: testNamespace, - }, - Spec: mcpv1alpha1.MCPServerSpec{ - Image: images.YardstickServerImage, - Transport: "streamable-http", - ProxyPort: 8080, - MCPPort: 8080, - Env: []mcpv1alpha1.EnvVar{ - {Name: "TRANSPORT", Value: "streamable-http"}, - }, - SessionStorage: &mcpv1alpha1.SessionStorageConfig{ - Provider: "redis", - Address: fmt.Sprintf("redis.%s.svc.cluster.local:6379", testNamespace), + EnsureRedis(ctx, k8sClient, testNamespace, timeout, pollingInterval) + }) + + AfterAll(func() { + By("Cleaning up Redis") + CleanupRedis(ctx, k8sClient, testNamespace) + }) + + Context("shared rate limits", Ordered, func() { + var ( + serverName = "ratelimit-test" + nodePort int32 + ) + + BeforeAll(func() { + By("Creating MCPServer with shared rate limit (maxTokens=3, refillPeriod=1m)") + server := &mcpv1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: serverName, + Namespace: testNamespace, }, - RateLimiting: &mcpv1alpha1.RateLimitConfig{ - Shared: &mcpv1alpha1.RateLimitBucket{ - MaxTokens: 3, - RefillPeriod: metav1.Duration{Duration: time.Minute}, + Spec: mcpv1alpha1.MCPServerSpec{ + Image: images.YardstickServerImage, + Transport: "streamable-http", + ProxyPort: 8080, + MCPPort: 8080, + Env: []mcpv1alpha1.EnvVar{ + {Name: "TRANSPORT", Value: "streamable-http"}, + }, + SessionStorage: &mcpv1alpha1.SessionStorageConfig{ + Provider: "redis", + Address: fmt.Sprintf("redis.%s.svc.cluster.local:6379", testNamespace), + }, + RateLimiting: &mcpv1alpha1.RateLimitConfig{ + Shared: &mcpv1alpha1.RateLimitBucket{ + MaxTokens: 3, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, }, }, - }, - } - Expect(k8sClient.Create(ctx, server)).To(Succeed()) + } + Expect(k8sClient.Create(ctx, server)).To(Succeed()) - By("Waiting for MCPServer to be running") - testutil.WaitForMCPServerRunning(ctx, k8sClient, serverName, testNamespace, timeout, pollingInterval) + By("Waiting for MCPServer to be running") + testutil.WaitForMCPServerRunning(ctx, k8sClient, serverName, testNamespace, timeout, pollingInterval) - By("Creating NodePort service for MCPServer proxy") - testutil.CreateNodePortService(ctx, k8sClient, serverName, testNamespace) + By("Creating NodePort service for MCPServer proxy") + testutil.CreateNodePortService(ctx, k8sClient, serverName, testNamespace) - By("Getting NodePort") - nodePort = testutil.GetNodePort(ctx, k8sClient, serverName+"-nodeport", testNamespace, timeout, pollingInterval) - GinkgoWriter.Printf("MCPServer accessible at http://localhost:%d\n", nodePort) + By("Getting NodePort") + nodePort = testutil.GetNodePort(ctx, k8sClient, serverName+"-nodeport", testNamespace, timeout, pollingInterval) + GinkgoWriter.Printf("MCPServer accessible at http://localhost:%d\n", nodePort) - By("Waiting for proxy to be reachable") - Eventually(func() error { - resp, err := httpClient.Get(fmt.Sprintf("http://localhost:%d/health", nodePort)) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("health check returned %d", resp.StatusCode) - } - return nil - }, 2*time.Minute, pollingInterval).Should(Succeed()) - }) + By("Waiting for proxy to be reachable") + Eventually(func() error { + resp, err := httpClient.Get(fmt.Sprintf("http://localhost:%d/health", nodePort)) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("health check returned %d", resp.StatusCode) + } + return nil + }, 2*time.Minute, pollingInterval).Should(Succeed()) + }) - AfterAll(func() { - By("Cleaning up NodePort service") - _ = k8sClient.Delete(ctx, &corev1.Service{ - ObjectMeta: metav1.ObjectMeta{Name: serverName + "-nodeport", Namespace: testNamespace}, + AfterAll(func() { + By("Cleaning up NodePort service") + _ = k8sClient.Delete(ctx, &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{Name: serverName + "-nodeport", Namespace: testNamespace}, + }) + By("Cleaning up MCPServer") + _ = k8sClient.Delete(ctx, &mcpv1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: serverName, Namespace: testNamespace}, + }) }) - By("Cleaning up MCPServer") - _ = k8sClient.Delete(ctx, &mcpv1alpha1.MCPServer{ - ObjectMeta: metav1.ObjectMeta{Name: serverName, Namespace: testNamespace}, + It("should reject requests after shared limit exceeded (AC7)", func() { + By("Sending 3 requests within the rate limit — all should succeed") + for i := range 3 { + status, body := SendToolCall(ctx, httpClient, nodePort, "echo", i+1) + Expect(status).To(Equal(http.StatusOK), + "request %d should succeed, got status %d: %s", i+1, status, string(body)) + } + + By("Sending a 4th request — should be rate limited with HTTP 429") + status, body := SendToolCall(ctx, httpClient, nodePort, "echo", 4) + Expect(status).To(Equal(http.StatusTooManyRequests), + "4th request should be rate limited, body: %s", string(body)) + + By("Verifying JSON-RPC error code -32029") + var resp map[string]any + Expect(json.Unmarshal(body, &resp)).To(Succeed()) + + errObj, ok := resp["error"].(map[string]any) + Expect(ok).To(BeTrue(), "response should have error object") + Expect(errObj["code"]).To(BeEquivalentTo(-32029)) + Expect(errObj["message"]).To(Equal("Rate limit exceeded")) + + data, ok := errObj["data"].(map[string]any) + Expect(ok).To(BeTrue(), "error should have data object") + Expect(data["retryAfterSeconds"]).To(BeNumerically(">", 0)) }) - By("Cleaning up Redis") - CleanupRedis(ctx, k8sClient, testNamespace) - }) + It("should accept CRD with both shared and per-tool config (AC8)", func() { + By("Creating a second MCPServer with both shared and tools config") + server2Name := "ratelimit-both" + server2 := &mcpv1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: server2Name, + Namespace: testNamespace, + }, + Spec: mcpv1alpha1.MCPServerSpec{ + Image: images.YardstickServerImage, + Transport: "streamable-http", + ProxyPort: 8080, + MCPPort: 8080, + Env: []mcpv1alpha1.EnvVar{ + {Name: "TRANSPORT", Value: "streamable-http"}, + }, + SessionStorage: &mcpv1alpha1.SessionStorageConfig{ + Provider: "redis", + Address: fmt.Sprintf("redis.%s.svc.cluster.local:6379", testNamespace), + }, + RateLimiting: &mcpv1alpha1.RateLimitConfig{ + Shared: &mcpv1alpha1.RateLimitBucket{ + MaxTokens: 100, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + Tools: []mcpv1alpha1.ToolRateLimitConfig{ + { + Name: "echo", + Shared: &mcpv1alpha1.RateLimitBucket{ + MaxTokens: 10, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + }, + }, + }, + }, + } + Expect(k8sClient.Create(ctx, server2)).To(Succeed()) - It("should reject requests after shared limit exceeded (AC7)", func() { - By("Sending 3 requests within the rate limit — all should succeed") - for i := range 3 { - status, body := SendToolCall(httpClient, nodePort, "echo", i+1) - Expect(status).To(Equal(http.StatusOK), - "request %d should succeed, got status %d: %s", i+1, status, string(body)) - } - - By("Sending a 4th request — should be rate limited with HTTP 429") - status, body := SendToolCall(httpClient, nodePort, "echo", 4) - Expect(status).To(Equal(http.StatusTooManyRequests), - "4th request should be rate limited, body: %s", string(body)) - - By("Verifying JSON-RPC error code -32029") - var resp map[string]any - Expect(json.Unmarshal(body, &resp)).To(Succeed()) - - errObj, ok := resp["error"].(map[string]any) - Expect(ok).To(BeTrue(), "response should have error object") - Expect(errObj["code"]).To(BeEquivalentTo(-32029)) - Expect(errObj["message"]).To(Equal("Rate limit exceeded")) - - data, ok := errObj["data"].(map[string]any) - Expect(ok).To(BeTrue(), "error should have data object") - Expect(data["retryAfterSeconds"]).To(BeNumerically(">", 0)) + By("Waiting for MCPServer with both configs to be running") + testutil.WaitForMCPServerRunning(ctx, k8sClient, server2Name, testNamespace, timeout, pollingInterval) + + By("Cleaning up second MCPServer") + _ = k8sClient.Delete(ctx, server2) + }) }) - It("should accept CRD with both shared and per-tool config (AC8)", func() { - By("Creating a second MCPServer with both shared and tools config") - server2Name := "ratelimit-both" - server2 := &mcpv1alpha1.MCPServer{ - ObjectMeta: metav1.ObjectMeta{ - Name: server2Name, - Namespace: testNamespace, - }, - Spec: mcpv1alpha1.MCPServerSpec{ - Image: images.YardstickServerImage, - Transport: "streamable-http", - ProxyPort: 8080, - MCPPort: 8080, - Env: []mcpv1alpha1.EnvVar{ - {Name: "TRANSPORT", Value: "streamable-http"}, - }, - SessionStorage: &mcpv1alpha1.SessionStorageConfig{ - Provider: "redis", - Address: fmt.Sprintf("redis.%s.svc.cluster.local:6379", testNamespace), + Context("per-user rate limits", Ordered, func() { + var ( + serverName = "peruser-rl-test" + oidcServerName = "oidc-peruser-rl" + nodePort int32 + oidcNodePort int32 + oidcCleanup func() + ) + + BeforeAll(func() { + By("Deploying mock OIDC server for per-user identity") + var issuerURL string + issuerURL, oidcNodePort, oidcCleanup = testutil.DeployParameterizedOIDCServer( + ctx, k8sClient, oidcServerName, testNamespace, timeout, pollingInterval) + GinkgoWriter.Printf("Mock OIDC server: issuer=%s nodePort=%d\n", issuerURL, oidcNodePort) + + By("Creating MCPServer with per-user rate limit and inline OIDC auth") + server := &mcpv1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: serverName, + Namespace: testNamespace, }, - RateLimiting: &mcpv1alpha1.RateLimitConfig{ - Shared: &mcpv1alpha1.RateLimitBucket{ - MaxTokens: 100, - RefillPeriod: metav1.Duration{Duration: time.Minute}, + Spec: mcpv1alpha1.MCPServerSpec{ + Image: images.YardstickServerImage, + Transport: "streamable-http", + ProxyPort: 8080, + MCPPort: 8080, + Env: []mcpv1alpha1.EnvVar{ + {Name: "TRANSPORT", Value: "streamable-http"}, }, - Tools: []mcpv1alpha1.ToolRateLimitConfig{ - { - Name: "echo", - Shared: &mcpv1alpha1.RateLimitBucket{ - MaxTokens: 10, - RefillPeriod: metav1.Duration{Duration: time.Minute}, - }, + SessionStorage: &mcpv1alpha1.SessionStorageConfig{ + Provider: "redis", + Address: fmt.Sprintf("redis.%s.svc.cluster.local:6379", testNamespace), + }, + OIDCConfig: &mcpv1alpha1.OIDCConfigRef{ + Type: "inline", + Inline: &mcpv1alpha1.InlineOIDCConfig{ + Issuer: issuerURL, + Audience: "vmcp-audience", + JWKSAllowPrivateIP: true, + InsecureAllowHTTP: true, + }, + }, + RateLimiting: &mcpv1alpha1.RateLimitConfig{ + PerUser: &mcpv1alpha1.RateLimitBucket{ + MaxTokens: 2, + RefillPeriod: metav1.Duration{Duration: time.Minute}, }, }, }, - }, - } - Expect(k8sClient.Create(ctx, server2)).To(Succeed()) + } + Expect(k8sClient.Create(ctx, server)).To(Succeed()) - By("Waiting for MCPServer with both configs to be running") - testutil.WaitForMCPServerRunning(ctx, k8sClient, server2Name, testNamespace, timeout, pollingInterval) + By("Waiting for MCPServer to be running") + testutil.WaitForMCPServerRunning(ctx, k8sClient, serverName, testNamespace, timeout, pollingInterval) + + By("Creating NodePort service for MCPServer proxy") + testutil.CreateNodePortService(ctx, k8sClient, serverName, testNamespace) + + By("Getting NodePort") + nodePort = testutil.GetNodePort(ctx, k8sClient, serverName+"-nodeport", testNamespace, timeout, pollingInterval) + GinkgoWriter.Printf("MCPServer accessible at http://localhost:%d\n", nodePort) + + By("Waiting for proxy to be reachable") + Eventually(func() error { + resp, err := httpClient.Get(fmt.Sprintf("http://localhost:%d/health", nodePort)) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("health check returned %d", resp.StatusCode) + } + return nil + }, 2*time.Minute, pollingInterval).Should(Succeed()) + }) - By("Cleaning up second MCPServer") - _ = k8sClient.Delete(ctx, server2) + AfterAll(func() { + By("Cleaning up NodePort service") + _ = k8sClient.Delete(ctx, &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{Name: serverName + "-nodeport", Namespace: testNamespace}, + }) + By("Cleaning up MCPServer") + _ = k8sClient.Delete(ctx, &mcpv1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: serverName, Namespace: testNamespace}, + }) + By("Cleaning up OIDC server") + if oidcCleanup != nil { + oidcCleanup() + } + }) + + It("should reject user after per-user limit exceeded and allow independent user (AC11, AC12)", func() { + By("Getting JWT for user-a") + tokenA := GetOIDCToken(ctx, httpClient, oidcNodePort, "user-a") + + By("Initializing MCP session for user-a") + sessionA := SendInitialize(ctx, httpClient, nodePort, tokenA) + + By("Sending 2 requests as user-a — all should succeed") + for i := range 2 { + status, body, _ := SendAuthenticatedToolCallWithSession(ctx, httpClient, nodePort, "echo", i+1, tokenA, sessionA) + Expect(status).To(Equal(http.StatusOK), + "user-a request %d should succeed, got status %d: %s", i+1, status, string(body)) + } + + By("Sending a 3rd request as user-a — should be rate limited with HTTP 429") + status, body, retryAfter := SendAuthenticatedToolCallWithSession(ctx, httpClient, nodePort, "echo", 3, tokenA, sessionA) + Expect(status).To(Equal(http.StatusTooManyRequests), + "user-a 3rd request should be rate limited, body: %s", string(body)) + + By("Verifying Retry-After header is present (AC12)") + Expect(retryAfter).ToNot(BeEmpty(), "Retry-After header should be set on 429 response") + + By("Verifying JSON-RPC error code -32029 with retryAfterSeconds") + var resp map[string]any + Expect(json.Unmarshal(body, &resp)).To(Succeed()) + + errObj, ok := resp["error"].(map[string]any) + Expect(ok).To(BeTrue(), "response should have error object") + Expect(errObj["code"]).To(BeEquivalentTo(-32029)) + + data, ok := errObj["data"].(map[string]any) + Expect(ok).To(BeTrue(), "error should have data object") + Expect(data["retryAfterSeconds"]).To(BeNumerically(">", 0)) + + By("Getting JWT for user-b") + tokenB := GetOIDCToken(ctx, httpClient, oidcNodePort, "user-b") + + By("Initializing MCP session for user-b") + sessionB := SendInitialize(ctx, httpClient, nodePort, tokenB) + + By("Sending request as user-b — should succeed (independent bucket)") + status, body, _ = SendAuthenticatedToolCallWithSession(ctx, httpClient, nodePort, "echo", 4, tokenB, sessionB) + Expect(status).To(Equal(http.StatusOK), + "user-b should not be blocked by user-a's limit, got status %d: %s", status, string(body)) + }) }) }) diff --git a/test/e2e/thv-operator/testutil/oidc.go b/test/e2e/thv-operator/testutil/oidc.go new file mode 100644 index 0000000000..02d8387211 --- /dev/null +++ b/test/e2e/thv-operator/testutil/oidc.go @@ -0,0 +1,199 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package testutil + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/onsi/ginkgo/v2" + "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/intstr" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +// DeployParameterizedOIDCServer deploys an in-cluster mock OIDC server that +// issues RSA-signed JWTs with a caller-controlled subject claim (via +// POST /token?subject=). The server is exposed via a NodePort so +// the test process (running outside the cluster) can reach it. +// +// Returns the in-cluster issuer URL (http://..svc.cluster.local) +// and a cleanup function that removes all created resources. +func DeployParameterizedOIDCServer( + ctx context.Context, + c client.Client, + name, namespace string, + timeout, pollingInterval time.Duration, +) (issuerURL string, allocatedNodePort int32, cleanup func()) { + configMapName := name + "-code" + + // Patch the placeholder issuer into the script so the JWT iss claim and + // the OIDC discovery document match the in-cluster service URL. + issuerURL = fmt.Sprintf("http://%s.%s.svc.cluster.local", name, namespace) + script := strings.ReplaceAll(parameterizedOIDCServerScript, + "http://OIDC_SERVICE_NAME.OIDC_NAMESPACE.svc.cluster.local", issuerURL) + + ginkgo.By("Creating ConfigMap with parameterized OIDC server code") + gomega.Expect(c.Create(ctx, &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{Name: configMapName, Namespace: namespace}, + Data: map[string]string{"server.py": script}, + })).To(gomega.Succeed()) + + ginkgo.By("Creating parameterized OIDC server pod") + gomega.Expect(c.Create(ctx, &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + Labels: map[string]string{"app": name}, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "oidc", + Image: "python:3.11-slim", + Command: []string{"sh", "-c", "pip install --no-cache-dir cryptography && python3 /app/server.py"}, + Ports: []corev1.ContainerPort{{ContainerPort: 8080}}, + ReadinessProbe: &corev1.Probe{ + ProbeHandler: corev1.ProbeHandler{ + HTTPGet: &corev1.HTTPGetAction{ + Path: "/.well-known/openid-configuration", + Port: intstr.FromInt(8080), + }, + }, + InitialDelaySeconds: 5, + PeriodSeconds: 2, + FailureThreshold: 30, + }, + VolumeMounts: []corev1.VolumeMount{{Name: "code", MountPath: "/app"}}, + }}, + Volumes: []corev1.Volume{{ + Name: "code", + VolumeSource: corev1.VolumeSource{ + ConfigMap: &corev1.ConfigMapVolumeSource{ + LocalObjectReference: corev1.LocalObjectReference{Name: configMapName}, + DefaultMode: ptr.To(int32(0755)), + }, + }, + }}, + }, + })).To(gomega.Succeed()) + + ginkgo.By("Creating parameterized OIDC server service with auto-assigned NodePort") + oidcSvc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{Name: name, Namespace: namespace}, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeNodePort, + Selector: map[string]string{"app": name}, + Ports: []corev1.ServicePort{{ + Port: 80, + TargetPort: intstr.FromInt(8080), + Protocol: corev1.ProtocolTCP, + }}, + }, + } + gomega.Expect(c.Create(ctx, oidcSvc)).To(gomega.Succeed()) + + // Read back the auto-assigned NodePort + gomega.Expect(c.Get(ctx, types.NamespacedName{Name: name, Namespace: namespace}, oidcSvc)).To(gomega.Succeed()) + allocatedNodePort = oidcSvc.Spec.Ports[0].NodePort + gomega.Expect(allocatedNodePort).NotTo(gomega.BeZero(), "Kubernetes should auto-assign a NodePort") + + ginkgo.By("Waiting for parameterized OIDC server to be ready") + gomega.Eventually(func() bool { + pod := &corev1.Pod{} + if err := c.Get(ctx, types.NamespacedName{Name: name, Namespace: namespace}, pod); err != nil { + return false + } + if pod.Status.Phase != corev1.PodRunning { + return false + } + for _, cond := range pod.Status.Conditions { + if cond.Type == corev1.PodReady && cond.Status == corev1.ConditionTrue { + return true + } + } + return false + }, timeout, pollingInterval).Should(gomega.BeTrue(), "parameterized OIDC server should be ready") + + cleanup = func() { + _ = c.Delete(ctx, &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: name, Namespace: namespace}}) + _ = c.Delete(ctx, &corev1.Service{ObjectMeta: metav1.ObjectMeta{Name: name, Namespace: namespace}}) + _ = c.Delete(ctx, &corev1.ConfigMap{ObjectMeta: metav1.ObjectMeta{Name: configMapName, Namespace: namespace}}) + // Wait for the Pod and Service to be fully removed so their NodePort + // and name can be reused immediately in a subsequent test run. + gomega.Eventually(func() bool { + pod := &corev1.Pod{} + svc := &corev1.Service{} + podGone := apierrors.IsNotFound(c.Get(ctx, types.NamespacedName{Name: name, Namespace: namespace}, pod)) + svcGone := apierrors.IsNotFound(c.Get(ctx, types.NamespacedName{Name: name, Namespace: namespace}, svc)) + return podGone && svcGone + }, timeout, pollingInterval).Should(gomega.BeTrue(), "OIDC server pod and service should be fully deleted") + } + return issuerURL, allocatedNodePort, cleanup +} + +// parameterizedOIDCServerScript is a minimal Python OIDC server that issues +// RSA-signed RS256 JWTs with a caller-controlled subject. +// +// Usage: POST /token?subject=alice → returns {"access_token": "", ...} +// The subject defaults to "test-user" when the query parameter is omitted. +const parameterizedOIDCServerScript = ` +import base64, json, time, http.server, socketserver +from urllib.parse import urlparse, parse_qs +from cryptography.hazmat.primitives.asymmetric import rsa, padding as asym_padding +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.backends import default_backend + +private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend()) +public_key = private_key.public_key() +pub_numbers = public_key.public_numbers() + +def to_b64url(num): + b = num.to_bytes((num.bit_length() + 7) // 8, byteorder="big") + return base64.urlsafe_b64encode(b).decode().rstrip("=") + +n_b64 = to_b64url(pub_numbers.n) +e_b64 = to_b64url(pub_numbers.e) +ISSUER = "http://OIDC_SERVICE_NAME.OIDC_NAMESPACE.svc.cluster.local" + +class H(http.server.BaseHTTPRequestHandler): + def do_GET(self): + if self.path == "/.well-known/openid-configuration": + self._json({"issuer": ISSUER, "authorization_endpoint": ISSUER+"/auth", + "token_endpoint": ISSUER+"/token", "jwks_uri": ISSUER+"/jwks", + "response_types_supported": ["code"], "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["RS256"]}) + elif self.path == "/jwks": + self._json({"keys": [{"kty": "RSA", "use": "sig", "kid": "k1", "alg": "RS256", "n": n_b64, "e": e_b64}]}) + else: + self.send_response(404); self.end_headers() + def do_POST(self): + if self.path.startswith("/token"): + params = parse_qs(urlparse(self.path).query) + sub = params.get("subject", ["test-user"])[0] + hdr = {"alg": "RS256", "typ": "JWT", "kid": "k1"} + pay = {"sub": sub, "iss": ISSUER, "aud": "vmcp-audience", "exp": int(time.time())+3600, "iat": int(time.time())} + def enc(d): return base64.urlsafe_b64encode(json.dumps(d, separators=(",",":")).encode()).decode().rstrip("=") + h64, p64 = enc(hdr), enc(pay) + sig = private_key.sign((h64+"."+p64).encode(), asym_padding.PKCS1v15(), hashes.SHA256()) + jwt = h64 + "." + p64 + "." + base64.urlsafe_b64encode(sig).decode().rstrip("=") + print(f"Issued JWT for sub={sub}", flush=True) + self._json({"access_token": jwt, "token_type": "Bearer", "expires_in": 3600}) + else: + self.send_response(404); self.end_headers() + def _json(self, obj): + body = json.dumps(obj).encode() + self.send_response(200); self.send_header("Content-Type","application/json"); self.end_headers(); self.wfile.write(body) + def log_message(self, f, *a): pass + +with socketserver.TCPServer(("", 8080), H) as s: + print("OIDC server ready on 8080", flush=True) + s.serve_forever() +` diff --git a/test/e2e/thv-operator/virtualmcp/helpers.go b/test/e2e/thv-operator/virtualmcp/helpers.go index 2ca2fdf21f..d68614afda 100644 --- a/test/e2e/thv-operator/virtualmcp/helpers.go +++ b/test/e2e/thv-operator/virtualmcp/helpers.go @@ -1195,188 +1195,15 @@ with socketserver.TCPServer(("", 8080), Handler) as httpd: } } -// ParameterizedOIDCServerScript is a minimal Python OIDC server that issues -// RSA-signed RS256 JWTs with a caller-controlled subject. -// -// Usage: POST /token?subject=alice → returns {"access_token": "", ...} -// The subject defaults to "test-user" when the query parameter is omitted. -// -// The issuer is derived from the service name: the server reads the HOST -// environment variable set by the caller via the ISSUER constant below. Tests -// that deploy this script must set the correct issuer URL in the VirtualMCPServer -// InlineOIDCConfig.Issuer field. -const ParameterizedOIDCServerScript = ` -import base64, json, time, http.server, socketserver -from urllib.parse import urlparse, parse_qs -from cryptography.hazmat.primitives.asymmetric import rsa, padding as asym_padding -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.backends import default_backend - -private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend()) -public_key = private_key.public_key() -pub_numbers = public_key.public_numbers() - -def to_b64url(num): - b = num.to_bytes((num.bit_length() + 7) // 8, byteorder="big") - return base64.urlsafe_b64encode(b).decode().rstrip("=") - -n_b64 = to_b64url(pub_numbers.n) -e_b64 = to_b64url(pub_numbers.e) -ISSUER = "http://OIDC_SERVICE_NAME.OIDC_NAMESPACE.svc.cluster.local" - -class H(http.server.BaseHTTPRequestHandler): - def do_GET(self): - if self.path == "/.well-known/openid-configuration": - self._json({"issuer": ISSUER, "authorization_endpoint": ISSUER+"/auth", - "token_endpoint": ISSUER+"/token", "jwks_uri": ISSUER+"/jwks", - "response_types_supported": ["code"], "subject_types_supported": ["public"], - "id_token_signing_alg_values_supported": ["RS256"]}) - elif self.path == "/jwks": - self._json({"keys": [{"kty": "RSA", "use": "sig", "kid": "k1", "alg": "RS256", "n": n_b64, "e": e_b64}]}) - else: - self.send_response(404); self.end_headers() - def do_POST(self): - if self.path.startswith("/token"): - params = parse_qs(urlparse(self.path).query) - sub = params.get("subject", ["test-user"])[0] - hdr = {"alg": "RS256", "typ": "JWT", "kid": "k1"} - pay = {"sub": sub, "iss": ISSUER, "aud": "vmcp-audience", "exp": int(time.time())+3600, "iat": int(time.time())} - def enc(d): return base64.urlsafe_b64encode(json.dumps(d, separators=(",",":")).encode()).decode().rstrip("=") - h64, p64 = enc(hdr), enc(pay) - sig = private_key.sign((h64+"."+p64).encode(), asym_padding.PKCS1v15(), hashes.SHA256()) - jwt = h64 + "." + p64 + "." + base64.urlsafe_b64encode(sig).decode().rstrip("=") - print(f"Issued JWT for sub={sub}", flush=True) - self._json({"access_token": jwt, "token_type": "Bearer", "expires_in": 3600}) - else: - self.send_response(404); self.end_headers() - def _json(self, obj): - body = json.dumps(obj).encode() - self.send_response(200); self.send_header("Content-Type","application/json"); self.end_headers(); self.wfile.write(body) - def log_message(self, f, *a): pass - -with socketserver.TCPServer(("", 8080), H) as s: - print("OIDC server ready on 8080", flush=True) - s.serve_forever() -` - -// DeployParameterizedOIDCServer deploys an in-cluster mock OIDC server that -// issues RSA-signed JWTs with a caller-controlled subject claim (via -// POST /token?subject=). The server is exposed via a fixed NodePort so -// the test process (running outside the cluster) can reach it. -// -// Returns the in-cluster issuer URL (http://..svc.cluster.local) -// and a cleanup function that removes all created resources. +// DeployParameterizedOIDCServer delegates to testutil.DeployParameterizedOIDCServer. +// Kept here for backwards compatibility with existing virtualmcp tests. func DeployParameterizedOIDCServer( ctx context.Context, c client.Client, name, namespace string, timeout, pollingInterval time.Duration, ) (issuerURL string, allocatedNodePort int32, cleanup func()) { - configMapName := name + "-code" - - // Patch the placeholder issuer into the script so the JWT iss claim and - // the OIDC discovery document match the in-cluster service URL. - issuerURL = fmt.Sprintf("http://%s.%s.svc.cluster.local", name, namespace) - script := strings.ReplaceAll(ParameterizedOIDCServerScript, - "http://OIDC_SERVICE_NAME.OIDC_NAMESPACE.svc.cluster.local", issuerURL) - - ginkgo.By("Creating ConfigMap with parameterized OIDC server code") - gomega.Expect(c.Create(ctx, &corev1.ConfigMap{ - ObjectMeta: metav1.ObjectMeta{Name: configMapName, Namespace: namespace}, - Data: map[string]string{"server.py": script}, - })).To(gomega.Succeed()) - - ginkgo.By("Creating parameterized OIDC server pod") - mode := int32Ptr(0755) - gomega.Expect(c.Create(ctx, &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - Name: name, - Namespace: namespace, - Labels: map[string]string{"app": name}, - }, - Spec: corev1.PodSpec{ - Containers: []corev1.Container{{ - Name: "oidc", - Image: "python:3.11-slim", - Command: []string{"sh", "-c", "pip install --no-cache-dir cryptography && python3 /app/server.py"}, - Ports: []corev1.ContainerPort{{ContainerPort: 8080}}, - ReadinessProbe: &corev1.Probe{ - ProbeHandler: corev1.ProbeHandler{ - HTTPGet: &corev1.HTTPGetAction{ - Path: "/.well-known/openid-configuration", - Port: intstr.FromInt(8080), - }, - }, - InitialDelaySeconds: 5, - PeriodSeconds: 2, - FailureThreshold: 30, - }, - VolumeMounts: []corev1.VolumeMount{{Name: "code", MountPath: "/app"}}, - }}, - Volumes: []corev1.Volume{{ - Name: "code", - VolumeSource: corev1.VolumeSource{ - ConfigMap: &corev1.ConfigMapVolumeSource{ - LocalObjectReference: corev1.LocalObjectReference{Name: configMapName}, - DefaultMode: mode, - }, - }, - }}, - }, - })).To(gomega.Succeed()) - - ginkgo.By("Creating parameterized OIDC server service with auto-assigned NodePort") - oidcSvc := &corev1.Service{ - ObjectMeta: metav1.ObjectMeta{Name: name, Namespace: namespace}, - Spec: corev1.ServiceSpec{ - Type: corev1.ServiceTypeNodePort, - Selector: map[string]string{"app": name}, - Ports: []corev1.ServicePort{{ - Port: 80, - TargetPort: intstr.FromInt(8080), - Protocol: corev1.ProtocolTCP, - }}, - }, - } - gomega.Expect(c.Create(ctx, oidcSvc)).To(gomega.Succeed()) - - // Read back the auto-assigned NodePort - gomega.Expect(c.Get(ctx, types.NamespacedName{Name: name, Namespace: namespace}, oidcSvc)).To(gomega.Succeed()) - allocatedNodePort = oidcSvc.Spec.Ports[0].NodePort - gomega.Expect(allocatedNodePort).NotTo(gomega.BeZero(), "Kubernetes should auto-assign a NodePort") - - ginkgo.By("Waiting for parameterized OIDC server to be ready") - gomega.Eventually(func() bool { - pod := &corev1.Pod{} - if err := c.Get(ctx, types.NamespacedName{Name: name, Namespace: namespace}, pod); err != nil { - return false - } - if pod.Status.Phase != corev1.PodRunning { - return false - } - for _, cond := range pod.Status.Conditions { - if cond.Type == corev1.PodReady && cond.Status == corev1.ConditionTrue { - return true - } - } - return false - }, timeout, pollingInterval).Should(gomega.BeTrue(), "parameterized OIDC server should be ready") - - cleanup = func() { - _ = c.Delete(ctx, &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: name, Namespace: namespace}}) - _ = c.Delete(ctx, &corev1.Service{ObjectMeta: metav1.ObjectMeta{Name: name, Namespace: namespace}}) - _ = c.Delete(ctx, &corev1.ConfigMap{ObjectMeta: metav1.ObjectMeta{Name: configMapName, Namespace: namespace}}) - // Wait for the Pod and Service to be fully removed so their fixed NodePort - // and name can be reused immediately in a subsequent test run. - gomega.Eventually(func() bool { - pod := &corev1.Pod{} - svc := &corev1.Service{} - podGone := apierrors.IsNotFound(c.Get(ctx, types.NamespacedName{Name: name, Namespace: namespace}, pod)) - svcGone := apierrors.IsNotFound(c.Get(ctx, types.NamespacedName{Name: name, Namespace: namespace}, svc)) - return podGone && svcGone - }, timeout, pollingInterval).Should(gomega.BeTrue(), "OIDC server pod and service should be fully deleted") - } - return issuerURL, allocatedNodePort, cleanup + return testutil.DeployParameterizedOIDCServer(ctx, c, name, namespace, timeout, pollingInterval) } // CleanupMockHTTPServer removes the mock HTTP server resources