Skip to content
10 changes: 9 additions & 1 deletion pkg/ratelimit/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/redis/go-redis/v9"

v1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1"
"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/mcp"
"github.com/stacklok/toolhive/pkg/transport/types"
)
Expand Down Expand Up @@ -119,7 +120,14 @@ func rateLimitHandler(limiter Limiter) types.MiddlewareFunction {
return
}

decision, err := limiter.Allow(r.Context(), parsed.ResourceID, "")
// When no identity is present (unauthenticated), userID stays empty
// and per-user buckets are skipped — only shared limits apply. CEL
// validation ensures perUser rate limits require auth to be enabled.
var userID string
if identity, ok := auth.IdentityFromContext(r.Context()); ok {
Comment thread
jerm-dro marked this conversation as resolved.
userID = identity.Subject
}
decision, err := limiter.Allow(r.Context(), parsed.ResourceID, userID)
if err != nil {
slog.Warn("rate limit check failed, allowing request", "error", err)
next.ServeHTTP(w, r)
Expand Down
60 changes: 60 additions & 0 deletions pkg/ratelimit/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/mcp"
)

Expand All @@ -29,6 +30,25 @@ func (d *dummyLimiter) Allow(context.Context, string, string) (*Decision, error)
return d.decision, d.err
}

// recordingLimiter captures the arguments passed to Allow.
type recordingLimiter struct {
toolName string
userID string
}

func (r *recordingLimiter) Allow(_ context.Context, toolName, userID string) (*Decision, error) {
r.toolName = toolName
r.userID = userID
return &Decision{Allowed: true}, nil
}
Comment thread
jerm-dro marked this conversation as resolved.

// withIdentity adds an auth.Identity with the given subject to the request context.
func withIdentity(r *http.Request, subject string) *http.Request {
identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: subject}}
ctx := auth.WithIdentity(r.Context(), identity)
return r.WithContext(ctx)
}

// withParsedMCPRequest adds a ParsedMCPRequest to the request context.
func withParsedMCPRequest(r *http.Request, method, resourceID string, id any) *http.Request {
parsed := &mcp.ParsedMCPRequest{
Expand Down Expand Up @@ -148,3 +168,43 @@ func TestRateLimitHandler_NonToolCallPassesThrough(t *testing.T) {
assert.True(t, nextCalled, "non-tools/call should pass through regardless of limiter")
assert.Equal(t, http.StatusOK, w.Code)
}

func TestRateLimitHandler_PassesUserID(t *testing.T) {
t.Parallel()

recorder := &recordingLimiter{}
handler := rateLimitHandler(recorder)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))

req := httptest.NewRequest(http.MethodPost, "/mcp", nil)
req = withParsedMCPRequest(req, "tools/call", "echo", 1)
req = withIdentity(req, "alice@example.com")
w := httptest.NewRecorder()

handler.ServeHTTP(w, req)

assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "echo", recorder.toolName)
assert.Equal(t, "alice@example.com", recorder.userID)
}

func TestRateLimitHandler_NoIdentityPassesEmptyUserID(t *testing.T) {
t.Parallel()

recorder := &recordingLimiter{}
handler := rateLimitHandler(recorder)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))

req := httptest.NewRequest(http.MethodPost, "/mcp", nil)
req = withParsedMCPRequest(req, "tools/call", "echo", 1)
// No identity in context — unauthenticated request.
w := httptest.NewRecorder()

handler.ServeHTTP(w, req)

assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "echo", recorder.toolName)
assert.Empty(t, recorder.userID, "unauthenticated requests should pass empty userID")
}
120 changes: 113 additions & 7 deletions test/e2e/thv-operator/acceptance_tests/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/onsi/gomega"
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/utils/ptr"
Expand All @@ -24,9 +25,9 @@ import (
"github.com/stacklok/toolhive/test/e2e/images"
)

