From 780c52c24cff9c2537cb4a7af6de935538610bf5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20Echterh=C3=B6lter?= Date: Mon, 15 Sep 2025 10:37:42 +0200 Subject: [PATCH] feat(middleware): enhance middleware configuration for auth MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Introduces optional authentication middleware in CreateMiddleware • Updates tests to validate middleware behavior with and without auth --- middleware/authzMiddlewares.go | 9 +++------ middleware/middleware.go | 13 +++++++++--- middleware/middleware_test.go | 37 +++++++++++++++++++++++++++++++--- 3 files changed, 47 insertions(+), 12 deletions(-) diff --git a/middleware/authzMiddlewares.go b/middleware/authzMiddlewares.go index 84a202e..37be850 100644 --- a/middleware/authzMiddlewares.go +++ b/middleware/authzMiddlewares.go @@ -4,13 +4,10 @@ import ( "net/http" ) -// Middleware defines a function that wraps an http.Handler. -type Middleware func(http.Handler) http.Handler - -// CreateAuthMiddleware returns a slice of Middleware functions for authentication and authorization. +// CreateAuthMiddleware returns a slice of middleware functions for authentication and authorization. // The returned middlewares are: StoreWebToken, StoreAuthHeader, and StoreSpiffeHeader. -func CreateAuthMiddleware() []Middleware { - return []Middleware{ +func CreateAuthMiddleware() []func(http.Handler) http.Handler { + return []func(http.Handler) http.Handler{ StoreWebToken(), StoreAuthHeader(), StoreSpiffeHeader(), diff --git a/middleware/middleware.go b/middleware/middleware.go index 0e49a42..6f57e66 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -6,13 +6,20 @@ import ( "github.com/platform-mesh/golang-commons/logger" ) -// attaches a request-scoped logger (using the provided logger), assigns a request ID, and propagates that ID into the logger. -func CreateMiddleware(log *logger.Logger) []func(http.Handler) http.Handler { - return []func(http.Handler) http.Handler{ +// CreateMiddleware creates a middleware chain with logging, tracing, and optional authentication. +// It attaches a request-scoped logger (using the provided logger), assigns a request ID, and propagates that ID into the logger. +// When auth is true, authentication middlewares (StoreWebToken, StoreAuthHeader, StoreSpiffeHeader) are included. +func CreateMiddleware(log *logger.Logger, auth bool) []func(http.Handler) http.Handler { + mws := []func(http.Handler) http.Handler{ SetOtelTracingContext(), SentryRecoverer, StoreLoggerMiddleware(log), SetRequestId(), SetRequestIdInLogger(), } + + if auth { + mws = append(mws, CreateAuthMiddleware()...) + } + return mws } diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 53d3855..2c002ff 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -9,11 +9,11 @@ import ( "github.com/stretchr/testify/assert" ) -func TestCreateMiddleware(t *testing.T) { +func TestCreateMiddleware_WithoutAuth(t *testing.T) { log := testlogger.New() - middlewares := CreateMiddleware(log.Logger) + middlewares := CreateMiddleware(log.Logger, false) - // Should return 5 middlewares + // Should return 5 middlewares when auth is false assert.Len(t, middlewares, 5) // Each middleware should be a valid function @@ -39,3 +39,34 @@ func TestCreateMiddleware(t *testing.T) { assert.Equal(t, http.StatusOK, recorder.Code) } + +func TestCreateMiddleware_WithAuth(t *testing.T) { + log := testlogger.New() + middlewares := CreateMiddleware(log.Logger, true) + + // Should return 8 middlewares when auth is true (5 base + 3 auth) + assert.Len(t, middlewares, 8) + + // Each middleware should be a valid function + for _, mw := range middlewares { + assert.NotNil(t, mw) + } + + // Test that middlewares can be chained + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Apply all middlewares + var finalHandler http.Handler = handler + for i := len(middlewares) - 1; i >= 0; i-- { + finalHandler = middlewares[i](finalHandler) + } + + req := httptest.NewRequest("GET", "http://testing", nil) + recorder := httptest.NewRecorder() + + finalHandler.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) +}