diff --git a/README.md b/README.md index 4f40391..a5661e2 100644 --- a/README.md +++ b/README.md @@ -152,16 +152,29 @@ If the private key is password protected, the password can be provided via the C environment variable or will be prompted on stdin. Examples: - https-wrench certinfo --tls-endpoint example.com:443 + + # Print info about local certificates and keys + # with optional CA and public key match validation + https-wrench certinfo --cert-bundle ./bundle.pem --key-file ./key.pem https-wrench certinfo --cert-bundle ./bundle.pem https-wrench certinfo --key-file ./key.pem + https-wrench certinfo --ca-bundle ./ca-bundle.pem --cert-bundle ./bundle.pem --key-file ./key.pem + + # Print info about remote certificates + # with optional CA and public key match validation + + https-wrench certinfo --tls-endpoint example.com:443 https-wrench certinfo --tls-endpoint example.com:443 --key-file ./key.pem https-wrench certinfo --tls-endpoint example.com:443 --cert-bundle ./bundle.pem --key-file ./key.pem https-wrench certinfo --tls-endpoint example.com:443 --tls-servername www.example.com https-wrench certinfo --tls-endpoint [2001:db8::1]:443 --tls-insecure https-wrench certinfo --ca-bundle ./ca-bundle.pem --tls-endpoint example.com:443 - https-wrench certinfo --ca-bundle ./ca-bundle.pem --cert-bundle ./bundle.pem --key-file ./key.pem + + # Print info about remote certificates + # with optional display of negotiated and supported TLS protocols and ciphers + + https-wrench certinfo --tls-endpoint example.com:443 --tls-info Usage: https-wrench certinfo [flags] @@ -175,6 +188,7 @@ Flags: --tls-endpoint string TLS enabled endpoint exposing certificates to fetch. Forms: 'host:port', '[host]:port'. IPv6 addresses must be enclosed in square brackets, as in '[::1]:80' + --tls-info Show negotiated TLS info and probe supported protocols/ciphers --tls-insecure Skip certificate validation when connecting to a TLS endpoint --tls-servername string ServerName to use when connecting to an SNI enabled TLS endpoint diff --git a/internal/certinfo/certinfo.go b/internal/certinfo/certinfo.go index aaa28fd..a62026f 100644 --- a/internal/certinfo/certinfo.go +++ b/internal/certinfo/certinfo.go @@ -50,6 +50,16 @@ type Config struct { TLSServerName string // TLSInsecure indicates if certificate verification should be skipped. TLSInsecure bool + // TLSInfoRequested indicates if negotiated TLS info and supported protocol/cipher scan was requested. + TLSInfoRequested bool + // NegotiatedProtocol is the TLS protocol version negotiated in the primary connection. + NegotiatedProtocol string + // NegotiatedCipher is the TLS cipher suite negotiated in the primary connection. + NegotiatedCipher string + // ProbedProtocols maps a TLS protocol name to whether the remote endpoint supports it. + ProbedProtocols map[string]bool + // ProbedCiphers is a slice of ciphers that were probed against the endpoint. + ProbedCiphers []ProbedCipher } // Reader defines an interface for reading files and passwords. @@ -198,3 +208,18 @@ func (c *Config) SetTLSServerName(serverName string) *Config { return c } + +// ProbedCipher holds the result of a single cipher suite probe. +type ProbedCipher struct { + ID uint16 + Name string + Protocol string + Insecure bool + Supported bool +} + +// SetTLSInfoRequested sets whether to probe and print remote TLS protocol/cipher information. +func (c *Config) SetTLSInfoRequested(requested bool) *Config { + c.TLSInfoRequested = requested + return c +} diff --git a/internal/certinfo/certinfo_handlers.go b/internal/certinfo/certinfo_handlers.go index cab2d8c..ec58a52 100644 --- a/internal/certinfo/certinfo_handlers.go +++ b/internal/certinfo/certinfo_handlers.go @@ -5,6 +5,7 @@ Copyright © 2025 Zeno Belli xeno@os76.xyz package certinfo import ( + "cmp" "crypto/sha256" "crypto/tls" "crypto/x509" @@ -12,8 +13,10 @@ import ( "fmt" "io" "net" + "slices" "strconv" "strings" + "sync" "time" "github.com/charmbracelet/lipgloss" @@ -45,6 +48,11 @@ func (c *Config) PrintData(w io.Writer) error { return err } + if c.TLSInfoRequested { + _ = c.ProbeTLSInfo() + c.printTLSInfo(w, ks, sl, sv) + } + return c.printCACerts(w, ks, sl, sv) } @@ -197,6 +205,8 @@ func (c *Config) GetRemoteCerts() error { cs := conn.ConnectionState() c.TLSEndpointCerts = cs.PeerCertificates + c.NegotiatedProtocol = tlsVersionToString(cs.Version) + c.NegotiatedCipher = tls.CipherSuiteName(cs.CipherSuite) // do not verify server certificates if TLSInsecure if c.TLSInsecure { @@ -277,3 +287,276 @@ func CertsToTables(w io.Writer, certs []*x509.Certificate) { t.ClearRows() } } + +// tlsVersionToString converts TLS version uint16 to standard string representation. +func tlsVersionToString(version uint16) string { + switch version { + case tls.VersionTLS10: + return "TLS 1.0" + + case tls.VersionTLS11: + return "TLS 1.1" + + case tls.VersionTLS12: + return "TLS 1.2" + + case tls.VersionTLS13: + return "TLS 1.3" + + default: + return fmt.Sprintf("Unknown (0x%04x)", version) + } +} + +// probeProtocol tests whether the TLS endpoint supports a specific TLS protocol version. +func (c *Config) probeProtocol(version uint16) bool { + tlsConfig := &tls.Config{ + MinVersion: version, + MaxVersion: version, + InsecureSkipVerify: true, + } + + if c.TLSServerName != emptyString { + tlsConfig.ServerName = c.TLSServerName + } + + serverAddr := net.JoinHostPort(c.TLSEndpointHost, c.TLSEndpointPort) + + dialer := &net.Dialer{ + Timeout: TLSTimeout, + } + + conn, err := tls.DialWithDialer(dialer, "tcp", serverAddr, tlsConfig) + if err == nil { + conn.Close() + + return true + } + + return false +} + +// probeCipher tests whether a specific TLS 1.0-1.2 cipher suite is supported. +func (c *Config) probeCipher(suite *tls.CipherSuite) (bool, string) { + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS10, + MaxVersion: tls.VersionTLS12, + CipherSuites: []uint16{suite.ID}, + InsecureSkipVerify: true, + } + + if c.TLSServerName != emptyString { + tlsConfig.ServerName = c.TLSServerName + } + + serverAddr := net.JoinHostPort(c.TLSEndpointHost, c.TLSEndpointPort) + + dialer := &net.Dialer{ + Timeout: TLSTimeout, + } + + conn, err := tls.DialWithDialer(dialer, "tcp", serverAddr, tlsConfig) + if err == nil { + state := conn.ConnectionState() + + conn.Close() + + return true, tlsVersionToString(state.Version) + } + + return false, "" +} + +// ProbeTLSInfo concurrently scans the endpoint for supported TLS versions and cipher suites. +func (c *Config) ProbeTLSInfo() error { + if c.TLSEndpoint == emptyString { + return nil + } + + c.ProbedProtocols = make(map[string]bool) + + // 1. Probe protocols + versions := []uint16{tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13} + + for _, v := range versions { + supported := c.probeProtocol(v) + + c.ProbedProtocols[tlsVersionToString(v)] = supported + } + + // 2. Probe ciphers concurrently + suites := append(tls.CipherSuites(), tls.InsecureCipherSuites()...) + + c.ProbedCiphers = c.probeCiphersConcurrently(suites) + + return nil +} + +// probeCiphersConcurrently manages the worker pool to concurrently scan cipher suites. +// +//nolint:gocognit,revive,wsl +func (c *Config) probeCiphersConcurrently(suites []*tls.CipherSuite) []ProbedCipher { + type job struct { + suite *tls.CipherSuite + } + + type result struct { + probed ProbedCipher + } + + numJobs := len(suites) + jobs := make(chan job, numJobs) + results := make(chan result, numJobs) + + // Start 10 concurrent workers + numWorkers := 10 + if numWorkers > numJobs { + numWorkers = numJobs + } + + var wg sync.WaitGroup + + for w := 0; w < numWorkers; w++ { + wg.Add(1) + + go func() { + defer wg.Done() + + for j := range jobs { + suite := j.suite + isTLS13 := false + + for _, v := range suite.SupportedVersions { + if v == tls.VersionTLS13 { + isTLS13 = true + + break + } + } + + var ( + supported bool + protoName string + ) + + if isTLS13 { + supported = c.ProbedProtocols["TLS 1.3"] + protoName = "TLS 1.3" + } else { + ok, name := c.probeCipher(suite) + + supported = ok + protoName = name + } + + results <- result{ + probed: ProbedCipher{ + ID: suite.ID, + Name: suite.Name, + Protocol: protoName, + Insecure: suite.Insecure, + Supported: supported, + }, + } + } + }() + } + + // Queue up all jobs + for _, s := range suites { + jobs <- job{suite: s} + } + + close(jobs) + + // Wait for workers to finish + wg.Wait() + + close(results) + + // Collect results + var list []ProbedCipher + + for r := range results { + list = append(list, r.probed) + } + + // Sort ciphers by Name for stable output + slices.SortFunc(list, func(a, b ProbedCipher) int { + return cmp.Compare(a.Name, b.Name) + }) + + return list +} + +// printTLSInfo formats and prints the scanned TLS info tables. +func (c *Config) printTLSInfo(w io.Writer, ks, _, _ lipgloss.Style) { + if !c.TLSInfoRequested { + return + } + + // 1. Render Negotiated Connection details + fmt.Fprintln(w, style.LgSprintf(ks, "Negotiated TLS Connection")) + + t1 := table.New().Border(style.LGDefBorder) + t1.Row(style.CertKeyP4.Render("Protocol Version"), style.CertValue.Render(c.NegotiatedProtocol)) + t1.Row(style.CertKeyP4.Render("Cipher Suite"), style.CertValue.Render(c.NegotiatedCipher)) + fmt.Fprintln(w, t1.Render()) + + // 2. Render Supported Protocol Versions Scan + fmt.Fprintln(w, style.LgSprintf(ks, "Protocol Support Scan")) + + t2 := table.New().Border(style.LGDefBorder) + protoOrder := []string{"TLS 1.3", "TLS 1.2", "TLS 1.1", "TLS 1.0"} + + for _, protoName := range protoOrder { + supported := c.ProbedProtocols[protoName] + statusStr, statusStyle := "No", style.BoolFalse.Render + + if supported { + statusStr, statusStyle = "Yes", style.BoolTrue.Render + } + + t2.Row(style.CertKeyP4.Render(protoName), statusStyle(statusStr)) + } + + fmt.Fprintln(w, t2.Render()) + + // 3. Render Probed Cipher Suites + fmt.Fprintln(w, style.LgSprintf(ks, "Cipher Suite Scan")) + + slRender := style.CertKeyP4.Bold(true).Render + slNoPadRender := style.CertKeyP4.PaddingLeft(0).Bold(true).Render + t3 := table.New().Border(style.LGDefBorder).Headers( + slRender("Cipher Suite Name"), + slNoPadRender("Protocol"), + slNoPadRender("Status"), + slNoPadRender("Security"), + ) + + var hasSupported bool + + for _, pc := range c.ProbedCiphers { + if pc.Supported { + hasSupported = true + secStr, secStyle := "Secure", style.BoolTrue.Render + + if pc.Insecure { + secStr, secStyle = "Insecure", style.Warn.Render + } + + t3.Row( + style.CertKeyP4.Render(pc.Name), + style.CertValue.Render(pc.Protocol), + style.BoolTrue.Render("Yes"), + secStyle(secStr), + ) + } + } + + if !hasSupported { + t3.Row(style.CertKeyP4.Render("No supported cipher suites found"), "", "", "") + } + + fmt.Fprintln(w, t3.Render()) +} diff --git a/internal/certinfo/certinfo_handlers_test.go b/internal/certinfo/certinfo_handlers_test.go index 6365efe..51b58cd 100644 --- a/internal/certinfo/certinfo_handlers_test.go +++ b/internal/certinfo/certinfo_handlers_test.go @@ -25,7 +25,7 @@ func TestCertinfo_GetRemoteCerts(t *testing.T) { srvCfg: demoHTTPServerConfig{ serverAddr: "localhost:46301", serverName: "example.com", - serverCertFile: RSASampleCertFile, + serverCertFile: RSASampleCertBundleFile, serverKeyFile: RSASampleCertKeyFile, }, caCertFile: RSACaCertFile, @@ -132,8 +132,6 @@ func TestCertinfo_GetRemoteCerts(t *testing.T) { for _, tc := range tests { tt := tc t.Run(tt.desc, func(t *testing.T) { - t.Parallel() - ts, err := NewHTTPSTestServer(tt.srvCfg) require.NoError(t, err) @@ -472,8 +470,6 @@ type printDataTestCase struct { } func runPrintDataSubtest(t *testing.T, tt printDataTestCase) { - t.Parallel() - buffer := bytes.Buffer{} cc, err := New() diff --git a/internal/certinfo/certinfo_test.go b/internal/certinfo/certinfo_test.go index 8ac7181..fd33470 100644 --- a/internal/certinfo/certinfo_test.go +++ b/internal/certinfo/certinfo_test.go @@ -1,9 +1,15 @@ package certinfo import ( + "bytes" + "crypto/tls" "fmt" + "net/http" + "net/http/httptest" + "net/url" "testing" + "github.com/charmbracelet/lipgloss" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" ) @@ -412,3 +418,262 @@ func TestCertinfo_SetTLSEndpoint(t *testing.T) { }) } } + +func TestCertinfo_ProbeTLSInfo(t *testing.T) { + // Start a local TLS server + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Parse host and port from server URL + u, err := url.Parse(server.URL) + require.NoError(t, err) + + cc, err := New() + require.NoError(t, err) + + cc.SetTLSInfoRequested(true) + require.True(t, cc.TLSInfoRequested) + + // Skip verification to allow connection to the self-signed test server + cc.SetTLSInsecure(true) + cc.SetTLSServerName("example.com") + + err = cc.SetTLSEndpoint(u.Host) + require.NoError(t, err) + + err = cc.ProbeTLSInfo() + require.NoError(t, err) + + // Since it's a local TLS server run by Go's httptest, it supports TLS 1.3 or TLS 1.2 + hasSupported := false + + for _, supported := range cc.ProbedProtocols { + if supported { + hasSupported = true + + break + } + } + + require.True(t, hasSupported) + require.NotEmpty(t, cc.ProbedCiphers) +} + +func TestCertinfo_ProbeTLSInfo_NotRequested(t *testing.T) { + t.Parallel() + + cc, err := New() + require.NoError(t, err) + + cc.SetTLSInfoRequested(false) + require.False(t, cc.TLSInfoRequested) + + err = cc.ProbeTLSInfo() + require.NoError(t, err) + require.Empty(t, cc.NegotiatedProtocol) +} + +func TestCertinfo_ProbeTLSInfo_NoEndpoint(t *testing.T) { + t.Parallel() + + cc, err := New() + require.NoError(t, err) + + cc.SetTLSInfoRequested(true) + + err = cc.ProbeTLSInfo() + require.NoError(t, err) + require.Empty(t, cc.ProbedProtocols) +} + +func TestCertinfo_ProbeTLSInfo_Unreachable(t *testing.T) { + t.Parallel() + + cc, err := New() + require.NoError(t, err) + + cc.SetTLSInfoRequested(true) + cc.SetTLSInsecure(true) + + // Manually populate fields to bypass pre-flight certificate fetch in SetTLSEndpoint + cc.TLSEndpoint = "127.0.0.1:54321" + cc.TLSEndpointHost = "127.0.0.1" + cc.TLSEndpointPort = "54321" + + err = cc.ProbeTLSInfo() + require.NoError(t, err) + + // When unreachable, all scanned protocols should be unsupported + for _, supported := range cc.ProbedProtocols { + require.False(t, supported) + } +} + +func TestCertinfo_GettersAndSetters(t *testing.T) { + t.Parallel() + + cc, err := New() + require.NoError(t, err) + + cc.NegotiatedProtocol = "TLS 1.3" + require.Equal(t, "TLS 1.3", cc.NegotiatedProtocol) + + cc.NegotiatedCipher = "TLS_AES_128_GCM_SHA256" + require.Equal(t, "TLS_AES_128_GCM_SHA256", cc.NegotiatedCipher) +} + +func TestCertinfo_PrintTLSInfo_NotRequested(t *testing.T) { + t.Parallel() + + cc, err := New() + require.NoError(t, err) + + cc.TLSInfoRequested = false + + var buf bytes.Buffer + + ks := lipgloss.NewStyle() + sl := lipgloss.NewStyle() + sv := lipgloss.NewStyle() + + cc.printTLSInfo(&buf, ks, sl, sv) + require.Empty(t, buf.String()) +} + +func TestCertinfo_PrintTLSInfo_HappyPath(t *testing.T) { + t.Parallel() + + cc, err := New() + require.NoError(t, err) + + cc.TLSInfoRequested = true + cc.NegotiatedProtocol = "TLS 1.3" + cc.NegotiatedCipher = "TLS_AES_128_GCM_SHA256" + cc.ProbedProtocols = map[string]bool{ + "TLS 1.3": true, + "TLS 1.2": true, + "TLS 1.1": false, + "TLS 1.0": false, + } + cc.ProbedCiphers = []ProbedCipher{ + { + Name: "TLS_AES_128_GCM_SHA256", + Protocol: "TLS 1.3", + Supported: true, + Insecure: false, + }, + { + Name: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", + Protocol: "TLS 1.2", + Supported: true, + Insecure: true, + }, + } + + var buf bytes.Buffer + + ks := lipgloss.NewStyle() + sl := lipgloss.NewStyle() + sv := lipgloss.NewStyle() + + cc.printTLSInfo(&buf, ks, sl, sv) + + got := buf.String() + require.Contains(t, got, "Negotiated TLS Connection") + require.Contains(t, got, "TLS 1.3") + require.Contains(t, got, "TLS_AES_128_GCM_SHA256") + require.Contains(t, got, "Protocol Support Scan") + require.Contains(t, got, "Cipher Suite Scan") + require.Contains(t, got, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256") + require.Contains(t, got, "Insecure") + require.Contains(t, got, "Secure") +} + +func TestCertinfo_PrintTLSInfo_NoSupportedCiphers(t *testing.T) { + t.Parallel() + + cc, err := New() + require.NoError(t, err) + + cc.TLSInfoRequested = true + cc.NegotiatedProtocol = "TLS 1.2" + cc.NegotiatedCipher = "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA" + cc.ProbedProtocols = map[string]bool{ + "TLS 1.3": false, + "TLS 1.2": true, + "TLS 1.1": false, + "TLS 1.0": false, + } + cc.ProbedCiphers = []ProbedCipher{ + { + Name: "TLS_AES_128_GCM_SHA256", + Protocol: "TLS 1.3", + Supported: false, + Insecure: false, + }, + } + + var buf bytes.Buffer + + ks := lipgloss.NewStyle() + sl := lipgloss.NewStyle() + sv := lipgloss.NewStyle() + + cc.printTLSInfo(&buf, ks, sl, sv) + + got := buf.String() + require.Contains(t, got, "Negotiated TLS Connection") + require.Contains(t, got, "No supported cipher suites found") +} + +func TestCertinfo_TLSVersionToString_Unknown(t *testing.T) { + t.Parallel() + + res := tlsVersionToString(0x1234) + require.Equal(t, "Unknown (0x1234)", res) +} + +func TestCertinfo_ProbeTLSInfo_SingleCipher(t *testing.T) { + t.Parallel() + + cc, err := New() + require.NoError(t, err) + + cc.TLSInfoRequested = true + cc.TLSEndpoint = "127.0.0.1:54321" + cc.TLSEndpointHost = "127.0.0.1" + cc.TLSEndpointPort = "54321" + + // Mock only 1 cipher suite to trigger numWorkers > numJobs inside probeCiphersConcurrently + ciphers := []*tls.CipherSuite{ + { + ID: tls.TLS_AES_128_GCM_SHA256, + Name: "TLS_AES_128_GCM_SHA256", + Insecure: false, + }, + } + + res := cc.probeCiphersConcurrently(ciphers) + require.Len(t, res, 1) + require.Equal(t, "TLS_AES_128_GCM_SHA256", res[0].Name) + require.False(t, res[0].Supported) +} + +func TestCertinfo_PrintData_WithTLSInfo(t *testing.T) { + t.Parallel() + + cc, err := New() + require.NoError(t, err) + + cc.TLSInfoRequested = true + cc.NegotiatedProtocol = "TLS 1.3" + cc.NegotiatedCipher = "TLS_AES_128_GCM_SHA256" + + var buf bytes.Buffer + + err = cc.PrintData(&buf) + require.NoError(t, err) + require.Contains(t, buf.String(), "Negotiated TLS Connection") +} diff --git a/internal/certinfo/main_test.go b/internal/certinfo/main_test.go index b7515ba..6c799ca 100644 --- a/internal/certinfo/main_test.go +++ b/internal/certinfo/main_test.go @@ -108,13 +108,13 @@ func TestMain(m *testing.M) { panic(errDataDir) } - // Cleanup (register early so panics in setup still clean up what was created) defer func() { filesToDel := []string{ RSACaCertKeyFile, RSACaCertFile, RSASampleCertFile, RSASampleCertKeyFile, + RSASampleCertBundleFile, } for _, fileToDel := range filesToDel { err := os.Remove(fileToDel) @@ -200,6 +200,15 @@ func generateRSACertificateData() { if err != nil { fmt.Print(err) } + + RSASampleCertBundleFile, err = createTmpFileWithContent( + testdataDir, + "RSASampleCertBundle", + []byte(RSASampleCertPEMString+RSACaCertPEMString), + ) + if err != nil { + fmt.Print(err) + } } func generateRSACaData() { diff --git a/internal/cmd/certinfo.go b/internal/cmd/certinfo.go index 6e370c5..35cdc45 100644 --- a/internal/cmd/certinfo.go +++ b/internal/cmd/certinfo.go @@ -14,6 +14,7 @@ var ( tlsEndpoint string tlsServerName string tlsInsecure bool + tlsInfo bool keyPwEnvVar = "CERTINFO_PKEY_PW" ) @@ -33,16 +34,29 @@ If the private key is password protected, the password can be provided via the C environment variable or will be prompted on stdin. Examples: - https-wrench certinfo --tls-endpoint example.com:443 + + # Print info about local certificates and keys + # with optional CA and public key match validation + https-wrench certinfo --cert-bundle ./bundle.pem --key-file ./key.pem https-wrench certinfo --cert-bundle ./bundle.pem https-wrench certinfo --key-file ./key.pem + https-wrench certinfo --ca-bundle ./ca-bundle.pem --cert-bundle ./bundle.pem --key-file ./key.pem + + # Print info about remote certificates + # with optional CA and public key match validation + + https-wrench certinfo --tls-endpoint example.com:443 https-wrench certinfo --tls-endpoint example.com:443 --key-file ./key.pem https-wrench certinfo --tls-endpoint example.com:443 --cert-bundle ./bundle.pem --key-file ./key.pem https-wrench certinfo --tls-endpoint example.com:443 --tls-servername www.example.com https-wrench certinfo --tls-endpoint [2001:db8::1]:443 --tls-insecure https-wrench certinfo --ca-bundle ./ca-bundle.pem --tls-endpoint example.com:443 - https-wrench certinfo --ca-bundle ./ca-bundle.pem --cert-bundle ./bundle.pem --key-file ./key.pem + + # Print info about remote certificates + # with optional display of negotiated and supported TLS protocols and ciphers + + https-wrench certinfo --tls-endpoint example.com:443 --tls-info `, Run: func(cmd *cobra.Command, _ []string) { caBundleValue := viper.GetString("ca-bundle") @@ -55,6 +69,11 @@ Examples: return } + if tlsInfo && tlsEndpoint == "" { + cmd.Print("Error: --tls-info requires --tls-endpoint\n") + return + } + // display the help if none of the main flags is set if len(caBundleValue+certBundleValue+keyFileValue+tlsEndpoint) == 0 { _ = cmd.Help() @@ -75,13 +94,14 @@ Examples: cmd.Printf("Error importing Certificate bundle from file: %s", err) } - certinfoCfg.SetTLSInsecure(tlsInsecure).SetTLSServerName(tlsServerName) + certinfoCfg.SetTLSInsecure(tlsInsecure).SetTLSServerName(tlsServerName).SetTLSInfoRequested(tlsInfo) // SetTLSEndpoint may need the SNI/ServerName and insecure options to be set // before being able to ask details about the certificate we want to a // webserver using self-signed and valid certificates if err = certinfoCfg.SetTLSEndpoint(tlsEndpoint); err != nil { cmd.Printf("Error setting TLS endpoint: %s", err) + return } if err = certinfoCfg.SetPrivateKeyFromFile( @@ -114,5 +134,9 @@ IPv6 addresses must be enclosed in square brackets, as in '[::1]:80'`) "tls-insecure", false, "Skip certificate validation when connecting to a TLS endpoint") + certinfoCmd.Flags().BoolVar(&tlsInfo, + "tls-info", + false, + "Show negotiated TLS info and probe supported protocols/ciphers") rootCmd.AddCommand(certinfoCmd) } diff --git a/internal/cmd/certinfo_test.go b/internal/cmd/certinfo_test.go index 19f0b12..c7519e8 100644 --- a/internal/cmd/certinfo_test.go +++ b/internal/cmd/certinfo_test.go @@ -3,6 +3,9 @@ package cmd import ( "bytes" _ "embed" + "net/http" + "net/http/httptest" + "net/url" "testing" _ "github.com/breml/rootcerts" @@ -97,11 +100,29 @@ func TestCertinfoCmd(t *testing.T) { }, { //nolint:revive - name: "invalid files and endpoints", + name: "invalid files", //nolint:revive - args: []string{"certinfo", "--ca-bundle", "non_existent.pem", "--cert-bundle", "non_existent.pem", "--key-file", "non_existent.pem", "--tls-endpoint", "invalid://"}, + args: []string{"certinfo", "--ca-bundle", "non_existent.pem", "--cert-bundle", "non_existent.pem", "--key-file", "non_existent.pem"}, expectError: false, - expected: []string{"Error importing CA Certificate bundle", "Error importing Certificate bundle", "Error importing key", "Error setting TLS endpoint"}, + expected: []string{"Error importing CA Certificate bundle", "Error importing Certificate bundle", "Error importing key"}, + }, + { + name: "invalid tls-endpoint", + args: []string{"certinfo", "--tls-endpoint", "invalid://"}, + expectError: false, + expected: []string{"Error setting TLS endpoint"}, + }, + { + name: "tls-info flag without tls-endpoint", + args: []string{"certinfo", "--tls-info"}, + expectError: false, + expected: []string{"Error: --tls-info requires --tls-endpoint"}, + }, + { + name: "tls-info flag with invalid tls-endpoint", + args: []string{"certinfo", "--tls-info", "--tls-endpoint", "invalid://"}, + expectError: false, + expected: []string{"Error setting TLS endpoint"}, }, } @@ -115,6 +136,7 @@ func TestCertinfoCmd(t *testing.T) { require.NoError(t, certinfoCmd.Flags().Set("tls-endpoint", "")) require.NoError(t, certinfoCmd.Flags().Set("tls-servername", "")) require.NoError(t, certinfoCmd.Flags().Set("tls-insecure", "false")) + require.NoError(t, certinfoCmd.Flags().Set("tls-info", "false")) require.NoError(t, certinfoCmd.Flags().Set("cert-bundle", "")) require.NoError(t, certinfoCmd.Flags().Set("key-file", "")) }) @@ -147,3 +169,42 @@ func TestCertinfoCmd(t *testing.T) { }) } } + +func TestCertinfoCmd_WithTLSInfo(t *testing.T) { + // Start a local TLS server + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Parse host and port from server URL + u, err := url.Parse(server.URL) + require.NoError(t, err) + + t.Cleanup(func() { + require.NoError(t, rootCmd.PersistentFlags().Set("version", "false")) + require.NoError(t, certinfoCmd.Flags().Set("ca-bundle", "")) + require.NoError(t, certinfoCmd.Flags().Set("tls-endpoint", "")) + require.NoError(t, certinfoCmd.Flags().Set("tls-servername", "")) + require.NoError(t, certinfoCmd.Flags().Set("tls-insecure", "false")) + require.NoError(t, certinfoCmd.Flags().Set("tls-info", "false")) + require.NoError(t, certinfoCmd.Flags().Set("cert-bundle", "")) + require.NoError(t, certinfoCmd.Flags().Set("key-file", "")) + }) + + reqOut := new(bytes.Buffer) + reqCmd := rootCmd + + reqCmd.SetOut(reqOut) + reqCmd.SetErr(reqOut) + reqCmd.SetArgs([]string{"certinfo", "--tls-endpoint", u.Host, "--tls-info", "--tls-insecure"}) + + err = reqCmd.Execute() + require.NoError(t, err) + + got := reqOut.String() + + require.Contains(t, got, "Negotiated TLS Connection") + require.Contains(t, got, "Protocol Support Scan") + require.Contains(t, got, "Cipher Suite Scan") +}