// DeployRedis creates a Redis Deployment and Service in the given namespace.
// No password is configured — matches the default empty THV_SESSION_REDIS_PASSWORD.
func DeployRedis(ctx context.Context, c client.Client, namespace string, timeout, pollingInterval time.Duration) {
// EnsureRedis creates a Redis Deployment and Service if they don't already exist,
// then waits for Redis to be ready. Safe to call concurrently from multiple test blocks.
func EnsureRedis(ctx context.Context, c client.Client, namespace string, timeout, pollingInterval time.Duration) {
labels := map[string]string{"app": "redis"}

deployment := &appsv1.Deployment{
Expand All @@ -51,7 +52,9 @@ func DeployRedis(ctx context.Context, c client.Client, namespace string, timeout
},
},
}
gomega.Expect(c.Create(ctx, deployment)).To(gomega.Succeed())
if err := c.Create(ctx, deployment); err != nil && !apierrors.IsAlreadyExists(err) {
gomega.Expect(err).ToNot(gomega.HaveOccurred())
}

service := &corev1.Service{
ObjectMeta: metav1.ObjectMeta{
Expand All @@ -65,7 +68,9 @@ func DeployRedis(ctx context.Context, c client.Client, namespace string, timeout
},
},
}
gomega.Expect(c.Create(ctx, service)).To(gomega.Succeed())
if err := c.Create(ctx, service); err != nil && !apierrors.IsAlreadyExists(err) {
gomega.Expect(err).ToNot(gomega.HaveOccurred())
}

ginkgo.By("Waiting for Redis to be ready")
gomega.Eventually(func() error {
Expand Down Expand Up @@ -99,7 +104,7 @@ func CleanupRedis(ctx context.Context, c client.Client, namespace string) {
}

// SendToolCall sends a JSON-RPC tools/call request and returns the HTTP status code and body.
func SendToolCall(httpClient *http.Client, port int32, toolName string, requestID int) (int, []byte) {
func SendToolCall(ctx context.Context, httpClient *http.Client, port int32, toolName string, requestID int) (int, []byte) {
reqBody := map[string]any{
"jsonrpc": "2.0",
"id": requestID,
Expand All @@ -113,7 +118,7 @@ func SendToolCall(httpClient *http.Client, port int32, toolName string, requestI
gomega.Expect(err).ToNot(gomega.HaveOccurred())

url := fmt.Sprintf("http://localhost:%d/mcp", port)
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, url, bytes.NewReader(bodyBytes))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes))
gomega.Expect(err).ToNot(gomega.HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/event-stream")
Expand All @@ -127,3 +132,104 @@ func SendToolCall(httpClient *http.Client, port int32, toolName string, requestI

return resp.StatusCode, respBody
}

// SendInitialize sends a JSON-RPC initialize request and returns the session ID
// from the Mcp-Session header. This must be called before tools/call when auth is enabled.
func SendInitialize(
ctx context.Context, httpClient *http.Client, port int32, bearerToken string,
) (sessionID string) {
reqBody := map[string]any{
"jsonrpc": "2.0",
"id": 0,
"method": "initialize",
"params": map[string]any{
"protocolVersion": "2025-03-26",
"capabilities": map[string]any{},
"clientInfo": map[string]any{
"name": "e2e-test",
"version": "1.0.0",
},
},
}
bodyBytes, err := json.Marshal(reqBody)
gomega.Expect(err).ToNot(gomega.HaveOccurred())

url := fmt.Sprintf("http://localhost:%d/mcp", port)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes))
gomega.Expect(err).ToNot(gomega.HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/event-stream")
if bearerToken != "" {
req.Header.Set("Authorization", "Bearer "+bearerToken)
}

resp, err := httpClient.Do(req)
gomega.Expect(err).ToNot(gomega.HaveOccurred())
defer func() { _ = resp.Body.Close() }()

gomega.Expect(resp.StatusCode).To(gomega.Equal(http.StatusOK),
"initialize should succeed")

sessionID = resp.Header.Get("Mcp-Session-Id")
gomega.Expect(sessionID).ToNot(gomega.BeEmpty(),
"initialize response should include Mcp-Session-Id header")

return sessionID
}

// SendAuthenticatedToolCallWithSession sends a JSON-RPC tools/call with Bearer token and session ID.
func SendAuthenticatedToolCallWithSession(
ctx context.Context, httpClient *http.Client, port int32, toolName string, requestID int, bearerToken, sessionID string,
) (int, []byte, string) {
reqBody := map[string]any{
"jsonrpc": "2.0",
"id": requestID,
"method": "tools/call",
"params": map[string]any{
"name": toolName,
"arguments": map[string]any{"input": "test"},
},
}
bodyBytes, err := json.Marshal(reqBody)
gomega.Expect(err).ToNot(gomega.HaveOccurred())

url := fmt.Sprintf("http://localhost:%d/mcp", port)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes))
gomega.Expect(err).ToNot(gomega.HaveOccurred())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/event-stream")
req.Header.Set("Authorization", "Bearer "+bearerToken)
if sessionID != "" {
req.Header.Set("Mcp-Session-Id", sessionID)
}

resp, err := httpClient.Do(req)
gomega.Expect(err).ToNot(gomega.HaveOccurred())
defer func() { _ = resp.Body.Close() }()

retryAfter := resp.Header.Get("Retry-After")

respBody, err := io.ReadAll(resp.Body)
gomega.Expect(err).ToNot(gomega.HaveOccurred())

return resp.StatusCode, respBody, retryAfter
}

// GetOIDCToken fetches a JWT from the mock OIDC server for the given subject.
func GetOIDCToken(ctx context.Context, httpClient *http.Client, oidcNodePort int32, subject string) string {
url := fmt.Sprintf("http://localhost:%d/token?subject=%s", oidcNodePort, subject)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
gomega.Expect(err).ToNot(gomega.HaveOccurred())

resp, err := httpClient.Do(req)
gomega.Expect(err).ToNot(gomega.HaveOccurred())
defer func() { _ = resp.Body.Close() }()

var tokenResp struct {
AccessToken string `json:"access_token"`
}
gomega.Expect(json.NewDecoder(resp.Body).Decode(&tokenResp)).To(gomega.Succeed())
gomega.Expect(tokenResp.AccessToken).ToNot(gomega.BeEmpty(), "OIDC server should return a token")

return tokenResp.AccessToken
}
Loading
Loading