Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions handler/signatures.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package handler
import (
"encoding/json"
"fmt"
"io"
"log/slog"
baseHttp "net/http"
"time"
Expand Down Expand Up @@ -32,17 +31,15 @@ func MakeSignaturesHandler(validator *portal.Validator, ApiKeys *repository.ApiK
func (s *SignaturesHandler) Generate(w baseHttp.ResponseWriter, r *baseHttp.Request) *http.ApiError {
defer portal.CloseWithLog(r.Body)

var err error
var bodyBytes []byte

bodyBytes, err = io.ReadAll(r.Body)

if err != nil {
return http.LogBadRequestError("could not read signatures request body", err)
}
var (
err error
req payload.SignatureRequest
)

var req payload.SignatureRequest
if err = json.Unmarshal(bodyBytes, &req); err != nil {
r.Body = baseHttp.MaxBytesReader(w, r.Body, http.MaxRequestSize)
dec := json.NewDecoder(r.Body)
dec.DisallowUnknownFields()
if err = dec.Decode(&req); err != nil {
return http.LogBadRequestError("could not parse the given data.", err)
}

Expand Down Expand Up @@ -77,7 +74,7 @@ func (s *SignaturesHandler) Generate(w baseHttp.ResponseWriter, r *baseHttp.Requ

if err = resp.RespondOk(response); err != nil {
slog.Error("Error marshaling JSON for signatures response", "error", err)
return nil
return http.LogInternalError("could not encode signatures response", err)
}

return nil // A nil return indicates success.
Expand Down
47 changes: 47 additions & 0 deletions handler/signatures_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package handler

import (
"fmt"
nethttp "net/http"
"net/http/httptest"
"strings"
"testing"
"time"

apih "github.com/oullin/pkg/http"
"github.com/oullin/pkg/portal"
)

func TestSignaturesHandlerGenerate_ParseError(t *testing.T) {
h := SignaturesHandler{Validator: portal.GetDefaultValidator()}
req := httptest.NewRequest("POST", "/signatures", strings.NewReader("{"))
rec := httptest.NewRecorder()
if err := h.Generate(rec, req); err == nil || err.Status != nethttp.StatusBadRequest {
t.Fatalf("expected parse error, got %#v", err)
}
}

func TestSignaturesHandlerGenerate_UnknownField(t *testing.T) {
h := SignaturesHandler{Validator: portal.GetDefaultValidator()}
body := fmt.Sprintf(`{"nonce":"%s","public_key":"%s","username":"%s","timestamp":%d,"extra":"nope"}`,
strings.Repeat("a", 32),
strings.Repeat("b", 64),
"validuser",
time.Now().Unix(),
)
req := httptest.NewRequest("POST", "/signatures", strings.NewReader(body))
rec := httptest.NewRecorder()
if err := h.Generate(rec, req); err == nil || err.Status != nethttp.StatusBadRequest {
t.Fatalf("expected unknown field error, got %#v", err)
}
}

func TestSignaturesHandlerGenerate_BodyTooLarge(t *testing.T) {
h := SignaturesHandler{Validator: portal.GetDefaultValidator()}
large := `{"nonce":"` + strings.Repeat("a", apih.MaxRequestSize+1) + `"}`
req := httptest.NewRequest("POST", "/signatures", strings.NewReader(large))
rec := httptest.NewRecorder()
if err := h.Generate(rec, req); err == nil || err.Status != nethttp.StatusBadRequest {
t.Fatalf("expected body too large error, got %#v", err)
}
}
2 changes: 1 addition & 1 deletion pkg/middleware/token_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ func (t TokenCheckMiddleware) HasInvalidSignature(headers AuthTokenHeaders, apiK
Key: apiKey,
Signature: byteSignature,
Origin: headers.IntendedOriginURL,
ServerTime: time.Now(),
ServerTime: t.now(),
}

signature := t.ApiKeys.FindSignatureFrom(entity)
Expand Down
31 changes: 28 additions & 3 deletions pkg/middleware/token_middleware_additional_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func TestTokenMiddleware_SignatureMismatch(t *testing.T) {
handler := tm.Handle(next)

req := makeSignedRequest(t, http.MethodPost, "https://api.test.local/v1/x", "body", seed.AccountName, seed.PublicKey, seed.PublicKey, time.Now(), "nonce-sig", "req-sig")
seedSignature(t, repo, key, req)
seedSignature(t, repo, key, req, time.Now().Add(time.Hour))
req.Header.Set("X-Forwarded-For", "1.1.1.1")

// mutate signature while keeping valid hex encoding
Expand Down Expand Up @@ -152,7 +152,7 @@ func TestTokenMiddleware_NonceReplay(t *testing.T) {
handler := tm.Handle(next)

req := makeSignedRequest(t, http.MethodPost, "https://api.test.local/v1/x", "{}", seed.AccountName, seed.PublicKey, seed.PublicKey, time.Now(), "nonce-rp", "req-rp")
seedSignature(t, repo, key, req)
seedSignature(t, repo, key, req, time.Now().Add(time.Hour))
req.Header.Set("X-Forwarded-For", "1.1.1.1")
rec := httptest.NewRecorder()
if err := handler(rec, req); err != nil {
Expand Down Expand Up @@ -196,7 +196,7 @@ func TestTokenMiddleware_RateLimiter(t *testing.T) {
seed.AccountName, seed.PublicKey, seed.PublicKey, time.Now(),
"nonce-rl-final", "req-rl-final",
)
seedSignature(t, repo, key, req)
seedSignature(t, repo, key, req, time.Now().Add(time.Hour))
req.Header.Set("X-Forwarded-For", "9.9.9.9")
rec := httptest.NewRecorder()
err := handler(rec, req)
Expand All @@ -208,3 +208,28 @@ func TestTokenMiddleware_RateLimiter(t *testing.T) {
t.Fatalf("expected next not to be invoked when rate limited, got %d calls", nextCalled)
}
}

func TestTokenMiddleware_CustomClockValidatesSignature(t *testing.T) {
repo, th, seed, key := makeRepo(t, "clock")
tm := MakeTokenMiddleware(th, repo)
tm.clockSkew = time.Minute
past := time.Now().Add(-10 * time.Minute)
tm.now = func() time.Time { return past }

nextCalled := false
handler := tm.Handle(func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError {
nextCalled = true
return nil
})

req := makeSignedRequest(t, http.MethodGet, "https://api.test.local/v1/clock", "", seed.AccountName, seed.PublicKey, seed.PublicKey, past, "nonce-clock", "req-clock")
seedSignature(t, repo, key, req, past.Add(5*time.Minute))
req.Header.Set("X-Forwarded-For", "1.1.1.1")
rec := httptest.NewRecorder()
if err := handler(rec, req); err != nil {
t.Fatalf("expected success with injected clock, got %#v", err)
}
if !nextCalled {
t.Fatalf("expected next to be called")
}
}
9 changes: 5 additions & 4 deletions pkg/middleware/token_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ func makeSignedRequest(t *testing.T, method, rawURL, body, account, public, sign
}

// seedSignature stores the request signature for the given API key in the repository.
func seedSignature(t *testing.T, repo *repository.ApiKeys, key *database.APIKey, req *http.Request) {
// expiresAt allows tests to control the validity window for the stored signature.
func seedSignature(t *testing.T, repo *repository.ApiKeys, key *database.APIKey, req *http.Request, expiresAt time.Time) {
t.Helper()
sigHex := req.Header.Get("X-API-Signature")
sigBytes, err := hex.DecodeString(sigHex)
Expand All @@ -188,7 +189,7 @@ func seedSignature(t *testing.T, repo *repository.ApiKeys, key *database.APIKey,
Key: key,
Seed: sigBytes,
Origin: req.Header.Get("X-API-Intended-Origin"),
ExpiresAt: time.Now().Add(time.Hour),
ExpiresAt: expiresAt,
})
if err != nil {
t.Skipf("create signature: %v", err)
Expand Down Expand Up @@ -244,7 +245,7 @@ func TestTokenMiddleware_DB_Integration(t *testing.T) {
"nonce-1",
"req-001",
)
seedSignature(t, repo, apiKey, req)
seedSignature(t, repo, apiKey, req, time.Now().Add(time.Hour))
rec := httptest.NewRecorder()
if err := handler(rec, req); err != nil {
t.Fatalf("expected success, got error: %#v", err)
Expand Down Expand Up @@ -323,7 +324,7 @@ func TestTokenMiddleware_DB_Integration_HappyPath(t *testing.T) {
"n-happy-1",
"rid-happy-1",
)
seedSignature(t, repo, apiKey, req)
seedSignature(t, repo, apiKey, req, time.Now().Add(time.Hour))
rec := httptest.NewRecorder()
if err := handler(rec, req); err != nil {
t.Fatalf("happy path failed: %#v", err)
Expand Down
Loading