From d12cedd996c35e77475407d6fcd804b271dfe83d Mon Sep 17 00:00:00 2001 From: Gus Date: Mon, 8 Sep 2025 15:14:06 +0800 Subject: [PATCH 1/3] refactor: improve signature handling --- handler/signatures.go | 18 +++++++----------- pkg/middleware/token_middleware.go | 2 +- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/handler/signatures.go b/handler/signatures.go index 6afee908..a0003729 100644 --- a/handler/signatures.go +++ b/handler/signatures.go @@ -32,17 +32,13 @@ func MakeSignaturesHandler(validator *portal.Validator, ApiKeys *repository.ApiK func (s *SignaturesHandler) Generate(w baseHttp.ResponseWriter, r *baseHttp.Request) *http.ApiError { defer portal.CloseWithLog(r.Body) - var err error - var bodyBytes []byte - - bodyBytes, err = io.ReadAll(r.Body) - - if err != nil { - return http.LogBadRequestError("could not read signatures request body", err) - } + var ( + err error + req payload.SignatureRequest + ) - var req payload.SignatureRequest - if err = json.Unmarshal(bodyBytes, &req); err != nil { + limited := io.LimitReader(r.Body, http.MaxRequestSize) + if err = json.NewDecoder(limited).Decode(&req); err != nil { return http.LogBadRequestError("could not parse the given data.", err) } @@ -77,7 +73,7 @@ func (s *SignaturesHandler) Generate(w baseHttp.ResponseWriter, r *baseHttp.Requ if err = resp.RespondOk(response); err != nil { slog.Error("Error marshaling JSON for signatures response", "error", err) - return nil + return http.LogInternalError("could not encode signatures response", err) } return nil // A nil return indicates success. diff --git a/pkg/middleware/token_middleware.go b/pkg/middleware/token_middleware.go index 58620fd8..462ea2d9 100644 --- a/pkg/middleware/token_middleware.go +++ b/pkg/middleware/token_middleware.go @@ -215,7 +215,7 @@ func (t TokenCheckMiddleware) HasInvalidSignature(headers AuthTokenHeaders, apiK Key: apiKey, Signature: byteSignature, Origin: headers.IntendedOriginURL, - ServerTime: time.Now(), + ServerTime: t.now(), } signature := t.ApiKeys.FindSignatureFrom(entity) From 7c3ec44d8481996158c7c44afc713afee4e8a5cf Mon Sep 17 00:00:00 2001 From: Gus Date: Mon, 8 Sep 2025 15:36:24 +0800 Subject: [PATCH 2/3] test: cover signatures handler parsing and middleware clock --- handler/signatures_test.go | 19 ++++++++++++ .../token_middleware_additional_test.go | 31 +++++++++++++++++-- pkg/middleware/token_middleware_test.go | 9 +++--- 3 files changed, 52 insertions(+), 7 deletions(-) create mode 100644 handler/signatures_test.go diff --git a/handler/signatures_test.go b/handler/signatures_test.go new file mode 100644 index 00000000..178c57eb --- /dev/null +++ b/handler/signatures_test.go @@ -0,0 +1,19 @@ +package handler + +import ( + nethttp "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/oullin/pkg/portal" +) + +func TestSignaturesHandlerGenerate_ParseError(t *testing.T) { + h := SignaturesHandler{Validator: portal.GetDefaultValidator()} + req := httptest.NewRequest("POST", "/signatures", strings.NewReader("{")) + rec := httptest.NewRecorder() + if err := h.Generate(rec, req); err == nil || err.Status != nethttp.StatusBadRequest { + t.Fatalf("expected parse error, got %#v", err) + } +} diff --git a/pkg/middleware/token_middleware_additional_test.go b/pkg/middleware/token_middleware_additional_test.go index a9e7c282..edffd390 100644 --- a/pkg/middleware/token_middleware_additional_test.go +++ b/pkg/middleware/token_middleware_additional_test.go @@ -122,7 +122,7 @@ func TestTokenMiddleware_SignatureMismatch(t *testing.T) { handler := tm.Handle(next) req := makeSignedRequest(t, http.MethodPost, "https://api.test.local/v1/x", "body", seed.AccountName, seed.PublicKey, seed.PublicKey, time.Now(), "nonce-sig", "req-sig") - seedSignature(t, repo, key, req) + seedSignature(t, repo, key, req, time.Now().Add(time.Hour)) req.Header.Set("X-Forwarded-For", "1.1.1.1") // mutate signature while keeping valid hex encoding @@ -152,7 +152,7 @@ func TestTokenMiddleware_NonceReplay(t *testing.T) { handler := tm.Handle(next) req := makeSignedRequest(t, http.MethodPost, "https://api.test.local/v1/x", "{}", seed.AccountName, seed.PublicKey, seed.PublicKey, time.Now(), "nonce-rp", "req-rp") - seedSignature(t, repo, key, req) + seedSignature(t, repo, key, req, time.Now().Add(time.Hour)) req.Header.Set("X-Forwarded-For", "1.1.1.1") rec := httptest.NewRecorder() if err := handler(rec, req); err != nil { @@ -196,7 +196,7 @@ func TestTokenMiddleware_RateLimiter(t *testing.T) { seed.AccountName, seed.PublicKey, seed.PublicKey, time.Now(), "nonce-rl-final", "req-rl-final", ) - seedSignature(t, repo, key, req) + seedSignature(t, repo, key, req, time.Now().Add(time.Hour)) req.Header.Set("X-Forwarded-For", "9.9.9.9") rec := httptest.NewRecorder() err := handler(rec, req) @@ -208,3 +208,28 @@ func TestTokenMiddleware_RateLimiter(t *testing.T) { t.Fatalf("expected next not to be invoked when rate limited, got %d calls", nextCalled) } } + +func TestTokenMiddleware_CustomClockValidatesSignature(t *testing.T) { + repo, th, seed, key := makeRepo(t, "clock") + tm := MakeTokenMiddleware(th, repo) + tm.clockSkew = time.Minute + past := time.Now().Add(-10 * time.Minute) + tm.now = func() time.Time { return past } + + nextCalled := false + handler := tm.Handle(func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError { + nextCalled = true + return nil + }) + + req := makeSignedRequest(t, http.MethodGet, "https://api.test.local/v1/clock", "", seed.AccountName, seed.PublicKey, seed.PublicKey, past, "nonce-clock", "req-clock") + seedSignature(t, repo, key, req, past.Add(5*time.Minute)) + req.Header.Set("X-Forwarded-For", "1.1.1.1") + rec := httptest.NewRecorder() + if err := handler(rec, req); err != nil { + t.Fatalf("expected success with injected clock, got %#v", err) + } + if !nextCalled { + t.Fatalf("expected next to be called") + } +} diff --git a/pkg/middleware/token_middleware_test.go b/pkg/middleware/token_middleware_test.go index 3416753f..d71c429c 100644 --- a/pkg/middleware/token_middleware_test.go +++ b/pkg/middleware/token_middleware_test.go @@ -177,7 +177,8 @@ func makeSignedRequest(t *testing.T, method, rawURL, body, account, public, sign } // seedSignature stores the request signature for the given API key in the repository. -func seedSignature(t *testing.T, repo *repository.ApiKeys, key *database.APIKey, req *http.Request) { +// expiresAt allows tests to control the validity window for the stored signature. +func seedSignature(t *testing.T, repo *repository.ApiKeys, key *database.APIKey, req *http.Request, expiresAt time.Time) { t.Helper() sigHex := req.Header.Get("X-API-Signature") sigBytes, err := hex.DecodeString(sigHex) @@ -188,7 +189,7 @@ func seedSignature(t *testing.T, repo *repository.ApiKeys, key *database.APIKey, Key: key, Seed: sigBytes, Origin: req.Header.Get("X-API-Intended-Origin"), - ExpiresAt: time.Now().Add(time.Hour), + ExpiresAt: expiresAt, }) if err != nil { t.Skipf("create signature: %v", err) @@ -244,7 +245,7 @@ func TestTokenMiddleware_DB_Integration(t *testing.T) { "nonce-1", "req-001", ) - seedSignature(t, repo, apiKey, req) + seedSignature(t, repo, apiKey, req, time.Now().Add(time.Hour)) rec := httptest.NewRecorder() if err := handler(rec, req); err != nil { t.Fatalf("expected success, got error: %#v", err) @@ -323,7 +324,7 @@ func TestTokenMiddleware_DB_Integration_HappyPath(t *testing.T) { "n-happy-1", "rid-happy-1", ) - seedSignature(t, repo, apiKey, req) + seedSignature(t, repo, apiKey, req, time.Now().Add(time.Hour)) rec := httptest.NewRecorder() if err := handler(rec, req); err != nil { t.Fatalf("happy path failed: %#v", err) From 1fbac2ff599a89193e4615045e3758f78f989081 Mon Sep 17 00:00:00 2001 From: Gus Date: Mon, 8 Sep 2025 15:41:13 +0800 Subject: [PATCH 3/3] Use MaxBytesReader and strict JSON decoding in signatures handler --- handler/signatures.go | 7 ++++--- handler/signatures_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/handler/signatures.go b/handler/signatures.go index a0003729..00db5800 100644 --- a/handler/signatures.go +++ b/handler/signatures.go @@ -3,7 +3,6 @@ package handler import ( "encoding/json" "fmt" - "io" "log/slog" baseHttp "net/http" "time" @@ -37,8 +36,10 @@ func (s *SignaturesHandler) Generate(w baseHttp.ResponseWriter, r *baseHttp.Requ req payload.SignatureRequest ) - limited := io.LimitReader(r.Body, http.MaxRequestSize) - if err = json.NewDecoder(limited).Decode(&req); err != nil { + r.Body = baseHttp.MaxBytesReader(w, r.Body, http.MaxRequestSize) + dec := json.NewDecoder(r.Body) + dec.DisallowUnknownFields() + if err = dec.Decode(&req); err != nil { return http.LogBadRequestError("could not parse the given data.", err) } diff --git a/handler/signatures_test.go b/handler/signatures_test.go index 178c57eb..7a9ff2aa 100644 --- a/handler/signatures_test.go +++ b/handler/signatures_test.go @@ -1,11 +1,14 @@ package handler import ( + "fmt" nethttp "net/http" "net/http/httptest" "strings" "testing" + "time" + apih "github.com/oullin/pkg/http" "github.com/oullin/pkg/portal" ) @@ -17,3 +20,28 @@ func TestSignaturesHandlerGenerate_ParseError(t *testing.T) { t.Fatalf("expected parse error, got %#v", err) } } + +func TestSignaturesHandlerGenerate_UnknownField(t *testing.T) { + h := SignaturesHandler{Validator: portal.GetDefaultValidator()} + body := fmt.Sprintf(`{"nonce":"%s","public_key":"%s","username":"%s","timestamp":%d,"extra":"nope"}`, + strings.Repeat("a", 32), + strings.Repeat("b", 64), + "validuser", + time.Now().Unix(), + ) + req := httptest.NewRequest("POST", "/signatures", strings.NewReader(body)) + rec := httptest.NewRecorder() + if err := h.Generate(rec, req); err == nil || err.Status != nethttp.StatusBadRequest { + t.Fatalf("expected unknown field error, got %#v", err) + } +} + +func TestSignaturesHandlerGenerate_BodyTooLarge(t *testing.T) { + h := SignaturesHandler{Validator: portal.GetDefaultValidator()} + large := `{"nonce":"` + strings.Repeat("a", apih.MaxRequestSize+1) + `"}` + req := httptest.NewRequest("POST", "/signatures", strings.NewReader(large)) + rec := httptest.NewRecorder() + if err := h.Generate(rec, req); err == nil || err.Status != nethttp.StatusBadRequest { + t.Fatalf("expected body too large error, got %#v", err) + } +}