From d7bc2b6002743917c0e1291412038f72f07bdb6e Mon Sep 17 00:00:00 2001 From: Gus Date: Mon, 11 Aug 2025 13:09:38 +0800 Subject: [PATCH 1/9] Add tests for TTL cache UseOnce and pruning --- pkg/cache/ttl_cache_useonce_test.go | 36 +++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 pkg/cache/ttl_cache_useonce_test.go diff --git a/pkg/cache/ttl_cache_useonce_test.go b/pkg/cache/ttl_cache_useonce_test.go new file mode 100644 index 00000000..cca463ac --- /dev/null +++ b/pkg/cache/ttl_cache_useonce_test.go @@ -0,0 +1,36 @@ +package cache + +import ( + "testing" + "time" +) + +// TestTTLCache_UseOnce verifies the behavior of UseOnce for first use, +// repeated use before expiry and reuse after the TTL has elapsed. +func TestTTLCache_UseOnce(t *testing.T) { + c := NewTTLCache() + key := "nonce" + + if used := c.UseOnce(key, 50*time.Millisecond); used { + t.Fatalf("expected first UseOnce to return false") + } + if used := c.UseOnce(key, 50*time.Millisecond); !used { + t.Fatalf("expected second UseOnce to return true before expiry") + } + time.Sleep(60 * time.Millisecond) + if used := c.UseOnce(key, 50*time.Millisecond); used { + t.Fatalf("expected expired key to be usable again") + } +} + +// TestTTLCache_Mark_PrunesExpiredEntries ensures that calling Mark prunes +// any expired keys in the cache. +func TestTTLCache_Mark_PrunesExpiredEntries(t *testing.T) { + c := NewTTLCache() + c.Mark("old", 10*time.Millisecond) + time.Sleep(20 * time.Millisecond) + c.Mark("new", 10*time.Millisecond) // should prune "old" + if c.Used("old") { + t.Fatalf("expected expired key to be pruned from cache") + } +} From 05a5348f385c281276fa480194b9bfc50f8ab910 Mon Sep 17 00:00:00 2001 From: Gus Date: Mon, 11 Aug 2025 13:36:55 +0800 Subject: [PATCH 2/9] Add token middleware tests to boost coverage --- go.mod | 2 + go.sum | 4 + .../token_middleware_additional_test.go | 145 ++++++++++++++++++ 3 files changed, 151 insertions(+) create mode 100644 pkg/middleware/token_middleware_additional_test.go diff --git a/go.mod b/go.mod index fbcbb591..3e62370b 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( golang.org/x/text v0.27.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/postgres v1.6.0 + gorm.io/driver/sqlite v1.5.7 gorm.io/gorm v1.30.1 ) @@ -54,6 +55,7 @@ require ( github.com/leodido/go-urn v1.4.0 // indirect github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35 // indirect github.com/magiconair/properties v1.8.10 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/go-archive v0.1.0 // indirect github.com/moby/patternmatcher v0.6.0 // indirect diff --git a/go.sum b/go.sum index 28e2059b..0cbb9e0b 100644 --- a/go.sum +++ b/go.sum @@ -95,6 +95,8 @@ github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35 h1:PpXWgLPs+Fqr32 github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35/go.mod h1:autxFIvghDt3jPTLoqZ9OZ7s9qTGNAWmYCjVFWPX/zg= github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= @@ -236,6 +238,8 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4= gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo= +gorm.io/driver/sqlite v1.5.7 h1:8NvsrhP0ifM7LX9G4zPB97NwovUakUxc+2V2uuf3Z1I= +gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4= gorm.io/gorm v1.30.1 h1:lSHg33jJTBxs2mgJRfRZeLDG+WZaHYCk3Wtfl6Ngzo4= gorm.io/gorm v1.30.1/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= diff --git a/pkg/middleware/token_middleware_additional_test.go b/pkg/middleware/token_middleware_additional_test.go new file mode 100644 index 00000000..b1ace429 --- /dev/null +++ b/pkg/middleware/token_middleware_additional_test.go @@ -0,0 +1,145 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" + "time" + "unsafe" + + "github.com/google/uuid" + "github.com/oullin/database" + "github.com/oullin/database/repository" + "github.com/oullin/pkg/auth" + "github.com/oullin/pkg/cache" + pkgHttp "github.com/oullin/pkg/http" + "github.com/oullin/pkg/limiter" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +// makeRepo creates an in-memory sqlite repo with a seeded API key +func makeRepo(t *testing.T, account string) (*repository.ApiKeys, *auth.TokenHandler, *auth.Token) { + t.Helper() + th, err := auth.MakeTokensHandler(generate32(t)) + if err != nil { + t.Fatalf("MakeTokensHandler: %v", err) + } + seed, err := th.SetupNewAccount(account) + if err != nil { + t.Fatalf("SetupNewAccount: %v", err) + } + db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) + if err != nil { + t.Fatalf("gorm open: %v", err) + } + if err := db.AutoMigrate(&database.APIKey{}); err != nil { + t.Fatalf("migrate: %v", err) + } + if err := db.Create(&database.APIKey{ + UUID: uuid.NewString(), + AccountName: seed.AccountName, + PublicKey: seed.EncryptedPublicKey, + SecretKey: seed.EncryptedSecretKey, + }).Error; err != nil { + t.Fatalf("seed api key: %v", err) + } + conn := &database.Connection{} + v := reflect.ValueOf(conn).Elem().FieldByName("driver") + reflect.NewAt(v.Type(), unsafe.Pointer(v.UnsafeAddr())).Elem().Set(reflect.ValueOf(db)) + repo := &repository.ApiKeys{DB: conn} + return repo, th, seed +} + +func TestTokenMiddlewareGuardDependencies(t *testing.T) { + logger := slogNoop() + tm := TokenCheckMiddleware{} + if err := tm.guardDependencies(logger); err == nil || err.Status != http.StatusUnauthorized { + t.Fatalf("expected unauthorized when dependencies missing") + } + tm.ApiKeys, tm.TokenHandler, _ = makeRepo(t, "guard1") + tm.nonceCache = cache.NewTTLCache() + tm.rateLimiter = limiter.NewMemoryLimiter(time.Minute, 1) + if err := tm.guardDependencies(logger); err != nil { + t.Fatalf("expected no error when dependencies provided, got %#v", err) + } +} + +func TestTokenMiddleware_PublicTokenMismatch(t *testing.T) { + repo, th, seed := makeRepo(t, "mismatch") + tm := MakeTokenMiddleware(th, repo) + tm.clockSkew = time.Minute + next := func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError { return nil } + handler := tm.Handle(next) + + req := makeSignedRequest(t, http.MethodGet, "https://api.test.local/v1/x", "", seed.AccountName, "wrong-"+seed.PublicKey, seed.SecretKey, time.Now(), "nonce-mm", "req-mm") + req.Header.Set("X-Forwarded-For", "1.1.1.1") + rec := httptest.NewRecorder() + if err := handler(rec, req); err == nil || err.Status != http.StatusUnauthorized { + t.Fatalf("expected unauthorized for public token mismatch, got %#v", err) + } +} + +func TestTokenMiddleware_SignatureMismatch(t *testing.T) { + repo, th, seed := makeRepo(t, "siggy") + tm := MakeTokenMiddleware(th, repo) + tm.clockSkew = time.Minute + next := func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError { return nil } + handler := tm.Handle(next) + + req := makeSignedRequest(t, http.MethodPost, "https://api.test.local/v1/x", "body", seed.AccountName, seed.PublicKey, seed.SecretKey, time.Now(), "nonce-sig", "req-sig") + req.Header.Set("X-Forwarded-For", "1.1.1.1") + req.Header.Set("X-API-Signature", req.Header.Get("X-API-Signature")+"tamper") + rec := httptest.NewRecorder() + if err := handler(rec, req); err == nil || err.Status != http.StatusUnauthorized { + t.Fatalf("expected unauthorized for signature mismatch, got %#v", err) + } +} + +func TestTokenMiddleware_NonceReplay(t *testing.T) { + repo, th, seed := makeRepo(t, "replay") + tm := MakeTokenMiddleware(th, repo) + tm.clockSkew = time.Minute + tm.nonceTTL = time.Minute + nextCalled := 0 + next := func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError { + nextCalled++ + return nil + } + handler := tm.Handle(next) + + req := makeSignedRequest(t, http.MethodPost, "https://api.test.local/v1/x", "{}", seed.AccountName, seed.PublicKey, seed.SecretKey, time.Now(), "nonce-rp", "req-rp") + req.Header.Set("X-Forwarded-For", "1.1.1.1") + rec := httptest.NewRecorder() + if err := handler(rec, req); err != nil { + t.Fatalf("first call failed: %#v", err) + } + if nextCalled != 1 { + t.Fatalf("expected next called once on first request") + } + rec = httptest.NewRecorder() + if err := handler(rec, req); err == nil || err.Status != http.StatusUnauthorized { + t.Fatalf("expected unauthorized on nonce replay, got %#v", err) + } +} + +func TestTokenMiddleware_RateLimiter(t *testing.T) { + repo, th, seed := makeRepo(t, "ratey") + tm := MakeTokenMiddleware(th, repo) + tm.clockSkew = time.Minute + key := "9.9.9.9|" + strings.ToLower(seed.AccountName) + for i := 0; i < tm.maxFailPerScope; i++ { + tm.rateLimiter.Fail(key) + } + next := func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError { return nil } + handler := tm.Handle(next) + + req := makeSignedRequest(t, http.MethodGet, "https://api.test.local/v1/rl", "", seed.AccountName, seed.PublicKey, seed.SecretKey, time.Now(), "nonce-rl", "req-rl") + req.Header.Set("X-Forwarded-For", "9.9.9.9") + rec := httptest.NewRecorder() + if err := handler(rec, req); err == nil || err.Status != http.StatusTooManyRequests { + t.Fatalf("expected rate limited error, got %#v", err) + } +} From fba0cca39371c14c8b74b999fcfb632b17b74823 Mon Sep 17 00:00:00 2001 From: Gus Date: Mon, 11 Aug 2025 13:36:58 +0800 Subject: [PATCH 3/9] Use testcontainers for token middleware tests --- go.mod | 2 - go.sum | 4 -- .../token_middleware_additional_test.go | 38 +++++++++++++++---- 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/go.mod b/go.mod index 3e62370b..fbcbb591 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,6 @@ require ( golang.org/x/text v0.27.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/postgres v1.6.0 - gorm.io/driver/sqlite v1.5.7 gorm.io/gorm v1.30.1 ) @@ -55,7 +54,6 @@ require ( github.com/leodido/go-urn v1.4.0 // indirect github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35 // indirect github.com/magiconair/properties v1.8.10 // indirect - github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/go-archive v0.1.0 // indirect github.com/moby/patternmatcher v0.6.0 // indirect diff --git a/go.sum b/go.sum index 0cbb9e0b..28e2059b 100644 --- a/go.sum +++ b/go.sum @@ -95,8 +95,6 @@ github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35 h1:PpXWgLPs+Fqr32 github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35/go.mod h1:autxFIvghDt3jPTLoqZ9OZ7s9qTGNAWmYCjVFWPX/zg= github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= -github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= -github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= @@ -238,8 +236,6 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4= gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo= -gorm.io/driver/sqlite v1.5.7 h1:8NvsrhP0ifM7LX9G4zPB97NwovUakUxc+2V2uuf3Z1I= -gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4= gorm.io/gorm v1.30.1 h1:lSHg33jJTBxs2mgJRfRZeLDG+WZaHYCk3Wtfl6Ngzo4= gorm.io/gorm v1.30.1/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= diff --git a/pkg/middleware/token_middleware_additional_test.go b/pkg/middleware/token_middleware_additional_test.go index b1ace429..dec0bf1a 100644 --- a/pkg/middleware/token_middleware_additional_test.go +++ b/pkg/middleware/token_middleware_additional_test.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "net/http" "net/http/httptest" "reflect" @@ -16,28 +17,49 @@ import ( "github.com/oullin/pkg/cache" pkgHttp "github.com/oullin/pkg/http" "github.com/oullin/pkg/limiter" - "gorm.io/driver/sqlite" + "github.com/testcontainers/testcontainers-go" + postgrescontainer "github.com/testcontainers/testcontainers-go/modules/postgres" + "gorm.io/driver/postgres" "gorm.io/gorm" ) -// makeRepo creates an in-memory sqlite repo with a seeded API key +// makeRepo creates a temporary postgres repo with a seeded API key func makeRepo(t *testing.T, account string) (*repository.ApiKeys, *auth.TokenHandler, *auth.Token) { t.Helper() - th, err := auth.MakeTokensHandler(generate32(t)) + testcontainers.SkipIfProviderIsNotHealthy(t) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + pgC, err := postgrescontainer.RunContainer(ctx, + testcontainers.WithImage("postgres:16-alpine"), + postgrescontainer.WithDatabase("testdb"), + postgrescontainer.WithUsername("test"), + postgrescontainer.WithPassword("test"), + ) if err != nil { - t.Fatalf("MakeTokensHandler: %v", err) + t.Skipf("run postgres container: %v", err) } - seed, err := th.SetupNewAccount(account) + t.Cleanup(func() { + _ = pgC.Terminate(ctx) + }) + dsn, err := pgC.ConnectionString(ctx, "sslmode=disable") if err != nil { - t.Fatalf("SetupNewAccount: %v", err) + t.Skipf("connection string: %v", err) } - db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) if err != nil { - t.Fatalf("gorm open: %v", err) + t.Skipf("gorm open: %v", err) } if err := db.AutoMigrate(&database.APIKey{}); err != nil { t.Fatalf("migrate: %v", err) } + th, err := auth.MakeTokensHandler(generate32(t)) + if err != nil { + t.Fatalf("MakeTokensHandler: %v", err) + } + seed, err := th.SetupNewAccount(account) + if err != nil { + t.Fatalf("SetupNewAccount: %v", err) + } if err := db.Create(&database.APIKey{ UUID: uuid.NewString(), AccountName: seed.AccountName, From fd0e279603fd69781234fc3a0e64c8e621a515a9 Mon Sep 17 00:00:00 2001 From: Gus Date: Mon, 11 Aug 2025 13:46:18 +0800 Subject: [PATCH 4/9] test: ensure TTL cache pruning removes expired entries --- pkg/cache/ttl_cache_useonce_test.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pkg/cache/ttl_cache_useonce_test.go b/pkg/cache/ttl_cache_useonce_test.go index cca463ac..e59732a8 100644 --- a/pkg/cache/ttl_cache_useonce_test.go +++ b/pkg/cache/ttl_cache_useonce_test.go @@ -30,7 +30,10 @@ func TestTTLCache_Mark_PrunesExpiredEntries(t *testing.T) { c.Mark("old", 10*time.Millisecond) time.Sleep(20 * time.Millisecond) c.Mark("new", 10*time.Millisecond) // should prune "old" - if c.Used("old") { + + c.mu.Lock() + defer c.mu.Unlock() + if _, ok := c.data["old"]; ok { t.Fatalf("expected expired key to be pruned from cache") } } From 1a9f3de5507e7f49d0c747b2b2606b04c26142d8 Mon Sep 17 00:00:00 2001 From: Gus Date: Mon, 11 Aug 2025 14:14:13 +0800 Subject: [PATCH 5/9] Refactor TTL cache UseOnce test --- pkg/cache/ttl_cache_useonce_test.go | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/pkg/cache/ttl_cache_useonce_test.go b/pkg/cache/ttl_cache_useonce_test.go index e59732a8..6c095fbd 100644 --- a/pkg/cache/ttl_cache_useonce_test.go +++ b/pkg/cache/ttl_cache_useonce_test.go @@ -10,17 +10,26 @@ import ( func TestTTLCache_UseOnce(t *testing.T) { c := NewTTLCache() key := "nonce" + ttl := 50 * time.Millisecond - if used := c.UseOnce(key, 50*time.Millisecond); used { - t.Fatalf("expected first UseOnce to return false") - } - if used := c.UseOnce(key, 50*time.Millisecond); !used { - t.Fatalf("expected second UseOnce to return true before expiry") - } - time.Sleep(60 * time.Millisecond) - if used := c.UseOnce(key, 50*time.Millisecond); used { - t.Fatalf("expected expired key to be usable again") - } + t.Run("first use", func(t *testing.T) { + if used := c.UseOnce(key, ttl); used { + t.Fatalf("expected first UseOnce to return false") + } + }) + + t.Run("second use before expiry", func(t *testing.T) { + if used := c.UseOnce(key, ttl); !used { + t.Fatalf("expected second UseOnce to return true before expiry") + } + }) + + t.Run("use after expiry", func(t *testing.T) { + time.Sleep(ttl + 10*time.Millisecond) + if used := c.UseOnce(key, ttl); used { + t.Fatalf("expected UseOnce to return false for an expired key") + } + }) } // TestTTLCache_Mark_PrunesExpiredEntries ensures that calling Mark prunes From 85a33fb579554dc02c98b417fe5b548e8a6767c2 Mon Sep 17 00:00:00 2001 From: Gus Date: Mon, 11 Aug 2025 14:49:02 +0800 Subject: [PATCH 6/9] test: deflake UseOnce TTL cache --- pkg/cache/ttl_cache_useonce_test.go | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/pkg/cache/ttl_cache_useonce_test.go b/pkg/cache/ttl_cache_useonce_test.go index 6c095fbd..91b53452 100644 --- a/pkg/cache/ttl_cache_useonce_test.go +++ b/pkg/cache/ttl_cache_useonce_test.go @@ -8,9 +8,10 @@ import ( // TestTTLCache_UseOnce verifies the behavior of UseOnce for first use, // repeated use before expiry and reuse after the TTL has elapsed. func TestTTLCache_UseOnce(t *testing.T) { - c := NewTTLCache() - key := "nonce" - ttl := 50 * time.Millisecond + t.Parallel() + c := NewTTLCache() + key := "nonce" + ttl := 100 * time.Millisecond t.Run("first use", func(t *testing.T) { if used := c.UseOnce(key, ttl); used { @@ -24,12 +25,12 @@ func TestTTLCache_UseOnce(t *testing.T) { } }) - t.Run("use after expiry", func(t *testing.T) { - time.Sleep(ttl + 10*time.Millisecond) - if used := c.UseOnce(key, ttl); used { - t.Fatalf("expected UseOnce to return false for an expired key") - } - }) + t.Run("use after expiry", func(t *testing.T) { + time.Sleep(ttl + 50*time.Millisecond) + if used := c.UseOnce(key, ttl); used { + t.Fatalf("expected UseOnce to return false for an expired key") + } + }) } // TestTTLCache_Mark_PrunesExpiredEntries ensures that calling Mark prunes From d569d2ba43d0d628d9cb1ea2384b1b7f4d7ea5bd Mon Sep 17 00:00:00 2001 From: Gus Date: Mon, 11 Aug 2025 16:20:03 +0800 Subject: [PATCH 7/9] Add test helper for gorm connections? --- database/connection_testhelpers.go | 8 ++++++++ pkg/middleware/token_middleware_additional_test.go | 14 +++++++------- 2 files changed, 15 insertions(+), 7 deletions(-) create mode 100644 database/connection_testhelpers.go diff --git a/database/connection_testhelpers.go b/database/connection_testhelpers.go new file mode 100644 index 00000000..2a0b0b11 --- /dev/null +++ b/database/connection_testhelpers.go @@ -0,0 +1,8 @@ +package database + +import "gorm.io/gorm" + +// NewConnectionFromGorm is intended for tests only. +func NewConnectionFromGorm(db *gorm.DB) *Connection { + return &Connection{driver: db} +} diff --git a/pkg/middleware/token_middleware_additional_test.go b/pkg/middleware/token_middleware_additional_test.go index dec0bf1a..c1443009 100644 --- a/pkg/middleware/token_middleware_additional_test.go +++ b/pkg/middleware/token_middleware_additional_test.go @@ -4,11 +4,9 @@ import ( "context" "net/http" "net/http/httptest" - "reflect" "strings" "testing" "time" - "unsafe" "github.com/google/uuid" "github.com/oullin/database" @@ -27,7 +25,7 @@ import ( func makeRepo(t *testing.T, account string) (*repository.ApiKeys, *auth.TokenHandler, *auth.Token) { t.Helper() testcontainers.SkipIfProviderIsNotHealthy(t) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() pgC, err := postgrescontainer.RunContainer(ctx, testcontainers.WithImage("postgres:16-alpine"), @@ -39,7 +37,9 @@ func makeRepo(t *testing.T, account string) (*repository.ApiKeys, *auth.TokenHan t.Skipf("run postgres container: %v", err) } t.Cleanup(func() { - _ = pgC.Terminate(ctx) + cctx, ccancel := context.WithTimeout(context.Background(), 15*time.Second) + defer ccancel() + _ = pgC.Terminate(cctx) }) dsn, err := pgC.ConnectionString(ctx, "sslmode=disable") if err != nil { @@ -49,6 +49,8 @@ func makeRepo(t *testing.T, account string) (*repository.ApiKeys, *auth.TokenHan if err != nil { t.Skipf("gorm open: %v", err) } + sqlDB, _ := db.DB() + t.Cleanup(func() { _ = sqlDB.Close() }) if err := db.AutoMigrate(&database.APIKey{}); err != nil { t.Fatalf("migrate: %v", err) } @@ -68,9 +70,7 @@ func makeRepo(t *testing.T, account string) (*repository.ApiKeys, *auth.TokenHan }).Error; err != nil { t.Fatalf("seed api key: %v", err) } - conn := &database.Connection{} - v := reflect.ValueOf(conn).Elem().FieldByName("driver") - reflect.NewAt(v.Type(), unsafe.Pointer(v.UnsafeAddr())).Elem().Set(reflect.ValueOf(db)) + conn := database.NewConnectionFromGorm(db) repo := &repository.ApiKeys{DB: conn} return repo, th, seed } From c9b603cb9911fd5f377eaac4fa0e2fd8b1e6c872 Mon Sep 17 00:00:00 2001 From: Gus Date: Mon, 11 Aug 2025 16:26:32 +0800 Subject: [PATCH 8/9] Close sql DB in token middleware tests --- pkg/middleware/token_middleware_additional_test.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pkg/middleware/token_middleware_additional_test.go b/pkg/middleware/token_middleware_additional_test.go index c1443009..b571d523 100644 --- a/pkg/middleware/token_middleware_additional_test.go +++ b/pkg/middleware/token_middleware_additional_test.go @@ -49,8 +49,9 @@ func makeRepo(t *testing.T, account string) (*repository.ApiKeys, *auth.TokenHan if err != nil { t.Skipf("gorm open: %v", err) } - sqlDB, _ := db.DB() - t.Cleanup(func() { _ = sqlDB.Close() }) + if sqlDB, err := db.DB(); err == nil { + t.Cleanup(func() { _ = sqlDB.Close() }) + } if err := db.AutoMigrate(&database.APIKey{}); err != nil { t.Fatalf("migrate: %v", err) } From 1bc0e87d6ee504e5d104529097556f7add3b59ff Mon Sep 17 00:00:00 2001 From: Gus Date: Mon, 11 Aug 2025 16:34:35 +0800 Subject: [PATCH 9/9] test: drive token rate limiter through public API --- .../token_middleware_additional_test.go | 36 +++++++++++++++---- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/pkg/middleware/token_middleware_additional_test.go b/pkg/middleware/token_middleware_additional_test.go index b571d523..19535a3c 100644 --- a/pkg/middleware/token_middleware_additional_test.go +++ b/pkg/middleware/token_middleware_additional_test.go @@ -2,9 +2,9 @@ package middleware import ( "context" + "fmt" "net/http" "net/http/httptest" - "strings" "testing" "time" @@ -152,17 +152,39 @@ func TestTokenMiddleware_RateLimiter(t *testing.T) { repo, th, seed := makeRepo(t, "ratey") tm := MakeTokenMiddleware(th, repo) tm.clockSkew = time.Minute - key := "9.9.9.9|" + strings.ToLower(seed.AccountName) - for i := 0; i < tm.maxFailPerScope; i++ { - tm.rateLimiter.Fail(key) + nextCalled := 0 + next := func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError { + nextCalled++ + return nil } - next := func(w http.ResponseWriter, r *http.Request) *pkgHttp.ApiError { return nil } handler := tm.Handle(next) - req := makeSignedRequest(t, http.MethodGet, "https://api.test.local/v1/rl", "", seed.AccountName, seed.PublicKey, seed.SecretKey, time.Now(), "nonce-rl", "req-rl") + // Pre-warm limiter by sending invalid signature requests up to the limit + for i := 0; i < tm.maxFailPerScope; i++ { + req := makeSignedRequest( + t, http.MethodGet, "https://api.test.local/v1/rl", "", + seed.AccountName, seed.PublicKey, "wrong-secret", time.Now(), + fmt.Sprintf("nonce-rl-%d", i), fmt.Sprintf("req-rl-%d", i), + ) + req.Header.Set("X-Forwarded-For", "9.9.9.9") + rec := httptest.NewRecorder() + _ = handler(rec, req) // ignore errors while warming + } + + // Next request with valid signature should be rate limited + req := makeSignedRequest( + t, http.MethodGet, "https://api.test.local/v1/rl", "", + seed.AccountName, seed.PublicKey, seed.SecretKey, time.Now(), + "nonce-rl-final", "req-rl-final", + ) req.Header.Set("X-Forwarded-For", "9.9.9.9") rec := httptest.NewRecorder() - if err := handler(rec, req); err == nil || err.Status != http.StatusTooManyRequests { + err := handler(rec, req) + if err == nil || err.Status != http.StatusTooManyRequests { t.Fatalf("expected rate limited error, got %#v", err) } + + if nextCalled != 0 { + t.Fatalf("expected next not to be invoked when rate limited, got %d calls", nextCalled) + } }