From bfb03857ae2e059e17f8e174515adaabe5dee61f Mon Sep 17 00:00:00 2001 From: Gus Date: Tue, 9 Sep 2025 09:56:22 +0800 Subject: [PATCH 1/7] Apply public middleware to signature generation --- handler/signatures.go | 20 -------------- metal/kernel/router.go | 3 +++ metal/kernel/router_signature_test.go | 38 +++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 20 deletions(-) create mode 100644 metal/kernel/router_signature_test.go diff --git a/handler/signatures.go b/handler/signatures.go index 61c56f67..1e080f92 100644 --- a/handler/signatures.go +++ b/handler/signatures.go @@ -51,10 +51,6 @@ func (s *SignaturesHandler) Generate(w baseHttp.ResponseWriter, r *baseHttp.Requ receivedAt := time.Unix(req.Timestamp, 0) req.Origin = r.Header.Get(portal.IntendedOriginHeader) - if err = s.isRequestWithinTimeframe(serverTime, receivedAt); err != nil { - return http.LogBadRequestError(err.Error(), err) - } - var keySignature *database.APIKeySignatures if keySignature, err = s.CreateSignature(req, serverTime); err != nil { return http.LogInternalError(err.Error(), err) @@ -80,22 +76,6 @@ func (s *SignaturesHandler) Generate(w baseHttp.ResponseWriter, r *baseHttp.Requ return nil // A nil return indicates success. } -func (s *SignaturesHandler) isRequestWithinTimeframe(serverTime, receivedAt time.Time) error { - skew := 15 * time.Second - - earliestValidTime := serverTime.Add(-skew) - if receivedAt.Before(earliestValidTime) { - return fmt.Errorf("the request timestamp [%s] is too old", receivedAt.Format(portal.DatesLayout)) - } - - latestValidTime := serverTime.Add(skew) - if receivedAt.After(latestValidTime) { - return fmt.Errorf("the request timestamp [%s] is from the future", receivedAt.Format(portal.DatesLayout)) - } - - return nil -} - func (s *SignaturesHandler) CreateSignature(request payload.SignatureRequest, serverTime time.Time) (*database.APIKeySignatures, error) { var err error var token *database.APIKey diff --git a/metal/kernel/router.go b/metal/kernel/router.go index e8014b5f..43f643db 100644 --- a/metal/kernel/router.go +++ b/metal/kernel/router.go @@ -31,9 +31,12 @@ type Router struct { } func (r *Router) PublicPipelineFor(apiHandler http.ApiHandler) baseHttp.HandlerFunc { + pm := middleware.MakePublicMiddleware() + return http.MakeApiHandler( r.Pipeline.Chain( apiHandler, + pm.Handle, ), ) } diff --git a/metal/kernel/router_signature_test.go b/metal/kernel/router_signature_test.go new file mode 100644 index 00000000..d49e5138 --- /dev/null +++ b/metal/kernel/router_signature_test.go @@ -0,0 +1,38 @@ +package kernel + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/oullin/pkg/middleware" + "github.com/oullin/pkg/portal" +) + +func TestSignatureRoute_PublicMiddleware(t *testing.T) { + r := Router{ + Mux: http.NewServeMux(), + Pipeline: middleware.Pipeline{}, + validator: portal.GetDefaultValidator(), + } + r.Signature() + + req := httptest.NewRequest("POST", "/generate-signature", nil) + rec := httptest.NewRecorder() + r.Mux.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rec.Code) + } + + req2 := httptest.NewRequest("POST", "/generate-signature", strings.NewReader("{")) + req2.Header.Set(portal.RequestIDHeader, "req-1") + req2.Header.Set(portal.TimestampHeader, fmt.Sprintf("%d", time.Now().Unix())) + rec2 := httptest.NewRecorder() + r.Mux.ServeHTTP(rec2, req2) + if rec2.Code != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, rec2.Code) + } +} From 7a8ef47cfb2ae7ecce6d7c18ecc86b686a0607d9 Mon Sep 17 00:00:00 2001 From: Gus Date: Tue, 9 Sep 2025 10:09:20 +0800 Subject: [PATCH 2/7] Split signature route tests into sub-tests --- metal/kernel/router_signature_test.go | 32 +++++++++++++++------------ 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/metal/kernel/router_signature_test.go b/metal/kernel/router_signature_test.go index d49e5138..b3ecaf1c 100644 --- a/metal/kernel/router_signature_test.go +++ b/metal/kernel/router_signature_test.go @@ -20,19 +20,23 @@ func TestSignatureRoute_PublicMiddleware(t *testing.T) { } r.Signature() - req := httptest.NewRequest("POST", "/generate-signature", nil) - rec := httptest.NewRecorder() - r.Mux.ServeHTTP(rec, req) - if rec.Code != http.StatusUnauthorized { - t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rec.Code) - } + t.Run("request without public headers is unauthorized", func(t *testing.T) { + req := httptest.NewRequest("POST", "/generate-signature", nil) + rec := httptest.NewRecorder() + r.Mux.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rec.Code) + } + }) - req2 := httptest.NewRequest("POST", "/generate-signature", strings.NewReader("{")) - req2.Header.Set(portal.RequestIDHeader, "req-1") - req2.Header.Set(portal.TimestampHeader, fmt.Sprintf("%d", time.Now().Unix())) - rec2 := httptest.NewRecorder() - r.Mux.ServeHTTP(rec2, req2) - if rec2.Code != http.StatusBadRequest { - t.Fatalf("expected status %d, got %d", http.StatusBadRequest, rec2.Code) - } + t.Run("request with public headers but invalid body is bad request", func(t *testing.T) { + req := httptest.NewRequest("POST", "/generate-signature", strings.NewReader("{")) + req.Header.Set(portal.RequestIDHeader, "req-1") + req.Header.Set(portal.TimestampHeader, fmt.Sprintf("%d", time.Now().Unix())) + rec := httptest.NewRecorder() + r.Mux.ServeHTTP(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, rec.Code) + } + }) } From f2f25594d0017762eb5d82079a70edbdacdfdef0 Mon Sep 17 00:00:00 2001 From: Gus Date: Tue, 9 Sep 2025 10:15:24 +0800 Subject: [PATCH 3/7] Share public middleware instance across routes --- metal/kernel/app.go | 9 +++++---- metal/kernel/kernel_test.go | 5 +++-- metal/kernel/router.go | 15 +++++++-------- metal/kernel/router_signature_test.go | 7 ++++--- 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/metal/kernel/app.go b/metal/kernel/app.go index b9627343..18fda6e2 100644 --- a/metal/kernel/app.go +++ b/metal/kernel/app.go @@ -42,10 +42,11 @@ func MakeApp(env *env.Environment, validator *portal.Validator) (*App, error) { } router := Router{ - Env: env, - Db: db, - Mux: baseHttp.NewServeMux(), - validator: validator, + Env: env, + Db: db, + Mux: baseHttp.NewServeMux(), + validator: validator, + publicMiddleware: middleware.MakePublicMiddleware(), Pipeline: middleware.Pipeline{ Env: env, ApiKeys: &repository.ApiKeys{DB: db}, diff --git a/metal/kernel/kernel_test.go b/metal/kernel/kernel_test.go index 8ac19909..34cd00f5 100644 --- a/metal/kernel/kernel_test.go +++ b/metal/kernel/kernel_test.go @@ -96,7 +96,7 @@ func TestAppHelpers(t *testing.T) { app := &App{} mux := http.NewServeMux() - r := Router{Mux: mux} + r := Router{Mux: mux, publicMiddleware: middleware.MakePublicMiddleware()} app.SetRouter(r) @@ -141,7 +141,8 @@ func TestAppBootRoutes(t *testing.T) { ApiKeys: &repository.ApiKeys{DB: &database.Connection{}}, TokenHandler: handler, }, - Db: &database.Connection{}, + Db: &database.Connection{}, + publicMiddleware: middleware.MakePublicMiddleware(), } app := &App{} diff --git a/metal/kernel/router.go b/metal/kernel/router.go index 43f643db..9df1f666 100644 --- a/metal/kernel/router.go +++ b/metal/kernel/router.go @@ -23,20 +23,19 @@ func addStaticRoute[H StaticRouteResource](r *Router, path, file string, maker f } type Router struct { - Env *env.Environment - Mux *baseHttp.ServeMux - Pipeline middleware.Pipeline - Db *database.Connection - validator *portal.Validator + Env *env.Environment + Mux *baseHttp.ServeMux + Pipeline middleware.Pipeline + Db *database.Connection + validator *portal.Validator + publicMiddleware middleware.PublicMiddleware } func (r *Router) PublicPipelineFor(apiHandler http.ApiHandler) baseHttp.HandlerFunc { - pm := middleware.MakePublicMiddleware() - return http.MakeApiHandler( r.Pipeline.Chain( apiHandler, - pm.Handle, + r.publicMiddleware.Handle, ), ) } diff --git a/metal/kernel/router_signature_test.go b/metal/kernel/router_signature_test.go index b3ecaf1c..180f2112 100644 --- a/metal/kernel/router_signature_test.go +++ b/metal/kernel/router_signature_test.go @@ -14,9 +14,10 @@ import ( func TestSignatureRoute_PublicMiddleware(t *testing.T) { r := Router{ - Mux: http.NewServeMux(), - Pipeline: middleware.Pipeline{}, - validator: portal.GetDefaultValidator(), + Mux: http.NewServeMux(), + Pipeline: middleware.Pipeline{}, + validator: portal.GetDefaultValidator(), + publicMiddleware: middleware.MakePublicMiddleware(), } r.Signature() From 351cfa8c49bf4010a3586ef6849effad098a0991 Mon Sep 17 00:00:00 2001 From: Gus Date: Tue, 9 Sep 2025 11:20:23 +0800 Subject: [PATCH 4/7] Restrict public middleware to whitelisted IP in production --- .env.example | 4 ++ .env.gh.example | 1 + .env.prod.example | 3 + metal/kernel/app.go | 74 ++++++++++++------------ metal/kernel/kernel_test.go | 4 +- metal/kernel/router_signature_test.go | 48 ++++++++++----- pkg/middleware/public_middleware.go | 48 +++++++++------ pkg/middleware/public_middleware_test.go | 55 ++++++++++++++++-- 8 files changed, 159 insertions(+), 78 deletions(-) diff --git a/.env.example b/.env.example index 8ed420dc..9614e3bd 100644 --- a/.env.example +++ b/.env.example @@ -6,6 +6,10 @@ ENV_APP_LOGS_DIR="./storage/logs/logs_%s.log" ENV_APP_LOGS_DATE_FORMAT="2006_02_01" ENV_APP_MASTER_KEY= +# --- Public middleware +# Optional IP whitelist for production-only public routes +ENV_PUBLIC_ALLOWED_IP= + # --- DB ENV_DB_USER_NAME="gus" ENV_DB_USER_PASSWORD="password" diff --git a/.env.gh.example b/.env.gh.example index 4d18b68a..ec9f6582 100644 --- a/.env.gh.example +++ b/.env.gh.example @@ -2,3 +2,4 @@ ENV_HTTP_PORT=8080 ENV_DOCKER_USER=gocanto ENV_DOCKER_USER_GROUP=ggroup CADDY_LOGS_PATH=./storage/logs/caddy +ENV_PUBLIC_ALLOWED_IP= diff --git a/.env.prod.example b/.env.prod.example index ebbf82b7..873ff085 100644 --- a/.env.prod.example +++ b/.env.prod.example @@ -5,6 +5,9 @@ ENV_APP_LOG_LEVEL=debug ENV_APP_LOGS_DIR="./storage/logs/logs_%s.log" ENV_APP_LOGS_DATE_FORMAT="2006_02_01" +# --- Public middleware +ENV_PUBLIC_ALLOWED_IP=31.97.60.190 + # --- DB ENV_DB_PORT= ENV_DB_HOST= diff --git a/metal/kernel/app.go b/metal/kernel/app.go index 18fda6e2..28d190e5 100644 --- a/metal/kernel/app.go +++ b/metal/kernel/app.go @@ -1,58 +1,58 @@ package kernel import ( - "fmt" - baseHttp "net/http" + "fmt" + baseHttp "net/http" - "github.com/oullin/database" - "github.com/oullin/database/repository" - "github.com/oullin/metal/env" - "github.com/oullin/pkg/auth" - "github.com/oullin/pkg/llogs" - "github.com/oullin/pkg/middleware" - "github.com/oullin/pkg/portal" + "github.com/oullin/database" + "github.com/oullin/database/repository" + "github.com/oullin/metal/env" + "github.com/oullin/pkg/auth" + "github.com/oullin/pkg/llogs" + "github.com/oullin/pkg/middleware" + "github.com/oullin/pkg/portal" ) type App struct { - router *Router - sentry *portal.Sentry - logs llogs.Driver - validator *portal.Validator - env *env.Environment - db *database.Connection + router *Router + sentry *portal.Sentry + logs llogs.Driver + validator *portal.Validator + env *env.Environment + db *database.Connection } -func MakeApp(env *env.Environment, validator *portal.Validator) (*App, error) { - tokenHandler, err := auth.MakeTokensHandler( - []byte(env.App.MasterKey), - ) +func MakeApp(e *env.Environment, validator *portal.Validator) (*App, error) { + tokenHandler, err := auth.MakeTokensHandler( + []byte(e.App.MasterKey), + ) if err != nil { return nil, fmt.Errorf("bootstrapping error > could not create a token handler: %w", err) } - db := MakeDbConnection(env) + db := MakeDbConnection(e) app := App{ - env: env, - validator: validator, - logs: MakeLogs(env), - sentry: MakeSentry(env), - db: db, - } + env: e, + validator: validator, + logs: MakeLogs(e), + sentry: MakeSentry(e), + db: db, + } router := Router{ - Env: env, - Db: db, - Mux: baseHttp.NewServeMux(), - validator: validator, - publicMiddleware: middleware.MakePublicMiddleware(), - Pipeline: middleware.Pipeline{ - Env: env, - ApiKeys: &repository.ApiKeys{DB: db}, - TokenHandler: tokenHandler, - }, - } + Env: e, + Db: db, + Mux: baseHttp.NewServeMux(), + validator: validator, + publicMiddleware: middleware.MakePublicMiddleware(env.GetEnvVar("ENV_PUBLIC_ALLOWED_IP"), e.App.IsProduction()), + Pipeline: middleware.Pipeline{ + Env: e, + ApiKeys: &repository.ApiKeys{DB: db}, + TokenHandler: tokenHandler, + }, + } app.SetRouter(router) diff --git a/metal/kernel/kernel_test.go b/metal/kernel/kernel_test.go index 34cd00f5..06002d67 100644 --- a/metal/kernel/kernel_test.go +++ b/metal/kernel/kernel_test.go @@ -96,7 +96,7 @@ func TestAppHelpers(t *testing.T) { app := &App{} mux := http.NewServeMux() - r := Router{Mux: mux, publicMiddleware: middleware.MakePublicMiddleware()} + r := Router{Mux: mux, publicMiddleware: middleware.MakePublicMiddleware("", false)} app.SetRouter(r) @@ -142,7 +142,7 @@ func TestAppBootRoutes(t *testing.T) { TokenHandler: handler, }, Db: &database.Connection{}, - publicMiddleware: middleware.MakePublicMiddleware(), + publicMiddleware: middleware.MakePublicMiddleware("", false), } app := &App{} diff --git a/metal/kernel/router_signature_test.go b/metal/kernel/router_signature_test.go index 180f2112..9dbef9c6 100644 --- a/metal/kernel/router_signature_test.go +++ b/metal/kernel/router_signature_test.go @@ -17,11 +17,11 @@ func TestSignatureRoute_PublicMiddleware(t *testing.T) { Mux: http.NewServeMux(), Pipeline: middleware.Pipeline{}, validator: portal.GetDefaultValidator(), - publicMiddleware: middleware.MakePublicMiddleware(), - } - r.Signature() + publicMiddleware: middleware.MakePublicMiddleware("", false), + } + r.Signature() - t.Run("request without public headers is unauthorized", func(t *testing.T) { + t.Run("request without public headers is unauthorized", func(t *testing.T) { req := httptest.NewRequest("POST", "/generate-signature", nil) rec := httptest.NewRecorder() r.Mux.ServeHTTP(rec, req) @@ -30,14 +30,34 @@ func TestSignatureRoute_PublicMiddleware(t *testing.T) { } }) - t.Run("request with public headers but invalid body is bad request", func(t *testing.T) { - req := httptest.NewRequest("POST", "/generate-signature", strings.NewReader("{")) - req.Header.Set(portal.RequestIDHeader, "req-1") - req.Header.Set(portal.TimestampHeader, fmt.Sprintf("%d", time.Now().Unix())) - rec := httptest.NewRecorder() - r.Mux.ServeHTTP(rec, req) - if rec.Code != http.StatusBadRequest { - t.Fatalf("expected status %d, got %d", http.StatusBadRequest, rec.Code) - } - }) + t.Run("request with public headers but invalid body is bad request", func(t *testing.T) { + req := httptest.NewRequest("POST", "/generate-signature", strings.NewReader("{")) + req.Header.Set(portal.RequestIDHeader, "req-1") + req.Header.Set(portal.TimestampHeader, fmt.Sprintf("%d", time.Now().Unix())) + rec := httptest.NewRecorder() + r.Mux.ServeHTTP(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, rec.Code) + } + }) + + t.Run("production rejects requests from non-whitelisted IP", func(t *testing.T) { + r := Router{ + Mux: http.NewServeMux(), + Pipeline: middleware.Pipeline{}, + validator: portal.GetDefaultValidator(), + publicMiddleware: middleware.MakePublicMiddleware("31.97.60.190", true), + } + r.Signature() + + req := httptest.NewRequest("POST", "/generate-signature", strings.NewReader("{")) + req.Header.Set(portal.RequestIDHeader, "req-1") + req.Header.Set(portal.TimestampHeader, fmt.Sprintf("%d", time.Now().Unix())) + req.Header.Set("X-Forwarded-For", "1.2.3.4") + rec := httptest.NewRecorder() + r.Mux.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rec.Code) + } + }) } diff --git a/pkg/middleware/public_middleware.go b/pkg/middleware/public_middleware.go index 94e41254..5fd56d8c 100644 --- a/pkg/middleware/public_middleware.go +++ b/pkg/middleware/public_middleware.go @@ -18,24 +18,30 @@ import ( // a simple in-memory rate limiter keyed by client IP. Reuse of a // request ID within a TTL window is rejected via TTLCache. type PublicMiddleware struct { - clockSkew time.Duration - disallowFuture bool - requestTTL time.Duration - rateLimiter *limiter.MemoryLimiter - requestCache *cache.TTLCache - now func() time.Time + clockSkew time.Duration + disallowFuture bool + requestTTL time.Duration + rateLimiter *limiter.MemoryLimiter + requestCache *cache.TTLCache + now func() time.Time + allowedIP string + isProduction bool } // MakePublicMiddleware constructs a PublicMiddleware with sane defaults. -func MakePublicMiddleware() PublicMiddleware { - return PublicMiddleware{ - clockSkew: 5 * time.Minute, - disallowFuture: true, - requestTTL: 5 * time.Minute, - rateLimiter: limiter.NewMemoryLimiter(1*time.Minute, 10), - requestCache: cache.NewTTLCache(), - now: time.Now, - } +// allowedIP restricts traffic to a specific client IP when isProduction is true. +// When not in production or allowedIP is blank, all IPs are permitted. +func MakePublicMiddleware(allowedIP string, isProduction bool) PublicMiddleware { + return PublicMiddleware{ + clockSkew: 5 * time.Minute, + disallowFuture: true, + requestTTL: 5 * time.Minute, + rateLimiter: limiter.NewMemoryLimiter(1*time.Minute, 10), + requestCache: cache.NewTTLCache(), + now: time.Now, + allowedIP: strings.TrimSpace(allowedIP), + isProduction: isProduction, + } } func (p PublicMiddleware) Handle(next http.ApiHandler) http.ApiHandler { @@ -50,10 +56,14 @@ func (p PublicMiddleware) Handle(next http.ApiHandler) http.ApiHandler { return mwguards.InvalidRequestError("Invalid authentication headers", "") } - ip := portal.ParseClientIP(r) - if ip == "" { - return mwguards.InvalidRequestError("Invalid client IP", "") - } + ip := portal.ParseClientIP(r) + if ip == "" { + return mwguards.InvalidRequestError("Invalid client IP", "") + } + + if p.isProduction && p.allowedIP != "" && ip != p.allowedIP { + return mwguards.InvalidRequestError("Invalid client IP", "unauthorised ip: "+ip) + } limiterKey := ip if p.rateLimiter.TooMany(limiterKey) { diff --git a/pkg/middleware/public_middleware_test.go b/pkg/middleware/public_middleware_test.go index 60608382..ed6f8d44 100644 --- a/pkg/middleware/public_middleware_test.go +++ b/pkg/middleware/public_middleware_test.go @@ -13,7 +13,7 @@ import ( ) func TestPublicMiddleware_InvalidHeaders(t *testing.T) { - pm := MakePublicMiddleware() + pm := MakePublicMiddleware("", false) handler := pm.Handle(func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError { return nil }) base := time.Unix(1_700_000_000, 0) @@ -58,7 +58,7 @@ func TestPublicMiddleware_InvalidHeaders(t *testing.T) { } func TestPublicMiddleware_TimestampExpired(t *testing.T) { - pm := MakePublicMiddleware() + pm := MakePublicMiddleware("", false) base := time.Unix(1_700_000_000, 0) pm.now = func() time.Time { return base } handler := pm.Handle(func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError { return nil }) @@ -75,7 +75,7 @@ func TestPublicMiddleware_TimestampExpired(t *testing.T) { } func TestPublicMiddleware_RateLimitAndReplay(t *testing.T) { - pm := MakePublicMiddleware() + pm := MakePublicMiddleware("", false) pm.rateLimiter = limiter.NewMemoryLimiter(time.Minute, 1) base := time.Unix(1_700_000_000, 0) pm.now = func() time.Time { return base } @@ -107,7 +107,50 @@ func TestPublicMiddleware_RateLimitAndReplay(t *testing.T) { req3.Header.Set(portal.RequestIDHeader, "def") req3.Header.Set(portal.TimestampHeader, strconv.FormatInt(base.Unix(), 10)) req3.Header.Set("X-Forwarded-For", "1.2.3.4") - if err := handler(rec3, req3); err == nil || err.Status != http.StatusTooManyRequests { - t.Fatalf("expected rate limit error, got %#v", err) - } + if err := handler(rec3, req3); err == nil || err.Status != http.StatusTooManyRequests { + t.Fatalf("expected rate limit error, got %#v", err) + } +} + +func TestPublicMiddleware_IPWhitelist(t *testing.T) { + base := time.Unix(1_700_000_000, 0) + pm := MakePublicMiddleware("31.97.60.190", true) + pm.now = func() time.Time { return base } + handler := pm.Handle(func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError { return nil }) + + t.Run("allowed ip passes", func(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set(portal.RequestIDHeader, "req-1") + req.Header.Set(portal.TimestampHeader, strconv.FormatInt(base.Unix(), 10)) + req.Header.Set("X-Forwarded-For", "31.97.60.190") + if err := handler(rec, req); err != nil { + t.Fatalf("expected request to pass, got %#v", err) + } + }) + + t.Run("other ip rejected in production", func(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set(portal.RequestIDHeader, "req-1") + req.Header.Set(portal.TimestampHeader, strconv.FormatInt(base.Unix(), 10)) + req.Header.Set("X-Forwarded-For", "1.2.3.4") + if err := handler(rec, req); err == nil || err.Status != http.StatusUnauthorized { + t.Fatalf("expected unauthorized, got %#v", err) + } + }) + + t.Run("non-production skips restriction", func(t *testing.T) { + pm := MakePublicMiddleware("31.97.60.190", false) + pm.now = func() time.Time { return base } + handler := pm.Handle(func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError { return nil }) + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set(portal.RequestIDHeader, "req-1") + req.Header.Set(portal.TimestampHeader, strconv.FormatInt(base.Unix(), 10)) + req.Header.Set("X-Forwarded-For", "1.2.3.4") + if err := handler(rec, req); err != nil { + t.Fatalf("expected request to pass, got %#v", err) + } + }) } From afba612c314c42fbc9f09a41d769cd27aae4e1dc Mon Sep 17 00:00:00 2001 From: Gus Date: Tue, 9 Sep 2025 11:20:30 +0800 Subject: [PATCH 5/7] Inject public IP into environment --- .env.prod.example | 2 +- metal/env/env.go | 11 +++--- metal/kernel/app.go | 72 ++++++++++++++++++------------------- metal/kernel/factory.go | 13 ++++--- metal/kernel/kernel_test.go | 9 +++-- 5 files changed, 58 insertions(+), 49 deletions(-) diff --git a/.env.prod.example b/.env.prod.example index 873ff085..a0393fed 100644 --- a/.env.prod.example +++ b/.env.prod.example @@ -6,7 +6,7 @@ ENV_APP_LOGS_DIR="./storage/logs/logs_%s.log" ENV_APP_LOGS_DATE_FORMAT="2006_02_01" # --- Public middleware -ENV_PUBLIC_ALLOWED_IP=31.97.60.190 +ENV_PUBLIC_ALLOWED_IP= # --- DB ENV_DB_PORT= diff --git a/metal/env/env.go b/metal/env/env.go index b2099d2d..b54d13cb 100644 --- a/metal/env/env.go +++ b/metal/env/env.go @@ -7,11 +7,12 @@ import ( ) type Environment struct { - App AppEnvironment - DB DBEnvironment - Logs LogsEnvironment - Network NetEnvironment - Sentry SentryEnvironment + App AppEnvironment + DB DBEnvironment + Logs LogsEnvironment + Network NetEnvironment + Sentry SentryEnvironment + PublicAllowedIP string } // SecretsDir defines where secret files are read from. It can be overridden in diff --git a/metal/kernel/app.go b/metal/kernel/app.go index 28d190e5..6e14a0fa 100644 --- a/metal/kernel/app.go +++ b/metal/kernel/app.go @@ -1,58 +1,58 @@ package kernel import ( - "fmt" - baseHttp "net/http" + "fmt" + baseHttp "net/http" - "github.com/oullin/database" - "github.com/oullin/database/repository" - "github.com/oullin/metal/env" - "github.com/oullin/pkg/auth" - "github.com/oullin/pkg/llogs" - "github.com/oullin/pkg/middleware" - "github.com/oullin/pkg/portal" + "github.com/oullin/database" + "github.com/oullin/database/repository" + "github.com/oullin/metal/env" + "github.com/oullin/pkg/auth" + "github.com/oullin/pkg/llogs" + "github.com/oullin/pkg/middleware" + "github.com/oullin/pkg/portal" ) type App struct { - router *Router - sentry *portal.Sentry - logs llogs.Driver - validator *portal.Validator - env *env.Environment - db *database.Connection + router *Router + sentry *portal.Sentry + logs llogs.Driver + validator *portal.Validator + env *env.Environment + db *database.Connection } func MakeApp(e *env.Environment, validator *portal.Validator) (*App, error) { - tokenHandler, err := auth.MakeTokensHandler( - []byte(e.App.MasterKey), - ) + tokenHandler, err := auth.MakeTokensHandler( + []byte(e.App.MasterKey), + ) if err != nil { return nil, fmt.Errorf("bootstrapping error > could not create a token handler: %w", err) } - db := MakeDbConnection(e) + db := MakeDbConnection(e) app := App{ - env: e, - validator: validator, - logs: MakeLogs(e), - sentry: MakeSentry(e), - db: db, - } + env: e, + validator: validator, + logs: MakeLogs(e), + sentry: MakeSentry(e), + db: db, + } router := Router{ - Env: e, - Db: db, - Mux: baseHttp.NewServeMux(), - validator: validator, - publicMiddleware: middleware.MakePublicMiddleware(env.GetEnvVar("ENV_PUBLIC_ALLOWED_IP"), e.App.IsProduction()), - Pipeline: middleware.Pipeline{ - Env: e, - ApiKeys: &repository.ApiKeys{DB: db}, - TokenHandler: tokenHandler, - }, - } + Env: e, + Db: db, + Mux: baseHttp.NewServeMux(), + validator: validator, + publicMiddleware: middleware.MakePublicMiddleware(e.PublicAllowedIP, e.App.IsProduction()), + Pipeline: middleware.Pipeline{ + Env: e, + ApiKeys: &repository.ApiKeys{DB: db}, + TokenHandler: tokenHandler, + }, + } app.SetRouter(router) diff --git a/metal/kernel/factory.go b/metal/kernel/factory.go index 3a92956d..3fcca4b9 100644 --- a/metal/kernel/factory.go +++ b/metal/kernel/factory.go @@ -93,6 +93,8 @@ func MakeEnv(validate *portal.Validator) *env.Environment { CSP: env.GetEnvVar("ENV_SENTRY_CSP"), } + publicAllowedIP := env.GetEnvVar("ENV_PUBLIC_ALLOWED_IP") + if _, err := validate.Rejects(app); err != nil { panic(errorSuffix + "invalid [APP] model: " + validate.GetErrorsAsJson()) } @@ -114,11 +116,12 @@ func MakeEnv(validate *portal.Validator) *env.Environment { } blog := &env.Environment{ - App: app, - DB: db, - Logs: logsCreds, - Network: net, - Sentry: sentryEnvironment, + App: app, + DB: db, + Logs: logsCreds, + Network: net, + Sentry: sentryEnvironment, + PublicAllowedIP: publicAllowedIP, } if _, err := validate.Rejects(blog); err != nil { diff --git a/metal/kernel/kernel_test.go b/metal/kernel/kernel_test.go index 06002d67..13c06a3e 100644 --- a/metal/kernel/kernel_test.go +++ b/metal/kernel/kernel_test.go @@ -33,6 +33,7 @@ func validEnvVars(t *testing.T) { t.Setenv("ENV_HTTP_PORT", "8080") t.Setenv("ENV_SENTRY_DSN", "dsn") t.Setenv("ENV_SENTRY_CSP", "csp") + t.Setenv("ENV_PUBLIC_ALLOWED_IP", "1.2.3.4") } func TestMakeEnv(t *testing.T) { @@ -43,6 +44,10 @@ func TestMakeEnv(t *testing.T) { if env.App.Name != "guss" { t.Fatalf("env not loaded") } + + if env.PublicAllowedIP != "1.2.3.4" { + t.Fatalf("expected public allowed ip to be loaded") + } } func TestIgnite(t *testing.T) { @@ -96,7 +101,7 @@ func TestAppHelpers(t *testing.T) { app := &App{} mux := http.NewServeMux() - r := Router{Mux: mux, publicMiddleware: middleware.MakePublicMiddleware("", false)} + r := Router{Mux: mux, publicMiddleware: middleware.MakePublicMiddleware("", false)} app.SetRouter(r) @@ -142,7 +147,7 @@ func TestAppBootRoutes(t *testing.T) { TokenHandler: handler, }, Db: &database.Connection{}, - publicMiddleware: middleware.MakePublicMiddleware("", false), + publicMiddleware: middleware.MakePublicMiddleware("", false), } app := &App{} From 95841aa998078577d36aefa99d992b46da1460cd Mon Sep 17 00:00:00 2001 From: Gus Date: Tue, 9 Sep 2025 11:47:38 +0800 Subject: [PATCH 6/7] move allowed ip into network env --- metal/env/env.go | 11 ++- metal/env/network.go | 6 +- metal/kernel/app.go | 16 ++--- metal/kernel/factory.go | 19 +++--- metal/kernel/kernel_test.go | 28 ++++++-- metal/kernel/router.go | 13 ++-- metal/kernel/router_signature_test.go | 72 ++++++++++---------- pkg/middleware/pipeline.go | 7 +- pkg/middleware/public_middleware.go | 50 +++++++------- pkg/middleware/public_middleware_test.go | 86 ++++++++++++------------ 10 files changed, 162 insertions(+), 146 deletions(-) diff --git a/metal/env/env.go b/metal/env/env.go index b54d13cb..b2099d2d 100644 --- a/metal/env/env.go +++ b/metal/env/env.go @@ -7,12 +7,11 @@ import ( ) type Environment struct { - App AppEnvironment - DB DBEnvironment - Logs LogsEnvironment - Network NetEnvironment - Sentry SentryEnvironment - PublicAllowedIP string + App AppEnvironment + DB DBEnvironment + Logs LogsEnvironment + Network NetEnvironment + Sentry SentryEnvironment } // SecretsDir defines where secret files are read from. It can be overridden in diff --git a/metal/env/network.go b/metal/env/network.go index f66d91c6..5ca17dd4 100644 --- a/metal/env/network.go +++ b/metal/env/network.go @@ -1,8 +1,10 @@ package env type NetEnvironment struct { - HttpHost string `validate:"required,lowercase,min=7"` - HttpPort string `validate:"required,numeric,oneof=8080"` + HttpHost string `validate:"required,lowercase,min=7"` + HttpPort string `validate:"required,numeric,oneof=8080"` + PublicAllowedIP string `validate:"required_if=IsProduction true,omitempty,ip"` + IsProduction bool `validate:"-"` } func (e NetEnvironment) GetHttpPort() string { diff --git a/metal/kernel/app.go b/metal/kernel/app.go index 6e14a0fa..e509cb2c 100644 --- a/metal/kernel/app.go +++ b/metal/kernel/app.go @@ -42,15 +42,15 @@ func MakeApp(e *env.Environment, validator *portal.Validator) (*App, error) { } router := Router{ - Env: e, - Db: db, - Mux: baseHttp.NewServeMux(), - validator: validator, - publicMiddleware: middleware.MakePublicMiddleware(e.PublicAllowedIP, e.App.IsProduction()), + Env: e, + Db: db, + Mux: baseHttp.NewServeMux(), + validator: validator, Pipeline: middleware.Pipeline{ - Env: e, - ApiKeys: &repository.ApiKeys{DB: db}, - TokenHandler: tokenHandler, + Env: e, + ApiKeys: &repository.ApiKeys{DB: db}, + TokenHandler: tokenHandler, + PublicMiddleware: middleware.MakePublicMiddleware(e.Network.PublicAllowedIP, e.App.IsProduction()), }, } diff --git a/metal/kernel/factory.go b/metal/kernel/factory.go index 3fcca4b9..2067864c 100644 --- a/metal/kernel/factory.go +++ b/metal/kernel/factory.go @@ -84,8 +84,10 @@ func MakeEnv(validate *portal.Validator) *env.Environment { } net := env.NetEnvironment{ - HttpHost: env.GetEnvVar("ENV_HTTP_HOST"), - HttpPort: env.GetEnvVar("ENV_HTTP_PORT"), + HttpHost: env.GetEnvVar("ENV_HTTP_HOST"), + HttpPort: env.GetEnvVar("ENV_HTTP_PORT"), + PublicAllowedIP: env.GetEnvVar("ENV_PUBLIC_ALLOWED_IP"), + IsProduction: app.IsProduction(), } sentryEnvironment := env.SentryEnvironment{ @@ -93,8 +95,6 @@ func MakeEnv(validate *portal.Validator) *env.Environment { CSP: env.GetEnvVar("ENV_SENTRY_CSP"), } - publicAllowedIP := env.GetEnvVar("ENV_PUBLIC_ALLOWED_IP") - if _, err := validate.Rejects(app); err != nil { panic(errorSuffix + "invalid [APP] model: " + validate.GetErrorsAsJson()) } @@ -116,12 +116,11 @@ func MakeEnv(validate *portal.Validator) *env.Environment { } blog := &env.Environment{ - App: app, - DB: db, - Logs: logsCreds, - Network: net, - Sentry: sentryEnvironment, - PublicAllowedIP: publicAllowedIP, + App: app, + DB: db, + Logs: logsCreds, + Network: net, + Sentry: sentryEnvironment, } if _, err := validate.Rejects(blog); err != nil { diff --git a/metal/kernel/kernel_test.go b/metal/kernel/kernel_test.go index 13c06a3e..a2d48bd6 100644 --- a/metal/kernel/kernel_test.go +++ b/metal/kernel/kernel_test.go @@ -45,11 +45,25 @@ func TestMakeEnv(t *testing.T) { t.Fatalf("env not loaded") } - if env.PublicAllowedIP != "1.2.3.4" { + if env.Network.PublicAllowedIP != "1.2.3.4" { t.Fatalf("expected public allowed ip to be loaded") } } +func TestMakeEnvRequiresIPInProduction(t *testing.T) { + validEnvVars(t) + t.Setenv("ENV_APP_ENV_TYPE", "production") + t.Setenv("ENV_PUBLIC_ALLOWED_IP", "") + + defer func() { + if r := recover(); r == nil { + t.Fatalf("expected panic") + } + }() + + MakeEnv(portal.GetDefaultValidator()) +} + func TestIgnite(t *testing.T) { content := "ENV_APP_NAME=guss\n" + "ENV_APP_ENV_TYPE=local\n" + @@ -101,7 +115,7 @@ func TestAppHelpers(t *testing.T) { app := &App{} mux := http.NewServeMux() - r := Router{Mux: mux, publicMiddleware: middleware.MakePublicMiddleware("", false)} + r := Router{Mux: mux, Pipeline: middleware.Pipeline{PublicMiddleware: middleware.MakePublicMiddleware("", false)}} app.SetRouter(r) @@ -142,12 +156,12 @@ func TestAppBootRoutes(t *testing.T) { Env: env, Mux: http.NewServeMux(), Pipeline: middleware.Pipeline{ - Env: env, - ApiKeys: &repository.ApiKeys{DB: &database.Connection{}}, - TokenHandler: handler, + Env: env, + ApiKeys: &repository.ApiKeys{DB: &database.Connection{}}, + TokenHandler: handler, + PublicMiddleware: middleware.MakePublicMiddleware("", false), }, - Db: &database.Connection{}, - publicMiddleware: middleware.MakePublicMiddleware("", false), + Db: &database.Connection{}, } app := &App{} diff --git a/metal/kernel/router.go b/metal/kernel/router.go index 9df1f666..8b76564b 100644 --- a/metal/kernel/router.go +++ b/metal/kernel/router.go @@ -23,19 +23,18 @@ func addStaticRoute[H StaticRouteResource](r *Router, path, file string, maker f } type Router struct { - Env *env.Environment - Mux *baseHttp.ServeMux - Pipeline middleware.Pipeline - Db *database.Connection - validator *portal.Validator - publicMiddleware middleware.PublicMiddleware + Env *env.Environment + Mux *baseHttp.ServeMux + Pipeline middleware.Pipeline + Db *database.Connection + validator *portal.Validator } func (r *Router) PublicPipelineFor(apiHandler http.ApiHandler) baseHttp.HandlerFunc { return http.MakeApiHandler( r.Pipeline.Chain( apiHandler, - r.publicMiddleware.Handle, + r.Pipeline.PublicMiddleware.Handle, ), ) } diff --git a/metal/kernel/router_signature_test.go b/metal/kernel/router_signature_test.go index 9dbef9c6..e246cca2 100644 --- a/metal/kernel/router_signature_test.go +++ b/metal/kernel/router_signature_test.go @@ -14,14 +14,15 @@ import ( func TestSignatureRoute_PublicMiddleware(t *testing.T) { r := Router{ - Mux: http.NewServeMux(), - Pipeline: middleware.Pipeline{}, - validator: portal.GetDefaultValidator(), - publicMiddleware: middleware.MakePublicMiddleware("", false), - } - r.Signature() + Mux: http.NewServeMux(), + Pipeline: middleware.Pipeline{ + PublicMiddleware: middleware.MakePublicMiddleware("", false), + }, + validator: portal.GetDefaultValidator(), + } + r.Signature() - t.Run("request without public headers is unauthorized", func(t *testing.T) { + t.Run("request without public headers is unauthorized", func(t *testing.T) { req := httptest.NewRequest("POST", "/generate-signature", nil) rec := httptest.NewRecorder() r.Mux.ServeHTTP(rec, req) @@ -30,34 +31,35 @@ func TestSignatureRoute_PublicMiddleware(t *testing.T) { } }) - t.Run("request with public headers but invalid body is bad request", func(t *testing.T) { - req := httptest.NewRequest("POST", "/generate-signature", strings.NewReader("{")) - req.Header.Set(portal.RequestIDHeader, "req-1") - req.Header.Set(portal.TimestampHeader, fmt.Sprintf("%d", time.Now().Unix())) - rec := httptest.NewRecorder() - r.Mux.ServeHTTP(rec, req) - if rec.Code != http.StatusBadRequest { - t.Fatalf("expected status %d, got %d", http.StatusBadRequest, rec.Code) - } - }) + t.Run("request with public headers but invalid body is bad request", func(t *testing.T) { + req := httptest.NewRequest("POST", "/generate-signature", strings.NewReader("{")) + req.Header.Set(portal.RequestIDHeader, "req-1") + req.Header.Set(portal.TimestampHeader, fmt.Sprintf("%d", time.Now().Unix())) + rec := httptest.NewRecorder() + r.Mux.ServeHTTP(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, rec.Code) + } + }) - t.Run("production rejects requests from non-whitelisted IP", func(t *testing.T) { - r := Router{ - Mux: http.NewServeMux(), - Pipeline: middleware.Pipeline{}, - validator: portal.GetDefaultValidator(), - publicMiddleware: middleware.MakePublicMiddleware("31.97.60.190", true), - } - r.Signature() + t.Run("production rejects requests from non-whitelisted IP", func(t *testing.T) { + r := Router{ + Mux: http.NewServeMux(), + Pipeline: middleware.Pipeline{ + PublicMiddleware: middleware.MakePublicMiddleware("31.97.60.190", true), + }, + validator: portal.GetDefaultValidator(), + } + r.Signature() - req := httptest.NewRequest("POST", "/generate-signature", strings.NewReader("{")) - req.Header.Set(portal.RequestIDHeader, "req-1") - req.Header.Set(portal.TimestampHeader, fmt.Sprintf("%d", time.Now().Unix())) - req.Header.Set("X-Forwarded-For", "1.2.3.4") - rec := httptest.NewRecorder() - r.Mux.ServeHTTP(rec, req) - if rec.Code != http.StatusUnauthorized { - t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rec.Code) - } - }) + req := httptest.NewRequest("POST", "/generate-signature", strings.NewReader("{")) + req.Header.Set(portal.RequestIDHeader, "req-1") + req.Header.Set(portal.TimestampHeader, fmt.Sprintf("%d", time.Now().Unix())) + req.Header.Set("X-Forwarded-For", "1.2.3.4") + rec := httptest.NewRecorder() + r.Mux.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rec.Code) + } + }) } diff --git a/pkg/middleware/pipeline.go b/pkg/middleware/pipeline.go index 8220f127..ac53b585 100644 --- a/pkg/middleware/pipeline.go +++ b/pkg/middleware/pipeline.go @@ -8,9 +8,10 @@ import ( ) type Pipeline struct { - Env *env.Environment - ApiKeys *repository.ApiKeys - TokenHandler *auth.TokenHandler + Env *env.Environment + ApiKeys *repository.ApiKeys + TokenHandler *auth.TokenHandler + PublicMiddleware PublicMiddleware } func (m Pipeline) Chain(h http.ApiHandler, handlers ...http.Middleware) http.ApiHandler { diff --git a/pkg/middleware/public_middleware.go b/pkg/middleware/public_middleware.go index 5fd56d8c..fb9aa365 100644 --- a/pkg/middleware/public_middleware.go +++ b/pkg/middleware/public_middleware.go @@ -18,30 +18,30 @@ import ( // a simple in-memory rate limiter keyed by client IP. Reuse of a // request ID within a TTL window is rejected via TTLCache. type PublicMiddleware struct { - clockSkew time.Duration - disallowFuture bool - requestTTL time.Duration - rateLimiter *limiter.MemoryLimiter - requestCache *cache.TTLCache - now func() time.Time - allowedIP string - isProduction bool + clockSkew time.Duration + disallowFuture bool + requestTTL time.Duration + rateLimiter *limiter.MemoryLimiter + requestCache *cache.TTLCache + now func() time.Time + allowedIP string + isProduction bool } // MakePublicMiddleware constructs a PublicMiddleware with sane defaults. // allowedIP restricts traffic to a specific client IP when isProduction is true. // When not in production or allowedIP is blank, all IPs are permitted. func MakePublicMiddleware(allowedIP string, isProduction bool) PublicMiddleware { - return PublicMiddleware{ - clockSkew: 5 * time.Minute, - disallowFuture: true, - requestTTL: 5 * time.Minute, - rateLimiter: limiter.NewMemoryLimiter(1*time.Minute, 10), - requestCache: cache.NewTTLCache(), - now: time.Now, - allowedIP: strings.TrimSpace(allowedIP), - isProduction: isProduction, - } + return PublicMiddleware{ + clockSkew: 5 * time.Minute, + disallowFuture: true, + requestTTL: 5 * time.Minute, + rateLimiter: limiter.NewMemoryLimiter(1*time.Minute, 10), + requestCache: cache.NewTTLCache(), + now: time.Now, + allowedIP: strings.TrimSpace(allowedIP), + isProduction: isProduction, + } } func (p PublicMiddleware) Handle(next http.ApiHandler) http.ApiHandler { @@ -56,14 +56,14 @@ func (p PublicMiddleware) Handle(next http.ApiHandler) http.ApiHandler { return mwguards.InvalidRequestError("Invalid authentication headers", "") } - ip := portal.ParseClientIP(r) - if ip == "" { - return mwguards.InvalidRequestError("Invalid client IP", "") - } + ip := portal.ParseClientIP(r) + if ip == "" { + return mwguards.InvalidRequestError("Invalid client IP", "") + } - if p.isProduction && p.allowedIP != "" && ip != p.allowedIP { - return mwguards.InvalidRequestError("Invalid client IP", "unauthorised ip: "+ip) - } + if p.isProduction && ip != p.allowedIP { + return mwguards.InvalidRequestError("Invalid client IP", "unauthorised ip: "+ip) + } limiterKey := ip if p.rateLimiter.TooMany(limiterKey) { diff --git a/pkg/middleware/public_middleware_test.go b/pkg/middleware/public_middleware_test.go index ed6f8d44..73fd1c0e 100644 --- a/pkg/middleware/public_middleware_test.go +++ b/pkg/middleware/public_middleware_test.go @@ -13,7 +13,7 @@ import ( ) func TestPublicMiddleware_InvalidHeaders(t *testing.T) { - pm := MakePublicMiddleware("", false) + pm := MakePublicMiddleware("", false) handler := pm.Handle(func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError { return nil }) base := time.Unix(1_700_000_000, 0) @@ -58,7 +58,7 @@ func TestPublicMiddleware_InvalidHeaders(t *testing.T) { } func TestPublicMiddleware_TimestampExpired(t *testing.T) { - pm := MakePublicMiddleware("", false) + pm := MakePublicMiddleware("", false) base := time.Unix(1_700_000_000, 0) pm.now = func() time.Time { return base } handler := pm.Handle(func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError { return nil }) @@ -75,7 +75,7 @@ func TestPublicMiddleware_TimestampExpired(t *testing.T) { } func TestPublicMiddleware_RateLimitAndReplay(t *testing.T) { - pm := MakePublicMiddleware("", false) + pm := MakePublicMiddleware("", false) pm.rateLimiter = limiter.NewMemoryLimiter(time.Minute, 1) base := time.Unix(1_700_000_000, 0) pm.now = func() time.Time { return base } @@ -107,50 +107,50 @@ func TestPublicMiddleware_RateLimitAndReplay(t *testing.T) { req3.Header.Set(portal.RequestIDHeader, "def") req3.Header.Set(portal.TimestampHeader, strconv.FormatInt(base.Unix(), 10)) req3.Header.Set("X-Forwarded-For", "1.2.3.4") - if err := handler(rec3, req3); err == nil || err.Status != http.StatusTooManyRequests { - t.Fatalf("expected rate limit error, got %#v", err) - } + if err := handler(rec3, req3); err == nil || err.Status != http.StatusTooManyRequests { + t.Fatalf("expected rate limit error, got %#v", err) + } } func TestPublicMiddleware_IPWhitelist(t *testing.T) { - base := time.Unix(1_700_000_000, 0) - pm := MakePublicMiddleware("31.97.60.190", true) - pm.now = func() time.Time { return base } - handler := pm.Handle(func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError { return nil }) + base := time.Unix(1_700_000_000, 0) + pm := MakePublicMiddleware("31.97.60.190", true) + pm.now = func() time.Time { return base } + handler := pm.Handle(func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError { return nil }) - t.Run("allowed ip passes", func(t *testing.T) { - rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/", nil) - req.Header.Set(portal.RequestIDHeader, "req-1") - req.Header.Set(portal.TimestampHeader, strconv.FormatInt(base.Unix(), 10)) - req.Header.Set("X-Forwarded-For", "31.97.60.190") - if err := handler(rec, req); err != nil { - t.Fatalf("expected request to pass, got %#v", err) - } - }) + t.Run("allowed ip passes", func(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set(portal.RequestIDHeader, "req-1") + req.Header.Set(portal.TimestampHeader, strconv.FormatInt(base.Unix(), 10)) + req.Header.Set("X-Forwarded-For", "31.97.60.190") + if err := handler(rec, req); err != nil { + t.Fatalf("expected request to pass, got %#v", err) + } + }) - t.Run("other ip rejected in production", func(t *testing.T) { - rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/", nil) - req.Header.Set(portal.RequestIDHeader, "req-1") - req.Header.Set(portal.TimestampHeader, strconv.FormatInt(base.Unix(), 10)) - req.Header.Set("X-Forwarded-For", "1.2.3.4") - if err := handler(rec, req); err == nil || err.Status != http.StatusUnauthorized { - t.Fatalf("expected unauthorized, got %#v", err) - } - }) + t.Run("other ip rejected in production", func(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set(portal.RequestIDHeader, "req-1") + req.Header.Set(portal.TimestampHeader, strconv.FormatInt(base.Unix(), 10)) + req.Header.Set("X-Forwarded-For", "1.2.3.4") + if err := handler(rec, req); err == nil || err.Status != http.StatusUnauthorized { + t.Fatalf("expected unauthorized, got %#v", err) + } + }) - t.Run("non-production skips restriction", func(t *testing.T) { - pm := MakePublicMiddleware("31.97.60.190", false) - pm.now = func() time.Time { return base } - handler := pm.Handle(func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError { return nil }) - rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/", nil) - req.Header.Set(portal.RequestIDHeader, "req-1") - req.Header.Set(portal.TimestampHeader, strconv.FormatInt(base.Unix(), 10)) - req.Header.Set("X-Forwarded-For", "1.2.3.4") - if err := handler(rec, req); err != nil { - t.Fatalf("expected request to pass, got %#v", err) - } - }) + t.Run("non-production skips restriction", func(t *testing.T) { + pm := MakePublicMiddleware("31.97.60.190", false) + pm.now = func() time.Time { return base } + handler := pm.Handle(func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError { return nil }) + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set(portal.RequestIDHeader, "req-1") + req.Header.Set(portal.TimestampHeader, strconv.FormatInt(base.Unix(), 10)) + req.Header.Set("X-Forwarded-For", "1.2.3.4") + if err := handler(rec, req); err != nil { + t.Fatalf("expected request to pass, got %#v", err) + } + }) } From 2da549f839e8172988fbe5c84e4efc84957d3708 Mon Sep 17 00:00:00 2001 From: Gustavo Ocanto Date: Tue, 9 Sep 2025 12:18:12 +0800 Subject: [PATCH 7/7] format --- metal/kernel/app.go | 11 ++++--- metal/kernel/factory.go | 2 +- pkg/middleware/public_middleware.go | 50 ++++++++++++++++++++--------- pkg/portal/support.go | 1 + 4 files changed, 43 insertions(+), 21 deletions(-) diff --git a/metal/kernel/app.go b/metal/kernel/app.go index e509cb2c..b581f976 100644 --- a/metal/kernel/app.go +++ b/metal/kernel/app.go @@ -47,10 +47,13 @@ func MakeApp(e *env.Environment, validator *portal.Validator) (*App, error) { Mux: baseHttp.NewServeMux(), validator: validator, Pipeline: middleware.Pipeline{ - Env: e, - ApiKeys: &repository.ApiKeys{DB: db}, - TokenHandler: tokenHandler, - PublicMiddleware: middleware.MakePublicMiddleware(e.Network.PublicAllowedIP, e.App.IsProduction()), + Env: e, + ApiKeys: &repository.ApiKeys{DB: db}, + TokenHandler: tokenHandler, + PublicMiddleware: middleware.MakePublicMiddleware( + e.Network.PublicAllowedIP, + e.Network.IsProduction, + ), }, } diff --git a/metal/kernel/factory.go b/metal/kernel/factory.go index 2067864c..12d6e96a 100644 --- a/metal/kernel/factory.go +++ b/metal/kernel/factory.go @@ -87,7 +87,7 @@ func MakeEnv(validate *portal.Validator) *env.Environment { HttpHost: env.GetEnvVar("ENV_HTTP_HOST"), HttpPort: env.GetEnvVar("ENV_HTTP_PORT"), PublicAllowedIP: env.GetEnvVar("ENV_PUBLIC_ALLOWED_IP"), - IsProduction: app.IsProduction(), + IsProduction: app.IsProduction(), // --- only needed for validation purposes } sentryEnvironment := env.SentryEnvironment{ diff --git a/pkg/middleware/public_middleware.go b/pkg/middleware/public_middleware.go index fb9aa365..c0357443 100644 --- a/pkg/middleware/public_middleware.go +++ b/pkg/middleware/public_middleware.go @@ -28,9 +28,6 @@ type PublicMiddleware struct { isProduction bool } -// MakePublicMiddleware constructs a PublicMiddleware with sane defaults. -// allowedIP restricts traffic to a specific client IP when isProduction is true. -// When not in production or allowedIP is blank, all IPs are permitted. func MakePublicMiddleware(allowedIP string, isProduction bool) PublicMiddleware { return PublicMiddleware{ clockSkew: 5 * time.Minute, @@ -46,38 +43,41 @@ func MakePublicMiddleware(allowedIP string, isProduction bool) PublicMiddleware func (p PublicMiddleware) Handle(next http.ApiHandler) http.ApiHandler { return func(w baseHttp.ResponseWriter, r *baseHttp.Request) *http.ApiError { - if err := p.guardDependencies(); err != nil { + if err := p.GuardDependencies(); err != nil { return err } - reqID := strings.TrimSpace(r.Header.Get(portal.RequestIDHeader)) + uri := portal.GenerateURL(r) ts := strings.TrimSpace(r.Header.Get(portal.TimestampHeader)) + reqID := strings.TrimSpace(r.Header.Get(portal.RequestIDHeader)) + if reqID == "" || ts == "" { return mwguards.InvalidRequestError("Invalid authentication headers", "") } - ip := portal.ParseClientIP(r) - if ip == "" { - return mwguards.InvalidRequestError("Invalid client IP", "") - } + limiterKey := strings.Join([]string{uri, reqID, ts}, "|") - if p.isProduction && ip != p.allowedIP { - return mwguards.InvalidRequestError("Invalid client IP", "unauthorised ip: "+ip) - } - - limiterKey := ip if p.rateLimiter.TooMany(limiterKey) { return mwguards.RateLimitedError("Too many requests", "Too many requests for key: "+limiterKey) } + if err := p.HasInvalidIP(r); err != nil { + p.rateLimiter.Fail(limiterKey) + + return err + } + vt := NewValidTimestamp(ts, p.now) if err := vt.Validate(p.clockSkew, p.disallowFuture); err != nil { + p.rateLimiter.Fail(limiterKey) + return err } - key := strings.Join([]string{limiterKey, reqID, ip}, "|") + key := strings.Join([]string{limiterKey, reqID}, "|") if p.requestCache.UseOnce(key, p.requestTTL) { p.rateLimiter.Fail(limiterKey) + return mwguards.UnauthenticatedError( "Invalid request id", "duplicate request id: "+key, @@ -89,17 +89,35 @@ func (p PublicMiddleware) Handle(next http.ApiHandler) http.ApiHandler { } } -func (p PublicMiddleware) guardDependencies() *http.ApiError { +func (p PublicMiddleware) HasInvalidIP(r *baseHttp.Request) *http.ApiError { + ip := portal.ParseClientIP(r) + + if ip == "" { + return mwguards.InvalidRequestError("Clients IPs are required to access this endpoint", "") + } + + if p.isProduction && ip != p.allowedIP { + return mwguards.InvalidRequestError("The given IP is not allowed", "unauthorised ip: "+ip) + } + + return nil +} + +func (p PublicMiddleware) GuardDependencies() *http.ApiError { missing := []string{} + if p.requestCache == nil { missing = append(missing, "requestCache") } + if p.rateLimiter == nil { missing = append(missing, "rateLimiter") } + if len(missing) > 0 { err := fmt.Errorf("public middleware missing dependencies: %s", strings.Join(missing, ",")) return http.LogInternalError("public middleware missing dependencies", err) } + return nil } diff --git a/pkg/portal/support.go b/pkg/portal/support.go index 99d49704..1d0c8441 100644 --- a/pkg/portal/support.go +++ b/pkg/portal/support.go @@ -94,6 +94,7 @@ func BuildCanonical(method string, u *url.URL, username, public, ts, nonce, body func ParseClientIP(r *baseHttp.Request) string { // prefer X-Forwarded-For if present + xff := strings.TrimSpace(r.Header.Get("X-Forwarded-For")) if xff != "" { // take first IP