diff --git a/.env.example b/.env.example index 8ed420dc..9614e3bd 100644 --- a/.env.example +++ b/.env.example @@ -6,6 +6,10 @@ ENV_APP_LOGS_DIR="./storage/logs/logs_%s.log" ENV_APP_LOGS_DATE_FORMAT="2006_02_01" ENV_APP_MASTER_KEY= +# --- Public middleware +# Optional IP whitelist for production-only public routes +ENV_PUBLIC_ALLOWED_IP= + # --- DB ENV_DB_USER_NAME="gus" ENV_DB_USER_PASSWORD="password" diff --git a/.env.gh.example b/.env.gh.example index 4d18b68a..ec9f6582 100644 --- a/.env.gh.example +++ b/.env.gh.example @@ -2,3 +2,4 @@ ENV_HTTP_PORT=8080 ENV_DOCKER_USER=gocanto ENV_DOCKER_USER_GROUP=ggroup CADDY_LOGS_PATH=./storage/logs/caddy +ENV_PUBLIC_ALLOWED_IP= diff --git a/.env.prod.example b/.env.prod.example index ebbf82b7..a0393fed 100644 --- a/.env.prod.example +++ b/.env.prod.example @@ -5,6 +5,9 @@ ENV_APP_LOG_LEVEL=debug ENV_APP_LOGS_DIR="./storage/logs/logs_%s.log" ENV_APP_LOGS_DATE_FORMAT="2006_02_01" +# --- Public middleware +ENV_PUBLIC_ALLOWED_IP= + # --- DB ENV_DB_PORT= ENV_DB_HOST= diff --git a/handler/signatures.go b/handler/signatures.go index 61c56f67..1e080f92 100644 --- a/handler/signatures.go +++ b/handler/signatures.go @@ -51,10 +51,6 @@ func (s *SignaturesHandler) Generate(w baseHttp.ResponseWriter, r *baseHttp.Requ receivedAt := time.Unix(req.Timestamp, 0) req.Origin = r.Header.Get(portal.IntendedOriginHeader) - if err = s.isRequestWithinTimeframe(serverTime, receivedAt); err != nil { - return http.LogBadRequestError(err.Error(), err) - } - var keySignature *database.APIKeySignatures if keySignature, err = s.CreateSignature(req, serverTime); err != nil { return http.LogInternalError(err.Error(), err) @@ -80,22 +76,6 @@ func (s *SignaturesHandler) Generate(w baseHttp.ResponseWriter, r *baseHttp.Requ return nil // A nil return indicates success. } -func (s *SignaturesHandler) isRequestWithinTimeframe(serverTime, receivedAt time.Time) error { - skew := 15 * time.Second - - earliestValidTime := serverTime.Add(-skew) - if receivedAt.Before(earliestValidTime) { - return fmt.Errorf("the request timestamp [%s] is too old", receivedAt.Format(portal.DatesLayout)) - } - - latestValidTime := serverTime.Add(skew) - if receivedAt.After(latestValidTime) { - return fmt.Errorf("the request timestamp [%s] is from the future", receivedAt.Format(portal.DatesLayout)) - } - - return nil -} - func (s *SignaturesHandler) CreateSignature(request payload.SignatureRequest, serverTime time.Time) (*database.APIKeySignatures, error) { var err error var token *database.APIKey diff --git a/metal/env/network.go b/metal/env/network.go index f66d91c6..5ca17dd4 100644 --- a/metal/env/network.go +++ b/metal/env/network.go @@ -1,8 +1,10 @@ package env type NetEnvironment struct { - HttpHost string `validate:"required,lowercase,min=7"` - HttpPort string `validate:"required,numeric,oneof=8080"` + HttpHost string `validate:"required,lowercase,min=7"` + HttpPort string `validate:"required,numeric,oneof=8080"` + PublicAllowedIP string `validate:"required_if=IsProduction true,omitempty,ip"` + IsProduction bool `validate:"-"` } func (e NetEnvironment) GetHttpPort() string { diff --git a/metal/kernel/app.go b/metal/kernel/app.go index b9627343..b581f976 100644 --- a/metal/kernel/app.go +++ b/metal/kernel/app.go @@ -22,34 +22,38 @@ type App struct { db *database.Connection } -func MakeApp(env *env.Environment, validator *portal.Validator) (*App, error) { +func MakeApp(e *env.Environment, validator *portal.Validator) (*App, error) { tokenHandler, err := auth.MakeTokensHandler( - []byte(env.App.MasterKey), + []byte(e.App.MasterKey), ) if err != nil { return nil, fmt.Errorf("bootstrapping error > could not create a token handler: %w", err) } - db := MakeDbConnection(env) + db := MakeDbConnection(e) app := App{ - env: env, + env: e, validator: validator, - logs: MakeLogs(env), - sentry: MakeSentry(env), + logs: MakeLogs(e), + sentry: MakeSentry(e), db: db, } router := Router{ - Env: env, + Env: e, Db: db, Mux: baseHttp.NewServeMux(), validator: validator, Pipeline: middleware.Pipeline{ - Env: env, + Env: e, ApiKeys: &repository.ApiKeys{DB: db}, TokenHandler: tokenHandler, + PublicMiddleware: middleware.MakePublicMiddleware( + e.Network.PublicAllowedIP, + e.Network.IsProduction, + ), }, } diff --git a/metal/kernel/factory.go b/metal/kernel/factory.go index 3a92956d..12d6e96a 100644 --- a/metal/kernel/factory.go +++ b/metal/kernel/factory.go @@ -84,8 +84,10 @@ func MakeEnv(validate *portal.Validator) *env.Environment { } net := env.NetEnvironment{ - HttpHost: env.GetEnvVar("ENV_HTTP_HOST"), - HttpPort: env.GetEnvVar("ENV_HTTP_PORT"), + HttpHost: env.GetEnvVar("ENV_HTTP_HOST"), + HttpPort: env.GetEnvVar("ENV_HTTP_PORT"), + PublicAllowedIP: env.GetEnvVar("ENV_PUBLIC_ALLOWED_IP"), + IsProduction: app.IsProduction(), // --- only needed for validation purposes } sentryEnvironment := env.SentryEnvironment{ diff --git a/metal/kernel/kernel_test.go b/metal/kernel/kernel_test.go index 8ac19909..a2d48bd6 100644 --- a/metal/kernel/kernel_test.go +++ b/metal/kernel/kernel_test.go @@ -33,6 +33,7 @@ func validEnvVars(t *testing.T) { t.Setenv("ENV_HTTP_PORT", "8080") t.Setenv("ENV_SENTRY_DSN", "dsn") t.Setenv("ENV_SENTRY_CSP", "csp") + t.Setenv("ENV_PUBLIC_ALLOWED_IP", "1.2.3.4") } func TestMakeEnv(t *testing.T) { @@ -43,6 +44,24 @@ func TestMakeEnv(t *testing.T) { if env.App.Name != "guss" { t.Fatalf("env not loaded") } + + if env.Network.PublicAllowedIP != "1.2.3.4" { + t.Fatalf("expected public allowed ip to be loaded") + } +} + +func TestMakeEnvRequiresIPInProduction(t *testing.T) { + validEnvVars(t) + t.Setenv("ENV_APP_ENV_TYPE", "production") + t.Setenv("ENV_PUBLIC_ALLOWED_IP", "") + + defer func() { + if r := recover(); r == nil { + t.Fatalf("expected panic") + } + }() + + MakeEnv(portal.GetDefaultValidator()) } func TestIgnite(t *testing.T) { @@ -96,7 +115,7 @@ func TestAppHelpers(t *testing.T) { app := &App{} mux := http.NewServeMux() - r := Router{Mux: mux} + r := Router{Mux: mux, Pipeline: middleware.Pipeline{PublicMiddleware: middleware.MakePublicMiddleware("", false)}} app.SetRouter(r) @@ -137,9 +156,10 @@ func TestAppBootRoutes(t *testing.T) { Env: env, Mux: http.NewServeMux(), Pipeline: middleware.Pipeline{ - Env: env, - ApiKeys: &repository.ApiKeys{DB: &database.Connection{}}, - TokenHandler: handler, + Env: env, + ApiKeys: &repository.ApiKeys{DB: &database.Connection{}}, + TokenHandler: handler, + PublicMiddleware: middleware.MakePublicMiddleware("", false), }, Db: &database.Connection{}, } diff --git a/metal/kernel/router.go b/metal/kernel/router.go index e8014b5f..8b76564b 100644 --- a/metal/kernel/router.go +++ b/metal/kernel/router.go @@ -34,6 +34,7 @@ func (r *Router) PublicPipelineFor(apiHandler http.ApiHandler) baseHttp.HandlerF return http.MakeApiHandler( r.Pipeline.Chain( apiHandler, + r.Pipeline.PublicMiddleware.Handle, ), ) } diff --git a/metal/kernel/router_signature_test.go b/metal/kernel/router_signature_test.go new file mode 100644 index 00000000..e246cca2 --- /dev/null +++ b/metal/kernel/router_signature_test.go @@ -0,0 +1,65 @@ +package kernel + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/oullin/pkg/middleware" + "github.com/oullin/pkg/portal" +) + +func TestSignatureRoute_PublicMiddleware(t *testing.T) { + r := Router{ + Mux: http.NewServeMux(), + Pipeline: middleware.Pipeline{ + PublicMiddleware: middleware.MakePublicMiddleware("", false), + }, + validator: portal.GetDefaultValidator(), + } + r.Signature() + + t.Run("request without public headers is unauthorized", func(t *testing.T) { + req := httptest.NewRequest("POST", "/generate-signature", nil) + rec := httptest.NewRecorder() + r.Mux.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rec.Code) + } + }) + + t.Run("request with public headers but invalid body is bad request", func(t *testing.T) { + req := httptest.NewRequest("POST", "/generate-signature", strings.NewReader("{")) + req.Header.Set(portal.RequestIDHeader, "req-1") + req.Header.Set(portal.TimestampHeader, fmt.Sprintf("%d", time.Now().Unix())) + rec := httptest.NewRecorder() + r.Mux.ServeHTTP(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, rec.Code) + } + }) + + t.Run("production rejects requests from non-whitelisted IP", func(t *testing.T) { + r := Router{ + Mux: http.NewServeMux(), + Pipeline: middleware.Pipeline{ + PublicMiddleware: middleware.MakePublicMiddleware("31.97.60.190", true), + }, + validator: portal.GetDefaultValidator(), + } + r.Signature() + + req := httptest.NewRequest("POST", "/generate-signature", strings.NewReader("{")) + req.Header.Set(portal.RequestIDHeader, "req-1") + req.Header.Set(portal.TimestampHeader, fmt.Sprintf("%d", time.Now().Unix())) + req.Header.Set("X-Forwarded-For", "1.2.3.4") + rec := httptest.NewRecorder() + r.Mux.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rec.Code) + } + }) +} diff --git a/pkg/middleware/pipeline.go b/pkg/middleware/pipeline.go index 8220f127..ac53b585 100644 --- a/pkg/middleware/pipeline.go +++ b/pkg/middleware/pipeline.go @@ -8,9 +8,10 @@ import ( ) type Pipeline struct { - Env *env.Environment - ApiKeys *repository.ApiKeys - TokenHandler *auth.TokenHandler + Env *env.Environment + ApiKeys *repository.ApiKeys + TokenHandler *auth.TokenHandler + PublicMiddleware PublicMiddleware } func (m Pipeline) Chain(h http.ApiHandler, handlers ...http.Middleware) http.ApiHandler { diff --git a/pkg/middleware/public_middleware.go b/pkg/middleware/public_middleware.go index 94e41254..c0357443 100644 --- a/pkg/middleware/public_middleware.go +++ b/pkg/middleware/public_middleware.go @@ -24,10 +24,11 @@ type PublicMiddleware struct { rateLimiter *limiter.MemoryLimiter requestCache *cache.TTLCache now func() time.Time + allowedIP string + isProduction bool } -// MakePublicMiddleware constructs a PublicMiddleware with sane defaults. -func MakePublicMiddleware() PublicMiddleware { +func MakePublicMiddleware(allowedIP string, isProduction bool) PublicMiddleware { return PublicMiddleware{ clockSkew: 5 * time.Minute, disallowFuture: true, @@ -35,39 +36,48 @@ func MakePublicMiddleware() PublicMiddleware { rateLimiter: limiter.NewMemoryLimiter(1*time.Minute, 10), requestCache: cache.NewTTLCache(), now: time.Now, + allowedIP: strings.TrimSpace(allowedIP), + isProduction: isProduction, } } func (p PublicMiddleware) Handle(next http.ApiHandler) http.ApiHandler { return func(w baseHttp.ResponseWriter, r *baseHttp.Request) *http.ApiError { - if err := p.guardDependencies(); err != nil { + if err := p.GuardDependencies(); err != nil { return err } - reqID := strings.TrimSpace(r.Header.Get(portal.RequestIDHeader)) + uri := portal.GenerateURL(r) ts := strings.TrimSpace(r.Header.Get(portal.TimestampHeader)) + reqID := strings.TrimSpace(r.Header.Get(portal.RequestIDHeader)) + if reqID == "" || ts == "" { return mwguards.InvalidRequestError("Invalid authentication headers", "") } - ip := portal.ParseClientIP(r) - if ip == "" { - return mwguards.InvalidRequestError("Invalid client IP", "") - } + limiterKey := strings.Join([]string{uri, reqID, ts}, "|") - limiterKey := ip if p.rateLimiter.TooMany(limiterKey) { return mwguards.RateLimitedError("Too many requests", "Too many requests for key: "+limiterKey) } + if err := p.HasInvalidIP(r); err != nil { + p.rateLimiter.Fail(limiterKey) + + return err + } + vt := NewValidTimestamp(ts, p.now) if err := vt.Validate(p.clockSkew, p.disallowFuture); err != nil { + p.rateLimiter.Fail(limiterKey) + return err } - key := strings.Join([]string{limiterKey, reqID, ip}, "|") + key := strings.Join([]string{limiterKey, reqID}, "|") if p.requestCache.UseOnce(key, p.requestTTL) { p.rateLimiter.Fail(limiterKey) + return mwguards.UnauthenticatedError( "Invalid request id", "duplicate request id: "+key, @@ -79,17 +89,35 @@ func (p PublicMiddleware) Handle(next http.ApiHandler) http.ApiHandler { } } -func (p PublicMiddleware) guardDependencies() *http.ApiError { +func (p PublicMiddleware) HasInvalidIP(r *baseHttp.Request) *http.ApiError { + ip := portal.ParseClientIP(r) + + if ip == "" { + return mwguards.InvalidRequestError("Clients IPs are required to access this endpoint", "") + } + + if p.isProduction && ip != p.allowedIP { + return mwguards.InvalidRequestError("The given IP is not allowed", "unauthorised ip: "+ip) + } + + return nil +} + +func (p PublicMiddleware) GuardDependencies() *http.ApiError { missing := []string{} + if p.requestCache == nil { missing = append(missing, "requestCache") } + if p.rateLimiter == nil { missing = append(missing, "rateLimiter") } + if len(missing) > 0 { err := fmt.Errorf("public middleware missing dependencies: %s", strings.Join(missing, ",")) return http.LogInternalError("public middleware missing dependencies", err) } + return nil } diff --git a/pkg/middleware/public_middleware_test.go b/pkg/middleware/public_middleware_test.go index 60608382..73fd1c0e 100644 --- a/pkg/middleware/public_middleware_test.go +++ b/pkg/middleware/public_middleware_test.go @@ -13,7 +13,7 @@ import ( ) func TestPublicMiddleware_InvalidHeaders(t *testing.T) { - pm := MakePublicMiddleware() + pm := MakePublicMiddleware("", false) handler := pm.Handle(func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError { return nil }) base := time.Unix(1_700_000_000, 0) @@ -58,7 +58,7 @@ func TestPublicMiddleware_InvalidHeaders(t *testing.T) { } func TestPublicMiddleware_TimestampExpired(t *testing.T) { - pm := MakePublicMiddleware() + pm := MakePublicMiddleware("", false) base := time.Unix(1_700_000_000, 0) pm.now = func() time.Time { return base } handler := pm.Handle(func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError { return nil }) @@ -75,7 +75,7 @@ func TestPublicMiddleware_TimestampExpired(t *testing.T) { } func TestPublicMiddleware_RateLimitAndReplay(t *testing.T) { - pm := MakePublicMiddleware() + pm := MakePublicMiddleware("", false) pm.rateLimiter = limiter.NewMemoryLimiter(time.Minute, 1) base := time.Unix(1_700_000_000, 0) pm.now = func() time.Time { return base } @@ -111,3 +111,46 @@ func TestPublicMiddleware_RateLimitAndReplay(t *testing.T) { t.Fatalf("expected rate limit error, got %#v", err) } } + +func TestPublicMiddleware_IPWhitelist(t *testing.T) { + base := time.Unix(1_700_000_000, 0) + pm := MakePublicMiddleware("31.97.60.190", true) + pm.now = func() time.Time { return base } + handler := pm.Handle(func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError { return nil }) + + t.Run("allowed ip passes", func(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set(portal.RequestIDHeader, "req-1") + req.Header.Set(portal.TimestampHeader, strconv.FormatInt(base.Unix(), 10)) + req.Header.Set("X-Forwarded-For", "31.97.60.190") + if err := handler(rec, req); err != nil { + t.Fatalf("expected request to pass, got %#v", err) + } + }) + + t.Run("other ip rejected in production", func(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set(portal.RequestIDHeader, "req-1") + req.Header.Set(portal.TimestampHeader, strconv.FormatInt(base.Unix(), 10)) + req.Header.Set("X-Forwarded-For", "1.2.3.4") + if err := handler(rec, req); err == nil || err.Status != http.StatusUnauthorized { + t.Fatalf("expected unauthorized, got %#v", err) + } + }) + + t.Run("non-production skips restriction", func(t *testing.T) { + pm := MakePublicMiddleware("31.97.60.190", false) + pm.now = func() time.Time { return base } + handler := pm.Handle(func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError { return nil }) + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set(portal.RequestIDHeader, "req-1") + req.Header.Set(portal.TimestampHeader, strconv.FormatInt(base.Unix(), 10)) + req.Header.Set("X-Forwarded-For", "1.2.3.4") + if err := handler(rec, req); err != nil { + t.Fatalf("expected request to pass, got %#v", err) + } + }) +} diff --git a/pkg/portal/support.go b/pkg/portal/support.go index 99d49704..1d0c8441 100644 --- a/pkg/portal/support.go +++ b/pkg/portal/support.go @@ -94,6 +94,7 @@ func BuildCanonical(method string, u *url.URL, username, public, ts, nonce, body func ParseClientIP(r *baseHttp.Request) string { // prefer X-Forwarded-For if present + xff := strings.TrimSpace(r.Header.Get("X-Forwarded-For")) if xff != "" { // take first IP