From 026247786cfd7c814d1c4a201264e9ddf87b05ab Mon Sep 17 00:00:00 2001 From: Daniel Wood Date: Sun, 1 Mar 2026 19:50:09 -0500 Subject: [PATCH 1/8] fix(wasm): harden ingestion and require explicit unverified export (#105) * fix(wasm): harden ingest and make unverified export explicit * fix: address PR 105 reliability review feedback * fix(web): enforce upstream timeout through body reads * fix: address latest PR 105 review feedback * fix(crl): clarify read errors and harden size-limit coverage --- CHANGELOG.md | 11 +++ cmd/certkit/crl.go | 5 +- cmd/wasm/aia.go | 51 ++++++++++---- cmd/wasm/export.go | 34 +++++++--- cmd/wasm/inspect.go | 21 +++++- cmd/wasm/main.go | 96 ++++++++++++++++++++++++-- crl.go | 51 +++++++++++++- crl_test.go | 61 +++++++++++++++++ web/functions/api/fetch.test.ts | 58 ++++++++++++++++ web/functions/api/fetch.ts | 117 +++++++++++++++++++++----------- web/public/app.js | 44 ++++++++++-- 11 files changed, 467 insertions(+), 82 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 661165e5..85e66fca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `connect` automatically checks OCSP revocation status on the leaf certificate (best-effort; shows "skipped" or "unavailable" when check cannot complete) ([#78]) - Add `--crl` flag to `connect` for opt-in CRL revocation checking via distribution points ([#78]) - Add `FetchCRL` library function for downloading CRLs from HTTP URLs with SSRF validation ([#78]) +- Add `ReadCRLFile` library function for reading local CRL files with the same 10 MB size cap as `FetchCRL` ([#105]) - `connect` exits with code 2 when OCSP or CRL reports a revoked certificate ([#78]) - `connect --crl` verifies CRL signatures against the issuer certificate — rejects CRLs signed by a different CA ([#78]) - `connect --crl` rejects expired CRLs (past `NextUpdate`) to prevent replay of stale revocation data ([#78]) @@ -67,6 +68,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Require verified WASM bundle export by default; retrying export without verification is now an explicit user action surfaced in the web UI ([#105]) - Prefer user-provided passwords for PKCS#12/JKS outputs while keeping `changeit` as the default fallback for compatibility ([#87]) - **Breaking:** Standardize certificate serial number formatting to `0x`-prefixed hex across CLI/JSON output ([#87]) - Move local pre-commit hook definitions from repo config into the shared `sensiblebit/.github` hook set, and pin this branch to the shared commit so all repositories can consume the same workflow checks and Node tool bootstrapping behavior from one source ([#85]) @@ -87,6 +89,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Security +- Enforce bounded per-file and total upload limits in WASM `addFiles` and `inspect` ingestion paths to prevent unbounded memory growth ([#105]) +- Enforce local CRL file size limits for `certkit crl` and shared CRL readers to reject oversized inputs early ([#105]) - Prevent bundle export path traversal by sanitizing bundle folder names and enforcing safe output paths ([#87]) - Enforce size limits on input reads to avoid unbounded memory usage ([#87]) - Add SSRF validation (`ValidateAIAURL`) to OCSP responder URLs and CRL distribution point URLs — previously only AIA certificate URLs were validated ([#78]) @@ -95,6 +99,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fix CRL read errors to include `reading CRL data` context before caller wrapping, improving nested error diagnostics ([#105]) +- Fix WASM ingestion promises to recover from internal panics instead of crashing asynchronous file processing ([#105]) +- Fix WASM AIA fetch callback lifecycle to release JS callbacks on cancellation paths after promise completion ([#105]) +- Fix web AIA proxy upstream handling to enforce explicit fetch timeout/abort behavior and return 504 timeout errors ([#105]) +- Fix web AIA proxy timeout handling to keep abort timers active through response body reads, including stalled-after-headers upstream responses ([#105]) - Fix verify JSON chain output to use `not_after` for consistency with other commands ([#87]) - Fix Certificate Transparency availability handling to preserve parsed SCT candidates when the log list cannot be loaded and mark them as unavailable instead of dropping them ([#86]) - Fix chain conversion failures in Certificate Transparency checks to report SCTs as `unavailable` instead of `invalid` and keep diagnostics as warnings ([#86]) @@ -210,6 +219,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Tests +- Consolidate CRL oversize-input coverage into one table-driven test for HTTP and local-file sources, asserting `ErrCRLTooLarge` behaviorally ([#105]) - Remove `TestBuildLegacyClientHelloMsg` — behavioral coverage exists through `TestLegacyFallbackConnect` per T-11 ([`6492fa5`]) - Remove `TestParseCertificateMessage` — behavioral coverage exists through `TestReadServerCertificates` per T-11 ([#82]) - Fix `_, _` error discards in `TestLegacyFallbackConnect` mock server goroutine — replaced with `slog.Debug` per ERR-5 ([#82]) @@ -937,6 +947,7 @@ Initial release. [#85]: https://github.com/sensiblebit/certkit/pull/85 [#86]: https://github.com/sensiblebit/certkit/pull/86 [#87]: https://github.com/sensiblebit/certkit/pull/87 +[#105]: https://github.com/sensiblebit/certkit/pull/105 [#73]: https://github.com/sensiblebit/certkit/pull/73 [#64]: https://github.com/sensiblebit/certkit/pull/64 [#63]: https://github.com/sensiblebit/certkit/pull/63 diff --git a/cmd/certkit/crl.go b/cmd/certkit/crl.go index fa6c4fe0..5837f8d8 100644 --- a/cmd/certkit/crl.go +++ b/cmd/certkit/crl.go @@ -3,7 +3,6 @@ package main import ( "encoding/json" "fmt" - "os" "strings" "github.com/sensiblebit/certkit" @@ -66,9 +65,9 @@ func runCRL(cmd *cobra.Command, args []string) error { return fmt.Errorf("fetching CRL: %w", err) } } else { - data, err = os.ReadFile(source) + data, err = certkit.ReadCRLFile(source) if err != nil { - return fmt.Errorf("reading CRL file: %w", err) + return fmt.Errorf("reading CRL file %q: %w", source, err) } } diff --git a/cmd/wasm/aia.go b/cmd/wasm/aia.go index 903a2773..fbba6a47 100644 --- a/cmd/wasm/aia.go +++ b/cmd/wasm/aia.go @@ -5,7 +5,9 @@ package main import ( "context" "fmt" + "sync" "syscall/js" + "time" "github.com/sensiblebit/certkit/internal/certstore" ) @@ -40,6 +42,10 @@ func resolveAIA(ctx context.Context, s *certstore.MemStore) []string { // direct fetch with automatic CORS proxy fallback. Blocks until the JS // Promise resolves or rejects, or the context is cancelled. func jsFetchURL(ctx context.Context, url string) ([]byte, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + fetchFn := js.Global().Get("certkitFetchURL") if fetchFn.Type() != js.TypeFunction { return nil, fmt.Errorf("certkitFetchURL not defined") @@ -50,25 +56,52 @@ func jsFetchURL(ctx context.Context, url string) ([]byte, error) { err error } ch := make(chan result, 1) + var releaseOnce sync.Once + releaseCallbacks := func(thenCb js.Func, catchCb js.Func) { + releaseOnce.Do(func() { + thenCb.Release() + catchCb.Release() + }) + } + sendResult := func(r result) { + select { + case ch <- r: + default: + } + } - promise := fetchFn.Invoke(url) + timeoutMillis := 10_000 + if deadline, ok := ctx.Deadline(); ok { + remaining := time.Until(deadline).Milliseconds() + if remaining <= 0 { + return nil, context.DeadlineExceeded + } + timeoutMillis = int(remaining) + } + + promise := fetchFn.Invoke(url, timeoutMillis) const maxAIAResponseSize = 1 << 20 // 1MB, consistent with CLI httpAIAFetcher - thenCb := js.FuncOf(func(_ js.Value, args []js.Value) any { + var thenCb js.Func + var catchCb js.Func + + thenCb = js.FuncOf(func(_ js.Value, args []js.Value) any { + defer releaseCallbacks(thenCb, catchCb) uint8Array := args[0] size := uint8Array.Length() if size > maxAIAResponseSize { - ch <- result{err: fmt.Errorf("AIA response too large (%d bytes, max %d)", size, maxAIAResponseSize)} + sendResult(result{err: fmt.Errorf("AIA response too large (%d bytes, max %d)", size, maxAIAResponseSize)}) return nil } data := make([]byte, size) js.CopyBytesToGo(data, uint8Array) - ch <- result{data: data} + sendResult(result{data: data}) return nil }) - catchCb := js.FuncOf(func(_ js.Value, args []js.Value) any { + catchCb = js.FuncOf(func(_ js.Value, args []js.Value) any { + defer releaseCallbacks(thenCb, catchCb) val := args[0] var errMsg string if val.Type() == js.TypeObject || val.Type() == js.TypeFunction { @@ -76,7 +109,7 @@ func jsFetchURL(ctx context.Context, url string) ([]byte, error) { } else { errMsg = val.String() } - ch <- result{err: fmt.Errorf("AIA fetch: %s", errMsg)} + sendResult(result{err: fmt.Errorf("AIA fetch: %s", errMsg)}) return nil }) @@ -84,14 +117,8 @@ func jsFetchURL(ctx context.Context, url string) ([]byte, error) { select { case r := <-ch: - thenCb.Release() - catchCb.Release() return r.data, r.err case <-ctx.Done(): - // Do NOT release callbacks here. The JS promise is still pending and - // will eventually invoke one of them. Calling a released js.Func panics. - // The buffered channel (cap 1) absorbs the late send harmlessly. - // The callbacks leak, but that is preferable to a crash. return nil, ctx.Err() } } diff --git a/cmd/wasm/export.go b/cmd/wasm/export.go index 04d7f4b1..a0175647 100644 --- a/cmd/wasm/export.go +++ b/cmd/wasm/export.go @@ -6,6 +6,7 @@ import ( "archive/zip" "bytes" "context" + "errors" "fmt" "time" @@ -13,19 +14,29 @@ import ( "github.com/sensiblebit/certkit/internal/certstore" ) +var errVerifiedExportFailed = errors.New("verified export failed") + // exportBundles generates a ZIP file containing organized certificate bundles. // If filterSKIs is non-empty, only pairs whose colon-hex SKI appears in the // list are included. Otherwise all matched pairs are exported. -func exportBundles(ctx context.Context, s *certstore.MemStore, filterSKIs []string, p12Password string) ([]byte, error) { - matched := s.MatchedPairs() +// When AllowUnverifiedExport is true, chain verification is disabled explicitly. +type exportBundlesInput struct { + Store *certstore.MemStore + FilterSKIs []string + P12Password string + AllowUnverifiedExport bool +} + +func exportBundles(ctx context.Context, input exportBundlesInput) ([]byte, error) { + matched := input.Store.MatchedPairs() if len(matched) == 0 { return nil, fmt.Errorf("no matched key-certificate pairs found") } // Build a lookup set from the colon-hex formatted filter list. - if len(filterSKIs) > 0 { - allowed := make(map[string]bool, len(filterSKIs)) - for _, ski := range filterSKIs { + if len(input.FilterSKIs) > 0 { + allowed := make(map[string]bool, len(input.FilterSKIs)) + for _, ski := range input.FilterSKIs { allowed[ski] = true } var filtered []string @@ -47,18 +58,21 @@ func exportBundles(ctx context.Context, s *certstore.MemStore, filterSKIs []stri opts := certkit.BundleOptions{ FetchAIA: false, TrustStore: "mozilla", - Verify: true, + Verify: !input.AllowUnverifiedExport, } if err := certstore.ExportMatchedBundles(ctx, certstore.ExportMatchedBundleInput{ - Store: s, + Store: input.Store, SKIs: matched, BundleOpts: opts, Writer: &zipBundleWriter{zw: zw}, - RetryNoVerify: true, - P12Password: p12Password, + RetryNoVerify: false, + P12Password: input.P12Password, }); err != nil { - return nil, err + if opts.Verify { + return nil, fmt.Errorf("%w: %w", errVerifiedExportFailed, err) + } + return nil, fmt.Errorf("unverified export failed: %w", err) } if err := zw.Close(); err != nil { diff --git a/cmd/wasm/inspect.go b/cmd/wasm/inspect.go index dc08dbaf..92fab222 100644 --- a/cmd/wasm/inspect.go +++ b/cmd/wasm/inspect.go @@ -25,6 +25,9 @@ func inspectFiles(_ js.Value, args []js.Value) any { filesArg := args[0] length := filesArg.Length() + if length > wasmMaxInputFiles { + return jsError(fmt.Sprintf("too many files: %d (max %d)", length, wasmMaxInputFiles)) + } var passwords []string if len(args) >= 2 && args[1].Type() == js.TypeString { @@ -51,6 +54,7 @@ func inspectFiles(_ js.Value, args []js.Value) any { defer cancel() var allResults []internal.InspectResult + var totalBytes int64 for i := range length { select { case <-ctx.Done(): @@ -59,9 +63,20 @@ func inspectFiles(_ js.Value, args []js.Value) any { default: } file := filesArg.Index(i) - dataJS := file.Get("data") - data := make([]byte, dataJS.Length()) - js.CopyBytesToGo(data, dataJS) + name := file.Get("name").String() + if name == "" { + name = fmt.Sprintf("file[%d]", i) + } + + data, err := readWASMFileData(readWASMFileDataInput{ + DataJS: file.Get("data"), + Name: name, + TotalBytes: &totalBytes, + }) + if err != nil { + reject.Invoke(js.Global().Get("Error").New(err.Error())) + return + } results := internal.InspectData(data, passwords) allResults = append(allResults, results...) diff --git a/cmd/wasm/main.go b/cmd/wasm/main.go index cf255319..ab10f9e8 100644 --- a/cmd/wasm/main.go +++ b/cmd/wasm/main.go @@ -9,6 +9,7 @@ import ( "context" "encoding/hex" "encoding/json" + "errors" "fmt" "log/slog" "strings" @@ -19,6 +20,12 @@ import ( "github.com/sensiblebit/certkit/internal/certstore" ) +const ( + wasmMaxInputFiles = 200 + wasmMaxInputFileBytes = 10 * 1024 * 1024 + wasmMaxInputTotalBytes = 50 * 1024 * 1024 +) + // version is set at build time via -ldflags "-X main.version=v0.6.1". var version = "dev" @@ -39,6 +46,41 @@ func main() { select {} } +type readWASMFileDataInput struct { + DataJS js.Value + Name string + TotalBytes *int64 +} + +// readWASMFileData copies a JS Uint8Array into Go memory with hard size caps. +func readWASMFileData(input readWASMFileDataInput) ([]byte, error) { + if input.DataJS.Type() != js.TypeObject { + return nil, fmt.Errorf("file %q has invalid data payload", input.Name) + } + + size := input.DataJS.Length() + if size < 0 { + return nil, fmt.Errorf("file %q has invalid size", input.Name) + } + + if size > wasmMaxInputFileBytes { + return nil, fmt.Errorf("file %q exceeds max size (%d bytes)", input.Name, wasmMaxInputFileBytes) + } + + nextTotal := *input.TotalBytes + int64(size) + if nextTotal > wasmMaxInputTotalBytes { + return nil, fmt.Errorf("total upload exceeds max size (%d bytes)", wasmMaxInputTotalBytes) + } + + data := make([]byte, size) + copied := js.CopyBytesToGo(data, input.DataJS) + if copied != size { + return nil, fmt.Errorf("file %q read incomplete data: expected %d bytes, got %d", input.Name, size, copied) + } + *input.TotalBytes = nextTotal + return data, nil +} + // addFiles processes an array of {name, data} objects with optional passwords. // JS signature: certkitAddFiles(files: Array<{name: string, data: Uint8Array}>, passwords: string) → Promise func addFiles(_ js.Value, args []js.Value) any { @@ -48,6 +90,9 @@ func addFiles(_ js.Value, args []js.Value) any { filesArg := args[0] length := filesArg.Length() + if length > wasmMaxInputFiles { + return jsError(fmt.Sprintf("too many files: %d (max %d)", length, wasmMaxInputFiles)) + } var passwords []string if len(args) >= 2 && args[1].Type() == js.TypeString { @@ -64,12 +109,19 @@ func addFiles(_ js.Value, args []js.Value) any { resolve := promiseArgs[0] reject := promiseArgs[1] go func() { + defer func() { + if r := recover(); r != nil { + reject.Invoke(js.Global().Get("Error").New(fmt.Sprintf("internal error: %v", r))) + } + }() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() storeMu.Lock() defer storeMu.Unlock() var results []map[string]any + var totalBytes int64 for i := range length { select { case <-ctx.Done(): @@ -79,11 +131,26 @@ func addFiles(_ js.Value, args []js.Value) any { } file := filesArg.Index(i) name := file.Get("name").String() - dataJS := file.Get("data") - data := make([]byte, dataJS.Length()) - js.CopyBytesToGo(data, dataJS) + if name == "" { + name = fmt.Sprintf("file[%d]", i) + } + + data, err := readWASMFileData(readWASMFileDataInput{ + DataJS: file.Get("data"), + Name: name, + TotalBytes: &totalBytes, + }) + if err != nil { + slog.Debug("skipping file due to read error", "name", name, "error", err) + results = append(results, map[string]any{ + "name": name, + "status": "error", + "error": err.Error(), + }) + continue + } - err := certstore.ProcessData(certstore.ProcessInput{ + err = certstore.ProcessData(certstore.ProcessInput{ Data: data, Path: name, Passwords: passwords, @@ -281,7 +348,7 @@ func getState(_ js.Value, _ []js.Value) any { } // exportBundlesJS generates a ZIP and returns it as a Uint8Array. -// JS signature: certkitExportBundles(skis: string[], p12Password?: string) → Promise +// JS signature: certkitExportBundles(skis: string[], p12Password?: string, allowUnverifiedExport?: boolean) → Promise // Only bundles for the specified SKIs are included. func exportBundlesJS(_ js.Value, args []js.Value) any { // Parse the SKI filter list from the JS array argument. @@ -301,6 +368,11 @@ func exportBundlesJS(_ js.Value, args []js.Value) any { } } + allowUnverifiedExport := false + if len(args) >= 3 && args[2].Type() == js.TypeBoolean { + allowUnverifiedExport = args[2].Bool() + } + handler := js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any { resolve := promiseArgs[0] reject := promiseArgs[1] @@ -310,9 +382,21 @@ func exportBundlesJS(_ js.Value, args []js.Value) any { storeMu.RLock() defer storeMu.RUnlock() - zipData, err := exportBundles(ctx, globalStore, filterSKIs, p12Password) + zipData, err := exportBundles(ctx, exportBundlesInput{ + Store: globalStore, + FilterSKIs: filterSKIs, + P12Password: p12Password, + AllowUnverifiedExport: allowUnverifiedExport, + }) if err != nil { + if errors.Is(err, errVerifiedExportFailed) { + errObj := js.Global().Get("Object").New() + errObj.Set("code", "VERIFY_FAILED") + errObj.Set("message", err.Error()) + reject.Invoke(errObj) + return + } reject.Invoke(js.Global().Get("Error").New(err.Error())) return } diff --git a/crl.go b/crl.go index 25a4a81a..ce49c0e6 100644 --- a/crl.go +++ b/crl.go @@ -4,12 +4,21 @@ import ( "context" "crypto/x509" "encoding/pem" + "errors" "fmt" "io" + "log/slog" "net/http" + "os" + "strconv" "time" ) +const maxCRLBytes int64 = 10 << 20 + +// ErrCRLTooLarge indicates that CRL input exceeded the maximum allowed size. +var ErrCRLTooLarge = errors.New("CRL data exceeds max size") + // CRLInfo contains parsed CRL details for display. type CRLInfo struct { // Issuer is the CRL issuer distinguished name. @@ -80,13 +89,53 @@ func FetchCRL(ctx context.Context, input FetchCRLInput) ([]byte, error) { return nil, fmt.Errorf("CRL server returned HTTP %d from %s", resp.StatusCode, input.URL) } - data, err := io.ReadAll(io.LimitReader(resp.Body, 10<<20)) // 10MB limit + if contentLength := resp.Header.Get("Content-Length"); contentLength != "" { + parsedLength, err := strconv.ParseInt(contentLength, 10, 64) + if err == nil && parsedLength > maxCRLBytes { + return nil, fmt.Errorf("CRL response exceeds max size (%d bytes): %w", maxCRLBytes, ErrCRLTooLarge) + } + } + + data, err := readCRLData(resp.Body) if err != nil { return nil, fmt.Errorf("reading CRL response: %w", err) } return data, nil } +// ReadCRLFile reads a local CRL file with the same hard size cap as FetchCRL. +func ReadCRLFile(path string) ([]byte, error) { + f, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("opening CRL file: %w", err) + } + defer func() { _ = f.Close() }() + + if info, err := f.Stat(); err != nil { + slog.Debug("stat failed on CRL file, skipping size pre-check", "path", path, "err", err) + } else if info.Size() > maxCRLBytes { + return nil, fmt.Errorf("CRL file exceeds max size (%d bytes): %w", maxCRLBytes, ErrCRLTooLarge) + } + + data, err := readCRLData(f) + if err != nil { + return nil, fmt.Errorf("reading CRL file: %w", err) + } + return data, nil +} + +func readCRLData(r io.Reader) ([]byte, error) { + limited := io.LimitReader(r, maxCRLBytes+1) + data, err := io.ReadAll(limited) + if err != nil { + return nil, fmt.Errorf("reading CRL data: %w", err) + } + if int64(len(data)) > maxCRLBytes { + return nil, fmt.Errorf("%w (%d bytes)", ErrCRLTooLarge, maxCRLBytes) + } + return data, nil +} + // ParseCRL parses a CRL from PEM or DER data. Returns the parsed // RevocationList from the stdlib. func ParseCRL(data []byte) (*x509.RevocationList, error) { diff --git a/crl_test.go b/crl_test.go index ad18e6d4..4789c1d3 100644 --- a/crl_test.go +++ b/crl_test.go @@ -1,13 +1,17 @@ package certkit import ( + "bytes" "context" "crypto/rand" "crypto/x509" "encoding/pem" + "errors" "math/big" "net/http" "net/http/httptest" + "os" + "path/filepath" "strings" "testing" "time" @@ -309,6 +313,63 @@ func TestFetchCRL_AllowPrivateNetworks(t *testing.T) { } } +func TestCRLSizeLimit(t *testing.T) { + t.Parallel() + + const tooLargeBytes = 10<<20 + 1 + + tests := []struct { + name string + run func(t *testing.T) error + }{ + { + name: "http response exceeds limit", + run: func(t *testing.T) error { + t.Helper() + + tooLargeBody := bytes.Repeat([]byte("x"), tooLargeBytes) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Length", "10485761") + _, _ = w.Write(tooLargeBody) + })) + t.Cleanup(srv.Close) + + url := strings.Replace(srv.URL, "127.0.0.1", "localhost", 1) + _, err := FetchCRL(context.Background(), FetchCRLInput{URL: url}) + return err + }, + }, + { + name: "local file exceeds limit", + run: func(t *testing.T) error { + t.Helper() + + path := filepath.Join(t.TempDir(), "oversize.crl") + if err := os.WriteFile(path, bytes.Repeat([]byte("x"), tooLargeBytes), 0o600); err != nil { + t.Fatalf("writing oversized CRL file: %v", err) + } + + _, err := ReadCRLFile(path) + return err + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + err := tc.run(t) + if err == nil { + t.Fatal("expected size-limit error, got nil") + } + if !errors.Is(err, ErrCRLTooLarge) { + t.Fatalf("error = %v, want ErrCRLTooLarge", err) + } + }) + } +} + func TestCheckLeafCRL(t *testing.T) { t.Parallel() diff --git a/web/functions/api/fetch.test.ts b/web/functions/api/fetch.test.ts index a73130e0..e7d3a201 100644 --- a/web/functions/api/fetch.test.ts +++ b/web/functions/api/fetch.test.ts @@ -429,6 +429,64 @@ describe("fetch behavior", () => { expect(await errorMsg(resp)).toMatch(/Fetch failed/); }); + it("returns 504 when upstream fetch aborts", async () => { + const abortError = Object.assign(new Error("timed out"), { + name: "AbortError", + }); + vi.stubGlobal("fetch", vi.fn().mockRejectedValue(abortError)); + + const resp = await callGet("http://crl.disa.mil/cert.p7c"); + expect(resp.status).toBe(504); + expect(await errorMsg(resp)).toMatch(/timed out/); + }); + + it("returns 504 when upstream abort is a non-Error object", async () => { + vi.stubGlobal("fetch", vi.fn().mockRejectedValue({ name: "AbortError" })); + + const resp = await callGet("http://crl.disa.mil/cert.p7c"); + expect(resp.status).toBe(504); + expect(await errorMsg(resp)).toMatch(/timed out/); + }); + + it("returns 504 when upstream body read stalls after headers", async () => { + vi.useFakeTimers(); + try { + vi.stubGlobal( + "fetch", + vi.fn().mockImplementation((_url: string, init?: RequestInit) => { + const signal = init?.signal; + const body = new ReadableStream({ + start(controller): void { + signal?.addEventListener( + "abort", + () => { + controller.error({ name: "AbortError" }); + }, + { once: true }, + ); + }, + }); + + return Promise.resolve( + new Response(body, { + status: 200, + headers: { "Content-Length": "16" }, + }), + ); + }), + ); + + const respPromise = callGet("http://crl.disa.mil/cert.p7c"); + await vi.advanceTimersByTimeAsync(8_001); + + const resp = await respPromise; + expect(resp.status).toBe(504); + expect(await errorMsg(resp)).toMatch(/timed out/); + } finally { + vi.useRealTimers(); + } + }); + it("falls back from HTTPS to HTTP when HTTPS fails", async () => { const mockFetch = vi .fn() diff --git a/web/functions/api/fetch.ts b/web/functions/api/fetch.ts index 27ec2e7c..520b6000 100644 --- a/web/functions/api/fetch.ts +++ b/web/functions/api/fetch.ts @@ -6,6 +6,7 @@ // Usage: GET /api/fetch?url=https://cacerts.digicert.com/... const MAX_RESPONSE_SIZE = 256 * 1024; // 256KB — certs are small +const UPSTREAM_TIMEOUT_MS = 8_000; // Allowed origins for CORS. The proxy only serves requests from these origins. const ALLOWED_ORIGINS: string[] = [ @@ -306,39 +307,66 @@ export function isAllowedDomain(hostname: string): boolean { // arbitrary URLs. const MAX_REDIRECTS = 5; -async function safeFetch(url: string): Promise { +type safeFetchResult = { + response: Response; + release: () => void; +}; + +async function safeFetch(url: string): Promise { let currentURL = url; for (let i = 0; i <= MAX_REDIRECTS; i++) { - const resp = await fetch(currentURL, { - headers: { "User-Agent": "certkit AIA proxy/1.0" }, - redirect: "manual", - }); + const controller = new AbortController(); + const timer = setTimeout(() => controller.abort(), UPSTREAM_TIMEOUT_MS); + const release = (): void => { + clearTimeout(timer); + controller.abort(); + }; + + let resp: Response; + try { + resp = await fetch(currentURL, { + headers: { "User-Agent": "certkit AIA proxy/1.0" }, + redirect: "manual", + signal: controller.signal, + }); + } catch (err) { + release(); + throw err; + } // Not a redirect — return as-is. if (resp.status < 300 || resp.status >= 400) { - return resp; + return { response: resp, release }; } const location = resp.headers.get("Location"); if (!location) { - return resp; + return { response: resp, release }; } const target = new URL(location, currentURL); if (target.protocol !== "https:" && target.protocol !== "http:") { + release(); throw new Error("Redirect to non-HTTP protocol"); } if (!isAllowedDomain(target.hostname)) { + release(); throw new Error(`Redirect to disallowed domain '${target.hostname}'`); } // Sanitize redirect URL — only keep protocol, host, and path. currentURL = `${target.protocol}//${target.hostname}${target.pathname}`; + release(); } throw new Error("Too many redirects"); } +function isAbortError(err: unknown): boolean { + const anyErr = err as { name?: unknown } | null | undefined; + return anyErr?.name === "AbortError"; +} + export const onRequestOptions: PagesFunction = async ({ request }) => { const origin = request.headers.get("Origin"); return new Response(null, { status: 204, headers: corsHeaders(origin) }); @@ -441,42 +469,51 @@ export const onRequestGet: PagesFunction = async ({ request }) => { for (const tryURL of urlsToTry) { try { - const upstream = await safeFetch(tryURL); - - if (!upstream.ok) { - lastStatus = upstream.status; - lastMessage = `Upstream returned ${upstream.status}`; - continue; // try next URL (HTTP fallback) + const { response: upstream, release } = await safeFetch(tryURL); + + try { + if (!upstream.ok) { + lastStatus = upstream.status; + lastMessage = `Upstream returned ${upstream.status}`; + continue; // try next URL (HTTP fallback) + } + + const contentLength = upstream.headers.get("content-length"); + if (contentLength && parseInt(contentLength, 10) > MAX_RESPONSE_SIZE) { + return errorResponse(413, "Response too large", origin); + } + + const body = await upstream.arrayBuffer(); + if (body.byteLength > MAX_RESPONSE_SIZE) { + return errorResponse(413, "Response too large", origin); + } + + if (body.byteLength === 0) { + lastStatus = 502; + lastMessage = "Upstream returned empty response"; + continue; + } + + const responseHeaders = new Headers(corsHeaders(origin)); + responseHeaders.set("Content-Type", "application/octet-stream"); + responseHeaders.set("Content-Length", body.byteLength.toString()); + responseHeaders.set("X-Content-Type-Options", "nosniff"); + // Cache forever — AIA certificates are immutable (same URL = same cert) + responseHeaders.set( + "Cache-Control", + "public, max-age=31536000, immutable", + ); + + return new Response(body, { status: 200, headers: responseHeaders }); + } finally { + release(); } - - const contentLength = upstream.headers.get("content-length"); - if (contentLength && parseInt(contentLength, 10) > MAX_RESPONSE_SIZE) { - return errorResponse(413, "Response too large", origin); - } - - const body = await upstream.arrayBuffer(); - if (body.byteLength > MAX_RESPONSE_SIZE) { - return errorResponse(413, "Response too large", origin); - } - - if (body.byteLength === 0) { - lastStatus = 502; - lastMessage = "Upstream returned empty response"; + } catch (err) { + if (isAbortError(err)) { + lastStatus = 504; + lastMessage = `Upstream fetch timed out after ${UPSTREAM_TIMEOUT_MS}ms for ${tryURL}`; continue; } - - const responseHeaders = new Headers(corsHeaders(origin)); - responseHeaders.set("Content-Type", "application/octet-stream"); - responseHeaders.set("Content-Length", body.byteLength.toString()); - responseHeaders.set("X-Content-Type-Options", "nosniff"); - // Cache forever — AIA certificates are immutable (same URL = same cert) - responseHeaders.set( - "Cache-Control", - "public, max-age=31536000, immutable", - ); - - return new Response(body, { status: 200, headers: responseHeaders }); - } catch { lastStatus = 502; lastMessage = `Fetch failed for ${tryURL}`; } diff --git a/web/public/app.js b/web/public/app.js index 0dc08ffa..510b0a98 100644 --- a/web/public/app.js +++ b/web/public/app.js @@ -61,13 +61,25 @@ const STATUS_ICONS = { // certkitFetchURL is called from Go (WASM) to fetch AIA certificates. // Tries direct fetch first, then falls back to our own /api/fetch proxy // (same-origin, no CORS issues). -window.certkitFetchURL = async function (url) { +window.certkitFetchURL = async function (url, timeoutMs = 10000) { + const fetchBytesWithTimeout = async (targetURL) => { + const controller = new AbortController(); + const timer = setTimeout(() => controller.abort(), timeoutMs); + try { + const resp = await fetch(targetURL, { signal: controller.signal }); + const body = await resp.arrayBuffer(); + return { resp, data: new Uint8Array(body) }; + } finally { + clearTimeout(timer); + } + }; + // 1. Try direct fetch (works if CA serves CORS headers) try { - const resp = await fetch(url); + const { resp, data } = await fetchBytesWithTimeout(url); if (resp.ok) { console.log("certkit: AIA direct fetch succeeded:", url); - return new Uint8Array(await resp.arrayBuffer()); + return data; } } catch (e) { console.log("certkit: AIA direct fetch failed:", url, e.message); @@ -76,13 +88,13 @@ window.certkitFetchURL = async function (url) { // 2. Proxy through our own /api/fetch endpoint const proxiedURL = "/api/fetch?url=" + encodeURIComponent(url); console.log("certkit: AIA proxy fetch:", proxiedURL); - const resp = await fetch(proxiedURL); + const { resp, data } = await fetchBytesWithTimeout(proxiedURL); if (!resp.ok) { - const body = await resp.text(); + const body = new TextDecoder().decode(data); throw new Error(`Proxy returned ${resp.status}: ${body}`); } console.log("certkit: AIA proxy fetch succeeded for", url); - return new Uint8Array(await resp.arrayBuffer()); + return data; }; // --- WASM Loading --- @@ -1184,7 +1196,25 @@ exportBtn.addEventListener("click", async () => { downloadBlob(zipData, "certkit-bundles.zip", "application/zip"); hideStatus(); } catch (err) { - showStatus(`Export error: ${err.message}`, true); + if ( + err?.code === "VERIFY_FAILED" && + window.confirm( + "Verified export failed. Retry without certificate chain verification?", + ) + ) { + try { + const zipData = await certkitExportBundles(skis, undefined, true); + downloadBlob(zipData, "certkit-bundles.zip", "application/zip"); + showStatus( + "Export completed without chain verification. Verify trust before use.", + false, + ); + } catch (retryErr) { + showStatus(`Export error: ${retryErr.message}`, true); + } + } else { + showStatus(`Export error: ${err.message}`, true); + } } finally { exportBtn.disabled = false; updateExportBtn(); From 7ec9a4b111357ce9464bce35ddc5ce4faee7fa22 Mon Sep 17 00:00:00 2001 From: Daniel Wood Date: Sun, 1 Mar 2026 19:52:42 -0500 Subject: [PATCH 2/8] fix: harden core parsing and identity selection (#107) * fix: harden parsing and issuer/key selection correctness * docs(changelog): reference PR for unreleased parsing fixes * fix: prevent parser fallback and JKS identity regressions * fix(inspect): preserve valid keys when PEM bundle has malformed blocks * fix: address remaining PR 107 review feedback * fix(jks): surface skipped-entry reasons in debug logs --- CHANGELOG.md | 10 + certkit.go | 269 ++++++++++++++++++++------- certkit_test.go | 202 +++++++++++++++++++- cmd/certkit/ocsp.go | 17 +- internal/certstore/container.go | 78 ++++++-- internal/certstore/container_test.go | 169 +++++++++++++++-- internal/certstore/memstore.go | 19 +- internal/certstore/memstore_test.go | 51 +++++ internal/certstore/sqlite.go | 9 +- internal/certstore/sqlite_test.go | 64 +++++++ internal/inspect.go | 39 +++- internal/inspect_test.go | 107 ++++++++++- jks.go | 111 +++++++---- jks_test.go | 100 ++++++++++ sign.go | 11 ++ sign_test.go | 55 ++++++ 16 files changed, 1161 insertions(+), 150 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 85e66fca..e026686b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -99,6 +99,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fix inspect/convert/container parsing to continue past malformed PEM certificate blocks so valid certificates are still processed, and add DER private-key detection for key-only inputs ([#107]) +- Fix inspect PEM key parsing to continue past malformed private-key blocks so valid keys in the same bundle are still reported ([#107]) +- Fix JKS container selection to keep private key entries paired with their own leaf certificate and chain instead of selecting unrelated trusted entries ([#107]) +- Fix `DecodeJKSKeyEntries` to emit debug logs when skipping non-private-key aliases and malformed PKCS#8 private-key payloads (ERR-5) ([#107]) +- Fix JKS/issuer parsing edge cases by adding debug logging for skipped JKS entry errors, requiring issuer DN match during OCSP issuer auto-selection, and consolidating duplicate CSR-scan tests into a single table-driven case ([#107]) +- Fix certificate identity deduplication for certificates without AKI by falling back to issuer+serial identity, and align SQLite persistence keys with the same identity to prevent dropped certs across different issuers ([#107]) +- Fix CSR signing to reject CA certificate/key mismatches before issuing certificates ([#107]) +- Fix OCSP issuer auto-selection to choose a certificate that actually signs the leaf (with AKI/SKI preference) instead of defaulting to the first extra certificate ([#107]) - Fix CRL read errors to include `reading CRL data` context before caller wrapping, improving nested error diagnostics ([#105]) - Fix WASM ingestion promises to recover from internal panics instead of crashing asynchronous file processing ([#105]) - Fix WASM AIA fetch callback lifecycle to release JS callbacks on cancellation paths after promise completion ([#105]) @@ -219,6 +227,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Tests +- Add behavior-focused edge-case coverage for malformed PEM + valid cert preservation, DER private-key parsing, JKS key-entry alias pairing, issuer selection, and CA cert/key mismatch validation; remove assertions that depended on old tautological error-path expectations ([#107]) - Consolidate CRL oversize-input coverage into one table-driven test for HTTP and local-file sources, asserting `ErrCRLTooLarge` behaviorally ([#105]) - Remove `TestBuildLegacyClientHelloMsg` — behavioral coverage exists through `TestLegacyFallbackConnect` per T-11 ([`6492fa5`]) - Remove `TestParseCertificateMessage` — behavioral coverage exists through `TestReadServerCertificates` per T-11 ([#82]) @@ -947,6 +956,7 @@ Initial release. [#85]: https://github.com/sensiblebit/certkit/pull/85 [#86]: https://github.com/sensiblebit/certkit/pull/86 [#87]: https://github.com/sensiblebit/certkit/pull/87 +[#107]: https://github.com/sensiblebit/certkit/pull/107 [#105]: https://github.com/sensiblebit/certkit/pull/105 [#73]: https://github.com/sensiblebit/certkit/pull/73 [#64]: https://github.com/sensiblebit/certkit/pull/64 diff --git a/certkit.go b/certkit.go index d33672ab..6387f681 100644 --- a/certkit.go +++ b/certkit.go @@ -19,6 +19,7 @@ import ( "encoding/pem" "errors" "fmt" + "log/slog" "math/big" "strings" "time" @@ -29,6 +30,7 @@ import ( // ParsePEMCertificates parses all certificates from a PEM bundle. func ParsePEMCertificates(pemData []byte) ([]*x509.Certificate, error) { var certs []*x509.Certificate + var firstErr error rest := pemData for { var block *pem.Block @@ -41,11 +43,18 @@ func ParsePEMCertificates(pemData []byte) ([]*x509.Certificate, error) { } cert, err := x509.ParseCertificate(block.Bytes) if err != nil { - return nil, fmt.Errorf("parsing certificate: %w", err) + if firstErr == nil { + firstErr = fmt.Errorf("parsing certificate: %w", err) + } + slog.Debug("skipping malformed CERTIFICATE PEM block", "error", err) + continue } certs = append(certs, cert) } if len(certs) == 0 { + if firstErr != nil { + return nil, firstErr + } return nil, errors.New("no certificates found in PEM data") } return certs, nil @@ -95,38 +104,32 @@ func normalizeKey(key crypto.PrivateKey) crypto.PrivateKey { // For "PRIVATE KEY" blocks it tries PKCS#8 first, then falls back to PKCS#1 // and EC parsers to handle mislabeled keys (e.g., from pkcs12.ToPEM). func ParsePEMPrivateKey(pemData []byte) (crypto.PrivateKey, error) { - block, _ := pem.Decode(pemData) - if block == nil { - return nil, errors.New("no PEM block found in private key data") - } - - switch block.Type { - case "RSA PRIVATE KEY": - return x509.ParsePKCS1PrivateKey(block.Bytes) - case "EC PRIVATE KEY": - return x509.ParseECPrivateKey(block.Bytes) - case "PRIVATE KEY": - if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil { - return normalizeKey(key), nil + rest := pemData + var firstErr error + for { + var block *pem.Block + block, rest = pem.Decode(rest) + if block == nil { + break } - // Fall back: some tools (e.g., pkcs12.ToPEM) label PKCS#1 keys as "PRIVATE KEY" - if key, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { - return key, nil + if !keyBlockTypes[block.Type] { + continue } - if key, err := x509.ParseECPrivateKey(block.Bytes); err == nil { + + singlePEM := pem.EncodeToMemory(block) + key, err := parsePEMPrivateKeyBlock(singlePEM, block) + if err == nil { return key, nil } - return nil, errors.New("parsing PRIVATE KEY block with any known format") - case "OPENSSH PRIVATE KEY": - // OpenSSH format uses a proprietary encoding; delegate to x/crypto/ssh - key, err := ssh.ParseRawPrivateKey(pemData) - if err != nil { - return nil, fmt.Errorf("parsing OpenSSH private key: %w", err) + if firstErr == nil { + firstErr = err } - return normalizeKey(key), nil - default: - return nil, fmt.Errorf("unsupported PEM block type %q", block.Type) } + + if firstErr != nil { + return nil, firstErr + } + return nil, errors.New("no private keys found in PEM data") } // DefaultPasswords returns the list of passwords tried by default when decrypting @@ -157,54 +160,102 @@ func DeduplicatePasswords(extra []string) []string { // order. Returns the first successfully decrypted key, or an error if all // passwords fail. func ParsePEMPrivateKeyWithPasswords(pemData []byte, passwords []string) (crypto.PrivateKey, error) { - // Try unencrypted first - if key, err := ParsePEMPrivateKey(pemData); err == nil { - return key, nil - } + rest := pemData + var firstErr error + for { + var block *pem.Block + block, rest = pem.Decode(rest) + if block == nil { + break + } + if !keyBlockTypes[block.Type] { + continue + } - block, _ := pem.Decode(pemData) - if block == nil { - return nil, errors.New("no PEM block found in private key data") - } + singlePEM := pem.EncodeToMemory(block) - // OpenSSH keys use their own encryption format, not legacy RFC 1423 - if block.Type == "OPENSSH PRIVATE KEY" { - for _, password := range passwords { - if password == "" { - continue // already tried unencrypted above + key, parseErr := parsePEMPrivateKeyBlock(singlePEM, block) + if parseErr == nil { + return key, nil + } + + // OpenSSH uses a proprietary encrypted format. + if block.Type == "OPENSSH PRIVATE KEY" { + if len(passwords) == 0 { + if firstErr == nil { + firstErr = parseErr + } + slog.Debug("skipping OpenSSH private key block with no passwords", "error", parseErr) + continue } - key, err := ssh.ParseRawPrivateKeyWithPassphrase(pemData, []byte(password)) - if err == nil { - return normalizeKey(key), nil + + var openSSHErr error + for _, password := range passwords { + if password == "" { + continue + } + key, err := ssh.ParseRawPrivateKeyWithPassphrase(singlePEM, []byte(password)) + if err == nil { + return normalizeKey(key), nil + } + if openSSHErr == nil { + openSSHErr = fmt.Errorf("parsing OpenSSH private key with provided passwords: %w", err) + } + slog.Debug("failed OpenSSH private key passphrase", "error", err) + } + if openSSHErr == nil { + openSSHErr = parseErr } + if firstErr == nil { + firstErr = openSSHErr + } + slog.Debug("skipping OpenSSH private key block after password attempts", "error", openSSHErr) + continue } - return nil, errors.New("parsing OpenSSH private key with any provided password") - } - - //nolint:staticcheck // x509.IsEncryptedPEMBlock is deprecated but needed for legacy encrypted PEM support - if !x509.IsEncryptedPEMBlock(block) { - // Not encrypted and unencrypted parse failed — return the original error - _, err := ParsePEMPrivateKey(pemData) - return nil, err - } - for _, password := range passwords { - //nolint:staticcheck // x509.DecryptPEMBlock is deprecated but needed for legacy encrypted PEM support - decrypted, err := x509.DecryptPEMBlock(block, []byte(password)) - if err != nil { + //nolint:staticcheck // x509.IsEncryptedPEMBlock is deprecated but needed for legacy encrypted PEM support + if !x509.IsEncryptedPEMBlock(block) { + if firstErr == nil { + firstErr = parseErr + } + slog.Debug("skipping unparseable unencrypted private key PEM block", "block_type", block.Type, "error", parseErr) continue } - clearPEM := pem.EncodeToMemory(&pem.Block{ - Type: block.Type, - Bytes: decrypted, - }) - if key, err := ParsePEMPrivateKey(clearPEM); err == nil { - return key, nil + var encryptedErr error + for _, password := range passwords { + //nolint:staticcheck // x509.DecryptPEMBlock is deprecated but needed for legacy encrypted PEM support + decrypted, err := x509.DecryptPEMBlock(block, []byte(password)) + if err != nil { + if encryptedErr == nil { + encryptedErr = fmt.Errorf("decrypting private key with provided passwords: %w", err) + } + slog.Debug("failed decrypting encrypted private key block", "block_type", block.Type, "error", err) + continue + } + clearPEM := pem.EncodeToMemory(&pem.Block{Type: block.Type, Bytes: decrypted}) + key, err := ParsePEMPrivateKey(clearPEM) + if err == nil { + return key, nil + } + if encryptedErr == nil { + encryptedErr = fmt.Errorf("parsing decrypted private key: %w", err) + } + slog.Debug("failed parsing decrypted private key block", "block_type", block.Type, "error", err) + } + if encryptedErr != nil && firstErr == nil { + firstErr = encryptedErr + } + if firstErr == nil { + firstErr = errors.New("decrypting private key with any provided password") } + slog.Debug("skipping encrypted private key block after password attempts", "block_type", block.Type, "error", firstErr) } - return nil, errors.New("decrypting private key with any provided password") + if firstErr != nil { + return nil, firstErr + } + return nil, errors.New("no private keys found in PEM data") } // keyBlockTypes is the set of PEM block types that represent private keys. @@ -249,18 +300,60 @@ func ParsePEMPrivateKeys(pemData []byte, passwords []string) ([]crypto.PrivateKe // ParsePEMCertificateRequest parses a single certificate request from PEM data. func ParsePEMCertificateRequest(pemData []byte) (*x509.CertificateRequest, error) { - block, _ := pem.Decode(pemData) - if block == nil { - return nil, errors.New("no PEM block found in certificate request data") + rest := pemData + var firstErr error + for { + var block *pem.Block + block, rest = pem.Decode(rest) + if block == nil { + break + } + if block.Type != "CERTIFICATE REQUEST" && block.Type != "NEW CERTIFICATE REQUEST" { + continue + } + csr, err := x509.ParseCertificateRequest(block.Bytes) + if err != nil { + if firstErr == nil { + firstErr = fmt.Errorf("parsing certificate request: %w", err) + } + slog.Debug("skipping malformed certificate request PEM block", "error", err) + continue + } + return csr, nil } - if block.Type != "CERTIFICATE REQUEST" && block.Type != "NEW CERTIFICATE REQUEST" { - return nil, fmt.Errorf("expected CERTIFICATE REQUEST PEM block, got %q", block.Type) + if firstErr != nil { + return nil, firstErr } - csr, err := x509.ParseCertificateRequest(block.Bytes) - if err != nil { - return nil, fmt.Errorf("parsing certificate request: %w", err) + return nil, errors.New("no certificate request found in PEM data") +} + +func parsePEMPrivateKeyBlock(singlePEM []byte, block *pem.Block) (crypto.PrivateKey, error) { + switch block.Type { + case "RSA PRIVATE KEY": + return x509.ParsePKCS1PrivateKey(block.Bytes) + case "EC PRIVATE KEY": + return x509.ParseECPrivateKey(block.Bytes) + case "PRIVATE KEY": + if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil { + return normalizeKey(key), nil + } + // Fall back: some tools (e.g., pkcs12.ToPEM) label PKCS#1 keys as "PRIVATE KEY" + if key, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { + return key, nil + } + if key, err := x509.ParseECPrivateKey(block.Bytes); err == nil { + return key, nil + } + return nil, errors.New("parsing PRIVATE KEY block with any known format") + case "OPENSSH PRIVATE KEY": + key, err := ssh.ParseRawPrivateKey(singlePEM) + if err != nil { + return nil, fmt.Errorf("parsing OpenSSH private key: %w", err) + } + return normalizeKey(key), nil + default: + return nil, fmt.Errorf("unsupported PEM block type %q", block.Type) } - return csr, nil } // CertToPEM encodes a certificate as PEM. @@ -528,6 +621,38 @@ func KeyMatchesCert(priv crypto.PrivateKey, cert *x509.Certificate) (bool, error return eq.Equal(cert.PublicKey), nil } +// SelectIssuerCertificate chooses the best issuer for cert from candidates. +// It requires both issuer DN match and a valid signature relationship, and +// prefers AKI/SKI matches when available. Returns nil when no candidate meets +// those criteria. +func SelectIssuerCertificate(cert *x509.Certificate, candidates []*x509.Certificate) *x509.Certificate { + if cert == nil { + return nil + } + + var fallback *x509.Certificate + for _, candidate := range candidates { + if candidate == nil { + continue + } + if !bytes.Equal(cert.RawIssuer, candidate.RawSubject) { + continue + } + if err := cert.CheckSignatureFrom(candidate); err != nil { + slog.Debug("skipping candidate with invalid issuer signature", "error", err) + continue + } + if len(cert.AuthorityKeyId) > 0 && len(candidate.SubjectKeyId) > 0 && bytes.Equal(cert.AuthorityKeyId, candidate.SubjectKeyId) { + return candidate + } + if fallback == nil { + fallback = candidate + } + } + + return fallback +} + // IsPEM returns true if the data appears to contain PEM-encoded content. func IsPEM(data []byte) bool { return bytes.Contains(data, []byte("-----BEGIN")) diff --git a/certkit_test.go b/certkit_test.go index f405e0b6..50c0e2c0 100644 --- a/certkit_test.go +++ b/certkit_test.go @@ -116,6 +116,43 @@ func TestParsePEMCertificates_invalidDER(t *testing.T) { } } +func TestParsePEMCertificates_PreservesValidWhenMalformedPresent(t *testing.T) { + // WHY: Mixed-quality bundles are common in the wild. A malformed + // CERTIFICATE block must not discard other valid certificates. + t.Parallel() + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + certTemplate := &x509.Certificate{ + SerialNumber: randomSerial(t), + Subject: pkix.Name{CommonName: "valid-cert.example.com"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + } + certDER, err := x509.CreateCertificate(rand.Reader, certTemplate, certTemplate, &key.PublicKey, key) + if err != nil { + t.Fatal(err) + } + + pemData := slices.Concat( + pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: []byte("bad-der")}), + pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}), + ) + + certs, err := ParsePEMCertificates(pemData) + if err != nil { + t.Fatalf("ParsePEMCertificates: %v", err) + } + if len(certs) != 1 { + t.Fatalf("expected 1 valid cert, got %d", len(certs)) + } + if certs[0].Subject.CommonName != "valid-cert.example.com" { + t.Errorf("CN=%q, want valid-cert.example.com", certs[0].Subject.CommonName) + } +} + func TestCertKeyIdEmbedded_NilExtensions(t *testing.T) { // WHY: Nil SubjectKeyId/AuthorityKeyId must return empty string gracefully, // not panic. Populated cases are tautological (ColonHex(x) == ColonHex(x)) @@ -336,6 +373,35 @@ func TestParsePEMPrivateKey_MislabeledBlockType(t *testing.T) { } } +func TestParsePEMPrivateKey_SkipsNonKeyBlocks(t *testing.T) { + // WHY: ParsePEMPrivateKey is used in key-only paths and must find the first + // key block even when certificate blocks appear first. + t.Parallel() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + keyDER := x509.MarshalPKCS1PrivateKey(key) + + pemData := slices.Concat( + pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: []byte("not-a-cert")}), + pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyDER}), + ) + + parsed, err := ParsePEMPrivateKey(pemData) + if err != nil { + t.Fatalf("ParsePEMPrivateKey: %v", err) + } + rsaParsed, ok := parsed.(*rsa.PrivateKey) + if !ok { + t.Fatalf("parsed key type = %T, want *rsa.PrivateKey", parsed) + } + if !key.Equal(rsaParsed) { + t.Error("parsed key does not Equal original") + } +} + func TestParsePEMPrivateKeyWithPasswords_Encrypted(t *testing.T) { // WHY: Encrypted PEM keys must decrypt with the correct password, fail // clearly with wrong passwords, iterate all candidates, and handle edge @@ -510,12 +576,12 @@ func TestParsePEMCertificateRequest_errors(t *testing.T) { { name: "invalid PEM", input: []byte("not valid PEM"), - wantErr: "no PEM block found", + wantErr: "no certificate request found", }, { name: "wrong block type", input: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: []byte("whatever")}), - wantErr: "expected CERTIFICATE REQUEST", + wantErr: "no certificate request found", }, { name: "invalid DER", @@ -537,6 +603,57 @@ func TestParsePEMCertificateRequest_errors(t *testing.T) { } } +func TestParsePEMCertificateRequest_SkipsBadBlocksBeforeValidCSR(t *testing.T) { + // WHY: CSR parsing must continue scanning when earlier PEM blocks are either + // wrong block types or malformed CSR DER. + t.Parallel() + + leaf, key := generateLeafWithSANs(t) + csrPEM, _, err := GenerateCSR(leaf, key) + if err != nil { + t.Fatal(err) + } + csrBlock, _ := pem.Decode([]byte(csrPEM)) + if csrBlock == nil { + t.Fatal("failed to decode generated CSR") + } + + tests := []struct { + name string + blockType string + blockDER []byte + }{ + { + name: "skips non-CSR block before valid CSR", + blockType: "CERTIFICATE", + blockDER: []byte("not-a-csr"), + }, + { + name: "skips malformed CSR block before valid CSR", + blockType: "CERTIFICATE REQUEST", + blockDER: []byte("bad-csr-der"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + pemData := slices.Concat( + pem.EncodeToMemory(&pem.Block{Type: tt.blockType, Bytes: tt.blockDER}), + pem.EncodeToMemory(csrBlock), + ) + + csr, err := ParsePEMCertificateRequest(pemData) + if err != nil { + t.Fatalf("ParsePEMCertificateRequest: %v", err) + } + if csr.Subject.CommonName != "test.example.com" { + t.Errorf("CN=%q, want test.example.com", csr.Subject.CommonName) + } + }) + } +} + func TestParsePEMCertificateRequest_LegacyBlockType(t *testing.T) { // WHY: Older tools (Netscape, MSIE) emit "NEW CERTIFICATE REQUEST" instead of // "CERTIFICATE REQUEST". The DER payload is identical; rejecting the legacy type @@ -741,6 +858,83 @@ func TestKeyMatchesCert(t *testing.T) { } } +func TestSelectIssuerCertificate(t *testing.T) { + // WHY: Issuer auto-selection must choose a candidate that actually signed + // the leaf and prefer AKI/SKI matches to avoid wrong-issuer OCSP checks. + t.Parallel() + + caPEM, interPEM, leafPEM := generateTestPKI(t) + ca, err := ParsePEMCertificate([]byte(caPEM)) + if err != nil { + t.Fatal(err) + } + intermediate, err := ParsePEMCertificate([]byte(interPEM)) + if err != nil { + t.Fatal(err) + } + leaf, err := ParsePEMCertificate([]byte(leafPEM)) + if err != nil { + t.Fatal(err) + } + + wrongCAKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + wrongCATmpl := &x509.Certificate{ + SerialNumber: randomSerial(t), + Subject: leaf.Issuer, // same issuer DN, different key + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + IsCA: true, + KeyUsage: x509.KeyUsageCertSign, + } + wrongCADER, err := x509.CreateCertificate(rand.Reader, wrongCATmpl, wrongCATmpl, &wrongCAKey.PublicKey, wrongCAKey) + if err != nil { + t.Fatal(err) + } + wrongCA, err := x509.ParseCertificate(wrongCADER) + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + candidates []*x509.Certificate + wantIssuer *x509.Certificate + }{ + { + name: "prefers AKI SKI matched valid signer", + candidates: []*x509.Certificate{wrongCA, ca, intermediate}, + wantIssuer: intermediate, + }, + { + name: "returns nil when no candidate signs leaf", + candidates: []*x509.Certificate{wrongCA, ca}, + wantIssuer: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + issuer := SelectIssuerCertificate(leaf, tt.candidates) + if tt.wantIssuer == nil { + if issuer != nil { + t.Errorf("expected nil issuer, got CN=%q", issuer.Subject.CommonName) + } + return + } + if issuer == nil { + t.Fatal("expected issuer, got nil") + } + if !issuer.Equal(tt.wantIssuer) { + t.Errorf("selected issuer CN = %q, want %q", issuer.Subject.CommonName, tt.wantIssuer.Subject.CommonName) + } + }) + } +} + func TestCertExpiresWithin(t *testing.T) { // WHY: Expiry window detection drives renewal warnings and the // --allow-expired filter. Covers within/outside window, already-expired, @@ -1082,10 +1276,10 @@ func TestParsePEMPrivateKey_ErrorPaths(t *testing.T) { input []byte wantInErr string }{ - {"empty input", nil, "no PEM block"}, + {"empty input", nil, "no private keys found"}, {"corrupt OpenSSH body", corruptOpenSSH, "OpenSSH"}, {"garbage PRIVATE KEY block", garbagePKCS8, "parsing PRIVATE KEY"}, - {"unsupported block type (DSA)", dsaPEM, "unsupported PEM block type"}, + {"unsupported block type (DSA)", dsaPEM, "no private keys found"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/cmd/certkit/ocsp.go b/cmd/certkit/ocsp.go index cd635723..0493e02f 100644 --- a/cmd/certkit/ocsp.go +++ b/cmd/certkit/ocsp.go @@ -1,6 +1,7 @@ package main import ( + "crypto/x509" "encoding/json" "fmt" "os" @@ -69,7 +70,7 @@ func runOCSP(cmd *cobra.Command, args []string) error { if err != nil { return fmt.Errorf("reading issuer certificate: %w", err) } - issuerCert, err := certkit.ParsePEMCertificate(issuerData) + issuerCert, err := parseAnyCertificate(issuerData) if err != nil { return fmt.Errorf("parsing issuer certificate: %w", err) } @@ -78,10 +79,13 @@ func runOCSP(cmd *cobra.Command, args []string) error { Issuer: issuerCert, } } else if len(contents.ExtraCerts) > 0 { - // Use first extra cert as issuer (typically the immediate issuer) + issuerCert := certkit.SelectIssuerCertificate(contents.Leaf, contents.ExtraCerts) + if issuerCert == nil { + return fmt.Errorf("no matching issuer certificate found in input; use --issuer to provide one") + } ocspInput = &certkit.CheckOCSPInput{ Cert: contents.Leaf, - Issuer: contents.ExtraCerts[0], + Issuer: issuerCert, } } else { return fmt.Errorf("no issuer certificate found; use --issuer to provide one") @@ -134,3 +138,10 @@ func runOCSP(cmd *cobra.Command, args []string) error { return nil } + +func parseAnyCertificate(data []byte) (*x509.Certificate, error) { + if certkit.IsPEM(data) { + return certkit.ParsePEMCertificate(data) + } + return x509.ParseCertificate(data) +} diff --git a/internal/certstore/container.go b/internal/certstore/container.go index 12baa391..2ceadc31 100644 --- a/internal/certstore/container.go +++ b/internal/certstore/container.go @@ -5,6 +5,7 @@ import ( "crypto/x509" "encoding/pem" "fmt" + "slices" "strings" "github.com/sensiblebit/certkit" @@ -34,19 +35,38 @@ func ParseContainerData(data []byte, passwords []string) (*ContainerContents, er } // Try JKS - if certs, keys, err := certkit.DecodeJKS(data, passwords); err == nil { - var leaf *x509.Certificate - var extras []*x509.Certificate - if len(certs) > 0 { - leaf = certs[0] - extras = certs[1:] - } - var key crypto.PrivateKey - if len(keys) > 0 { - key = keys[0] + if keyEntries, trustedCerts, err := certkit.DecodeJKSKeyEntries(data, passwords); err == nil { + if len(keyEntries) > 0 { + entry := keyEntries[0] + if idx := slices.IndexFunc(keyEntries, func(candidate certkit.DecodedJKSKeyEntry) bool { + return len(candidate.Chain) > 0 + }); idx >= 0 { + entry = keyEntries[idx] + } + + if len(entry.Chain) > 0 { + leaf, chainExtras := selectLeafAndExtras(entry.Chain, entry.Key) + allExtras := slices.Concat(chainExtras, trustedCerts) + return &ContainerContents{Leaf: leaf, Key: entry.Key, ExtraCerts: allExtras}, nil + } + + if trustedMatchIdx := slices.IndexFunc(trustedCerts, func(cert *x509.Certificate) bool { + ok, matchErr := certkit.KeyMatchesCert(entry.Key, cert) + return matchErr == nil && ok + }); trustedMatchIdx >= 0 { + leaf := trustedCerts[trustedMatchIdx] + extras := make([]*x509.Certificate, 0, len(trustedCerts)-1) + extras = append(extras, trustedCerts[:trustedMatchIdx]...) + extras = append(extras, trustedCerts[trustedMatchIdx+1:]...) + return &ContainerContents{Leaf: leaf, Key: entry.Key, ExtraCerts: extras}, nil + } + + return &ContainerContents{Key: entry.Key, ExtraCerts: trustedCerts}, nil } + + leaf, extras := selectLeafAndExtras(trustedCerts, nil) if leaf != nil { - return &ContainerContents{Leaf: leaf, Key: key, ExtraCerts: extras}, nil + return &ContainerContents{Leaf: leaf, ExtraCerts: extras}, nil } } @@ -75,9 +95,45 @@ func ParseContainerData(data []byte, passwords []string) (*ContainerContents, er return &ContainerContents{Leaf: cert}, nil } + // Try DER private key (PKCS#8, PKCS#1, SEC1). + if key, keyErr := x509.ParsePKCS8PrivateKey(data); keyErr == nil { + return &ContainerContents{Key: key}, nil + } + if key, keyErr := x509.ParsePKCS1PrivateKey(data); keyErr == nil { + return &ContainerContents{Key: key}, nil + } + if key, keyErr := x509.ParseECPrivateKey(data); keyErr == nil { + return &ContainerContents{Key: key}, nil + } + return nil, fmt.Errorf("could not parse as PEM, DER, PKCS#12, JKS, or PKCS#7") } +// selectLeafAndExtras picks a leaf certificate from certs and returns that leaf +// plus remaining certificates as extras. When a key is present, it prefers a +// certificate that matches the key to preserve JKS private-key entry pairing. +func selectLeafAndExtras(certs []*x509.Certificate, key crypto.PrivateKey) (*x509.Certificate, []*x509.Certificate) { + if len(certs) == 0 { + return nil, nil + } + + leafIdx := 0 + if key != nil { + if idx := slices.IndexFunc(certs, func(cert *x509.Certificate) bool { + ok, err := certkit.KeyMatchesCert(key, cert) + return err == nil && ok + }); idx >= 0 { + leafIdx = idx + } + } + + leaf := certs[leafIdx] + extras := make([]*x509.Certificate, 0, len(certs)-1) + extras = append(extras, certs[:leafIdx]...) + extras = append(extras, certs[leafIdx+1:]...) + return leaf, extras +} + // findPEMPrivateKey iterates over PEM blocks in data looking for a private key block. // Returns the first successfully parsed key, or nil if none found. func findPEMPrivateKey(data []byte, passwords []string) crypto.PrivateKey { diff --git a/internal/certstore/container_test.go b/internal/certstore/container_test.go index db4cbdb9..8328ad48 100644 --- a/internal/certstore/container_test.go +++ b/internal/certstore/container_test.go @@ -194,20 +194,11 @@ func TestParseContainerData_PEMKeyOnly(t *testing.T) { func TestParseContainerData_UnparseableInputs(t *testing.T) { // WHY: Data that doesn't match any container format (garbage bytes, DER - // private keys, empty JKS) must produce a clear "could not parse" error. - // DER keys are ProcessData's job, not ParseContainerData's. Empty JKS - // falls through all parsers after DecodeJKS returns no leaf. + // garbage, empty JKS) must produce a clear "could not parse" error. + // Empty JKS falls through all parsers after DecodeJKS returns no usable + // entries. t.Parallel() - key, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - t.Fatal(err) - } - pkcs8DER, err := x509.MarshalPKCS8PrivateKey(key) - if err != nil { - t.Fatal(err) - } - ks := keystore.New() var buf bytes.Buffer if err := ks.Store(&buf, []byte("changeit")); err != nil { @@ -220,9 +211,7 @@ func TestParseContainerData_UnparseableInputs(t *testing.T) { data []byte passwords []string }{ - // "garbage data" removed — exercises the same "nothing matched" fallthrough - // as "DER private key" (T-14). DER key is a more realistic input. - {"DER private key", pkcs8DER, nil}, + {"garbage data", []byte("not-a-container"), nil}, {"empty JKS", emptyJKSData, []string{"changeit"}}, } for _, tt := range tests { @@ -238,3 +227,153 @@ func TestParseContainerData_UnparseableInputs(t *testing.T) { }) } } + +func TestParseContainerData_DERPrivateKey(t *testing.T) { + // WHY: Convert and inspect use ParseContainerData; DER private keys must be + // recognized as key-only input instead of being treated as unparseable. + t.Parallel() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + pkcs8DER, err := x509.MarshalPKCS8PrivateKey(key) + if err != nil { + t.Fatal(err) + } + + contents, err := ParseContainerData(pkcs8DER, nil) + if err != nil { + t.Fatalf("ParseContainerData(DER private key): %v", err) + } + if contents.Key == nil { + t.Fatal("expected Key to be set") + } + if contents.Leaf != nil { + t.Errorf("expected nil Leaf for DER private key, got %q", contents.Leaf.Subject.CommonName) + } + if len(contents.ExtraCerts) != 0 { + t.Errorf("expected 0 extra certs, got %d", len(contents.ExtraCerts)) + } +} + +func TestParseContainerData_JKSPreservesKeyEntryChain(t *testing.T) { + // WHY: When JKS contains trusted cert entries plus private-key entries, + // ParseContainerData must select the leaf and chain from the key entry, + // not an unrelated trusted certificate. + t.Parallel() + + ca1 := newRSACA(t) + leaf1 := newRSALeaf(t, ca1, "jks-primary.example.com", []string{"jks-primary.example.com"}) + ca2 := newECDSACA(t) + + keyEntries := []certkit.JKSEntry{{ + PrivateKey: leaf1.key, + Leaf: leaf1.cert, + CACerts: []*x509.Certificate{ca1.cert}, + Alias: "primary", + }} + jksData, err := certkit.EncodeJKSEntries(keyEntries, "changeit") + if err != nil { + t.Fatalf("EncodeJKSEntries: %v", err) + } + + ks := keystore.New() + if err := ks.Load(bytes.NewReader(jksData), []byte("changeit")); err != nil { + t.Fatalf("load JKS: %v", err) + } + if err := ks.SetTrustedCertificateEntry("trusted-unrelated", keystore.TrustedCertificateEntry{ + CreationTime: time.Now(), + Certificate: keystore.Certificate{Type: "X.509", Content: ca2.cert.Raw}, + }); err != nil { + t.Fatalf("set trusted entry: %v", err) + } + var withTrusted bytes.Buffer + if err := ks.Store(&withTrusted, []byte("changeit")); err != nil { + t.Fatalf("store updated JKS: %v", err) + } + + contents, err := ParseContainerData(withTrusted.Bytes(), []string{"changeit"}) + if err != nil { + t.Fatalf("ParseContainerData(JKS): %v", err) + } + if contents.Key == nil || contents.Leaf == nil { + t.Fatal("expected key and leaf from JKS private key entry") + } + if contents.Leaf.Subject.CommonName != "jks-primary.example.com" { + t.Errorf("selected leaf CN = %q, want jks-primary.example.com", contents.Leaf.Subject.CommonName) + } + if len(contents.ExtraCerts) != 2 { + t.Fatalf("expected 2 extra certs (key-entry chain + trusted), got %d", len(contents.ExtraCerts)) + } + if contents.ExtraCerts[0].Subject.CommonName != ca1.cert.Subject.CommonName { + t.Errorf("extra cert CN = %q, want %q", contents.ExtraCerts[0].Subject.CommonName, ca1.cert.Subject.CommonName) + } + if contents.ExtraCerts[1].Subject.CommonName != ca2.cert.Subject.CommonName { + t.Errorf("extra cert CN = %q, want %q", contents.ExtraCerts[1].Subject.CommonName, ca2.cert.Subject.CommonName) + } +} + +func TestParseContainerData_JKSPrefersLaterKeyEntryWithChain(t *testing.T) { + // WHY: A JKS may contain multiple key entries where early entries have no + // chain. ParseContainerData should keep scanning and use the first key entry + // that has a usable leaf+chain. + t.Parallel() + + unusedKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + unusedKeyPKCS8, err := x509.MarshalPKCS8PrivateKey(unusedKey) + if err != nil { + t.Fatal(err) + } + + ca := newRSACA(t) + leaf := newRSALeaf(t, ca, "jks-chained.example.com", []string{"jks-chained.example.com"}) + leafKeyPKCS8, err := x509.MarshalPKCS8PrivateKey(leaf.key) + if err != nil { + t.Fatal(err) + } + + ks := keystore.New() + if err := ks.SetPrivateKeyEntry("a-no-chain", keystore.PrivateKeyEntry{ + CreationTime: time.Now(), + PrivateKey: unusedKeyPKCS8, + CertificateChain: []keystore.Certificate{}, + }, []byte("changeit")); err != nil { + t.Fatalf("set key entry without chain: %v", err) + } + if err := ks.SetPrivateKeyEntry("b-with-chain", keystore.PrivateKeyEntry{ + CreationTime: time.Now(), + PrivateKey: leafKeyPKCS8, + CertificateChain: []keystore.Certificate{ + {Type: "X.509", Content: leaf.cert.Raw}, + {Type: "X.509", Content: ca.cert.Raw}, + }, + }, []byte("changeit")); err != nil { + t.Fatalf("set key entry with chain: %v", err) + } + + var buf bytes.Buffer + if err := ks.Store(&buf, []byte("changeit")); err != nil { + t.Fatalf("store JKS: %v", err) + } + + contents, err := ParseContainerData(buf.Bytes(), []string{"changeit"}) + if err != nil { + t.Fatalf("ParseContainerData(JKS): %v", err) + } + if contents.Key == nil || contents.Leaf == nil { + t.Fatal("expected key and leaf from chained JKS entry") + } + if contents.Leaf.Subject.CommonName != "jks-chained.example.com" { + t.Errorf("selected leaf CN = %q, want jks-chained.example.com", contents.Leaf.Subject.CommonName) + } + if len(contents.ExtraCerts) != 1 { + t.Fatalf("expected 1 extra cert from chosen chain, got %d", len(contents.ExtraCerts)) + } + if contents.ExtraCerts[0].Subject.CommonName != ca.cert.Subject.CommonName { + t.Errorf("extra cert CN = %q, want %q", contents.ExtraCerts[0].Subject.CommonName, ca.cert.Subject.CommonName) + } +} diff --git a/internal/certstore/memstore.go b/internal/certstore/memstore.go index f53dbde6..02dd10cc 100644 --- a/internal/certstore/memstore.go +++ b/internal/certstore/memstore.go @@ -40,10 +40,16 @@ type KeyRecord struct { Source string // filename that contributed this key } -// certID returns the composite key for deduplication, matching the SQLite -// primary key of (serial_number, authority_key_identifier). +// certID returns the composite key for deduplication. +// +// RFC 5280 certificate identity is (issuer DN, serial number). We also include +// AKI when present to keep stable behavior for certificates that embed it while +// avoiding false dedup collisions for certificates that omit AKI. func certID(cert *x509.Certificate) string { - return cert.SerialNumber.String() + "\x00" + hex.EncodeToString(cert.AuthorityKeyId) + if len(cert.AuthorityKeyId) > 0 { + return cert.SerialNumber.String() + "\x00" + hex.EncodeToString(cert.AuthorityKeyId) + } + return cert.SerialNumber.String() + "\x00" + string(cert.RawIssuer) } // MemStore is an in-memory certificate and key store that implements @@ -64,9 +70,10 @@ func NewMemStore() *MemStore { } // HandleCertificate computes the SKI and stores the certificate. Certificates -// are deduplicated by (serial, AKI) — the same composite key the SQLite schema -// uses. Multiple certificates with the same SKI but different serials (key -// reuse across renewals) are all retained. +// are deduplicated by serial plus authority identity (AKI when present, +// otherwise raw issuer), matching RFC 5280 identity for missing-AKI +// certificates. Multiple certificates with the same SKI but different serials +// (key reuse across renewals) are all retained. func (s *MemStore) HandleCertificate(cert *x509.Certificate, source string) error { if cert == nil { return errors.New("certificate is nil") diff --git a/internal/certstore/memstore_test.go b/internal/certstore/memstore_test.go index 6fd46403..84d02cd1 100644 --- a/internal/certstore/memstore_test.go +++ b/internal/certstore/memstore_test.go @@ -82,6 +82,57 @@ func TestMemStore_HandleCertificate_DuplicateIgnored(t *testing.T) { } } +func TestMemStore_HandleCertificate_MissingAKIDUsesIssuerIdentity(t *testing.T) { + // WHY: Certificates without AKI should deduplicate by issuer+serial, not + // serial alone. Different issuers may legitimately issue the same serial. + t.Parallel() + + store := NewMemStore() + + ca1 := newRSACA(t) + ca2 := newECDSACA(t) + + newLeafNoAKI := func(t *testing.T, parent *x509.Certificate, signer any) *x509.Certificate { + t.Helper() + leafKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(9001), // same serial in both leaves + Subject: pkix.Name{CommonName: "same-serial.example.com"}, + DNSNames: []string{"same-serial.example.com"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + // Intentionally omit AuthorityKeyId to exercise missing-AKI identity logic. + } + leafDER, err := x509.CreateCertificate(rand.Reader, tmpl, parent, &leafKey.PublicKey, signer) + if err != nil { + t.Fatal(err) + } + leafCert, err := x509.ParseCertificate(leafDER) + if err != nil { + t.Fatal(err) + } + return leafCert + } + + leafFromCA1 := newLeafNoAKI(t, ca1.cert, ca1.key) + leafFromCA2 := newLeafNoAKI(t, ca2.cert, ca2.key) + + if err := store.HandleCertificate(leafFromCA1, "ca1-leaf.pem"); err != nil { + t.Fatal(err) + } + if err := store.HandleCertificate(leafFromCA2, "ca2-leaf.pem"); err != nil { + t.Fatal(err) + } + + if len(store.AllCertsFlat()) != 2 { + t.Fatalf("expected 2 distinct cert identities, got %d", len(store.AllCertsFlat())) + } +} + func TestMemStore_MatchedPairs(t *testing.T) { // WHY: MatchedPairs must only return SKIs with both a leaf cert and a key; // non-leaf certs must be excluded even if they have keys. Uses an diff --git a/internal/certstore/sqlite.go b/internal/certstore/sqlite.go index 981c26f2..f40ac15a 100644 --- a/internal/certstore/sqlite.go +++ b/internal/certstore/sqlite.go @@ -49,6 +49,13 @@ type sqliteKeyRow struct { KeyData []byte `db:"key_data"` } +func certificateIdentityAuthorityKeyIdentifier(cert *x509.Certificate) string { + if len(cert.AuthorityKeyId) > 0 { + return hex.EncodeToString(cert.AuthorityKeyId) + } + return "issuer:" + hex.EncodeToString(cert.RawIssuer) +} + // openMemDB creates an in-memory SQLite database with the certkit schema. func openMemDB() (*sqlx.DB, error) { dsn := "file::memory:?_pragma=temp_store(2)&_pragma=journal_mode(off)&_pragma=synchronous(off)" @@ -216,7 +223,7 @@ func SaveToSQLite(store *MemStore, dbPath string) error { row := sqliteCertRow{ SerialNumber: rec.Cert.SerialNumber.String(), SubjectKeyIdentifier: rec.SKI, - AuthorityKeyIdentifier: hex.EncodeToString(rec.Cert.AuthorityKeyId), + AuthorityKeyIdentifier: certificateIdentityAuthorityKeyIdentifier(rec.Cert), CertType: rec.CertType, KeyType: rec.KeyType, Expiry: rec.NotAfter, diff --git a/internal/certstore/sqlite_test.go b/internal/certstore/sqlite_test.go index fb2cdc62..68b82f82 100644 --- a/internal/certstore/sqlite_test.go +++ b/internal/certstore/sqlite_test.go @@ -3,11 +3,17 @@ package certstore import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" "encoding/hex" + "math/big" "net" "path/filepath" "strings" "testing" + "time" "github.com/sensiblebit/certkit" ) @@ -232,3 +238,61 @@ func TestSaveToSQLite_DoesNotMutateDNSNames(t *testing.T) { backing[2:]) } } + +func TestSaveToSQLite_PreservesMissingAKISerialAcrossIssuers(t *testing.T) { + // WHY: MemStore deduplicates missing-AKI certificates by issuer+serial. + // SQLite persistence must use the same identity to avoid dropping one cert + // when two issuers reuse the same serial. + t.Parallel() + + newLeafNoAKI := func(t *testing.T, parent *x509.Certificate, signer any, cn string) *x509.Certificate { + t.Helper() + leafKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(4242), + Subject: pkix.Name{CommonName: cn}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + } + leafDER, err := x509.CreateCertificate(rand.Reader, tmpl, parent, &leafKey.PublicKey, signer) + if err != nil { + t.Fatal(err) + } + leafCert, err := x509.ParseCertificate(leafDER) + if err != nil { + t.Fatal(err) + } + return leafCert + } + + ca1 := newRSACA(t) + ca2 := newECDSACA(t) + leaf1 := newLeafNoAKI(t, ca1.cert, ca1.key, "issuer-a.example.com") + leaf2 := newLeafNoAKI(t, ca2.cert, ca2.key, "issuer-b.example.com") + + store := NewMemStore() + if err := store.HandleCertificate(leaf1, "leaf1.pem"); err != nil { + t.Fatalf("store leaf1: %v", err) + } + if err := store.HandleCertificate(leaf2, "leaf2.pem"); err != nil { + t.Fatalf("store leaf2: %v", err) + } + + dbPath := filepath.Join(t.TempDir(), "missing-aki.db") + if err := SaveToSQLite(store, dbPath); err != nil { + t.Fatalf("SaveToSQLite: %v", err) + } + + loaded := NewMemStore() + if err := LoadFromSQLite(loaded, dbPath); err != nil { + t.Fatalf("LoadFromSQLite: %v", err) + } + + if got := len(loaded.AllCertsFlat()); got != 2 { + t.Fatalf("expected 2 certs after round-trip, got %d", got) + } +} diff --git a/internal/inspect.go b/internal/inspect.go index ab1806c3..e08be148 100644 --- a/internal/inspect.go +++ b/internal/inspect.go @@ -9,6 +9,7 @@ import ( "crypto/x509" "encoding/asn1" "encoding/json" + "encoding/pem" "fmt" "log/slog" "slices" @@ -94,13 +95,37 @@ func inspectPEMData(data []byte, passwords []string) []InspectResult { } // Try private key - if key, err := certkit.ParsePEMPrivateKeyWithPasswords(data, passwords); err == nil { + rest := data + for { + var block *pem.Block + block, rest = pem.Decode(rest) + if block == nil { + break + } + if !isPrivateKeyPEMBlockType(block.Type) { + continue + } + + key, err := certkit.ParsePEMPrivateKeyWithPasswords(pem.EncodeToMemory(block), passwords) + if err != nil { + slog.Debug("skipping malformed private key PEM block during inspect", "block_type", block.Type, "error", err) + continue + } results = append(results, inspectKey(key)) } return results } +func isPrivateKeyPEMBlockType(blockType string) bool { + switch blockType { + case "RSA PRIVATE KEY", "EC PRIVATE KEY", "PRIVATE KEY", "ENCRYPTED PRIVATE KEY", "OPENSSH PRIVATE KEY": + return true + default: + return false + } +} + func inspectDERData(data []byte, passwords []string) []InspectResult { var results []InspectResult @@ -122,6 +147,18 @@ func inspectDERData(data []byte, passwords []string) []InspectResult { return results } + // Try PKCS#1 RSA + if key, err := x509.ParsePKCS1PrivateKey(data); err == nil { + results = append(results, inspectKey(key)) + return results + } + + // Try SEC1 EC + if key, err := x509.ParseECPrivateKey(data); err == nil { + results = append(results, inspectKey(key)) + return results + } + // Try PKCS#7 if certs, err := certkit.DecodePKCS7(data); err == nil { for _, cert := range certs { diff --git a/internal/inspect_test.go b/internal/inspect_test.go index 2690cbfa..ef47b077 100644 --- a/internal/inspect_test.go +++ b/internal/inspect_test.go @@ -345,9 +345,7 @@ func TestInspectFile_MultiplePEMObjects(t *testing.T) { ca := newRSACA(t) leaf := newRSALeaf(t, ca, "multi-pem.example.com", []string{"multi-pem.example.com"}, nil) - // Put the key PEM BEFORE the cert PEM so ParsePEMPrivateKey finds the key - // (it only parses the first PEM block). ParsePEMCertificates iterates all blocks - // and will find the cert block. + // Put the key PEM BEFORE the cert PEM to verify mixed ordering is handled. combined := slices.Concat(leaf.keyPEM, leaf.certPEM) dir := t.TempDir() @@ -384,6 +382,109 @@ func TestInspectFile_MultiplePEMObjects(t *testing.T) { } } +func TestInspectFile_PEMMalformedCertAndValidCert(t *testing.T) { + // WHY: Inspect should retain valid certificates even when the same PEM file + // also contains malformed CERTIFICATE blocks. + t.Parallel() + + ca := newRSACA(t) + leaf := newRSALeaf(t, ca, "inspect-malformed.example.com", []string{"inspect-malformed.example.com"}, nil) + + combined := slices.Concat( + pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: []byte("bad-der")}), + leaf.certPEM, + ) + + dir := t.TempDir() + path := filepath.Join(dir, "mixed-malformed.pem") + if err := os.WriteFile(path, combined, 0644); err != nil { + t.Fatal(err) + } + + results, err := InspectFile(path, nil) + if err != nil { + t.Fatalf("InspectFile: %v", err) + } + foundLeaf := false + for _, r := range results { + if r.Type == "certificate" && strings.Contains(r.Subject, "inspect-malformed.example.com") { + foundLeaf = true + } + } + if !foundLeaf { + t.Error("expected valid certificate to be present despite malformed CERTIFICATE block") + } +} + +func TestInspectFile_PEMMalformedKeyAndValidKey(t *testing.T) { + // WHY: Inspect should retain valid private keys even when the same PEM file + // also contains malformed private key PEM blocks. + t.Parallel() + + combined := slices.Concat( + pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: []byte("bad-der")}), + rsaKeyPEM(t), + ) + + dir := t.TempDir() + path := filepath.Join(dir, "mixed-malformed-key.pem") + if err := os.WriteFile(path, combined, 0600); err != nil { + t.Fatal(err) + } + + results, err := InspectFile(path, nil) + if err != nil { + t.Fatalf("InspectFile: %v", err) + } + + var keyCount int + for _, r := range results { + if r.Type == "private_key" { + keyCount++ + if r.KeyType != "RSA" { + t.Errorf("KeyType = %q, want RSA", r.KeyType) + } + } + } + if keyCount != 1 { + t.Fatalf("expected exactly 1 private_key result, got %d", keyCount) + } +} + +func TestInspectFile_DERRSAPrivateKey(t *testing.T) { + // WHY: Inspect DER parsing should recognize PKCS#1 RSA keys, not only PKCS#8. + t.Parallel() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + der := x509.MarshalPKCS1PrivateKey(key) + + dir := t.TempDir() + path := filepath.Join(dir, "key.der") + if err := os.WriteFile(path, der, 0600); err != nil { + t.Fatal(err) + } + + results, err := InspectFile(path, nil) + if err != nil { + t.Fatalf("InspectFile: %v", err) + } + foundKey := false + for _, r := range results { + if r.Type == "private_key" { + foundKey = true + if r.KeyType != "RSA" { + t.Errorf("KeyType = %q, want RSA", r.KeyType) + } + } + } + if !foundKey { + t.Fatal("expected private_key result for PKCS#1 DER input") + } +} + func TestInspectFile_GarbageData(t *testing.T) { // WHY: Garbage data must produce a descriptive "no certificates, keys, or CSRs found" error, not a cryptic parsing failure or panic. t.Parallel() diff --git a/jks.go b/jks.go index ae129d19..2dae16d2 100644 --- a/jks.go +++ b/jks.go @@ -6,22 +6,27 @@ import ( "crypto/x509" "errors" "fmt" + "log/slog" "strings" "time" "github.com/pavlo-v-chernykh/keystore-go/v4" ) -// DecodeJKS decodes a Java KeyStore (JKS) and returns the certificates and -// private keys it contains. Passwords are tried in order to open the store. -// For private key entries, all passwords are tried independently since the -// key password may differ from the store password. -// -// TrustedCertificateEntry entries yield certificates. PrivateKeyEntry entries -// yield PKCS#8 private keys and their certificate chains. Individual entry -// errors are skipped; an error is returned only if the store cannot be loaded -// or no usable entries are found. -func DecodeJKS(data []byte, passwords []string) ([]*x509.Certificate, []crypto.PrivateKey, error) { +// DecodedJKSKeyEntry represents one decoded JKS PrivateKeyEntry with its alias +// and certificate chain. +type DecodedJKSKeyEntry struct { + Alias string + Key crypto.PrivateKey + Chain []*x509.Certificate +} + +// DecodeJKSKeyEntries decodes a Java KeyStore (JKS) and returns decoded +// private key entries (with alias + chain) and trusted-certificate entries. +// Passwords are tried in order to open the store, and each private key entry is +// attempted with all provided passwords to support different store/key +// passwords. +func DecodeJKSKeyEntries(data []byte, passwords []string) ([]DecodedJKSKeyEntry, []*x509.Certificate, error) { ks := keystore.New() var loaded bool @@ -35,53 +40,91 @@ func DecodeJKS(data []byte, passwords []string) ([]*x509.Certificate, []crypto.P return nil, nil, fmt.Errorf("loading JKS: none of the provided passwords worked") } - var certs []*x509.Certificate - var keys []crypto.PrivateKey + var keyEntries []DecodedJKSKeyEntry + var trustedCerts []*x509.Certificate for _, alias := range ks.Aliases() { if ks.IsTrustedCertificateEntry(alias) { entry, err := ks.GetTrustedCertificateEntry(alias) if err != nil { + slog.Debug("skipping unreadable JKS trusted certificate entry", "alias", alias, "error", err) continue } cert, err := x509.ParseCertificate(entry.Certificate.Content) if err != nil { + slog.Debug("skipping malformed JKS trusted certificate entry", "alias", alias, "error", err) continue } - certs = append(certs, cert) + trustedCerts = append(trustedCerts, cert) } - if ks.IsPrivateKeyEntry(alias) { - for _, pw := range passwords { - entry, err := ks.GetPrivateKeyEntry(alias, []byte(pw)) - if err != nil { - continue - } + if !ks.IsPrivateKeyEntry(alias) { + slog.Debug("skipping non-private-key JKS entry", "alias", alias) + continue + } + + for _, pw := range passwords { + entry, err := ks.GetPrivateKeyEntry(alias, []byte(pw)) + if err != nil { + slog.Debug("skipping JKS private key entry password attempt", "alias", alias, "error", err) + continue + } + + key, err := x509.ParsePKCS8PrivateKey(entry.PrivateKey) + if err != nil { + slog.Debug("skipping JKS private key entry with bad key data", "alias", alias, "error", err) + break // key data is bad, no point trying other passwords + } - // Parse the PKCS#8 private key - key, err := x509.ParsePKCS8PrivateKey(entry.PrivateKey) + var chain []*x509.Certificate + for _, certEntry := range entry.CertificateChain { + cert, err := x509.ParseCertificate(certEntry.Content) if err != nil { - break // key data is bad, no point trying other passwords - } - keys = append(keys, normalizeKey(key)) - - // Parse the certificate chain - for _, certEntry := range entry.CertificateChain { - cert, err := x509.ParseCertificate(certEntry.Content) - if err != nil { - continue - } - certs = append(certs, cert) + slog.Debug("skipping malformed certificate in JKS private key chain", "alias", alias, "error", err) + continue } - break + chain = append(chain, cert) } + + keyEntries = append(keyEntries, DecodedJKSKeyEntry{ + Alias: alias, + Key: normalizeKey(key), + Chain: chain, + }) + break } } - if len(certs) == 0 && len(keys) == 0 { + if len(trustedCerts) == 0 && len(keyEntries) == 0 { return nil, nil, errors.New("JKS contains no usable certificates or keys") } + return keyEntries, trustedCerts, nil +} + +// DecodeJKS decodes a Java KeyStore (JKS) and returns the certificates and +// private keys it contains. Passwords are tried in order to open the store. +// For private key entries, all passwords are tried independently since the +// key password may differ from the store password. +// +// TrustedCertificateEntry entries yield certificates. PrivateKeyEntry entries +// yield PKCS#8 private keys and their certificate chains. Individual entry +// errors are skipped; an error is returned only if the store cannot be loaded +// or no usable entries are found. +func DecodeJKS(data []byte, passwords []string) ([]*x509.Certificate, []crypto.PrivateKey, error) { + keyEntries, trustedCerts, err := DecodeJKSKeyEntries(data, passwords) + if err != nil { + return nil, nil, err + } + + var certs []*x509.Certificate + var keys []crypto.PrivateKey + certs = append(certs, trustedCerts...) + for _, entry := range keyEntries { + keys = append(keys, entry.Key) + certs = append(certs, entry.Chain...) + } + return certs, keys, nil } diff --git a/jks_test.go b/jks_test.go index 30ae9538..2ba8a42a 100644 --- a/jks_test.go +++ b/jks_test.go @@ -649,6 +649,106 @@ func TestDecodeJKS_MultiplePrivateKeyEntries(t *testing.T) { } } +func TestDecodeJKSKeyEntries_PreservesAliasChainPairing(t *testing.T) { + // WHY: Key-entry decoding must preserve alias -> key -> cert-chain pairing + // so callers can select a coherent leaf+chain for each private key entry. + t.Parallel() + + password := "changeit" + serverKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + clientKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + newSelfSigned := func(t *testing.T, cn string, pub any, signer any) []byte { + t.Helper() + tmpl := &x509.Certificate{ + SerialNumber: randomSerial(t), + Subject: pkix.Name{CommonName: cn}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + } + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, pub, signer) + if err != nil { + t.Fatal(err) + } + return der + } + + serverCertDER := newSelfSigned(t, "server-entry", &serverKey.PublicKey, serverKey) + clientCertDER := newSelfSigned(t, "client-entry", &clientKey.PublicKey, clientKey) + + serverPKCS8, err := x509.MarshalPKCS8PrivateKey(serverKey) + if err != nil { + t.Fatal(err) + } + clientPKCS8, err := x509.MarshalPKCS8PrivateKey(clientKey) + if err != nil { + t.Fatal(err) + } + + ks := keystore.New() + if err := ks.SetPrivateKeyEntry("server", keystore.PrivateKeyEntry{ + CreationTime: time.Now(), + PrivateKey: serverPKCS8, + CertificateChain: []keystore.Certificate{{Type: "X.509", Content: serverCertDER}}, + }, []byte(password)); err != nil { + t.Fatal(err) + } + if err := ks.SetPrivateKeyEntry("client", keystore.PrivateKeyEntry{ + CreationTime: time.Now(), + PrivateKey: clientPKCS8, + CertificateChain: []keystore.Certificate{{Type: "X.509", Content: clientCertDER}}, + }, []byte(password)); err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + if err := ks.Store(&buf, []byte(password)); err != nil { + t.Fatal(err) + } + + entries, trusted, err := DecodeJKSKeyEntries(buf.Bytes(), []string{password}) + if err != nil { + t.Fatalf("DecodeJKSKeyEntries: %v", err) + } + if len(trusted) != 0 { + t.Fatalf("expected 0 trusted cert entries, got %d", len(trusted)) + } + if len(entries) != 2 { + t.Fatalf("expected 2 key entries, got %d", len(entries)) + } + + byAlias := make(map[string]DecodedJKSKeyEntry, len(entries)) + for _, entry := range entries { + byAlias[entry.Alias] = entry + } + serverEntry, ok := byAlias["server"] + if !ok { + t.Fatal("missing server alias") + } + if len(serverEntry.Chain) != 1 || serverEntry.Chain[0].Subject.CommonName != "server-entry" { + t.Fatalf("server entry chain mismatch: %+v", serverEntry.Chain) + } + if !serverKey.Equal(serverEntry.Key) { + t.Error("server key mismatch") + } + + clientEntry, ok := byAlias["client"] + if !ok { + t.Fatal("missing client alias") + } + if len(clientEntry.Chain) != 1 || clientEntry.Chain[0].Subject.CommonName != "client-entry" { + t.Fatalf("client entry chain mismatch: %+v", clientEntry.Chain) + } + if !clientKey.Equal(clientEntry.Key) { + t.Error("client key mismatch") + } +} + func TestDecodeJKS_PrivateKeyEntry_EmptyCertChain(t *testing.T) { // WHY: A JKS PrivateKeyEntry with a valid key but an empty certificate // chain should still return the key. The cert chain is optional in the diff --git a/sign.go b/sign.go index 77fe2a43..bf31fe0e 100644 --- a/sign.go +++ b/sign.go @@ -5,11 +5,15 @@ import ( "crypto/rand" "crypto/x509" "crypto/x509/pkix" + "errors" "fmt" "math/big" "time" ) +// ErrCAKeyMismatch indicates the signing CA private key does not match the CA certificate. +var ErrCAKeyMismatch = errors.New("CA key does not match CA certificate") + // SelfSignedInput contains parameters for self-signed certificate generation. type SelfSignedInput struct { // Signer is the private key used to sign the certificate. @@ -95,6 +99,13 @@ func SignCSR(input SignCSRInput) (*x509.Certificate, error) { if input.CAKey == nil { return nil, fmt.Errorf("signing CSR: CA key is required") } + caKeyMatches, err := KeyMatchesCert(input.CAKey, input.CACert) + if err != nil { + return nil, fmt.Errorf("validating CA certificate and key: %w", err) + } + if !caKeyMatches { + return nil, fmt.Errorf("validating CA certificate and key: %w", ErrCAKeyMismatch) + } if err := input.CSR.CheckSignature(); err != nil { return nil, fmt.Errorf("verifying CSR signature: %w", err) diff --git a/sign_test.go b/sign_test.go index 8b2f38d3..99e11635 100644 --- a/sign_test.go +++ b/sign_test.go @@ -6,6 +6,7 @@ import ( "crypto/rand" "crypto/x509" "crypto/x509/pkix" + "errors" "net" "testing" "time" @@ -324,3 +325,57 @@ func TestSignCSR_ChainVerifies(t *testing.T) { t.Fatalf("chain verification failed: %v", err) } } + +func TestSignCSR_CACertKeyMismatch(t *testing.T) { + // WHY: Signing must fail fast when the CA private key does not match the + // CA certificate to prevent issuing certs under the wrong identity. + t.Parallel() + + caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + ca, err := CreateSelfSigned(SelfSignedInput{ + Signer: caKey, + Subject: pkix.Name{CommonName: "Mismatch CA"}, + Days: 365, + IsCA: true, + }) + if err != nil { + t.Fatal(err) + } + + wrongKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + csrKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + csrDER, err := x509.CreateCertificateRequest(rand.Reader, &x509.CertificateRequest{ + Subject: pkix.Name{CommonName: "leaf.example.com"}, + }, csrKey) + if err != nil { + t.Fatal(err) + } + csr, err := x509.ParseCertificateRequest(csrDER) + if err != nil { + t.Fatal(err) + } + + _, err = SignCSR(SignCSRInput{ + CSR: csr, + CACert: ca, + CAKey: wrongKey, + Days: 30, + CopySANs: true, + }) + if err == nil { + t.Fatal("expected error for CA cert/key mismatch") + } + if !errors.Is(err, ErrCAKeyMismatch) { + t.Errorf("unexpected error: %v", err) + } +} From 0790bc7f0de40f18c007a06b5db7cb20f1b1c15b Mon Sep 17 00:00:00 2001 From: Daniel Wood Date: Sun, 1 Mar 2026 19:54:04 -0500 Subject: [PATCH 3/8] fix(scan): harden scan boundaries and export text summary (#106) * fix(scan): keep traversal bounded and restore export summaries * fix(scan): fail fast on walker processing errors * fix(scan): use typed max-size errors in read paths * fix(scan): reject invalid export formats consistently * fix(scan): keep export destination off stdout --- CHANGELOG.md | 9 ++ cmd/certkit/scan.go | 146 +++++++++++++++---------------- internal/format.go | 34 ++++++++ internal/format_test.go | 56 +++++++++++- internal/io.go | 15 +++- internal/io_test.go | 99 +++++++++++---------- internal/scanwalk.go | 142 ++++++++++++++++++++++++++++++ internal/scanwalk_test.go | 164 +++++++++++++++++++++++++++++++++++ internal/testhelpers_test.go | 9 ++ 9 files changed, 549 insertions(+), 125 deletions(-) create mode 100644 internal/scanwalk.go create mode 100644 internal/scanwalk_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index e026686b..e17fbd7b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -99,6 +99,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fix `scan` directory traversal boundaries and resilience: symlinks that point outside the scan root are skipped, max-file-size checks now apply to symlink targets and archive reads, and transient walk errors no longer prune unrelated files during traversal ([#91]) +- Fix `scan --bundle-path` text/default output to print a useful post-export summary (certificate/key counts plus export path) while keeping JSON export output unchanged ([#95]) +- Fix `scan` to fail fast when per-file processing or size-check `stat` calls fail during directory traversal, instead of logging and silently continuing ([#106]) +- Fix max-size read failures to return stable wrapped error causes instead of requiring string matching in internal scan/read paths ([#106]) +- Fix `scan --bundle-path` to reject unsupported `--format` values with the same validation used by non-export scans (for example, `--format yaml` now errors instead of silently falling back to text) ([#106]) +- Fix `scan --bundle-path` text output to keep scan summary on stdout while sending the export destination path to stderr per CLI output conventions ([#106]) - Fix inspect/convert/container parsing to continue past malformed PEM certificate blocks so valid certificates are still processed, and add DER private-key detection for key-only inputs ([#107]) - Fix inspect PEM key parsing to continue past malformed private-key blocks so valid keys in the same bundle are still reported ([#107]) - Fix JKS container selection to keep private key entries paired with their own leaf certificate and chain instead of selecting unrelated trusted entries ([#107]) @@ -956,6 +962,9 @@ Initial release. [#85]: https://github.com/sensiblebit/certkit/pull/85 [#86]: https://github.com/sensiblebit/certkit/pull/86 [#87]: https://github.com/sensiblebit/certkit/pull/87 +[#91]: https://github.com/sensiblebit/certkit/pull/91 +[#95]: https://github.com/sensiblebit/certkit/pull/95 +[#106]: https://github.com/sensiblebit/certkit/pull/106 [#107]: https://github.com/sensiblebit/certkit/pull/107 [#105]: https://github.com/sensiblebit/certkit/pull/105 [#73]: https://github.com/sensiblebit/certkit/pull/73 diff --git a/cmd/certkit/scan.go b/cmd/certkit/scan.go index bd9058e7..e0d213ab 100644 --- a/cmd/certkit/scan.go +++ b/cmd/certkit/scan.go @@ -7,11 +7,9 @@ import ( "encoding/pem" "fmt" "io" - "io/fs" "log/slog" "net/http" "os" - "path/filepath" "strings" "time" @@ -106,68 +104,41 @@ func runScan(cmd *cobra.Command, args []string) error { return fmt.Errorf("input path %s: %w", inputPath, err) } - err := filepath.WalkDir(inputPath, func(path string, d fs.DirEntry, err error) error { - if err != nil { - slog.Warn("skipping inaccessible path", "path", path, "error", err) - return filepath.SkipDir - } - if d.IsDir() { - if internal.IsSkippableDir(d.Name()) { - slog.Debug("skipping directory", "path", path) - return filepath.SkipDir - } - return nil - } - // Resolve symlinks: skip broken links and links to directories - if d.Type()&fs.ModeSymlink != 0 { - fi, err := os.Stat(path) - if err != nil { - slog.Debug("skipping broken symlink", "path", path) - return nil - } - if fi.IsDir() { - slog.Debug("skipping symlink to directory", "path", path) - return nil - } - } - if scanMaxFileSize > 0 { - if info, err := d.Info(); err == nil && info.Size() > scanMaxFileSize { - slog.Debug("skipping large file", "path", path, "size", info.Size(), "max", scanMaxFileSize) - return nil - } - } - // Check for archive formats before falling through to ProcessFile - if archiveFormat := internal.ArchiveFormat(path); archiveFormat != "" { - data, readErr := os.ReadFile(path) - if readErr != nil { - slog.Warn("reading archive", "path", path, "error", readErr) + err := internal.WalkScanFiles(internal.WalkScanFilesInput{ + RootPath: inputPath, + MaxFileSize: scanMaxFileSize, + OnFile: func(path string) error { + if archiveFormat := internal.ArchiveFormat(path); archiveFormat != "" { + data, readErr := internal.ReadFileLimited(path, scanMaxFileSize) + if readErr != nil { + return fmt.Errorf("reading archive %s: %w", path, readErr) + } + limits := internal.DefaultArchiveLimits() + if scanMaxFileSize > 0 { + limits.MaxEntrySize = scanMaxFileSize + } + if _, archiveErr := internal.ProcessArchive(internal.ProcessArchiveInput{ + ArchivePath: path, + Data: data, + Format: archiveFormat, + Limits: limits, + Store: store, + Passwords: passwords, + }); archiveErr != nil { + return fmt.Errorf("processing archive %s: %w", path, archiveErr) + } return nil } - limits := internal.DefaultArchiveLimits() - if scanMaxFileSize > 0 { - limits.MaxEntrySize = scanMaxFileSize - } - if _, archiveErr := internal.ProcessArchive(internal.ProcessArchiveInput{ - ArchivePath: path, - Data: data, - Format: archiveFormat, - Limits: limits, - Store: store, - Passwords: passwords, - }); archiveErr != nil { - slog.Warn("processing archive", "path", path, "error", archiveErr) + if err := internal.ProcessFile(internal.ProcessFileInput{ + Path: path, + Store: store, + Passwords: passwords, + MaxBytes: scanMaxFileSize, + }); err != nil { + return fmt.Errorf("processing file %s: %w", path, err) } return nil - } - if err := internal.ProcessFile(internal.ProcessFileInput{ - Path: path, - Store: store, - Passwords: passwords, - MaxBytes: scanMaxFileSize, - }); err != nil { - slog.Warn("processing file", "path", path, "error", err) - } - return nil + }, }) if err != nil { return fmt.Errorf("walking input path: %w", err) @@ -287,7 +258,8 @@ func runScan(cmd *cobra.Command, args []string) error { return fmt.Errorf("exporting bundles: %w", err) } store.DumpDebug() - if format == "json" { + switch format { + case "json": mozillaPool, err := certkit.MozillaRootPool() if err != nil { return fmt.Errorf("loading Mozilla root pool: %w", err) @@ -304,6 +276,32 @@ func runScan(cmd *cobra.Command, args []string) error { return fmt.Errorf("marshaling JSON: %w", err) } fmt.Println(string(data)) + case "text": + mozillaPool, err := certkit.MozillaRootPool() + if err != nil { + return fmt.Errorf("loading Mozilla root pool: %w", err) + } + summary := store.ScanSummary(certstore.ScanSummaryInput{ + RootPool: mozillaPool, + }) + fmt.Print(internal.FormatScanTextSummary(internal.ScanTextSummaryInput{ + Roots: summary.Roots, + Intermediates: summary.Intermediates, + Leaves: summary.Leaves, + Keys: summary.Keys, + Matched: summary.Matched, + ExpiredRoots: summary.ExpiredRoots, + ExpiredIntermediates: summary.ExpiredIntermediates, + ExpiredLeaves: summary.ExpiredLeaves, + UntrustedRoots: summary.UntrustedRoots, + UntrustedIntermediates: summary.UntrustedIntermediates, + UntrustedLeaves: summary.UntrustedLeaves, + })) + if _, err := fmt.Fprintf(os.Stderr, "Exported bundles to %s\n", scanBundlePath); err != nil { + return fmt.Errorf("writing export status: %w", err) + } + default: + return fmt.Errorf("unsupported output format %q (use text or json)", format) } } else { // Print summary with trust and expiry annotations @@ -335,19 +333,19 @@ func runScan(cmd *cobra.Command, args []string) error { fmt.Println(string(data)) } case "text": - total := summary.Roots + summary.Intermediates + summary.Leaves - fmt.Printf("\nFound %d certificate(s) and %d key(s)\n", total, summary.Keys) - if total > 0 { - fmt.Printf(" Roots: %d%s\n", summary.Roots, - internal.CertAnnotation(summary.ExpiredRoots, summary.UntrustedRoots)) - fmt.Printf(" Intermediates: %d%s\n", summary.Intermediates, - internal.CertAnnotation(summary.ExpiredIntermediates, summary.UntrustedIntermediates)) - fmt.Printf(" Leaves: %d%s\n", summary.Leaves, - internal.CertAnnotation(summary.ExpiredLeaves, summary.UntrustedLeaves)) - } - if summary.Keys > 0 { - fmt.Printf(" Key-cert pairs: %d\n", summary.Matched) - } + fmt.Print(internal.FormatScanTextSummary(internal.ScanTextSummaryInput{ + Roots: summary.Roots, + Intermediates: summary.Intermediates, + Leaves: summary.Leaves, + Keys: summary.Keys, + Matched: summary.Matched, + ExpiredRoots: summary.ExpiredRoots, + ExpiredIntermediates: summary.ExpiredIntermediates, + ExpiredLeaves: summary.ExpiredLeaves, + UntrustedRoots: summary.UntrustedRoots, + UntrustedIntermediates: summary.UntrustedIntermediates, + UntrustedLeaves: summary.UntrustedLeaves, + })) if verbose { printScanVerboseText(store) } diff --git a/internal/format.go b/internal/format.go index d0b9f3a8..276a5ab9 100644 --- a/internal/format.go +++ b/internal/format.go @@ -5,6 +5,21 @@ import ( "strings" ) +// ScanTextSummaryInput holds fields needed for text scan summaries. +type ScanTextSummaryInput struct { + Roots int + Intermediates int + Leaves int + Keys int + Matched int + ExpiredRoots int + ExpiredIntermediates int + ExpiredLeaves int + UntrustedRoots int + UntrustedIntermediates int + UntrustedLeaves int +} + // CertAnnotation returns a parenthetical annotation like " (2 expired, 1 untrusted)" // for non-zero counts, or an empty string if both are zero. func CertAnnotation(expired, untrusted int) string { @@ -20,3 +35,22 @@ func CertAnnotation(expired, untrusted int) string { } return " (" + strings.Join(parts, ", ") + ")" } + +// FormatScanTextSummary renders the user-facing scan summary for text output. +func FormatScanTextSummary(input ScanTextSummaryInput) string { + total := input.Roots + input.Intermediates + input.Leaves + var out strings.Builder + _, _ = fmt.Fprintf(&out, "\nFound %d certificate(s) and %d key(s)\n", total, input.Keys) + if total > 0 { + _, _ = fmt.Fprintf(&out, " Roots: %d%s\n", input.Roots, + CertAnnotation(input.ExpiredRoots, input.UntrustedRoots)) + _, _ = fmt.Fprintf(&out, " Intermediates: %d%s\n", input.Intermediates, + CertAnnotation(input.ExpiredIntermediates, input.UntrustedIntermediates)) + _, _ = fmt.Fprintf(&out, " Leaves: %d%s\n", input.Leaves, + CertAnnotation(input.ExpiredLeaves, input.UntrustedLeaves)) + } + if input.Keys > 0 { + _, _ = fmt.Fprintf(&out, " Key-cert pairs: %d\n", input.Matched) + } + return out.String() +} diff --git a/internal/format_test.go b/internal/format_test.go index 32f18e1e..ef5fb847 100644 --- a/internal/format_test.go +++ b/internal/format_test.go @@ -1,6 +1,9 @@ package internal -import "testing" +import ( + "strings" + "testing" +) func TestCertAnnotation(t *testing.T) { // WHY: CertAnnotation formats the parenthetical trust/expiry annotations @@ -28,3 +31,54 @@ func TestCertAnnotation(t *testing.T) { }) } } + +func TestFormatScanTextSummary(t *testing.T) { + // WHY: Scan text output must include counts and trust/expiry annotations + // without mixing side-effect status messages into stdout. + t.Parallel() + + tests := []struct { + name string + input ScanTextSummaryInput + contains []string + }{ + { + name: "summary without export", + input: ScanTextSummaryInput{ + Roots: 1, + Intermediates: 2, + Leaves: 3, + Keys: 4, + Matched: 3, + ExpiredRoots: 1, + UntrustedIntermediates: 2, + }, + contains: []string{ + "Found 6 certificate(s) and 4 key(s)", + "Roots: 1 (1 expired)", + "Intermediates: 2 (2 untrusted)", + "Leaves: 3", + "Key-cert pairs: 3", + }, + }, + { + name: "summary with zero counts", + input: ScanTextSummaryInput{}, + contains: []string{ + "Found 0 certificate(s) and 0 key(s)", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := FormatScanTextSummary(tt.input) + for _, want := range tt.contains { + if !strings.Contains(got, want) { + t.Fatalf("FormatScanTextSummary() missing %q in %q", want, got) + } + } + }) + } +} diff --git a/internal/io.go b/internal/io.go index 45815e45..939158a7 100644 --- a/internal/io.go +++ b/internal/io.go @@ -1,6 +1,7 @@ package internal import ( + "errors" "fmt" "io" "log/slog" @@ -10,6 +11,11 @@ import ( const defaultMaxInputBytes int64 = 10 * 1024 * 1024 +var ( + errInputExceedsMaxSize = errors.New("input exceeds max size") + errFileExceedsMaxSize = errors.New("file exceeds max size") +) + // readAllLimited reads from r with an optional hard byte limit. // When maxBytes <= 0, no limit is applied. // The maxBytes+1 sentinel detects oversized input without truncating silently. @@ -27,7 +33,7 @@ func readAllLimited(r io.Reader, maxBytes int64) ([]byte, error) { return nil, fmt.Errorf("reading input: %w", err) } if int64(len(data)) > maxBytes { - return nil, fmt.Errorf("input exceeds max size (%d bytes)", maxBytes) + return nil, fmt.Errorf("%w (%d bytes)", errInputExceedsMaxSize, maxBytes) } return data, nil } @@ -41,7 +47,7 @@ func readFileLimited(path string, maxBytes int64) ([]byte, error) { return nil, fmt.Errorf("stat %s: %w", path, err) } if info.Size() > maxBytes { - return nil, fmt.Errorf("file exceeds max size (%d bytes)", maxBytes) + return nil, fmt.Errorf("%w (%d bytes)", errFileExceedsMaxSize, maxBytes) } } @@ -61,3 +67,8 @@ func readFileLimited(path string, maxBytes int64) ([]byte, error) { } return data, nil } + +// ReadFileLimited reads a file with an optional hard byte limit. +func ReadFileLimited(path string, maxBytes int64) ([]byte, error) { + return readFileLimited(path, maxBytes) +} diff --git a/internal/io_test.go b/internal/io_test.go index 077dfe76..e2ca6487 100644 --- a/internal/io_test.go +++ b/internal/io_test.go @@ -1,71 +1,74 @@ package internal import ( - "bytes" + "errors" "os" "path/filepath" - "strings" "testing" ) -func TestReadAllLimited(t *testing.T) { - // WHY: readAllLimited enforces input-size limits that protect ingest paths - // from unbounded memory growth. Verify exact-limit, over-limit, and no-limit paths. +func TestReadFileLimited(t *testing.T) { + // WHY: size limits must be enforced consistently for direct files and symlinks, + // and disabling limits must still read full content via the exported API. t.Parallel() tests := []struct { - name string - data []byte - maxBytes int64 - wantErr string + name string + path func(t *testing.T) string }{ - {name: "exact limit", data: []byte("abcd"), maxBytes: 4}, - {name: "over limit", data: []byte("abcde"), maxBytes: 4, wantErr: "input exceeds max size"}, - {name: "no limit", data: bytes.Repeat([]byte("x"), 1024), maxBytes: 0}, + { + name: "direct file", + path: func(t *testing.T) string { + t.Helper() + dir := t.TempDir() + file := filepath.Join(dir, "input.bin") + if err := os.WriteFile(file, []byte("abcde"), 0644); err != nil { + t.Fatalf("write file: %v", err) + } + return file + }, + }, + { + name: "symlink target", + path: func(t *testing.T) string { + t.Helper() + dir := t.TempDir() + target := filepath.Join(dir, "target.bin") + if err := os.WriteFile(target, []byte("abcde"), 0644); err != nil { + t.Fatalf("write target file: %v", err) + } + link := filepath.Join(dir, "target-link.bin") + createSymlinkOrSkip(t, target, link) + return link + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - out, err := readAllLimited(bytes.NewReader(tt.data), tt.maxBytes) - if tt.wantErr == "" { - if err != nil { - t.Fatalf("readAllLimited error: %v", err) - } - if !bytes.Equal(out, tt.data) { - t.Fatalf("readAllLimited output mismatch") - } - return - } - if err == nil { - t.Fatalf("expected error containing %q", tt.wantErr) - } - if !strings.Contains(err.Error(), tt.wantErr) { - t.Fatalf("error = %q, want substring %q", err.Error(), tt.wantErr) - } - }) - } -} -func TestReadFileLimited(t *testing.T) { - // WHY: readFileLimited performs stat pre-check + bounded read; both must enforce limits. - t.Parallel() + path := tt.path(t) - dir := t.TempDir() - file := filepath.Join(dir, "input.bin") - if err := os.WriteFile(file, []byte("abcde"), 0644); err != nil { - t.Fatalf("write file: %v", err) - } + if _, err := ReadFileLimited(path, 4); !errors.Is(err, errFileExceedsMaxSize) { + t.Fatalf("error = %v, want errors.Is(_, %v)", err, errFileExceedsMaxSize) + } - if _, err := readFileLimited(file, 4); err == nil || !strings.Contains(err.Error(), "file exceeds max size") { - t.Fatalf("expected size error, got %v", err) - } + data, err := ReadFileLimited(path, 5) + if err != nil { + t.Fatalf("ReadFileLimited error: %v", err) + } + if string(data) != "abcde" { + t.Fatalf("data = %q, want %q", string(data), "abcde") + } - data, err := readFileLimited(file, 5) - if err != nil { - t.Fatalf("readFileLimited error: %v", err) - } - if string(data) != "abcde" { - t.Fatalf("data = %q, want %q", string(data), "abcde") + data, err = ReadFileLimited(path, 0) + if err != nil { + t.Fatalf("ReadFileLimited no-limit error: %v", err) + } + if string(data) != "abcde" { + t.Fatalf("data = %q, want %q", string(data), "abcde") + } + }) } } diff --git a/internal/scanwalk.go b/internal/scanwalk.go new file mode 100644 index 00000000..253d2a10 --- /dev/null +++ b/internal/scanwalk.go @@ -0,0 +1,142 @@ +package internal + +import ( + "fmt" + "io/fs" + "log/slog" + "os" + "path/filepath" + "strings" +) + +// WalkScanFilesInput configures WalkScanFiles. +type WalkScanFilesInput struct { + RootPath string + MaxFileSize int64 + OnFile func(path string) error +} + +// WalkScanFiles iterates scan-eligible files under RootPath. +func WalkScanFiles(input WalkScanFilesInput) error { + if input.RootPath == "" { + return fmt.Errorf("root path is required") + } + if input.OnFile == nil { + return fmt.Errorf("file handler is required") + } + + info, err := os.Stat(input.RootPath) + if err != nil { + return fmt.Errorf("stat %s: %w", input.RootPath, err) + } + if !info.IsDir() { + exceedsLimit, err := exceedsSizeLimit(input.RootPath, input.MaxFileSize) + if err != nil { + return fmt.Errorf("checking file size %s: %w", input.RootPath, err) + } + if exceedsLimit { + return nil + } + if err := input.OnFile(input.RootPath); err != nil { + return fmt.Errorf("handling file %s: %w", input.RootPath, err) + } + return nil + } + + rootBoundary, err := scanRootBoundary(input.RootPath) + if err != nil { + return fmt.Errorf("resolving root boundary: %w", err) + } + + return filepath.WalkDir(input.RootPath, func(path string, d fs.DirEntry, walkErr error) error { + if walkErr != nil { + slog.Warn("skipping inaccessible path", "path", path, "error", walkErr) + if d != nil && d.IsDir() { + return filepath.SkipDir + } + return nil + } + if d.IsDir() { + if IsSkippableDir(d.Name()) { + slog.Debug("skipping directory", "path", path) + return filepath.SkipDir + } + return nil + } + + if d.Type()&fs.ModeSymlink != 0 { + resolvedPath, resolveErr := filepath.EvalSymlinks(path) + if resolveErr != nil { + slog.Debug("skipping broken symlink", "path", path) + return nil + } + resolvedInfo, resolvedErr := os.Stat(resolvedPath) + if resolvedErr != nil { + slog.Debug("skipping broken symlink", "path", path) + return nil + } + if resolvedInfo.IsDir() { + slog.Debug("skipping symlink to directory", "path", path) + return nil + } + if !pathWithinBoundary(resolvedPath, rootBoundary) { + slog.Debug("skipping symlink outside scan root", "path", path, "target", resolvedPath) + return nil + } + } + + exceedsLimit, err := exceedsSizeLimit(path, input.MaxFileSize) + if err != nil { + return fmt.Errorf("checking file size %s: %w", path, err) + } + if exceedsLimit { + return nil + } + + if err := input.OnFile(path); err != nil { + return fmt.Errorf("handling file %s: %w", path, err) + } + return nil + }) +} + +func scanRootBoundary(root string) (string, error) { + resolved, err := filepath.EvalSymlinks(root) + if err == nil { + return filepath.Abs(resolved) + } + return filepath.Abs(root) +} + +func pathWithinBoundary(path, rootBoundary string) bool { + absPath, err := filepath.Abs(path) + if err != nil { + return false + } + rel, err := filepath.Rel(rootBoundary, absPath) + if err != nil { + return false + } + if rel == "." { + return true + } + if rel == ".." { + return false + } + return !strings.HasPrefix(rel, ".."+string(os.PathSeparator)) +} + +func exceedsSizeLimit(path string, maxFileSize int64) (bool, error) { + if maxFileSize <= 0 { + return false, nil + } + info, err := os.Stat(path) + if err != nil { + return false, fmt.Errorf("stat %s: %w", path, err) + } + if info.Size() <= maxFileSize { + return false, nil + } + slog.Debug("skipping large file", "path", path, "size", info.Size(), "max", maxFileSize) + return true, nil +} diff --git a/internal/scanwalk_test.go b/internal/scanwalk_test.go new file mode 100644 index 00000000..59d0328a --- /dev/null +++ b/internal/scanwalk_test.go @@ -0,0 +1,164 @@ +package internal + +import ( + "errors" + "os" + "path/filepath" + "slices" + "testing" +) + +func TestWalkScanFiles_SkipsSymlinkOutsideRoot(t *testing.T) { + // WHY: Directory scans must stay within the requested root and avoid + // ingesting symlink targets from unrelated paths. + t.Parallel() + + root := t.TempDir() + outsideDir := t.TempDir() + + insideFile := filepath.Join(root, "inside.pem") + if err := os.WriteFile(insideFile, []byte("inside"), 0644); err != nil { + t.Fatalf("write inside file: %v", err) + } + + outsideFile := filepath.Join(outsideDir, "outside.pem") + if err := os.WriteFile(outsideFile, []byte("outside"), 0644); err != nil { + t.Fatalf("write outside file: %v", err) + } + + symlinkPath := filepath.Join(root, "outside-link.pem") + createSymlinkOrSkip(t, outsideFile, symlinkPath) + + var visited []string + err := WalkScanFiles(WalkScanFilesInput{ + RootPath: root, + OnFile: func(path string) error { + visited = append(visited, filepath.Base(path)) + return nil + }, + }) + if err != nil { + t.Fatalf("WalkScanFiles error: %v", err) + } + + if !slices.Contains(visited, "inside.pem") { + t.Fatalf("inside file not visited: %v", visited) + } + if slices.Contains(visited, "outside-link.pem") { + t.Fatalf("outside symlink should be skipped: %v", visited) + } +} + +func TestWalkScanFiles_UsesTargetSizeForSymlink(t *testing.T) { + // WHY: max file size must be enforced against the symlink target size, + // not the symlink inode size. + t.Parallel() + + root := t.TempDir() + + smallFile := filepath.Join(root, "small.pem") + if err := os.WriteFile(smallFile, []byte("small"), 0644); err != nil { + t.Fatalf("write small file: %v", err) + } + + largeTarget := filepath.Join(root, "large-target.pem") + if err := os.WriteFile(largeTarget, []byte("this file is definitely larger than ten bytes"), 0644); err != nil { + t.Fatalf("write large file: %v", err) + } + + largeLink := filepath.Join(root, "large-link.pem") + createSymlinkOrSkip(t, largeTarget, largeLink) + + var visited []string + err := WalkScanFiles(WalkScanFilesInput{ + RootPath: root, + MaxFileSize: 10, + OnFile: func(path string) error { + visited = append(visited, filepath.Base(path)) + return nil + }, + }) + if err != nil { + t.Fatalf("WalkScanFiles error: %v", err) + } + + if !slices.Contains(visited, "small.pem") { + t.Fatalf("small file not visited: %v", visited) + } + if slices.Contains(visited, "large-link.pem") { + t.Fatalf("large symlink target should be skipped: %v", visited) + } +} + +func TestWalkScanFiles_WalkErrorDoesNotPruneSiblings(t *testing.T) { + // WHY: A single walk error must not skip unrelated entries in the same + // parent directory. + t.Parallel() + + root := t.TempDir() + dir := filepath.Join(root, "input") + if err := os.MkdirAll(filepath.Join(dir, "sub"), 0755); err != nil { + t.Fatalf("mkdir input: %v", err) + } + + first := filepath.Join(dir, "a-first.pem") + removed := filepath.Join(dir, "b-removed.pem") + nested := filepath.Join(dir, "sub", "c-nested.pem") + for _, p := range []string{first, removed, nested} { + if err := os.WriteFile(p, []byte("x"), 0644); err != nil { + t.Fatalf("write %s: %v", p, err) + } + } + + var visited []string + err := WalkScanFiles(WalkScanFilesInput{ + RootPath: root, + OnFile: func(path string) error { + if path == first { + if removeErr := os.Remove(removed); removeErr != nil { + t.Fatalf("remove %s: %v", removed, removeErr) + } + } + visited = append(visited, path) + return nil + }, + }) + if err != nil { + t.Fatalf("WalkScanFiles error: %v", err) + } + + if !slices.Contains(visited, first) { + t.Fatalf("first file not visited: %v", visited) + } + if !slices.Contains(visited, nested) { + t.Fatalf("nested file should still be visited: %v", visited) + } +} + +func TestWalkScanFiles_PropagatesOnFileError(t *testing.T) { + // WHY: Scan must fail fast when processing a discovered file fails. + t.Parallel() + + root := t.TempDir() + inputFile := filepath.Join(root, "input.pem") + if err := os.WriteFile(inputFile, []byte("x"), 0644); err != nil { + t.Fatalf("write input file: %v", err) + } + + wantErr := errors.New("onfile failed") + err := WalkScanFiles(WalkScanFilesInput{ + RootPath: root, + OnFile: func(path string) error { + if path == inputFile { + return wantErr + } + return nil + }, + }) + if err == nil { + t.Fatalf("expected WalkScanFiles to return an error") + } + if !errors.Is(err, wantErr) { + t.Fatalf("error = %v, want wrapped %v", err, wantErr) + } +} diff --git a/internal/testhelpers_test.go b/internal/testhelpers_test.go index 42ec62e7..12261c7d 100644 --- a/internal/testhelpers_test.go +++ b/internal/testhelpers_test.go @@ -14,6 +14,7 @@ import ( "encoding/pem" "math/big" "net" + "os" "testing" "time" @@ -423,3 +424,11 @@ func createTestTarGz(t *testing.T, files map[string][]byte) []byte { } return buf.Bytes() } + +// createSymlinkOrSkip creates a symlink for tests or skips when unsupported. +func createSymlinkOrSkip(t *testing.T, target, link string) { + t.Helper() + if err := os.Symlink(target, link); err != nil { + t.Skipf("skipping symlink-dependent test: %v", err) + } +} From aea6d432aea44fea8fa4171f67fcbcd48e099e26 Mon Sep 17 00:00:00 2001 From: Daniel Wood Date: Sun, 1 Mar 2026 19:56:24 -0500 Subject: [PATCH 4/8] fix(network): harden SSRF fetch validation and apply default connect timeout (#108) * fix(network): harden revocation fetch SSRF checks and connect timeout defaults * fix(network): propagate SSRF validation deadlines and unblock inspect AIA opt-in * fix(bundle): restore private-network opt-in for AIA chain fetches * fix(network): address remaining PR feedback for inspect AIA handling * fix(wasm): keep AIA resolution working without DNS lookups --- CHANGELOG.md | 4 + README.md | 90 +++++++++--------- bundle.go | 116 +++++++++++++++++------ bundle_lookup_default.go | 16 ++++ bundle_lookup_js.go | 16 ++++ bundle_test.go | 11 ++- certkit_test.go | 163 ++++++++++++++++++++++++++++----- cmd/certkit/bundle.go | 13 ++- cmd/certkit/connect.go | 26 ++++-- cmd/certkit/inspect.go | 10 +- cmd/certkit/ocsp.go | 19 ++-- cmd/certkit/scan.go | 78 +++++++++------- cmd/certkit/verify.go | 41 +++++---- cmd/wasm/aia.go | 9 +- cmd/wasm/inspect.go | 12 ++- cmd/wasm/main.go | 9 +- connect.go | 48 +++++++--- connect_test.go | 135 ++++++++++++++++++--------- crl.go | 10 +- crl_test.go | 43 +++++---- internal/certstore/aia.go | 13 +-- internal/certstore/aia_test.go | 10 +- internal/inspect.go | 10 +- internal/verify.go | 31 ++++--- internal/verify_test.go | 96 ++++++++++--------- ocsp.go | 6 +- ocsp_test.go | 27 +++++- web/public/app.js | 8 ++ web/public/index.html | 12 +++ 29 files changed, 745 insertions(+), 337 deletions(-) create mode 100644 bundle_lookup_default.go create mode 100644 bundle_lookup_js.go diff --git a/CHANGELOG.md b/CHANGELOG.md index e17fbd7b..efe080fa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -91,6 +91,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Enforce bounded per-file and total upload limits in WASM `addFiles` and `inspect` ingestion paths to prevent unbounded memory growth ([#105]) - Enforce local CRL file size limits for `certkit crl` and shared CRL readers to reject oversized inputs early ([#105]) +- Harden AIA/OCSP/CRL SSRF checks by validating DNS-resolved hostnames against private/internal address ranges by default, and add explicit `--allow-private-network` opt-in flags for internal PKI endpoints in `connect`, `verify`, `ocsp`, `scan`, `inspect`, and `bundle` ([#108]) - Prevent bundle export path traversal by sanitizing bundle folder names and enforcing safe output paths ([#87]) - Enforce size limits on input reads to avoid unbounded memory usage ([#87]) - Add SSRF validation (`ValidateAIAURL`) to OCSP responder URLs and CRL distribution point URLs — previously only AIA certificate URLs were validated ([#78]) @@ -99,6 +100,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fix WASM `inspect` AIA resolution to expose an explicit private-network opt-in so internal PKI intermediates can still be fetched when needed ([#108]) +- Apply a default 10-second timeout in `ConnectTLS` when callers provide a context without a deadline, preventing indefinite hangs during TCP/TLS connect and handshake operations ([#108]) - Fix `scan` directory traversal boundaries and resilience: symlinks that point outside the scan root are skipped, max-file-size checks now apply to symlink targets and archive reads, and transient walk errors no longer prune unrelated files during traversal ([#91]) - Fix `scan --bundle-path` text/default output to print a useful post-export summary (certificate/key counts plus export path) while keeping JSON export output unchanged ([#95]) - Fix `scan` to fail fast when per-file processing or size-check `stat` calls fail during directory traversal, instead of logging and silently continuing ([#106]) @@ -962,6 +965,7 @@ Initial release. [#85]: https://github.com/sensiblebit/certkit/pull/85 [#86]: https://github.com/sensiblebit/certkit/pull/86 [#87]: https://github.com/sensiblebit/certkit/pull/87 +[#108]: https://github.com/sensiblebit/certkit/pull/108 [#91]: https://github.com/sensiblebit/certkit/pull/91 [#95]: https://github.com/sensiblebit/certkit/pull/95 [#106]: https://github.com/sensiblebit/certkit/pull/106 diff --git a/README.md b/README.md index 506380a1..464f94d4 100644 --- a/README.md +++ b/README.md @@ -131,23 +131,25 @@ Common passwords (`""`, `"password"`, `"changeit"`, `"keypassword"`) are always ### Inspect Flags -| Flag | Default | Description | -| ---------- | ------- | ------------------------- | -| `--format` | `text` | Output format: text, json | +| Flag | Default | Description | +| ------------------------- | ------- | ----------------------------------------------- | +| `--allow-private-network` | `false` | Allow AIA fetches to private/internal endpoints | +| `--format` | `text` | Output format: text, json | ### Verify Flags -| Flag | Default | Description | -| ---------------- | --------- | ------------------------------------------------------- | -| `--crl` | `false` | Check CRL distribution points for revocation | -| `--diagnose` | `false` | Show diagnostics when chain verification fails | -| `--expiry`, `-e` | | Check if cert expires within duration (e.g., 30d, 720h) | -| `--format` | `text` | Output format: text, json | -| `--key` | | Private key file to check against the certificate | -| `--ocsp` | `false` | Check OCSP revocation status | -| `--trust-store` | `mozilla` | Trust store: system, mozilla | +| Flag | Default | Description | +| ------------------------- | --------- | -------------------------------------------------------- | +| `--allow-private-network` | `false` | Allow AIA/OCSP/CRL fetches to private/internal endpoints | +| `--crl` | `false` | Check CRL distribution points for revocation | +| `--diagnose` | `false` | Show diagnostics when chain verification fails | +| `--expiry`, `-e` | | Check if cert expires within duration (e.g., 30d, 720h) | +| `--format` | `text` | Output format: text, json | +| `--key` | | Private key file to check against the certificate | +| `--ocsp` | `false` | Check OCSP revocation status | +| `--trust-store` | `mozilla` | Trust store: system, mozilla | Chain verification is always performed. When the input contains an embedded private key (PKCS#12, JKS), key match is checked automatically. Use `--ocsp` and/or `--crl` to check revocation status (requires network access and a valid chain). @@ -155,13 +157,14 @@ Chain verification is always performed. When the input contains an embedded priv ### Connect Flags -| Flag | Default | Description | -| -------------- | ------- | ----------------------------------------------------------- | -| `--ciphers` | `false` | Enumerate all supported cipher suites with security ratings | -| `--crl` | `false` | Check CRL distribution points for revocation | -| `--format` | `text` | Output format: text, json | -| `--no-ocsp` | `false` | Disable automatic OCSP revocation check | -| `--servername` | | Override SNI hostname (defaults to host) | +| Flag | Default | Description | +| ------------------------- | ------- | ----------------------------------------------------------- | +| `--allow-private-network` | `false` | Allow AIA/OCSP/CRL fetches to private/internal endpoints | +| `--ciphers` | `false` | Enumerate all supported cipher suites with security ratings | +| `--crl` | `false` | Check CRL distribution points for revocation | +| `--format` | `text` | Output format: text, json | +| `--no-ocsp` | `false` | Disable automatic OCSP revocation check | +| `--servername` | | Override SNI hostname (defaults to host) | Port defaults to 443 if not specified. OCSP revocation status is checked automatically (best-effort); use `--no-ocsp` to disable. Use `--verbose` for extended details (serial, key info, signature algorithm, key usage, EKU). @@ -169,13 +172,14 @@ Port defaults to 443 if not specified. OCSP revocation status is checked automat ### Bundle Flags -| Flag | Default | Description | -| ------------------ | ---------- | ---------------------------------------------- | -| `--force`, `-f` | `false` | Skip chain verification | -| `--format` | `pem` | Output format: pem, chain, fullchain, p12, jks | -| `--key` | | Private key file (PEM) | -| `--out-file`, `-o` | _(stdout)_ | Output file | -| `--trust-store` | `mozilla` | Trust store: system, mozilla | +| Flag | Default | Description | +| ------------------------- | ---------- | ----------------------------------------------- | +| `--allow-private-network` | `false` | Allow AIA fetches to private/internal endpoints | +| `--force`, `-f` | `false` | Skip chain verification | +| `--format` | `pem` | Output format: pem, chain, fullchain, p12, jks | +| `--key` | | Private key file (PEM) | +| `--out-file`, `-o` | _(stdout)_ | Output file | +| `--trust-store` | `mozilla` | Trust store: system, mozilla | ### Convert Flags @@ -217,18 +221,19 @@ Input format is auto-detected. ### Scan Flags -| Flag | Default | Description | -| ----------------- | ---------------- | -------------------------------------------------------- | -| `--bundle-path` | | Export bundles to this directory | -| `--config`, `-c` | `./bundles.yaml` | Path to bundle config YAML | -| `--dump-certs` | | Dump all discovered certificates to a single PEM file | -| `--dump-keys` | | Dump all discovered keys to a single PEM file | -| `--duplicates` | `false` | Export all certificates per bundle, not just the newest | -| `--force`, `-f` | `false` | Allow export of untrusted certificate bundles | -| `--format` | `text` | Output format: text, json | -| `--load-db` | | Load an existing database into memory before scanning | -| `--max-file-size` | `10485760` | Skip files larger than this size in bytes (0 to disable) | -| `--save-db` | | Save the in-memory database to disk after scanning | +| Flag | Default | Description | +| ------------------------- | ---------------- | -------------------------------------------------------- | +| `--allow-private-network` | `false` | Allow AIA fetches to private/internal endpoints | +| `--bundle-path` | | Export bundles to this directory | +| `--config`, `-c` | `./bundles.yaml` | Path to bundle config YAML | +| `--dump-certs` | | Dump all discovered certificates to a single PEM file | +| `--dump-keys` | | Dump all discovered keys to a single PEM file | +| `--duplicates` | `false` | Export all certificates per bundle, not just the newest | +| `--force`, `-f` | `false` | Allow export of untrusted certificate bundles | +| `--format` | `text` | Output format: text, json | +| `--load-db` | | Load an existing database into memory before scanning | +| `--max-file-size` | `10485760` | Skip files larger than this size in bytes (0 to disable) | +| `--save-db` | | Save the in-memory database to disk after scanning | ### Keygen Flags @@ -264,10 +269,11 @@ Exactly one of `--template`, `--from-cert`, or `--from-csr` is required. ### OCSP Flags -| Flag | Default | Description | -| ---------- | ------- | ------------------------------------------------------------------ | -| `--format` | `text` | Output format: text, json | -| `--issuer` | | Issuer certificate file (PEM); auto-resolved from input if omitted | +| Flag | Default | Description | +| ------------------------- | ------- | ------------------------------------------------------------------ | +| `--allow-private-network` | `false` | Allow OCSP fetches to private/internal endpoints | +| `--format` | `text` | Output format: text, json | +| `--issuer` | | Issuer certificate file (PEM); auto-resolved from input if omitted | The OCSP responder URL is read from the certificate's AIA extension. diff --git a/bundle.go b/bundle.go index 5bfc8bcb..e5fb2163 100644 --- a/bundle.go +++ b/bundle.go @@ -34,8 +34,11 @@ var ( // address space. Parsed once at init to avoid repeated net.ParseCIDR calls. var privateNetworks []*net.IPNet +const aiaURLResolveTimeout = 2 * time.Second + func init() { for _, cidr := range []string{ + "0.0.0.0/8", // RFC 791 "this network" "10.0.0.0/8", // RFC 1918 "172.16.0.0/12", // RFC 1918 "192.168.0.0/16", // RFC 1918 @@ -166,23 +169,40 @@ func IsIssuedByMozillaRoot(cert *x509.Certificate) bool { return MozillaRootSubjects()[string(cert.RawIssuer)] } -// ValidateAIAURL checks whether a URL is safe to fetch for AIA certificate -// resolution. It rejects non-HTTP(S) schemes and literal private/loopback/ -// link-local IP addresses to prevent SSRF. -// -// Known limitation: hostnames that resolve to private IPs are intentionally -// allowed. This means DNS rebinding (a hostname resolving to a public IP at -// validation time, then to a private IP at connection time) is theoretically -// possible. We accept this because: +// ValidateAIAURLInput holds parameters for ValidateAIAURLWithOptions. +type ValidateAIAURLInput struct { + // URL is the candidate URL to validate. + URL string + // AllowPrivateNetworks bypasses private/internal IP checks. + AllowPrivateNetworks bool + + lookupIPAddresses lookupIPAddressesFunc +} + +type lookupIPAddressesFunc func(ctx context.Context, host string) ([]net.IP, error) + +func ipBlockedForAIA(ip net.IP) error { + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified() { + return fmt.Errorf("blocked address %s (loopback, link-local, or unspecified)", ip.String()) + } + for _, network := range privateNetworks { + if network.Contains(ip) { + return fmt.Errorf("blocked private address %s", ip.String()) + } + } + return nil +} + +// ValidateAIAURLWithOptions checks whether a URL is safe to fetch for AIA, +// OCSP, and CRL HTTP requests. // -// 1. certkit is a short-lived CLI process — the window between ValidateAIAURL -// and the HTTP request is ~2ms, making rebinding impractical to exploit. -// 2. Blocking hostnames that resolve to private IPs would break legitimate -// internal CAs whose AIA endpoints are on private networks. -// 3. Adding net.Dialer.Control to check resolved IPs doesn't help: if we -// allow private IPs for internal CAs, the check is the same TOCTOU race. -func ValidateAIAURL(rawURL string) error { - parsed, err := url.Parse(rawURL) +// By default, it rejects non-HTTP(S) schemes plus literal and DNS-resolved +// private/loopback/link-local/unspecified addresses to reduce SSRF risk. Set +// AllowPrivateNetworks to bypass IP restrictions. This check does not fully +// prevent DNS-rebind TOCTOU attacks between validation-time DNS and dial-time +// DNS. +func ValidateAIAURLWithOptions(ctx context.Context, input ValidateAIAURLInput) error { + parsed, err := url.Parse(input.URL) if err != nil { return fmt.Errorf("parsing URL: %w", err) } @@ -193,21 +213,56 @@ func ValidateAIAURL(rawURL string) error { return fmt.Errorf("unsupported scheme %q (only http and https are allowed)", parsed.Scheme) } host := parsed.Hostname() + if host == "" { + return fmt.Errorf("missing hostname in URL") + } + + if input.AllowPrivateNetworks { + return nil + } + ip := net.ParseIP(host) - if ip == nil { - return nil // hostname, not a literal IP — allow (see doc comment) + if ip != nil { + if blockedErr := ipBlockedForAIA(ip); blockedErr != nil { + return blockedErr + } + return nil } - if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified() { - return fmt.Errorf("blocked address %s (loopback, link-local, or unspecified)", host) + + lookup := input.lookupIPAddresses + if lookup == nil { + if !aiaDNSResolutionAvailable() { + return nil + } + lookup = defaultLookupIPAddresses } - for _, network := range privateNetworks { - if network.Contains(ip) { - return fmt.Errorf("blocked private address %s", host) + + resolveCtx, cancel := context.WithTimeout(ctx, aiaURLResolveTimeout) + defer cancel() + + ips, err := lookup(resolveCtx, host) + if err != nil { + return fmt.Errorf("resolving host %q: %w", host, err) + } + if len(ips) == 0 { + return fmt.Errorf("resolving host %q: no IP addresses returned", host) + } + for _, resolvedIP := range ips { + if blockedErr := ipBlockedForAIA(resolvedIP); blockedErr != nil { + return fmt.Errorf("host %q resolved to %s: %w", host, resolvedIP.String(), blockedErr) } } + return nil } +// ValidateAIAURL checks whether a URL is safe to fetch for AIA, OCSP, and CRL +// requests. It rejects non-HTTP(S) schemes plus literal and DNS-resolved +// private/loopback/link-local/unspecified addresses. +func ValidateAIAURL(rawURL string) error { + return ValidateAIAURLWithOptions(context.Background(), ValidateAIAURLInput{URL: rawURL}) +} + // VerifyChainTrustInput holds parameters for VerifyChainTrust. type VerifyChainTrustInput struct { Cert *x509.Certificate @@ -277,6 +332,8 @@ type BundleOptions struct { Verify bool // ExcludeRoot omits the root certificate from the result. ExcludeRoot bool + // AllowPrivateNetworks allows AIA fetches to private/internal endpoints. + AllowPrivateNetworks bool } // DefaultOptions returns sensible defaults. @@ -352,6 +409,8 @@ type FetchAIACertificatesInput struct { Timeout time.Duration // MaxDepth is the maximum number of AIA hops to follow. MaxDepth int + // AllowPrivateNetworks allows AIA fetches to private/internal endpoints. + AllowPrivateNetworks bool } // FetchAIACertificates follows AIA CA Issuers URLs to fetch intermediate certificates. @@ -369,7 +428,7 @@ func FetchAIACertificates(ctx context.Context, input FetchAIACertificatesInput) if len(via) >= maxRedirects { return fmt.Errorf("stopped after %d redirects", maxRedirects) } - if err := ValidateAIAURL(req.URL.String()); err != nil { + if err := ValidateAIAURLWithOptions(req.Context(), ValidateAIAURLInput{URL: req.URL.String(), AllowPrivateNetworks: input.AllowPrivateNetworks}); err != nil { return fmt.Errorf("redirect blocked: %w", err) } return nil @@ -388,7 +447,7 @@ func FetchAIACertificates(ctx context.Context, input FetchAIACertificatesInput) } seen[aiaURL] = true - if err := ValidateAIAURL(aiaURL); err != nil { + if err := ValidateAIAURLWithOptions(ctx, ValidateAIAURLInput{URL: aiaURL, AllowPrivateNetworks: input.AllowPrivateNetworks}); err != nil { warnings = append(warnings, fmt.Sprintf("AIA URL rejected for %s: %v", aiaURL, err)) continue } @@ -532,9 +591,10 @@ func Bundle(ctx context.Context, input BundleInput) (*BundleResult, error) { if opts.FetchAIA { aiaCerts, warnings := FetchAIACertificates(ctx, FetchAIACertificatesInput{ - Cert: leaf, - Timeout: opts.AIATimeout, - MaxDepth: opts.AIAMaxDepth, + Cert: leaf, + Timeout: opts.AIATimeout, + MaxDepth: opts.AIAMaxDepth, + AllowPrivateNetworks: opts.AllowPrivateNetworks, }) result.Warnings = append(result.Warnings, warnings...) for _, cert := range aiaCerts { diff --git a/bundle_lookup_default.go b/bundle_lookup_default.go new file mode 100644 index 00000000..74cf4293 --- /dev/null +++ b/bundle_lookup_default.go @@ -0,0 +1,16 @@ +//go:build !js + +package certkit + +import ( + "context" + "net" +) + +func defaultLookupIPAddresses(ctx context.Context, host string) ([]net.IP, error) { + return net.DefaultResolver.LookupIP(ctx, "ip", host) +} + +func aiaDNSResolutionAvailable() bool { + return true +} diff --git a/bundle_lookup_js.go b/bundle_lookup_js.go new file mode 100644 index 00000000..85a8cc0e --- /dev/null +++ b/bundle_lookup_js.go @@ -0,0 +1,16 @@ +//go:build js + +package certkit + +import ( + "context" + "net" +) + +func defaultLookupIPAddresses(_ context.Context, _ string) ([]net.IP, error) { + return nil, nil +} + +func aiaDNSResolutionAvailable() bool { + return false +} diff --git a/bundle_test.go b/bundle_test.go index 7237f7c6..672d905c 100644 --- a/bundle_test.go +++ b/bundle_test.go @@ -598,8 +598,8 @@ func TestFetchAIACertificates_duplicateURLs(t *testing.T) { if err != nil { t.Fatal(err) } - // Replace 127.0.0.1 with localhost to avoid ValidateAIAURL SSRF blocking - // of literal loopback IPs. Hostname-based URLs pass SSRF validation. + // Replace 127.0.0.1 with localhost and opt in to private networks for this + // local integration test. srvURL := strings.Replace(srv.URL, "127.0.0.1", "localhost", 1) leafTemplate := &x509.Certificate{ @@ -623,9 +623,10 @@ func TestFetchAIACertificates_duplicateURLs(t *testing.T) { } fetched, _ := FetchAIACertificates(context.Background(), FetchAIACertificatesInput{ - Cert: leafCert, - Timeout: 2 * time.Second, - MaxDepth: 5, + Cert: leafCert, + Timeout: 2 * time.Second, + MaxDepth: 5, + AllowPrivateNetworks: true, }) if len(fetched) != 1 { t.Errorf("expected 1 fetched cert (deduped), got %d", len(fetched)) diff --git a/certkit_test.go b/certkit_test.go index 50c0e2c0..5fc00f7e 100644 --- a/certkit_test.go +++ b/certkit_test.go @@ -1,6 +1,7 @@ package certkit import ( + "context" "crypto" "crypto/ecdsa" "crypto/ed25519" @@ -14,7 +15,9 @@ import ( "encoding/asn1" "encoding/hex" "encoding/pem" + "errors" "fmt" + "net" "slices" "strings" "testing" @@ -1965,39 +1968,46 @@ func TestVerifyChainTrust(t *testing.T) { func TestValidateAIAURL(t *testing.T) { // WHY: ValidateAIAURL prevents SSRF by rejecting non-HTTP schemes and - // literal private/loopback/link-local IP addresses. Each case covers a - // distinct rejection rule or an allowed pattern. + // private/loopback/link-local/unspecified IPs. t.Parallel() tests := []struct { - name string - url string - wantErr bool - errSub string + name string + url string + allowPrivate bool + wantErr bool + errSub string }{ - {"valid http", "http://ca.example.com/issuer.cer", false, ""}, - {"valid https", "https://ca.example.com/issuer.cer", false, ""}, - {"ftp rejected", "ftp://ca.example.com/issuer.cer", true, "unsupported scheme"}, - {"file rejected", "file:///etc/passwd", true, "unsupported scheme"}, - {"empty scheme rejected", "://foo", true, "parsing URL"}, - {"loopback IPv4", "http://127.0.0.1/ca.cer", true, "loopback"}, - {"loopback IPv6", "http://[::1]/ca.cer", true, "loopback"}, - {"link-local IPv4", "http://169.254.1.1/ca.cer", true, "loopback, link-local, or unspecified"}, - {"unspecified IPv4", "http://0.0.0.0/ca.cer", true, "loopback, link-local, or unspecified"}, - {"unspecified IPv6", "http://[::]/ca.cer", true, "loopback, link-local, or unspecified"}, - {"private IPv6 ULA", "http://[fd12::1]/ca.cer", true, "blocked private"}, - {"private 10.x", "http://10.0.0.1/ca.cer", true, "blocked private"}, - {"private 172.16.x", "http://172.16.0.1/ca.cer", true, "blocked private"}, - {"private 192.168.x", "http://192.168.1.1/ca.cer", true, "blocked private"}, - {"CGN 100.64.x", "http://100.64.0.1/ca.cer", true, "blocked private"}, - {"public IP allowed", "http://8.8.8.8/ca.cer", false, ""}, - {"hostname allowed even if resolves to private", "http://internal.company.com/ca.cer", false, ""}, + {"valid public IPv4 http", "http://8.8.8.8/issuer.cer", false, false, ""}, + {"valid public IPv4 https", "https://8.8.8.8/issuer.cer", false, false, ""}, + {"ftp rejected", "ftp://ca.example.com/issuer.cer", false, true, "unsupported scheme"}, + {"file rejected", "file:///etc/passwd", false, true, "unsupported scheme"}, + {"empty scheme rejected", "://foo", false, true, "parsing URL"}, + {"missing hostname rejected", "https:///issuer.cer", false, true, "missing hostname"}, + {"loopback IPv4", "http://127.0.0.1/ca.cer", false, true, "loopback"}, + {"loopback IPv6", "http://[::1]/ca.cer", false, true, "loopback"}, + {"localhost hostname", "http://localhost/ca.cer", false, true, "resolved"}, + {"link-local IPv4", "http://169.254.1.1/ca.cer", false, true, "loopback, link-local, or unspecified"}, + {"unspecified IPv4", "http://0.0.0.0/ca.cer", false, true, "loopback, link-local, or unspecified"}, + {"this network IPv4 range", "http://0.1.2.3/ca.cer", false, true, "blocked private"}, + {"unspecified IPv6", "http://[::]/ca.cer", false, true, "loopback, link-local, or unspecified"}, + {"private IPv6 ULA", "http://[fd12::1]/ca.cer", false, true, "blocked private"}, + {"private 10.x", "http://10.0.0.1/ca.cer", false, true, "blocked private"}, + {"private 172.16.x", "http://172.16.0.1/ca.cer", false, true, "blocked private"}, + {"private 192.168.x", "http://192.168.1.1/ca.cer", false, true, "blocked private"}, + {"CGN 100.64.x", "http://100.64.0.1/ca.cer", false, true, "blocked private"}, + {"allow private network option", "http://127.0.0.1/ca.cer", true, false, ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - err := ValidateAIAURL(tt.url) + var err error + if tt.allowPrivate { + err = ValidateAIAURLWithOptions(context.Background(), ValidateAIAURLInput{URL: tt.url, AllowPrivateNetworks: true}) + } else { + err = ValidateAIAURL(tt.url) + } if tt.wantErr { if err == nil { t.Fatalf("expected error for %q", tt.url) @@ -2014,6 +2024,111 @@ func TestValidateAIAURL(t *testing.T) { } } +func TestValidateAIAURLWithOptions_HostnameResolution(t *testing.T) { + t.Parallel() + + lookup := func(_ context.Context, host string) ([]net.IP, error) { + switch host { + case "public.example": + return []net.IP{net.ParseIP("93.184.216.34")}, nil + case "mixed.example": + return []net.IP{net.ParseIP("93.184.216.34"), net.ParseIP("10.0.0.10")}, nil + case "empty.example": + return nil, nil + default: + return nil, fmt.Errorf("lookup failed") + } + } + + tests := []struct { + name string + input ValidateAIAURLInput + wantErr string + }{ + { + name: "public resolution allowed", + input: ValidateAIAURLInput{ + URL: "https://public.example/issuer.cer", + lookupIPAddresses: lookup, + }, + }, + { + name: "mixed public and private blocked", + input: ValidateAIAURLInput{ + URL: "https://mixed.example/issuer.cer", + lookupIPAddresses: lookup, + }, + wantErr: "blocked private address", + }, + { + name: "empty DNS answer blocked", + input: ValidateAIAURLInput{ + URL: "https://empty.example/issuer.cer", + lookupIPAddresses: lookup, + }, + wantErr: "no IP addresses returned", + }, + { + name: "resolver error blocked", + input: ValidateAIAURLInput{ + URL: "https://error.example/issuer.cer", + lookupIPAddresses: lookup, + }, + wantErr: "resolving host", + }, + { + name: "allow private bypasses DNS checks", + input: ValidateAIAURLInput{ + URL: "https://mixed.example/issuer.cer", + AllowPrivateNetworks: true, + lookupIPAddresses: lookup, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := ValidateAIAURLWithOptions(context.Background(), tt.input) + if tt.wantErr == "" { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + return + } + if err == nil { + t.Fatalf("expected error containing %q", tt.wantErr) + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("error = %q, want substring %q", err.Error(), tt.wantErr) + } + }) + } +} + +func TestValidateAIAURLWithOptions_ContextDeadline(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + lookup := func(ctx context.Context, _ string) ([]net.IP, error) { + <-ctx.Done() + return nil, ctx.Err() + } + + err := ValidateAIAURLWithOptions(ctx, ValidateAIAURLInput{ + URL: "https://example.com/issuer.cer", + lookupIPAddresses: lookup, + }) + if err == nil { + t.Fatal("expected context deadline error") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("error = %v, want context.DeadlineExceeded", err) + } +} + func TestAlgorithmName(t *testing.T) { // WHY: KeyAlgorithmName and PublicKeyAlgorithmName produce display strings // for CLI output and JSON; wrong names would confuse users and break JSON diff --git a/cmd/certkit/bundle.go b/cmd/certkit/bundle.go index 5ceaea06..140b662a 100644 --- a/cmd/certkit/bundle.go +++ b/cmd/certkit/bundle.go @@ -17,11 +17,12 @@ import ( ) var ( - bundleKeyPath string - bundleOutFile string - bundleFormat string - bundleForce bool - bundleTrustStore string + bundleKeyPath string + bundleOutFile string + bundleFormat string + bundleForce bool + bundleAllowPrivateNetwork bool + bundleTrustStore string ) var bundleCmd = &cobra.Command{ @@ -49,6 +50,7 @@ func init() { bundleCmd.Flags().StringVarP(&bundleOutFile, "out-file", "o", "", "Output file") bundleCmd.Flags().StringVar(&bundleFormat, "format", "pem", "Output format: pem, chain, fullchain, p12, jks") bundleCmd.Flags().BoolVarP(&bundleForce, "force", "f", false, "Skip chain verification") + bundleCmd.Flags().BoolVar(&bundleAllowPrivateNetwork, "allow-private-network", false, "Allow AIA fetches to private/internal endpoints") bundleCmd.Flags().StringVar(&bundleTrustStore, "trust-store", "mozilla", "Trust store: system, mozilla") bundleCmd.Flags().Lookup("out-file").Annotations = map[string][]string{"readme_default": {"_(stdout)_"}} @@ -98,6 +100,7 @@ func runBundle(cmd *cobra.Command, args []string) error { opts := certkit.DefaultOptions() opts.TrustStore = bundleTrustStore opts.ExtraIntermediates = extraCerts + opts.AllowPrivateNetworks = bundleAllowPrivateNetwork if bundleForce { opts.Verify = false } diff --git a/cmd/certkit/connect.go b/cmd/certkit/connect.go index 8ae6e092..12af2779 100644 --- a/cmd/certkit/connect.go +++ b/cmd/certkit/connect.go @@ -18,11 +18,12 @@ import ( ) var ( - connectServerName string - connectFormat string - connectCRL bool - connectNoOCSP bool - connectCiphers bool + connectServerName string + connectFormat string + connectCRL bool + connectNoOCSP bool + connectCiphers bool + connectAllowPrivateNetwork bool ) var connectCmd = &cobra.Command{ @@ -36,6 +37,9 @@ automatically (best-effort). Use --no-ocsp to disable. Use --crl to also check CRL distribution points. Use --ciphers to enumerate all cipher suites the server supports with security ratings. +Network fetches for AIA/OCSP/CRL block private/internal endpoints by default. +Use --allow-private-network to opt in for internal PKI environments. + Exits with code 2 if chain verification fails or the certificate is revoked.`, Example: ` certkit connect example.com certkit connect example.com:8443 @@ -53,6 +57,7 @@ func init() { connectCmd.Flags().BoolVar(&connectCRL, "crl", false, "Check CRL distribution points for revocation") connectCmd.Flags().BoolVar(&connectNoOCSP, "no-ocsp", false, "Disable automatic OCSP revocation check") connectCmd.Flags().BoolVar(&connectCiphers, "ciphers", false, "Enumerate all supported cipher suites with security ratings") + connectCmd.Flags().BoolVar(&connectAllowPrivateNetwork, "allow-private-network", false, "Allow AIA/OCSP/CRL fetches to private/internal endpoints") registerCompletion(connectCmd, completionInput{"format", fixedCompletion("text", "json")}) } @@ -114,11 +119,12 @@ func runConnect(cmd *cobra.Command, args []string) error { defer cancel() result, err := certkit.ConnectTLS(ctx, certkit.ConnectTLSInput{ - Host: host, - Port: port, - ServerName: connectServerName, - DisableOCSP: connectNoOCSP, - CheckCRL: connectCRL, + Host: host, + Port: port, + ServerName: connectServerName, + DisableOCSP: connectNoOCSP, + CheckCRL: connectCRL, + AllowPrivateNetworks: connectAllowPrivateNetwork, }) if err != nil { spin.Stop() diff --git a/cmd/certkit/inspect.go b/cmd/certkit/inspect.go index 4c6c2b07..f91c1529 100644 --- a/cmd/certkit/inspect.go +++ b/cmd/certkit/inspect.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "log/slog" "slices" @@ -10,6 +11,7 @@ import ( ) var inspectFormat string +var inspectAllowPrivateNetwork bool var inspectCmd = &cobra.Command{ Use: "inspect ", @@ -24,6 +26,7 @@ var inspectCmd = &cobra.Command{ func init() { inspectCmd.Flags().StringVar(&inspectFormat, "format", "text", "Output format: text, json") + inspectCmd.Flags().BoolVar(&inspectAllowPrivateNetwork, "allow-private-network", false, "Allow AIA fetches to private/internal endpoints") registerCompletion(inspectCmd, completionInput{"format", fixedCompletion("text", "json")}) } @@ -41,8 +44,11 @@ func runInspect(cmd *cobra.Command, args []string) error { // Resolve missing intermediates via AIA before trust annotation. results, aiaWarnings := internal.ResolveInspectAIA(cmd.Context(), internal.ResolveInspectAIAInput{ - Results: results, - Fetch: httpAIAFetcher, + Results: results, + AllowPrivateNetworks: inspectAllowPrivateNetwork, + Fetch: func(ctx context.Context, rawURL string) ([]byte, error) { + return fetchAIAURL(ctx, fetchAIAURLInput{rawURL: rawURL, allowPrivateNetworks: inspectAllowPrivateNetwork}) + }, }) for _, w := range aiaWarnings { slog.Warn("AIA resolution", "warning", w) diff --git a/cmd/certkit/ocsp.go b/cmd/certkit/ocsp.go index 0493e02f..d689b985 100644 --- a/cmd/certkit/ocsp.go +++ b/cmd/certkit/ocsp.go @@ -12,8 +12,9 @@ import ( ) var ( - ocspIssuerPath string - ocspFormat string + ocspIssuerPath string + ocspFormat string + ocspAllowPrivateNetwork bool ) var ocspCmd = &cobra.Command{ @@ -23,7 +24,8 @@ var ocspCmd = &cobra.Command{ The OCSP responder URL is read from the certificate's AIA extension. Use --issuer to provide the issuer certificate if it is not embedded -in the input file. +in the input file. Private/internal OCSP endpoints are blocked by default; +use --allow-private-network to opt in. Exits with code 2 if the certificate is revoked.`, Example: ` certkit ocsp cert.pem --issuer issuer.pem @@ -36,6 +38,7 @@ Exits with code 2 if the certificate is revoked.`, func init() { ocspCmd.Flags().StringVar(&ocspIssuerPath, "issuer", "", "Issuer certificate file (PEM); auto-resolved from input if omitted") ocspCmd.Flags().StringVar(&ocspFormat, "format", "text", "Output format: text, json") + ocspCmd.Flags().BoolVar(&ocspAllowPrivateNetwork, "allow-private-network", false, "Allow OCSP fetches to private/internal endpoints") registerCompletion(ocspCmd, completionInput{"issuer", fileCompletion}) registerCompletion(ocspCmd, completionInput{"format", fixedCompletion("text", "json")}) @@ -75,8 +78,9 @@ func runOCSP(cmd *cobra.Command, args []string) error { return fmt.Errorf("parsing issuer certificate: %w", err) } ocspInput = &certkit.CheckOCSPInput{ - Cert: contents.Leaf, - Issuer: issuerCert, + Cert: contents.Leaf, + Issuer: issuerCert, + AllowPrivateNetworks: ocspAllowPrivateNetwork, } } else if len(contents.ExtraCerts) > 0 { issuerCert := certkit.SelectIssuerCertificate(contents.Leaf, contents.ExtraCerts) @@ -84,8 +88,9 @@ func runOCSP(cmd *cobra.Command, args []string) error { return fmt.Errorf("no matching issuer certificate found in input; use --issuer to provide one") } ocspInput = &certkit.CheckOCSPInput{ - Cert: contents.Leaf, - Issuer: issuerCert, + Cert: contents.Leaf, + Issuer: issuerCert, + AllowPrivateNetworks: ocspAllowPrivateNetwork, } } else { return fmt.Errorf("no issuer certificate found; use --issuer to provide one") diff --git a/cmd/certkit/scan.go b/cmd/certkit/scan.go index e0d213ab..aae3058a 100644 --- a/cmd/certkit/scan.go +++ b/cmd/certkit/scan.go @@ -20,16 +20,17 @@ import ( ) var ( - scanLoadDB string - scanSaveDB string - scanConfigPath string - scanBundlePath string - scanForceExport bool - scanDuplicates bool - scanDumpKeys string - scanDumpCerts string - scanMaxFileSize int64 - scanFormat string + scanLoadDB string + scanSaveDB string + scanConfigPath string + scanBundlePath string + scanForceExport bool + scanDuplicates bool + scanDumpKeys string + scanDumpCerts string + scanMaxFileSize int64 + scanFormat string + scanAllowPrivateNetwork bool ) var scanCmd = &cobra.Command{ @@ -53,6 +54,7 @@ func init() { scanCmd.Flags().StringVar(&scanDumpCerts, "dump-certs", "", "Dump all discovered certificates to a single PEM file") scanCmd.Flags().Int64Var(&scanMaxFileSize, "max-file-size", 10*1024*1024, "Skip files larger than this size in bytes (0 to disable)") scanCmd.Flags().StringVar(&scanFormat, "format", "text", "Output format: text, json") + scanCmd.Flags().BoolVar(&scanAllowPrivateNetwork, "allow-private-network", false, "Allow AIA fetches to private/internal endpoints") scanCmd.Flags().StringVar(&scanSaveDB, "save-db", "", "Save the in-memory database to disk after scanning") scanCmd.Flags().StringVar(&scanLoadDB, "load-db", "", "Load an existing database into memory before scanning") @@ -154,8 +156,9 @@ func runScan(cmd *cobra.Command, args []string) error { if certstore.HasUnresolvedIssuers(store) { slog.Info("resolving certificate chains") aiaWarnings := certstore.ResolveAIA(cmd.Context(), certstore.ResolveAIAInput{ - Store: store, - Fetch: httpAIAFetcher, + Store: store, + Fetch: httpAIAFetcher, + AllowPrivateNetworks: scanAllowPrivateNetwork, }) for _, w := range aiaWarnings { slog.Warn("AIA resolution", "warning", w) @@ -471,45 +474,56 @@ func printScanVerboseText(store *certstore.MemStore) { } } -// aiaHTTPClient is reused across AIA fetches to enable TCP connection reuse. +// newAIAHTTPClient creates an HTTP client for AIA fetches. // Redirects are limited to 3 and validated against SSRF rules. -var aiaHTTPClient = &http.Client{ - Timeout: 2 * time.Second, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - if len(via) >= 3 { - return fmt.Errorf("stopped after 3 redirects") - } - if err := certkit.ValidateAIAURL(req.URL.String()); err != nil { - return fmt.Errorf("redirect blocked: %w", err) - } - return nil - }, +func newAIAHTTPClient(allowPrivateNetworks bool) *http.Client { + return &http.Client{ + Timeout: 2 * time.Second, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= 3 { + return fmt.Errorf("stopped after 3 redirects") + } + if err := certkit.ValidateAIAURLWithOptions(req.Context(), certkit.ValidateAIAURLInput{URL: req.URL.String(), AllowPrivateNetworks: allowPrivateNetworks}); err != nil { + return fmt.Errorf("redirect blocked: %w", err) + } + return nil + }, + } } -// httpAIAFetcher fetches raw certificate bytes from a URL via HTTP. -func httpAIAFetcher(ctx context.Context, rawURL string) ([]byte, error) { - if err := certkit.ValidateAIAURL(rawURL); err != nil { +type fetchAIAURLInput struct { + rawURL string + allowPrivateNetworks bool +} + +func fetchAIAURL(ctx context.Context, input fetchAIAURLInput) ([]byte, error) { + if err := certkit.ValidateAIAURLWithOptions(ctx, certkit.ValidateAIAURLInput{URL: input.rawURL, AllowPrivateNetworks: input.allowPrivateNetworks}); err != nil { return nil, fmt.Errorf("AIA URL rejected: %w", err) } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, input.rawURL, nil) if err != nil { return nil, fmt.Errorf("creating AIA request: %w", err) } - resp, err := aiaHTTPClient.Do(req) + resp, err := newAIAHTTPClient(input.allowPrivateNetworks).Do(req) if err != nil { - return nil, fmt.Errorf("fetching AIA URL %s: %w", rawURL, err) + return nil, fmt.Errorf("fetching AIA URL %s: %w", input.rawURL, err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("HTTP %d from %s", resp.StatusCode, rawURL) + return nil, fmt.Errorf("HTTP %d from %s", resp.StatusCode, input.rawURL) } data, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) // 1MB limit if err != nil { - return nil, fmt.Errorf("reading AIA response from %s: %w", rawURL, err) + return nil, fmt.Errorf("reading AIA response from %s: %w", input.rawURL, err) } return data, nil } +// httpAIAFetcher fetches raw certificate bytes from a URL via HTTP. +func httpAIAFetcher(ctx context.Context, rawURL string) ([]byte, error) { + return fetchAIAURL(ctx, fetchAIAURLInput{rawURL: rawURL, allowPrivateNetworks: scanAllowPrivateNetwork}) +} + // formatDN formats a pkix.Name as a one-line distinguished name string // matching the OpenSSL one-line format (e.g. "CN=example.com, O=Acme, C=US"). func formatDN(name pkix.Name) string { diff --git a/cmd/certkit/verify.go b/cmd/certkit/verify.go index e3485380..da90584a 100644 --- a/cmd/certkit/verify.go +++ b/cmd/certkit/verify.go @@ -15,13 +15,14 @@ import ( ) var ( - verifyKeyPath string - verifyExpiry string - verifyTrustStore string - verifyFormat string - verifyDiagnose bool - verifyOCSP bool - verifyCRL bool + verifyKeyPath string + verifyExpiry string + verifyTrustStore string + verifyFormat string + verifyDiagnose bool + verifyOCSP bool + verifyCRL bool + verifyAllowPrivateNetwork bool ) var verifyCmd = &cobra.Command{ @@ -36,7 +37,9 @@ is checked automatically. Use --key to check against an external key file. Use --ocsp to check OCSP revocation status, and --crl to check CRL distribution points. Both require network access and a valid chain (the issuer certificate -is needed to verify the response). Exits with code 2 if verification finds any errors (including revocation).`, +is needed to verify the response). Network fetches for AIA/OCSP/CRL block +private/internal endpoints by default; use --allow-private-network to opt in. +Exits with code 2 if verification finds any errors (including revocation).`, Example: ` certkit verify cert.pem certkit verify cert.pem --key key.pem certkit verify cert.pem --expiry 30d @@ -57,6 +60,7 @@ func init() { verifyCmd.Flags().BoolVar(&verifyDiagnose, "diagnose", false, "Show diagnostics when chain verification fails") verifyCmd.Flags().BoolVar(&verifyOCSP, "ocsp", false, "Check OCSP revocation status") verifyCmd.Flags().BoolVar(&verifyCRL, "crl", false, "Check CRL distribution points for revocation") + verifyCmd.Flags().BoolVar(&verifyAllowPrivateNetwork, "allow-private-network", false, "Allow AIA/OCSP/CRL fetches to private/internal endpoints") registerCompletion(verifyCmd, completionInput{"format", fixedCompletion("text", "json")}) registerCompletion(verifyCmd, completionInput{"trust-store", fixedCompletion("system", "mozilla")}) @@ -128,16 +132,17 @@ func runVerify(cmd *cobra.Command, args []string) error { } input := &internal.VerifyInput{ - Cert: contents.Leaf, - Key: key, - ExtraCerts: contents.ExtraCerts, - CheckKeyMatch: key != nil, - CheckChain: true, // Always verify chain - ExpiryDuration: expiryDuration, - TrustStore: verifyTrustStore, - Verbose: verbose, - CheckOCSP: verifyOCSP, - CheckCRL: verifyCRL, + Cert: contents.Leaf, + Key: key, + ExtraCerts: contents.ExtraCerts, + CheckKeyMatch: key != nil, + CheckChain: true, // Always verify chain + ExpiryDuration: expiryDuration, + TrustStore: verifyTrustStore, + Verbose: verbose, + CheckOCSP: verifyOCSP, + CheckCRL: verifyCRL, + AllowPrivateNetworks: verifyAllowPrivateNetwork, } result, err := internal.VerifyCert(cmd.Context(), input) diff --git a/cmd/wasm/aia.go b/cmd/wasm/aia.go index fbba6a47..586a2f72 100644 --- a/cmd/wasm/aia.go +++ b/cmd/wasm/aia.go @@ -18,11 +18,12 @@ import ( // // Progress is dispatched to JS via setTimeout so the browser event loop can // update the progress bar without blocking the AIA goroutine. -func resolveAIA(ctx context.Context, s *certstore.MemStore) []string { +func resolveAIA(ctx context.Context, s *certstore.MemStore, allowPrivateNetworks bool) []string { return certstore.ResolveAIA(ctx, certstore.ResolveAIAInput{ - Store: s, - Fetch: jsFetchURL, - Concurrency: 50, + Store: s, + Fetch: jsFetchURL, + Concurrency: 50, + AllowPrivateNetworks: allowPrivateNetworks, OnProgress: func(completed, total int) { var cb js.Func cb = js.FuncOf(func(_ js.Value, _ []js.Value) any { diff --git a/cmd/wasm/inspect.go b/cmd/wasm/inspect.go index 92fab222..0a30c532 100644 --- a/cmd/wasm/inspect.go +++ b/cmd/wasm/inspect.go @@ -17,7 +17,7 @@ import ( // inspectFiles performs stateless inspection of certificate, key, and CSR data. // Unlike addFiles, it does not accumulate into the global MemStore. -// JS signature: certkitInspect(files: Array<{name: string, data: Uint8Array}>, passwords: string) → Promise +// JS signature: certkitInspect(files: Array<{name: string, data: Uint8Array}>, passwords: string, allowPrivateNetwork?: boolean) → Promise func inspectFiles(_ js.Value, args []js.Value) any { if len(args) < 1 { return jsError("certkitInspect requires at least 1 argument") @@ -40,6 +40,11 @@ func inspectFiles(_ js.Value, args []js.Value) any { } passwords = certkit.DeduplicatePasswords(passwords) + allowPrivateNetworks := false + if len(args) >= 3 && args[2].Type() == js.TypeBoolean { + allowPrivateNetworks = args[2].Bool() + } + handler := js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any { resolve := promiseArgs[0] reject := promiseArgs[1] @@ -89,8 +94,9 @@ func inspectFiles(_ js.Value, args []js.Value) any { // Resolve missing intermediates via AIA before trust annotation. allResults, aiaWarnings := internal.ResolveInspectAIA(ctx, internal.ResolveInspectAIAInput{ - Results: allResults, - Fetch: jsFetchURL, + Results: allResults, + Fetch: jsFetchURL, + AllowPrivateNetworks: allowPrivateNetworks, }) for _, w := range aiaWarnings { slog.Warn("AIA resolution", "warning", w) diff --git a/cmd/wasm/main.go b/cmd/wasm/main.go index ab10f9e8..687a8e22 100644 --- a/cmd/wasm/main.go +++ b/cmd/wasm/main.go @@ -82,7 +82,7 @@ func readWASMFileData(input readWASMFileDataInput) ([]byte, error) { } // addFiles processes an array of {name, data} objects with optional passwords. -// JS signature: certkitAddFiles(files: Array<{name: string, data: Uint8Array}>, passwords: string) → Promise +// JS signature: certkitAddFiles(files: Array<{name: string, data: Uint8Array}>, passwords: string, allowPrivateNetwork?: boolean) → Promise func addFiles(_ js.Value, args []js.Value) any { if len(args) < 1 { return jsError("certkitAddFiles requires at least 1 argument") @@ -105,6 +105,11 @@ func addFiles(_ js.Value, args []js.Value) any { } passwords = certkit.DeduplicatePasswords(passwords) + allowPrivateNetworks := false + if len(args) >= 3 && args[2].Type() == js.TypeBoolean { + allowPrivateNetworks = args[2].Bool() + } + handler := js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any { resolve := promiseArgs[0] reject := promiseArgs[1] @@ -187,7 +192,7 @@ func addFiles(_ js.Value, args []js.Value) any { storeMu.Lock() defer storeMu.Unlock() - warnings := resolveAIA(ctx, globalStore) + warnings := resolveAIA(ctx, globalStore, allowPrivateNetworks) if warnings == nil { warnings = []string{} diff --git a/connect.go b/connect.go index 47999944..b74c1f7b 100644 --- a/connect.go +++ b/connect.go @@ -17,6 +17,8 @@ import ( "time" ) +const defaultConnectTimeout = 10 * time.Second + // ChainDiagnostic describes a single chain configuration issue found during connection probing. type ChainDiagnostic struct { // Check is the diagnostic identifier (e.g. "root-in-chain", "duplicate-cert", "missing-intermediate"). @@ -167,6 +169,8 @@ type ConnectTLSInput struct { Host string // Port is the TCP port (default: "443"). Port string + // ConnectTimeout is used when ctx has no deadline (default: 10s). + ConnectTimeout time.Duration // ServerName overrides the SNI hostname (defaults to Host). ServerName string // DisableAIA disables automatic AIA certificate fetching when chain verification fails. @@ -184,6 +188,8 @@ type ConnectTLSInput struct { // RootCAs overrides system roots for chain verification. When nil, // the system root pool is used. Useful for testing against private CAs. RootCAs *x509.CertPool + // AllowPrivateNetworks allows AIA/OCSP/CRL fetches to private/internal endpoints. + AllowPrivateNetworks bool } // ClientAuthInfo describes the server's client certificate request (mTLS). @@ -275,8 +281,19 @@ func ConnectTLS(ctx context.Context, input ConnectTLSInput) (*ConnectResult, err addr := net.JoinHostPort(input.Host, port) + connectCtx := ctx + connectCancel := func() {} + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + connectTimeout := input.ConnectTimeout + if connectTimeout == 0 { + connectTimeout = defaultConnectTimeout + } + connectCtx, connectCancel = context.WithTimeout(ctx, connectTimeout) + } + defer connectCancel() + dialer := &net.Dialer{} - conn, err := dialer.DialContext(ctx, "tcp", addr) + conn, err := dialer.DialContext(connectCtx, "tcp", addr) if err != nil { return nil, fmt.Errorf("connecting to %s: %w", addr, err) } @@ -312,13 +329,13 @@ func ConnectTLS(ctx context.Context, input ConnectTLSInput) (*ConnectResult, err tlsConn := tls.Client(conn, tlsConf) defer func() { _ = tlsConn.Close() }() - if deadline, ok := ctx.Deadline(); ok { + if deadline, ok := connectCtx.Deadline(); ok { if err := tlsConn.SetDeadline(deadline); err != nil { return nil, fmt.Errorf("setting deadline: %w", err) } } - handshakeErr := tlsConn.HandshakeContext(ctx) + handshakeErr := tlsConn.HandshakeContext(connectCtx) var tlsAlert tls.AlertError if handshakeErr != nil && clientAuth == nil && errors.As(handshakeErr, &tlsAlert) { // Close the failed TLS connection before opening a new one. @@ -330,7 +347,7 @@ func ConnectTLS(ctx context.Context, input ConnectTLSInput) (*ConnectResult, err // negotiation failure), not for network errors or certificate errors. // Use a dedicated timeout so a stalling server can't hold the // fallback connection open indefinitely. - fallbackCtx, fallbackCancel := context.WithTimeout(ctx, 5*time.Second) + fallbackCtx, fallbackCancel := context.WithTimeout(connectCtx, 5*time.Second) defer fallbackCancel() legacyResult, legacyErr := legacyFallbackConnect(fallbackCtx, legacyFallbackInput{ addr: addr, @@ -419,9 +436,10 @@ func (result *ConnectResult) populate(ctx context.Context, input ConnectTLSInput aiaTimeout = 5 * time.Second } aiaCerts, aiaWarnings := FetchAIACertificates(ctx, FetchAIACertificatesInput{ - Cert: leaf, - Timeout: aiaTimeout, - MaxDepth: 5, + Cert: leaf, + Timeout: aiaTimeout, + MaxDepth: 5, + AllowPrivateNetworks: input.AllowPrivateNetworks, }) for _, w := range aiaWarnings { slog.Debug("AIA fetch warning", "warning", w) @@ -508,8 +526,9 @@ func (result *ConnectResult) populate(ctx context.Context, input ConnectTLSInput } ocspCtx, ocspCancel := context.WithTimeout(ctx, ocspTimeout) ocspResult, ocspErr := CheckOCSP(ocspCtx, CheckOCSPInput{ - Cert: leaf, - Issuer: issuer, + Cert: leaf, + Issuer: issuer, + AllowPrivateNetworks: input.AllowPrivateNetworks, }) ocspCancel() if ocspErr != nil { @@ -527,9 +546,10 @@ func (result *ConnectResult) populate(ctx context.Context, input ConnectTLSInput // Opt-in CRL check on the leaf certificate. if input.CheckCRL && issuer != nil { result.CRL = CheckLeafCRL(ctx, CheckLeafCRLInput{ - Leaf: leaf, - Issuer: issuer, - Timeout: input.CRLTimeout, + Leaf: leaf, + Issuer: issuer, + Timeout: input.CRLTimeout, + AllowPrivateNetworks: input.AllowPrivateNetworks, }) } else if input.CheckCRL { result.CRL = &CRLCheckResult{ @@ -547,6 +567,8 @@ type CheckLeafCRLInput struct { Issuer *x509.Certificate // Timeout is the timeout for fetching the CRL (default: 5s). Timeout time.Duration + // AllowPrivateNetworks allows CRL fetches to private/internal endpoints. + AllowPrivateNetworks bool } // CheckLeafCRL fetches the first HTTP CRL distribution point and checks whether @@ -589,7 +611,7 @@ func CheckLeafCRL(ctx context.Context, input CheckLeafCRLInput) *CRLCheckResult crlCtx, crlCancel := context.WithTimeout(ctx, timeout) defer crlCancel() - data, err := FetchCRL(crlCtx, FetchCRLInput{URL: cdpURL}) + data, err := FetchCRL(crlCtx, FetchCRLInput{URL: cdpURL, AllowPrivateNetworks: input.AllowPrivateNetworks}) if err != nil { slog.Debug("CRL fetch failed", "url", cdpURL, "error", err) return &CRLCheckResult{ diff --git a/connect_test.go b/connect_test.go index 4f533cf8..850692b9 100644 --- a/connect_test.go +++ b/connect_test.go @@ -592,6 +592,48 @@ func TestConnectTLS_CancelledContext(t *testing.T) { } } +func TestConnectTLS_UsesConnectTimeoutWhenContextHasNoDeadline(t *testing.T) { + // WHY: When callers pass a context without a deadline, ConnectTLS should + // honor ConnectTimeout to avoid hanging on stalled handshakes. + t.Parallel() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = listener.Close() }) + + go func() { + for { + conn, acceptErr := listener.Accept() + if acceptErr != nil { + return + } + // Hold the socket open without speaking TLS until the client times out. + time.Sleep(250 * time.Millisecond) + _ = conn.Close() + } + }() + + _, portStr, err := net.SplitHostPort(listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + + start := time.Now() + _, err = ConnectTLS(context.Background(), ConnectTLSInput{ + Host: "127.0.0.1", + Port: portStr, + ConnectTimeout: 50 * time.Millisecond, + }) + if err == nil { + t.Fatal("expected timeout error") + } + if elapsed := time.Since(start); elapsed > 200*time.Millisecond { + t.Fatalf("ConnectTLS took too long without context deadline: %s", elapsed) + } +} + func TestConnectTLS_IPv6Loopback(t *testing.T) { // WHY: ConnectTLS should accept IPv6 hosts (with ServerName override) when available. t.Parallel() @@ -1393,10 +1435,11 @@ func TestConnectTLS_AIAFetch(t *testing.T) { defer cancel() result, err := ConnectTLS(ctx, ConnectTLSInput{ - Host: "127.0.0.1", - Port: portStr, - AIATimeout: 5 * time.Second, - RootCAs: rootPool, + Host: "127.0.0.1", + Port: portStr, + AIATimeout: 5 * time.Second, + RootCAs: rootPool, + AllowPrivateNetworks: true, }) if err != nil { t.Fatalf("ConnectTLS failed: %v", err) @@ -1496,9 +1539,10 @@ func TestConnectTLS_RootInChainDiagnostic(t *testing.T) { defer cancel() result, err := ConnectTLS(ctx, ConnectTLSInput{ - Host: "127.0.0.1", - Port: port, - RootCAs: rootPool, + Host: "127.0.0.1", + Port: port, + RootCAs: rootPool, + AllowPrivateNetworks: true, }) if err != nil { t.Fatalf("ConnectTLS failed: %v", err) @@ -1553,9 +1597,10 @@ func TestConnectTLS_AIAFetch_FallbackURL(t *testing.T) { defer cancel() result, err := ConnectTLS(ctx, ConnectTLSInput{ - Host: "127.0.0.1", - Port: port, - RootCAs: rootPool, + Host: "127.0.0.1", + Port: port, + RootCAs: rootPool, + AllowPrivateNetworks: true, }) if err != nil { t.Fatalf("ConnectTLS failed: %v", err) @@ -1728,10 +1773,11 @@ func TestConnectTLS_AIAFetch_Failure(t *testing.T) { defer cancel() result, err := ConnectTLS(ctx, ConnectTLSInput{ - Host: "127.0.0.1", - Port: portStr, - AIATimeout: tc.aiaTimeout, - RootCAs: rootPool, + Host: "127.0.0.1", + Port: portStr, + AIATimeout: tc.aiaTimeout, + RootCAs: rootPool, + AllowPrivateNetworks: true, }) if err != nil { t.Fatalf("ConnectTLS failed: %v", err) @@ -1791,9 +1837,10 @@ func TestConnectTLS_AIAFetch_WrongIssuer(t *testing.T) { defer cancel() result, err := ConnectTLS(ctx, ConnectTLSInput{ - Host: "127.0.0.1", - Port: port, - RootCAs: rootPool, + Host: "127.0.0.1", + Port: port, + RootCAs: rootPool, + AllowPrivateNetworks: true, }) if err != nil { t.Fatalf("ConnectTLS failed: %v", err) @@ -2054,9 +2101,10 @@ func TestConnectTLS_OCSP(t *testing.T) { defer cancel() result, err := ConnectTLS(ctx, ConnectTLSInput{ - Host: "127.0.0.1", - Port: port, - RootCAs: rootPool, + Host: "127.0.0.1", + Port: port, + RootCAs: rootPool, + AllowPrivateNetworks: true, }) if err != nil { t.Fatalf("ConnectTLS failed: %v", err) @@ -2155,10 +2203,11 @@ func TestConnectTLS_OCSP_SkipAndFailure(t *testing.T) { rootPool.AddCert(ca.Cert) result, err := ConnectTLS(ctx, ConnectTLSInput{ - Host: "127.0.0.1", - Port: port, - RootCAs: rootPool, - DisableOCSP: tc.disableOCSP, + Host: "127.0.0.1", + Port: port, + RootCAs: rootPool, + DisableOCSP: tc.disableOCSP, + AllowPrivateNetworks: true, }) if err != nil { t.Fatalf("ConnectTLS failed: %v", err) @@ -2256,9 +2305,10 @@ func TestConnectTLS_OCSP_InvalidResponses(t *testing.T) { defer cancel() result, err := ConnectTLS(ctx, ConnectTLSInput{ - Host: "127.0.0.1", - Port: port, - RootCAs: rootPool, + Host: "127.0.0.1", + Port: port, + RootCAs: rootPool, + AllowPrivateNetworks: true, }) if err != nil { t.Fatalf("ConnectTLS failed: %v", err) @@ -2412,11 +2462,12 @@ func TestConnectTLS_CRL(t *testing.T) { rootPool.AddCert(ca.Cert) result, err := ConnectTLS(ctx, ConnectTLSInput{ - Host: "127.0.0.1", - Port: port, - RootCAs: rootPool, - CheckCRL: true, - DisableOCSP: true, + Host: "127.0.0.1", + Port: port, + RootCAs: rootPool, + CheckCRL: true, + DisableOCSP: true, + AllowPrivateNetworks: true, }) if err != nil { t.Fatalf("ConnectTLS failed: %v", err) @@ -2617,11 +2668,12 @@ func TestConnectTLS_CRL_AIAFetchedIssuer(t *testing.T) { defer cancel() result, err := ConnectTLS(ctx, ConnectTLSInput{ - Host: "127.0.0.1", - Port: port, - CheckCRL: true, - DisableOCSP: true, - RootCAs: rootPool, + Host: "127.0.0.1", + Port: port, + CheckCRL: true, + DisableOCSP: true, + RootCAs: rootPool, + AllowPrivateNetworks: true, }) if err != nil { t.Fatalf("ConnectTLS failed: %v", err) @@ -3263,10 +3315,11 @@ func TestConnectTLS_CRL_DuplicateLeafInChain(t *testing.T) { defer cancel() result, err := ConnectTLS(ctx, ConnectTLSInput{ - Host: "127.0.0.1", - Port: port, - CheckCRL: true, - RootCAs: rootPool, + Host: "127.0.0.1", + Port: port, + CheckCRL: true, + RootCAs: rootPool, + AllowPrivateNetworks: true, }) if err != nil { t.Fatalf("ConnectTLS failed: %v", err) diff --git a/crl.go b/crl.go index ce49c0e6..9a5563dd 100644 --- a/crl.go +++ b/crl.go @@ -49,13 +49,13 @@ type FetchCRLInput struct { } // FetchCRL downloads a CRL from an HTTP or HTTPS URL. -// By default, the URL is validated against SSRF (literal private/loopback IPs -// are blocked; hostnames are allowed). Set AllowPrivateNetworks to bypass this -// for user-provided URLs. +// By default, the URL is validated against SSRF (literal and DNS-resolved +// private/loopback/link-local/unspecified IPs are blocked). Set +// AllowPrivateNetworks to bypass this for user-provided URLs. // The response is limited to 10 MB. func FetchCRL(ctx context.Context, input FetchCRLInput) ([]byte, error) { if !input.AllowPrivateNetworks { - if err := ValidateAIAURL(input.URL); err != nil { + if err := ValidateAIAURLWithOptions(ctx, ValidateAIAURLInput{URL: input.URL, AllowPrivateNetworks: input.AllowPrivateNetworks}); err != nil { return nil, fmt.Errorf("validating CRL URL: %w", err) } } @@ -68,7 +68,7 @@ func FetchCRL(ctx context.Context, input FetchCRLInput) ([]byte, error) { return fmt.Errorf("stopped after %d redirects", maxRedirects) } if !input.AllowPrivateNetworks { - if err := ValidateAIAURL(req.URL.String()); err != nil { + if err := ValidateAIAURLWithOptions(req.Context(), ValidateAIAURLInput{URL: req.URL.String(), AllowPrivateNetworks: input.AllowPrivateNetworks}); err != nil { return fmt.Errorf("redirect blocked: %w", err) } } diff --git a/crl_test.go b/crl_test.go index 4789c1d3..d2d6bee6 100644 --- a/crl_test.go +++ b/crl_test.go @@ -199,32 +199,28 @@ func TestFetchCRL(t *testing.T) { } tests := []struct { - name string - handler http.HandlerFunc // nil = no server needed (use overrideURL) - overrideURL string // direct URL (bypasses test server) - wantErr string - wantLength int + name string + handler http.HandlerFunc // nil = no server needed (use overrideURL) + overrideURL string // direct URL (bypasses test server) + allowPrivate bool + wantErr string + wantLength int }{ { name: "success", handler: func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write(crlDER) }, - wantLength: len(crlDER), + allowPrivate: true, + wantLength: len(crlDER), }, { name: "non-200 status", handler: func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) }, - wantErr: "HTTP 404", - }, - { - name: "redirect to private IP blocked", - handler: func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, "http://127.0.0.1/crl", http.StatusFound) - }, - wantErr: "redirect blocked", + allowPrivate: true, + wantErr: "HTTP 404", }, { name: "too many redirects", @@ -232,13 +228,19 @@ func TestFetchCRL(t *testing.T) { // Redirect back to self — after 3 hops the client stops. http.Redirect(w, r, r.URL.String(), http.StatusFound) }, - wantErr: "stopped after 3 redirects", + allowPrivate: true, + wantErr: "stopped after 3 redirects", }, { name: "SSRF blocked loopback IP", overrideURL: "http://127.0.0.1/crl", wantErr: "validating CRL URL", }, + { + name: "SSRF blocked localhost hostname", + overrideURL: "http://localhost/crl", + wantErr: "validating CRL URL", + }, { name: "invalid scheme", overrideURL: "ftp://example.com/crl", @@ -254,11 +256,11 @@ func TestFetchCRL(t *testing.T) { if tc.handler != nil { srv := httptest.NewServer(tc.handler) t.Cleanup(srv.Close) - // Replace 127.0.0.1 with localhost to pass SSRF validation. + // Use localhost for deterministic loopback testing. fetchURL = strings.Replace(srv.URL, "127.0.0.1", "localhost", 1) } - data, err := FetchCRL(context.Background(), FetchCRLInput{URL: fetchURL}) + data, err := FetchCRL(context.Background(), FetchCRLInput{URL: fetchURL, AllowPrivateNetworks: tc.allowPrivate}) if tc.wantErr != "" { if err == nil { t.Fatalf("expected error containing %q, got nil", tc.wantErr) @@ -335,7 +337,7 @@ func TestCRLSizeLimit(t *testing.T) { t.Cleanup(srv.Close) url := strings.Replace(srv.URL, "127.0.0.1", "localhost", 1) - _, err := FetchCRL(context.Background(), FetchCRLInput{URL: url}) + _, err := FetchCRL(context.Background(), FetchCRLInput{URL: url, AllowPrivateNetworks: true}) return err }, }, @@ -501,8 +503,9 @@ func TestCheckLeafCRL(t *testing.T) { } result := CheckLeafCRL(context.Background(), CheckLeafCRLInput{ - Leaf: leaf, - Issuer: ca.Cert, + Leaf: leaf, + Issuer: ca.Cert, + AllowPrivateNetworks: true, }) if result.Status != tc.wantStatus { t.Errorf("Status = %q, want %q", result.Status, tc.wantStatus) diff --git a/internal/certstore/aia.go b/internal/certstore/aia.go index cb76872e..b499233a 100644 --- a/internal/certstore/aia.go +++ b/internal/certstore/aia.go @@ -16,11 +16,12 @@ type AIAFetcher func(ctx context.Context, url string) ([]byte, error) // ResolveAIAInput holds parameters for ResolveAIA. type ResolveAIAInput struct { - Store *MemStore - Fetch AIAFetcher - MaxDepth int // 0 defaults to 5 - Concurrency int // 0 defaults to 20; max parallel fetches per round - OnProgress func(completed, total int) // optional; called after each cert's AIA URLs are processed + Store *MemStore + Fetch AIAFetcher + MaxDepth int // 0 defaults to 5 + Concurrency int // 0 defaults to 20; max parallel fetches per round + OnProgress func(completed, total int) // optional; called after each cert's AIA URLs are processed + AllowPrivateNetworks bool // AllowPrivateNetworks allows AIA fetches to private/internal endpoints. } // HasUnresolvedIssuers reports whether any non-root certificate in the store @@ -134,7 +135,7 @@ func ResolveAIA(ctx context.Context, input ResolveAIAInput) []string { } seen[aiaURL] = true - if err := certkit.ValidateAIAURL(aiaURL); err != nil { + if err := certkit.ValidateAIAURLWithOptions(ctx, certkit.ValidateAIAURLInput{URL: aiaURL, AllowPrivateNetworks: input.AllowPrivateNetworks}); err != nil { warnings = append(warnings, fmt.Sprintf( "AIA URL rejected for %q: %v", rec.Cert.Subject.CommonName, err, diff --git a/internal/certstore/aia_test.go b/internal/certstore/aia_test.go index 06ddc467..21247336 100644 --- a/internal/certstore/aia_test.go +++ b/internal/certstore/aia_test.go @@ -221,8 +221,9 @@ func TestResolveAIA_FetchesMissingIssuer(t *testing.T) { } warnings := ResolveAIA(context.Background(), ResolveAIAInput{ - Store: store, - Fetch: fetcher, + Store: store, + Fetch: fetcher, + AllowPrivateNetworks: true, }) if len(warnings) != 0 { @@ -709,8 +710,9 @@ func TestResolveAIA_PKCS7Response(t *testing.T) { } warnings := ResolveAIA(context.Background(), ResolveAIAInput{ - Store: store, - Fetch: fetcher, + Store: store, + Fetch: fetcher, + AllowPrivateNetworks: true, }) if len(warnings) != 0 { diff --git a/internal/inspect.go b/internal/inspect.go index e08be148..617cdad4 100644 --- a/internal/inspect.go +++ b/internal/inspect.go @@ -346,8 +346,9 @@ func privateKeySize(key any) string { // ResolveInspectAIAInput holds parameters for ResolveInspectAIA. type ResolveInspectAIAInput struct { - Results []InspectResult - Fetch certstore.AIAFetcher + Results []InspectResult + Fetch certstore.AIAFetcher + AllowPrivateNetworks bool } // ResolveInspectAIA fetches missing intermediate certificates via AIA for the @@ -377,8 +378,9 @@ func ResolveInspectAIA(ctx context.Context, input ResolveInspectAIAInput) ([]Ins } warnings := certstore.ResolveAIA(ctx, certstore.ResolveAIAInput{ - Store: store, - Fetch: input.Fetch, + Store: store, + Fetch: input.Fetch, + AllowPrivateNetworks: input.AllowPrivateNetworks, }) for _, rec := range store.AllCertsFlat() { diff --git a/internal/verify.go b/internal/verify.go index 26a160c3..824da201 100644 --- a/internal/verify.go +++ b/internal/verify.go @@ -16,17 +16,18 @@ import ( // VerifyInput holds the parsed certificate data and verification options. type VerifyInput struct { - Cert *x509.Certificate - Key crypto.PrivateKey - ExtraCerts []*x509.Certificate - CustomRoots []*x509.Certificate - CheckKeyMatch bool - CheckChain bool - ExpiryDuration time.Duration - TrustStore string - Verbose bool - CheckOCSP bool - CheckCRL bool + Cert *x509.Certificate + Key crypto.PrivateKey + ExtraCerts []*x509.Certificate + CustomRoots []*x509.Certificate + CheckKeyMatch bool + CheckChain bool + ExpiryDuration time.Duration + TrustStore string + Verbose bool + CheckOCSP bool + CheckCRL bool + AllowPrivateNetworks bool } // ChainCert holds display information for one certificate in the chain. @@ -143,6 +144,7 @@ func VerifyCert(ctx context.Context, input *VerifyInput) (*VerifyResult, error) opts.TrustStore = input.TrustStore opts.ExtraIntermediates = input.ExtraCerts opts.CustomRoots = input.CustomRoots + opts.AllowPrivateNetworks = input.AllowPrivateNetworks var bundleErr error bundle, bundleErr = certkit.Bundle(ctx, certkit.BundleInput{ Leaf: cert, @@ -171,7 +173,7 @@ func VerifyCert(ctx context.Context, input *VerifyInput) (*VerifyResult, error) } if issuer != nil { if input.CheckOCSP { - result.OCSP = checkVerifyOCSP(ctx, certkit.CheckOCSPInput{Cert: cert, Issuer: issuer}) + result.OCSP = checkVerifyOCSP(ctx, certkit.CheckOCSPInput{Cert: cert, Issuer: issuer, AllowPrivateNetworks: input.AllowPrivateNetworks}) if result.OCSP.Status == "revoked" { msg := "certificate is revoked (OCSP)" if result.OCSP.RevokedAt != nil { @@ -185,8 +187,9 @@ func VerifyCert(ctx context.Context, input *VerifyInput) (*VerifyResult, error) } if input.CheckCRL { result.CRL = certkit.CheckLeafCRL(ctx, certkit.CheckLeafCRLInput{ - Leaf: cert, - Issuer: issuer, + Leaf: cert, + Issuer: issuer, + AllowPrivateNetworks: input.AllowPrivateNetworks, }) if result.CRL.Status == "revoked" { result.Errors = append(result.Errors, fmt.Sprintf("certificate is revoked (CRL, %s)", result.CRL.Detail)) diff --git a/internal/verify_test.go b/internal/verify_test.go index cfa162a0..7748552f 100644 --- a/internal/verify_test.go +++ b/internal/verify_test.go @@ -1680,12 +1680,13 @@ func TestVerifyCert_RevocationBehavior(t *testing.T) { } result, err := VerifyCert(context.Background(), &VerifyInput{ - Cert: leaf.cert, - CheckChain: true, - TrustStore: "custom", - CustomRoots: []*x509.Certificate{ca.cert}, - CheckOCSP: tc.checkOCSP, - CheckCRL: tc.checkCRL, + Cert: leaf.cert, + CheckChain: true, + TrustStore: "custom", + CustomRoots: []*x509.Certificate{ca.cert}, + CheckOCSP: tc.checkOCSP, + CheckCRL: tc.checkCRL, + AllowPrivateNetworks: true, }) if err != nil { t.Fatal(err) @@ -1816,13 +1817,14 @@ func TestVerifyCert_RevocationIssuerIntermediate(t *testing.T) { leaf.cert.CRLDistributionPoints = []string{strings.Replace(crlServer.URL, "127.0.0.1", "localhost", 1)} result, err := VerifyCert(context.Background(), &VerifyInput{ - Cert: leaf.cert, - CheckChain: true, - TrustStore: "custom", - CustomRoots: []*x509.Certificate{root.cert}, - ExtraCerts: []*x509.Certificate{intermediate.cert}, - CheckOCSP: true, - CheckCRL: true, + Cert: leaf.cert, + CheckChain: true, + TrustStore: "custom", + CustomRoots: []*x509.Certificate{root.cert}, + ExtraCerts: []*x509.Certificate{intermediate.cert}, + CheckOCSP: true, + CheckCRL: true, + AllowPrivateNetworks: true, }) if err != nil { t.Fatal(err) @@ -1854,11 +1856,12 @@ func TestVerifyCert_RevocationWithoutChain(t *testing.T) { leaf := newRSALeaf(t, ca, "nochain.example.com", []string{"nochain.example.com"}, nil) result, err := VerifyCert(context.Background(), &VerifyInput{ - Cert: leaf.cert, - CheckOCSP: true, - CheckCRL: true, - CheckChain: false, - TrustStore: "custom", + Cert: leaf.cert, + CheckOCSP: true, + CheckCRL: true, + CheckChain: false, + TrustStore: "custom", + AllowPrivateNetworks: true, CustomRoots: []*x509.Certificate{ ca.cert, }, @@ -2076,11 +2079,12 @@ func TestVerifyCert_OCSPStatus(t *testing.T) { defer cancel() } result, err := VerifyCert(ctx, &VerifyInput{ - Cert: leaf.cert, - CheckChain: true, - TrustStore: "custom", - CustomRoots: []*x509.Certificate{ca.cert}, - CheckOCSP: true, + Cert: leaf.cert, + CheckChain: true, + TrustStore: "custom", + CustomRoots: []*x509.Certificate{ca.cert}, + CheckOCSP: true, + AllowPrivateNetworks: true, }) if err != nil { t.Fatal(err) @@ -2140,11 +2144,12 @@ func TestVerifyCert_OCSPStatus_ECDSA(t *testing.T) { leaf.cert.OCSPServer = []string{strings.Replace(server.URL, "127.0.0.1", "localhost", 1)} result, err := VerifyCert(context.Background(), &VerifyInput{ - Cert: leaf.cert, - CheckChain: true, - TrustStore: "custom", - CustomRoots: []*x509.Certificate{ca.cert}, - CheckOCSP: true, + Cert: leaf.cert, + CheckChain: true, + TrustStore: "custom", + CustomRoots: []*x509.Certificate{ca.cert}, + CheckOCSP: true, + AllowPrivateNetworks: true, }) if err != nil { t.Fatal(err) @@ -2326,11 +2331,12 @@ func TestVerifyCert_CRLStatus(t *testing.T) { defer cancel() } result, err := VerifyCert(ctx, &VerifyInput{ - Cert: leaf.cert, - CheckChain: true, - TrustStore: "custom", - CustomRoots: []*x509.Certificate{ca.cert}, - CheckCRL: true, + Cert: leaf.cert, + CheckChain: true, + TrustStore: "custom", + CustomRoots: []*x509.Certificate{ca.cert}, + CheckCRL: true, + AllowPrivateNetworks: true, }) if err != nil { t.Fatal(err) @@ -2386,11 +2392,12 @@ func TestVerifyCert_CRLStatus_ECDSA(t *testing.T) { leaf.cert.CRLDistributionPoints = []string{strings.Replace(server.URL, "127.0.0.1", "localhost", 1)} result, err := VerifyCert(context.Background(), &VerifyInput{ - Cert: leaf.cert, - CheckChain: true, - TrustStore: "custom", - CustomRoots: []*x509.Certificate{ca.cert}, - CheckCRL: true, + Cert: leaf.cert, + CheckChain: true, + TrustStore: "custom", + CustomRoots: []*x509.Certificate{ca.cert}, + CheckCRL: true, + AllowPrivateNetworks: true, }) if err != nil { t.Fatal(err) @@ -2493,12 +2500,13 @@ func TestVerifyCert_RevocationCombined(t *testing.T) { leaf.cert.CRLDistributionPoints = []string{strings.Replace(crlServer.URL, "127.0.0.1", "localhost", 1)} result, err := VerifyCert(context.Background(), &VerifyInput{ - Cert: leaf.cert, - CheckChain: true, - TrustStore: "custom", - CustomRoots: []*x509.Certificate{ca.cert}, - CheckOCSP: true, - CheckCRL: true, + Cert: leaf.cert, + CheckChain: true, + TrustStore: "custom", + CustomRoots: []*x509.Certificate{ca.cert}, + CheckOCSP: true, + CheckCRL: true, + AllowPrivateNetworks: true, }) if err != nil { t.Fatal(err) diff --git a/ocsp.go b/ocsp.go index 8b9b35c7..968d9922 100644 --- a/ocsp.go +++ b/ocsp.go @@ -18,6 +18,8 @@ type CheckOCSPInput struct { Cert *x509.Certificate // Issuer is the issuer certificate (used to build the OCSP request). Issuer *x509.Certificate + // AllowPrivateNetworks allows OCSP requests to private/internal endpoints. + AllowPrivateNetworks bool } // OCSPResult contains the OCSP response details. @@ -56,7 +58,7 @@ func CheckOCSP(ctx context.Context, input CheckOCSPInput) (*OCSPResult, error) { responderURL := input.Cert.OCSPServer[0] - if err := ValidateAIAURL(responderURL); err != nil { + if err := ValidateAIAURLWithOptions(ctx, ValidateAIAURLInput{URL: responderURL, AllowPrivateNetworks: input.AllowPrivateNetworks}); err != nil { return nil, fmt.Errorf("validating OCSP responder URL: %w", err) } @@ -78,7 +80,7 @@ func CheckOCSP(ctx context.Context, input CheckOCSPInput) (*OCSPResult, error) { if len(via) >= maxRedirects { return fmt.Errorf("stopped after %d redirects", maxRedirects) } - if err := ValidateAIAURL(req.URL.String()); err != nil { + if err := ValidateAIAURLWithOptions(req.Context(), ValidateAIAURLInput{URL: req.URL.String(), AllowPrivateNetworks: input.AllowPrivateNetworks}); err != nil { return fmt.Errorf("redirect blocked: %w", err) } return nil diff --git a/ocsp_test.go b/ocsp_test.go index 9072c271..aa1969e1 100644 --- a/ocsp_test.go +++ b/ocsp_test.go @@ -89,8 +89,9 @@ func TestCheckOCSP_MockResponse(t *testing.T) { defer cancel() result, err := CheckOCSP(ctx, CheckOCSPInput{ - Cert: leafCert, - Issuer: ca.Cert, + Cert: leafCert, + Issuer: ca.Cert, + AllowPrivateNetworks: true, }) if err != nil { t.Fatalf("CheckOCSP failed: %v", err) @@ -169,3 +170,25 @@ func TestFormatOCSPResult(t *testing.T) { } } } + +func TestCheckOCSP_PrivateEndpointBlockedByDefault(t *testing.T) { + t.Parallel() + + ca := generateTestCA(t, "OCSP Private Endpoint CA") + leaf := generateTestLeafCert(t, ca, withOCSPServer("http://localhost/ocsp")) + leafCert, err := x509.ParseCertificate(leaf.DER) + if err != nil { + t.Fatal(err) + } + + _, err = CheckOCSP(context.Background(), CheckOCSPInput{ + Cert: leafCert, + Issuer: ca.Cert, + }) + if err == nil { + t.Fatal("expected error for private OCSP endpoint") + } + if !strings.Contains(err.Error(), "validating OCSP responder URL") { + t.Fatalf("error = %q, want validating OCSP responder URL", err.Error()) + } +} diff --git a/web/public/app.js b/web/public/app.js index 510b0a98..bb652d87 100644 --- a/web/public/app.js +++ b/web/public/app.js @@ -24,11 +24,17 @@ const filterExpired = document.getElementById("filter-expired"); const filterUnmatched = document.getElementById("filter-unmatched"); const filterUntrusted = document.getElementById("filter-untrusted"); const selectAll = document.getElementById("select-all"); +const scanAllowPrivateNetwork = document.getElementById( + "scan-allow-private-network", +); // DOM references — Inspect page const inspectDropZone = document.getElementById("inspect-drop-zone"); const inspectFileInput = document.getElementById("inspect-file-input"); const inspectPasswordsInput = document.getElementById("inspect-passwords"); +const inspectAllowPrivateNetwork = document.getElementById( + "inspect-allow-private-network", +); const inspectStatusBar = document.getElementById("inspect-status"); const inspectStatusText = document.getElementById("inspect-status-text"); const inspectResultsSection = document.getElementById("inspect-results"); @@ -303,6 +309,7 @@ async function addFileObjects(fileObjects, statusMessage) { const resultJSON = await certkitAddFiles( fileObjects, passwordsInput.value.trim(), + scanAllowPrivateNetwork.checked, ); const results = JSON.parse(resultJSON); @@ -436,6 +443,7 @@ async function inspectFileObjects(fileObjects) { const resultJSON = await certkitInspect( fileObjects, inspectPasswordsInput.value.trim(), + inspectAllowPrivateNetwork.checked, ); const results = JSON.parse(resultJSON); diff --git a/web/public/index.html b/web/public/index.html index a3960d9b..89451cd7 100644 --- a/web/public/index.html +++ b/web/public/index.html @@ -144,6 +144,12 @@

certkit

placeholder="password1, password2, ..." /> +
+ +