Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ ENV_APP_LOGS_DIR="./storage/logs/logs_%s.log"
ENV_APP_LOGS_DATE_FORMAT="2006_02_01"
ENV_APP_MASTER_KEY=

# --- Public middleware
# Optional IP whitelist for production-only public routes
ENV_PUBLIC_ALLOWED_IP=

# --- DB
ENV_DB_USER_NAME="gus"
ENV_DB_USER_PASSWORD="password"
Expand Down
1 change: 1 addition & 0 deletions .env.gh.example
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ ENV_HTTP_PORT=8080
ENV_DOCKER_USER=gocanto
ENV_DOCKER_USER_GROUP=ggroup
CADDY_LOGS_PATH=./storage/logs/caddy
ENV_PUBLIC_ALLOWED_IP=
3 changes: 3 additions & 0 deletions .env.prod.example
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ ENV_APP_LOG_LEVEL=debug
ENV_APP_LOGS_DIR="./storage/logs/logs_%s.log"
ENV_APP_LOGS_DATE_FORMAT="2006_02_01"

# --- Public middleware
ENV_PUBLIC_ALLOWED_IP=

# --- DB
ENV_DB_PORT=
ENV_DB_HOST=
Expand Down
20 changes: 0 additions & 20 deletions handler/signatures.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ func (s *SignaturesHandler) Generate(w baseHttp.ResponseWriter, r *baseHttp.Requ
receivedAt := time.Unix(req.Timestamp, 0)
req.Origin = r.Header.Get(portal.IntendedOriginHeader)

if err = s.isRequestWithinTimeframe(serverTime, receivedAt); err != nil {
return http.LogBadRequestError(err.Error(), err)
}

var keySignature *database.APIKeySignatures
if keySignature, err = s.CreateSignature(req, serverTime); err != nil {
return http.LogInternalError(err.Error(), err)
Expand All @@ -80,22 +76,6 @@ func (s *SignaturesHandler) Generate(w baseHttp.ResponseWriter, r *baseHttp.Requ
return nil // A nil return indicates success.
}

func (s *SignaturesHandler) isRequestWithinTimeframe(serverTime, receivedAt time.Time) error {
skew := 15 * time.Second

earliestValidTime := serverTime.Add(-skew)
if receivedAt.Before(earliestValidTime) {
return fmt.Errorf("the request timestamp [%s] is too old", receivedAt.Format(portal.DatesLayout))
}

latestValidTime := serverTime.Add(skew)
if receivedAt.After(latestValidTime) {
return fmt.Errorf("the request timestamp [%s] is from the future", receivedAt.Format(portal.DatesLayout))
}

return nil
}

func (s *SignaturesHandler) CreateSignature(request payload.SignatureRequest, serverTime time.Time) (*database.APIKeySignatures, error) {
var err error
var token *database.APIKey
Expand Down
6 changes: 4 additions & 2 deletions metal/env/network.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package env

type NetEnvironment struct {
HttpHost string `validate:"required,lowercase,min=7"`
HttpPort string `validate:"required,numeric,oneof=8080"`
HttpHost string `validate:"required,lowercase,min=7"`
HttpPort string `validate:"required,numeric,oneof=8080"`
PublicAllowedIP string `validate:"required_if=IsProduction true,omitempty,ip"`
IsProduction bool `validate:"-"`
}

func (e NetEnvironment) GetHttpPort() string {
Expand Down
20 changes: 12 additions & 8 deletions metal/kernel/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,38 @@ type App struct {
db *database.Connection
}

func MakeApp(env *env.Environment, validator *portal.Validator) (*App, error) {
func MakeApp(e *env.Environment, validator *portal.Validator) (*App, error) {
tokenHandler, err := auth.MakeTokensHandler(
[]byte(env.App.MasterKey),
[]byte(e.App.MasterKey),
)

if err != nil {
return nil, fmt.Errorf("bootstrapping error > could not create a token handler: %w", err)
}

db := MakeDbConnection(env)
db := MakeDbConnection(e)

app := App{
env: env,
env: e,
validator: validator,
logs: MakeLogs(env),
sentry: MakeSentry(env),
logs: MakeLogs(e),
sentry: MakeSentry(e),
db: db,
}

router := Router{
Env: env,
Env: e,
Db: db,
Mux: baseHttp.NewServeMux(),
validator: validator,
Pipeline: middleware.Pipeline{
Env: env,
Env: e,
ApiKeys: &repository.ApiKeys{DB: db},
TokenHandler: tokenHandler,
PublicMiddleware: middleware.MakePublicMiddleware(
e.Network.PublicAllowedIP,
e.Network.IsProduction,
),
},
}

Expand Down
6 changes: 4 additions & 2 deletions metal/kernel/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ func MakeEnv(validate *portal.Validator) *env.Environment {
}

net := env.NetEnvironment{
HttpHost: env.GetEnvVar("ENV_HTTP_HOST"),
HttpPort: env.GetEnvVar("ENV_HTTP_PORT"),
HttpHost: env.GetEnvVar("ENV_HTTP_HOST"),
HttpPort: env.GetEnvVar("ENV_HTTP_PORT"),
PublicAllowedIP: env.GetEnvVar("ENV_PUBLIC_ALLOWED_IP"),
IsProduction: app.IsProduction(), // --- only needed for validation purposes
}

sentryEnvironment := env.SentryEnvironment{
Expand Down
28 changes: 24 additions & 4 deletions metal/kernel/kernel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func validEnvVars(t *testing.T) {
t.Setenv("ENV_HTTP_PORT", "8080")
t.Setenv("ENV_SENTRY_DSN", "dsn")
t.Setenv("ENV_SENTRY_CSP", "csp")
t.Setenv("ENV_PUBLIC_ALLOWED_IP", "1.2.3.4")
}

func TestMakeEnv(t *testing.T) {
Expand All @@ -43,6 +44,24 @@ func TestMakeEnv(t *testing.T) {
if env.App.Name != "guss" {
t.Fatalf("env not loaded")
}

if env.Network.PublicAllowedIP != "1.2.3.4" {
t.Fatalf("expected public allowed ip to be loaded")
}
}

func TestMakeEnvRequiresIPInProduction(t *testing.T) {
validEnvVars(t)
t.Setenv("ENV_APP_ENV_TYPE", "production")
t.Setenv("ENV_PUBLIC_ALLOWED_IP", "")

defer func() {
if r := recover(); r == nil {
t.Fatalf("expected panic")
}
}()

MakeEnv(portal.GetDefaultValidator())
}

func TestIgnite(t *testing.T) {
Expand Down Expand Up @@ -96,7 +115,7 @@ func TestAppHelpers(t *testing.T) {
app := &App{}

mux := http.NewServeMux()
r := Router{Mux: mux}
r := Router{Mux: mux, Pipeline: middleware.Pipeline{PublicMiddleware: middleware.MakePublicMiddleware("", false)}}

app.SetRouter(r)

Expand Down Expand Up @@ -137,9 +156,10 @@ func TestAppBootRoutes(t *testing.T) {
Env: env,
Mux: http.NewServeMux(),
Pipeline: middleware.Pipeline{
Env: env,
ApiKeys: &repository.ApiKeys{DB: &database.Connection{}},
TokenHandler: handler,
Env: env,
ApiKeys: &repository.ApiKeys{DB: &database.Connection{}},
TokenHandler: handler,
PublicMiddleware: middleware.MakePublicMiddleware("", false),
},
Db: &database.Connection{},
}
Expand Down
1 change: 1 addition & 0 deletions metal/kernel/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func (r *Router) PublicPipelineFor(apiHandler http.ApiHandler) baseHttp.HandlerF
return http.MakeApiHandler(
r.Pipeline.Chain(
apiHandler,
r.Pipeline.PublicMiddleware.Handle,
),
)
}
Expand Down
65 changes: 65 additions & 0 deletions metal/kernel/router_signature_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package kernel

import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"

"github.com/oullin/pkg/middleware"
"github.com/oullin/pkg/portal"
)

func TestSignatureRoute_PublicMiddleware(t *testing.T) {
r := Router{
Mux: http.NewServeMux(),
Pipeline: middleware.Pipeline{
PublicMiddleware: middleware.MakePublicMiddleware("", false),
},
validator: portal.GetDefaultValidator(),
}
r.Signature()

t.Run("request without public headers is unauthorized", func(t *testing.T) {
req := httptest.NewRequest("POST", "/generate-signature", nil)
rec := httptest.NewRecorder()
r.Mux.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rec.Code)
}
})

t.Run("request with public headers but invalid body is bad request", func(t *testing.T) {
req := httptest.NewRequest("POST", "/generate-signature", strings.NewReader("{"))
req.Header.Set(portal.RequestIDHeader, "req-1")
req.Header.Set(portal.TimestampHeader, fmt.Sprintf("%d", time.Now().Unix()))
rec := httptest.NewRecorder()
r.Mux.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected status %d, got %d", http.StatusBadRequest, rec.Code)
}
})

t.Run("production rejects requests from non-whitelisted IP", func(t *testing.T) {
r := Router{
Mux: http.NewServeMux(),
Pipeline: middleware.Pipeline{
PublicMiddleware: middleware.MakePublicMiddleware("31.97.60.190", true),
},
validator: portal.GetDefaultValidator(),
}
r.Signature()

req := httptest.NewRequest("POST", "/generate-signature", strings.NewReader("{"))
req.Header.Set(portal.RequestIDHeader, "req-1")
req.Header.Set(portal.TimestampHeader, fmt.Sprintf("%d", time.Now().Unix()))
req.Header.Set("X-Forwarded-For", "1.2.3.4")
rec := httptest.NewRecorder()
r.Mux.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rec.Code)
}
})
}
7 changes: 4 additions & 3 deletions pkg/middleware/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ import (
)

type Pipeline struct {
Env *env.Environment
ApiKeys *repository.ApiKeys
TokenHandler *auth.TokenHandler
Env *env.Environment
ApiKeys *repository.ApiKeys
TokenHandler *auth.TokenHandler
PublicMiddleware PublicMiddleware
}

func (m Pipeline) Chain(h http.ApiHandler, handlers ...http.Middleware) http.ApiHandler {
Expand Down
50 changes: 39 additions & 11 deletions pkg/middleware/public_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,50 +24,60 @@ type PublicMiddleware struct {
rateLimiter *limiter.MemoryLimiter
requestCache *cache.TTLCache
now func() time.Time
allowedIP string
isProduction bool
}

// MakePublicMiddleware constructs a PublicMiddleware with sane defaults.
func MakePublicMiddleware() PublicMiddleware {
func MakePublicMiddleware(allowedIP string, isProduction bool) PublicMiddleware {
return PublicMiddleware{
clockSkew: 5 * time.Minute,
disallowFuture: true,
requestTTL: 5 * time.Minute,
rateLimiter: limiter.NewMemoryLimiter(1*time.Minute, 10),
requestCache: cache.NewTTLCache(),
now: time.Now,
allowedIP: strings.TrimSpace(allowedIP),
isProduction: isProduction,
}
}

func (p PublicMiddleware) Handle(next http.ApiHandler) http.ApiHandler {
return func(w baseHttp.ResponseWriter, r *baseHttp.Request) *http.ApiError {
if err := p.guardDependencies(); err != nil {
if err := p.GuardDependencies(); err != nil {
return err
}

reqID := strings.TrimSpace(r.Header.Get(portal.RequestIDHeader))
uri := portal.GenerateURL(r)
ts := strings.TrimSpace(r.Header.Get(portal.TimestampHeader))
reqID := strings.TrimSpace(r.Header.Get(portal.RequestIDHeader))

if reqID == "" || ts == "" {
return mwguards.InvalidRequestError("Invalid authentication headers", "")
}

ip := portal.ParseClientIP(r)
if ip == "" {
return mwguards.InvalidRequestError("Invalid client IP", "")
}
limiterKey := strings.Join([]string{uri, reqID, ts}, "|")

limiterKey := ip
if p.rateLimiter.TooMany(limiterKey) {
return mwguards.RateLimitedError("Too many requests", "Too many requests for key: "+limiterKey)
}

if err := p.HasInvalidIP(r); err != nil {
p.rateLimiter.Fail(limiterKey)

return err
}

vt := NewValidTimestamp(ts, p.now)
if err := vt.Validate(p.clockSkew, p.disallowFuture); err != nil {
p.rateLimiter.Fail(limiterKey)

return err
}

key := strings.Join([]string{limiterKey, reqID, ip}, "|")
key := strings.Join([]string{limiterKey, reqID}, "|")
if p.requestCache.UseOnce(key, p.requestTTL) {
p.rateLimiter.Fail(limiterKey)

return mwguards.UnauthenticatedError(
"Invalid request id",
"duplicate request id: "+key,
Expand All @@ -79,17 +89,35 @@ func (p PublicMiddleware) Handle(next http.ApiHandler) http.ApiHandler {
}
}

func (p PublicMiddleware) guardDependencies() *http.ApiError {
func (p PublicMiddleware) HasInvalidIP(r *baseHttp.Request) *http.ApiError {
ip := portal.ParseClientIP(r)

if ip == "" {
return mwguards.InvalidRequestError("Clients IPs are required to access this endpoint", "")
}

if p.isProduction && ip != p.allowedIP {
return mwguards.InvalidRequestError("The given IP is not allowed", "unauthorised ip: "+ip)
}

return nil
}

func (p PublicMiddleware) GuardDependencies() *http.ApiError {
missing := []string{}

if p.requestCache == nil {
missing = append(missing, "requestCache")
}

if p.rateLimiter == nil {
missing = append(missing, "rateLimiter")
}

if len(missing) > 0 {
err := fmt.Errorf("public middleware missing dependencies: %s", strings.Join(missing, ","))
return http.LogInternalError("public middleware missing dependencies", err)
}

return nil
}
Loading
Loading