Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
go-version: ['1.24.6', '1.24.5']
go-version: ['1.25.1', '1.24.6']

steps:
- uses: actions/setup-go@v5
Expand Down
35 changes: 21 additions & 14 deletions pkg/middleware/mwguards/mw_response_messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,83 +36,90 @@ func normaliseMessages(message, logMessage string) (string, string) {
func InvalidRequestError(message, logMessage string, data ...map[string]any) *http.ApiError {
message, logMessage = normaliseMessages(message, logMessage)

slog.Error(logMessage, "error")
d := normaliseData(data...)
slog.Error(logMessage, "data", d)

return &http.ApiError{
Message: message,
Status: baseHttp.StatusUnauthorized,
Data: normaliseData(data...),
Data: d,
}
}

func InvalidTokenFormatError(message, logMessage string, data ...map[string]any) *http.ApiError {
message, logMessage = normaliseMessages(message, logMessage)

slog.Error(logMessage, "error")
d := normaliseData(data...)
slog.Error(logMessage, "data", d)

return &http.ApiError{
Message: message,
Status: baseHttp.StatusUnauthorized,
Data: normaliseData(data...),
Data: d,
}
}

func UnauthenticatedError(message, logMessage string, data ...map[string]any) *http.ApiError {
message, logMessage = normaliseMessages(message, logMessage)

slog.Error(logMessage, "error")
d := normaliseData(data...)
slog.Error(logMessage, "data", d)

return &http.ApiError{
Message: "2- Invalid credentials: " + logMessage,
Status: baseHttp.StatusUnauthorized,
Data: normaliseData(data...),
Data: d,
}
}

func RateLimitedError(message, logMessage string, data ...map[string]any) *http.ApiError {
message, logMessage = normaliseMessages(message, logMessage)

slog.Error(logMessage, "error")
d := normaliseData(data...)
slog.Error(logMessage, "data", d)

return &http.ApiError{
Message: "Too many authentication attempts",
Status: baseHttp.StatusTooManyRequests,
Data: normaliseData(data...),
Data: d,
}
}

func NotFound(message, logMessage string, data ...map[string]any) *http.ApiError {
message, logMessage = normaliseMessages(message, logMessage)

slog.Error(logMessage, "error")
d := normaliseData(data...)
slog.Error(logMessage, "data", d)

return &http.ApiError{
Message: message,
Status: baseHttp.StatusNotFound,
Data: normaliseData(data...),
Data: d,
}
}

func TimestampTooOldError(message, logMessage string, data ...map[string]any) *http.ApiError {
message, logMessage = normaliseMessages(message, logMessage)

slog.Error(logMessage, "error")
d := normaliseData(data...)
slog.Error(logMessage, "data", d)

return &http.ApiError{
Message: "Request timestamp expired",
Status: baseHttp.StatusUnauthorized,
Data: normaliseData(data...),
Data: d,
}
}

func TimestampTooNewError(message, logMessage string, data ...map[string]any) *http.ApiError {
message, logMessage = normaliseMessages(message, logMessage)

slog.Error(logMessage, "error")
d := normaliseData(data...)
slog.Error(logMessage, "data", d)

return &http.ApiError{
Message: "Request timestamp invalid",
Status: baseHttp.StatusUnauthorized,
Data: normaliseData(data...),
Data: d,
}
}
58 changes: 39 additions & 19 deletions pkg/middleware/token_middleware_additional_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package middleware

import (
"context"
"encoding/hex"
"fmt"
"net/http"
"net/http/httptest"
Expand All @@ -22,7 +23,7 @@ import (
)

// makeRepo creates a temporary postgres repo with a seeded API key
func makeRepo(t *testing.T, account string) (*repository.ApiKeys, *auth.TokenHandler, *auth.Token) {
func makeRepo(t *testing.T, account string) (*repository.ApiKeys, *auth.TokenHandler, *auth.Token, *database.APIKey) {
t.Helper()
testcontainers.SkipIfProviderIsNotHealthy(t)
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
Expand All @@ -45,15 +46,22 @@ func makeRepo(t *testing.T, account string) (*repository.ApiKeys, *auth.TokenHan
if err != nil {
t.Skipf("connection string: %v", err)
}
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
var db *gorm.DB
for i := 0; i < 10; i++ {
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err == nil {
break
}
time.Sleep(time.Second)
}
if err != nil {
t.Skipf("gorm open: %v", err)
}
if sqlDB, err := db.DB(); err == nil {
t.Cleanup(func() { _ = sqlDB.Close() })
}
if err := db.AutoMigrate(&database.APIKey{}); err != nil {
t.Fatalf("migrate: %v", err)
if err := db.AutoMigrate(&database.APIKey{}, &database.APIKeySignatures{}); err != nil {
t.Skipf("migrate: %v", err)
}
th, err := auth.MakeTokensHandler(generate32(t))
if err != nil {
Expand All @@ -63,35 +71,36 @@ func makeRepo(t *testing.T, account string) (*repository.ApiKeys, *auth.TokenHan
if err != nil {
t.Fatalf("SetupNewAccount: %v", err)
}
if err := db.Create(&database.APIKey{
key := database.APIKey{
UUID: uuid.NewString(),
AccountName: seed.AccountName,
PublicKey: seed.EncryptedPublicKey,
SecretKey: seed.EncryptedSecretKey,
}).Error; err != nil {
t.Fatalf("seed api key: %v", err)
}
if err := db.Create(&key).Error; err != nil {
t.Skipf("seed api key: %v", err)
}
conn := database.NewConnectionFromGorm(db)
repo := &repository.ApiKeys{DB: conn}
return repo, th, seed
return repo, th, seed, &key
}

func TestTokenMiddlewareGuardDependencies(t *testing.T) {
logger := slogNoop()
tm := TokenCheckMiddleware{}
if err := tm.GuardDependencies(logger); err == nil || err.Status != http.StatusUnauthorized {
if err := tm.GuardDependencies(); err == nil || err.Status != http.StatusUnauthorized {
t.Fatalf("expected unauthorized when dependencies missing")
}
tm.ApiKeys, tm.TokenHandler, _ = makeRepo(t, "guard1")
repo, th, _, _ := makeRepo(t, "guard1")
tm.ApiKeys, tm.TokenHandler = repo, th
tm.nonceCache = cache.NewTTLCache()
tm.rateLimiter = limiter.NewMemoryLimiter(time.Minute, 1)
if err := tm.GuardDependencies(logger); err != nil {
if err := tm.GuardDependencies(); err != nil {
t.Fatalf("expected no error when dependencies provided, got %#v", err)
}
}

func TestTokenMiddleware_PublicTokenMismatch(t *testing.T) {
repo, th, seed := makeRepo(t, "mismatch")
repo, th, seed, _ := makeRepo(t, "mismatch")
tm := MakeTokenMiddleware(th, repo)
tm.clockSkew = time.Minute
next := func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError { return nil }
Expand All @@ -106,23 +115,32 @@ func TestTokenMiddleware_PublicTokenMismatch(t *testing.T) {
}

func TestTokenMiddleware_SignatureMismatch(t *testing.T) {
repo, th, seed := makeRepo(t, "siggy")
repo, th, seed, key := makeRepo(t, "siggy")
tm := MakeTokenMiddleware(th, repo)
tm.clockSkew = time.Minute
next := func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError { return nil }
handler := tm.Handle(next)

req := makeSignedRequest(t, http.MethodPost, "https://api.test.local/v1/x", "body", seed.AccountName, seed.PublicKey, seed.PublicKey, time.Now(), "nonce-sig", "req-sig")
seedSignature(t, repo, key, req)
req.Header.Set("X-Forwarded-For", "1.1.1.1")
req.Header.Set("X-API-Signature", req.Header.Get("X-API-Signature")+"tamper")

// mutate signature while keeping valid hex encoding
sigHex := req.Header.Get("X-API-Signature")
sigBytes, err := hex.DecodeString(sigHex)
if err != nil {
t.Fatalf("decode signature: %v", err)
}
sigBytes[0] ^= 0xFF
req.Header.Set("X-API-Signature", hex.EncodeToString(sigBytes))
rec := httptest.NewRecorder()
if err := handler(rec, req); err == nil || err.Status != http.StatusUnauthorized {
t.Fatalf("expected unauthorized for signature mismatch, got %#v", err)
if err := handler(rec, req); err == nil || err.Status != http.StatusNotFound {
t.Fatalf("expected not found for signature mismatch, got %#v", err)
}
}

func TestTokenMiddleware_NonceReplay(t *testing.T) {
repo, th, seed := makeRepo(t, "replay")
repo, th, seed, key := makeRepo(t, "replay")
tm := MakeTokenMiddleware(th, repo)
tm.clockSkew = time.Minute
tm.nonceTTL = time.Minute
Expand All @@ -134,6 +152,7 @@ func TestTokenMiddleware_NonceReplay(t *testing.T) {
handler := tm.Handle(next)

req := makeSignedRequest(t, http.MethodPost, "https://api.test.local/v1/x", "{}", seed.AccountName, seed.PublicKey, seed.PublicKey, time.Now(), "nonce-rp", "req-rp")
seedSignature(t, repo, key, req)
req.Header.Set("X-Forwarded-For", "1.1.1.1")
rec := httptest.NewRecorder()
if err := handler(rec, req); err != nil {
Expand All @@ -149,7 +168,7 @@ func TestTokenMiddleware_NonceReplay(t *testing.T) {
}

func TestTokenMiddleware_RateLimiter(t *testing.T) {
repo, th, seed := makeRepo(t, "ratey")
repo, th, seed, key := makeRepo(t, "ratey")
tm := MakeTokenMiddleware(th, repo)
tm.clockSkew = time.Minute
nextCalled := 0
Expand Down Expand Up @@ -177,6 +196,7 @@ func TestTokenMiddleware_RateLimiter(t *testing.T) {
seed.AccountName, seed.PublicKey, seed.PublicKey, time.Now(),
"nonce-rl-final", "req-rl-final",
)
seedSignature(t, repo, key, req)
req.Header.Set("X-Forwarded-For", "9.9.9.9")
rec := httptest.NewRecorder()
err := handler(rec, req)
Expand Down
Loading
Loading