diff --git a/config/http_config.go b/config/http_config.go index 73163206..37aa9667 100644 --- a/config/http_config.go +++ b/config/http_config.go @@ -579,8 +579,7 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT // No need for a RoundTripper that reloads the CA file automatically. return newRT(tlsConfig) } - - return NewTLSRoundTripper(tlsConfig, cfg.TLSConfig.CAFile, cfg.TLSConfig.CertFile, cfg.TLSConfig.KeyFile, newRT) + return NewTLSRoundTripper(tlsConfig, cfg.TLSConfig.roundTripperSettings(), newRT) } type authorizationCredentialsRoundTripper struct { @@ -750,7 +749,7 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro if len(rt.config.TLSConfig.CAFile) == 0 { t, _ = tlsTransport(tlsConfig) } else { - t, err = NewTLSRoundTripper(tlsConfig, rt.config.TLSConfig.CAFile, rt.config.TLSConfig.CertFile, rt.config.TLSConfig.KeyFile, tlsTransport) + t, err = NewTLSRoundTripper(tlsConfig, rt.config.TLSConfig.roundTripperSettings(), tlsTransport) if err != nil { return nil, err } @@ -817,6 +816,10 @@ func cloneRequest(r *http.Request) *http.Request { // NewTLSConfig creates a new tls.Config from the given TLSConfig. func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) { + if err := cfg.Validate(); err != nil { + return nil, err + } + tlsConfig := &tls.Config{ InsecureSkipVerify: cfg.InsecureSkipVerify, MinVersion: uint16(cfg.MinVersion), @@ -831,7 +834,11 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) { // If a CA cert is provided then let's read it in so we can validate the // scrape target's certificate properly. - if len(cfg.CAFile) > 0 { + if len(cfg.CA) > 0 { + if !updateRootCA(tlsConfig, []byte(cfg.CA)) { + return nil, fmt.Errorf("unable to use inline CA cert") + } + } else if len(cfg.CAFile) > 0 { b, err := readCAFile(cfg.CAFile) if err != nil { return nil, err @@ -844,12 +851,9 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) { if len(cfg.ServerName) > 0 { tlsConfig.ServerName = cfg.ServerName } + // If a client cert & key is provided then configure TLS config accordingly. - if len(cfg.CertFile) > 0 && len(cfg.KeyFile) == 0 { - return nil, fmt.Errorf("client cert file %q specified without client key file", cfg.CertFile) - } else if len(cfg.KeyFile) > 0 && len(cfg.CertFile) == 0 { - return nil, fmt.Errorf("client key file %q specified without client cert file", cfg.KeyFile) - } else if len(cfg.CertFile) > 0 && len(cfg.KeyFile) > 0 { + if cfg.usingClientCert() && cfg.usingClientKey() { // Verify that client cert and key are valid. if _, err := cfg.getClientCertificate(nil); err != nil { return nil, err @@ -862,6 +866,12 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) { // TLSConfig configures the options for TLS connections. type TLSConfig struct { + // Text of the CA cert to use for the targets. + CA string `yaml:"ca,omitempty" json:"ca,omitempty"` + // Text of the client cert file for the targets. + Cert string `yaml:"cert,omitempty" json:"cert,omitempty"` + // Text of the client key file for the targets. + Key Secret `yaml:"key,omitempty" json:"key,omitempty"` // The CA cert to use for the targets. CAFile string `yaml:"ca_file,omitempty" json:"ca_file,omitempty"` // The client cert file for the targets. @@ -891,29 +901,77 @@ func (c *TLSConfig) SetDirectory(dir string) { // UnmarshalYAML implements the yaml.Unmarshaler interface. func (c *TLSConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { type plain TLSConfig - return unmarshal((*plain)(c)) + if err := unmarshal((*plain)(c)); err != nil { + return err + } + return c.Validate() } -// readCertAndKey reads the cert and key files from the disk. -func readCertAndKey(certFile, keyFile string) ([]byte, []byte, error) { - certData, err := os.ReadFile(certFile) - if err != nil { - return nil, nil, err +// Validate validates the TLSConfig to check that only one of the inlined or +// file-based fields for the TLS CA, client certificate, and client key are +// used. +func (c *TLSConfig) Validate() error { + if len(c.CA) > 0 && len(c.CAFile) > 0 { + return fmt.Errorf("at most one of ca and ca_file must be configured") + } + if len(c.Cert) > 0 && len(c.CertFile) > 0 { + return fmt.Errorf("at most one of cert and cert_file must be configured") + } + if len(c.Key) > 0 && len(c.KeyFile) > 0 { + return fmt.Errorf("at most one of key and key_file must be configured") } - keyData, err := os.ReadFile(keyFile) - if err != nil { - return nil, nil, err + if c.usingClientCert() && !c.usingClientKey() { + return fmt.Errorf("exactly one of key or key_file must be configured when a client certificate is configured") + } else if c.usingClientKey() && !c.usingClientCert() { + return fmt.Errorf("exactly one of cert or cert_file must be configured when a client key is configured") } - return certData, keyData, nil + return nil +} + +func (c *TLSConfig) usingClientCert() bool { + return len(c.Cert) > 0 || len(c.CertFile) > 0 +} + +func (c *TLSConfig) usingClientKey() bool { + return len(c.Key) > 0 || len(c.KeyFile) > 0 +} + +func (c *TLSConfig) roundTripperSettings() TLSRoundTripperSettings { + return TLSRoundTripperSettings{ + CA: c.CA, + CAFile: c.CAFile, + Cert: c.Cert, + CertFile: c.CertFile, + Key: string(c.Key), + KeyFile: c.KeyFile, + } } // getClientCertificate reads the pair of client cert and key from disk and returns a tls.Certificate. func (c *TLSConfig) getClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { - certData, keyData, err := readCertAndKey(c.CertFile, c.KeyFile) - if err != nil { - return nil, fmt.Errorf("unable to read specified client cert (%s) & key (%s): %s", c.CertFile, c.KeyFile, err) + var ( + certData, keyData []byte + err error + ) + + if c.CertFile != "" { + certData, err = os.ReadFile(c.CertFile) + if err != nil { + return nil, fmt.Errorf("unable to read specified client cert (%s): %s", c.CertFile, err) + } + } else { + certData = []byte(c.Cert) + } + + if c.KeyFile != "" { + keyData, err = os.ReadFile(c.KeyFile) + if err != nil { + return nil, fmt.Errorf("unable to read specified client key (%s): %s", c.KeyFile, err) + } + } else { + keyData = []byte(c.Key) } cert, err := tls.X509KeyPair(certData, keyData) @@ -946,30 +1004,32 @@ func updateRootCA(cfg *tls.Config, b []byte) bool { // tlsRoundTripper is a RoundTripper that updates automatically its TLS // configuration whenever the content of the CA file changes. type tlsRoundTripper struct { - caFile string - certFile string - keyFile string + settings TLSRoundTripperSettings // newRT returns a new RoundTripper. newRT func(*tls.Config) (http.RoundTripper, error) mtx sync.RWMutex rt http.RoundTripper - hashCAFile []byte - hashCertFile []byte - hashKeyFile []byte + hashCAData []byte + hashCertData []byte + hashKeyData []byte tlsConfig *tls.Config } +type TLSRoundTripperSettings struct { + CA, CAFile string + Cert, CertFile string + Key, KeyFile string +} + func NewTLSRoundTripper( cfg *tls.Config, - caFile, certFile, keyFile string, + settings TLSRoundTripperSettings, newRT func(*tls.Config) (http.RoundTripper, error), ) (http.RoundTripper, error) { t := &tlsRoundTripper{ - caFile: caFile, - certFile: certFile, - keyFile: keyFile, + settings: settings, newRT: newRT, tlsConfig: cfg, } @@ -979,7 +1039,7 @@ func NewTLSRoundTripper( return nil, err } t.rt = rt - _, t.hashCAFile, t.hashCertFile, t.hashKeyFile, err = t.getTLSFilesWithHash() + _, t.hashCAData, t.hashCertData, t.hashKeyData, err = t.getTLSDataWithHash() if err != nil { return nil, err } @@ -987,36 +1047,66 @@ func NewTLSRoundTripper( return t, nil } -func (t *tlsRoundTripper) getTLSFilesWithHash() ([]byte, []byte, []byte, []byte, error) { - b1, err := readCAFile(t.caFile) - if err != nil { - return nil, nil, nil, nil, err +func (t *tlsRoundTripper) getTLSDataWithHash() ([]byte, []byte, []byte, []byte, error) { + var ( + caBytes, certBytes, keyBytes []byte + + err error + ) + + if t.settings.CAFile != "" { + caBytes, err = os.ReadFile(t.settings.CAFile) + if err != nil { + return nil, nil, nil, nil, err + } + } else if t.settings.CA != "" { + caBytes = []byte(t.settings.CA) + } + + if t.settings.CertFile != "" { + certBytes, err = os.ReadFile(t.settings.CertFile) + if err != nil { + return nil, nil, nil, nil, err + } + } else if t.settings.Cert != "" { + certBytes = []byte(t.settings.Cert) } - h1 := sha256.Sum256(b1) - var h2, h3 [32]byte - if t.certFile != "" { - b2, b3, err := readCertAndKey(t.certFile, t.keyFile) + if t.settings.KeyFile != "" { + keyBytes, err = os.ReadFile(t.settings.KeyFile) if err != nil { return nil, nil, nil, nil, err } - h2, h3 = sha256.Sum256(b2), sha256.Sum256(b3) + } else if t.settings.Key != "" { + keyBytes = []byte(t.settings.Key) + } + + var caHash, certHash, keyHash [32]byte + + if len(caBytes) > 0 { + caHash = sha256.Sum256(caBytes) + } + if len(certBytes) > 0 { + certHash = sha256.Sum256(certBytes) + } + if len(keyBytes) > 0 { + keyHash = sha256.Sum256(keyBytes) } - return b1, h1[:], h2[:], h3[:], nil + return caBytes, caHash[:], certHash[:], keyHash[:], nil } // RoundTrip implements the http.RoundTrip interface. func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - caData, caHash, certHash, keyHash, err := t.getTLSFilesWithHash() + caData, caHash, certHash, keyHash, err := t.getTLSDataWithHash() if err != nil { return nil, err } t.mtx.RLock() - equal := bytes.Equal(caHash[:], t.hashCAFile) && - bytes.Equal(certHash[:], t.hashCertFile) && - bytes.Equal(keyHash[:], t.hashKeyFile) + equal := bytes.Equal(caHash[:], t.hashCAData) && + bytes.Equal(certHash[:], t.hashCertData) && + bytes.Equal(keyHash[:], t.hashKeyData) rt := t.rt t.mtx.RUnlock() if equal { @@ -1029,7 +1119,7 @@ func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { // using GetClientCertificate. tlsConfig := t.tlsConfig.Clone() if !updateRootCA(tlsConfig, caData) { - return nil, fmt.Errorf("unable to use specified CA cert %s", t.caFile) + return nil, fmt.Errorf("unable to use specified CA cert %s", t.settings.CAFile) } rt, err = t.newRT(tlsConfig) if err != nil { @@ -1039,9 +1129,9 @@ func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { t.mtx.Lock() t.rt = rt - t.hashCAFile = caHash[:] - t.hashCertFile = certHash[:] - t.hashKeyFile = keyHash[:] + t.hashCAData = caHash[:] + t.hashCertData = certHash[:] + t.hashKeyData = keyHash[:] t.mtx.Unlock() return rt.RoundTrip(req) diff --git a/config/http_config_test.go b/config/http_config_test.go index 638a1332..ca2ed71a 100644 --- a/config/http_config_test.go +++ b/config/http_config_test.go @@ -774,7 +774,7 @@ func TestTLSConfigInvalidCA(t *testing.T) { KeyFile: ClientKeyNoPassPath, ServerName: "", InsecureSkipVerify: false}, - errorMessage: fmt.Sprintf("unable to read specified client cert (%s) & key (%s):", MissingCert, ClientKeyNoPassPath), + errorMessage: fmt.Sprintf("unable to read specified client cert (%s):", MissingCert), }, { configTLSConfig: TLSConfig{ CAFile: "", @@ -782,7 +782,27 @@ func TestTLSConfigInvalidCA(t *testing.T) { KeyFile: MissingKey, ServerName: "", InsecureSkipVerify: false}, - errorMessage: fmt.Sprintf("unable to read specified client cert (%s) & key (%s):", ClientCertificatePath, MissingKey), + errorMessage: fmt.Sprintf("unable to read specified client key (%s):", MissingKey), + }, + { + configTLSConfig: TLSConfig{ + CAFile: "", + Cert: readFile(t, ClientCertificatePath), + CertFile: ClientCertificatePath, + KeyFile: ClientKeyNoPassPath, + ServerName: "", + InsecureSkipVerify: false}, + errorMessage: "at most one of cert and cert_file must be configured", + }, + { + configTLSConfig: TLSConfig{ + CAFile: "", + CertFile: ClientCertificatePath, + Key: Secret(readFile(t, ClientKeyNoPassPath)), + KeyFile: ClientKeyNoPassPath, + ServerName: "", + InsecureSkipVerify: false}, + errorMessage: "at most one of key and key_file must be configured", }, } @@ -1046,6 +1066,127 @@ func TestTLSRoundTripper(t *testing.T) { } } +func TestTLSRoundTripper_Inline(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, ExpectedMessage) + } + testServer, err := newTestServer(handler) + if err != nil { + t.Fatal(err.Error()) + } + defer testServer.Close() + + testCases := []struct { + caText, caFile string + certText, certFile string + keyText, keyFile string + + errMsg string + }{ + { + // File-based everything. + caFile: TLSCAChainPath, + certFile: ClientCertificatePath, + keyFile: ClientKeyNoPassPath, + }, + { + // Inline CA. + caText: readFile(t, TLSCAChainPath), + certFile: ClientCertificatePath, + keyFile: ClientKeyNoPassPath, + }, + { + // Inline cert. + caFile: TLSCAChainPath, + certText: readFile(t, ClientCertificatePath), + keyFile: ClientKeyNoPassPath, + }, + { + // Inline key. + caFile: TLSCAChainPath, + certFile: ClientCertificatePath, + keyText: readFile(t, ClientKeyNoPassPath), + }, + { + // Inline everything. + caText: readFile(t, TLSCAChainPath), + certText: readFile(t, ClientCertificatePath), + keyText: readFile(t, ClientKeyNoPassPath), + }, + + { + // Invalid inline CA. + caText: "badca", + certText: readFile(t, ClientCertificatePath), + keyText: readFile(t, ClientKeyNoPassPath), + + errMsg: "unable to use inline CA cert", + }, + { + // Invalid cert. + caText: readFile(t, TLSCAChainPath), + certText: "badcert", + keyText: readFile(t, ClientKeyNoPassPath), + + errMsg: "failed to find any PEM data in certificate input", + }, + { + // Invalid key. + caText: readFile(t, TLSCAChainPath), + certText: readFile(t, ClientCertificatePath), + keyText: "badkey", + + errMsg: "failed to find any PEM data in key input", + }, + } + + for i, tc := range testCases { + tc := tc + t.Run(strconv.Itoa(i), func(t *testing.T) { + cfg := HTTPClientConfig{ + TLSConfig: TLSConfig{ + CA: tc.caText, + CAFile: tc.caFile, + Cert: tc.certText, + CertFile: tc.certFile, + Key: Secret(tc.keyText), + KeyFile: tc.keyFile, + InsecureSkipVerify: false}, + } + + c, err := NewClientFromConfig(cfg, "test") + if tc.errMsg != "" { + if !strings.Contains(err.Error(), tc.errMsg) { + t.Fatalf("Expected error message to contain %q, got %q", tc.errMsg, err) + } + return + } else if err != nil { + t.Fatalf("Error creating HTTP Client: %v", err) + } + + req, err := http.NewRequest(http.MethodGet, testServer.URL, nil) + if err != nil { + t.Fatalf("Error creating HTTP request: %v", err) + } + r, err := c.Do(req) + if err != nil { + t.Fatalf("Can't connect to the test server") + } + + b, err := io.ReadAll(r.Body) + r.Body.Close() + if err != nil { + t.Errorf("Can't read the server response body") + } + + got := strings.TrimSpace(string(b)) + if ExpectedMessage != got { + t.Errorf("The expected message %q differs from the obtained message %q", ExpectedMessage, got) + } + }) + } +} + func TestTLSRoundTripperRaces(t *testing.T) { bs := getCertificateBlobs(t) @@ -1838,3 +1979,14 @@ no_proxy: promcon.io,cncf.io`, proxyServer.URL), }) } } + +func readFile(t *testing.T, filename string) string { + t.Helper() + + content, err := os.ReadFile(filename) + if err != nil { + t.Fatalf("Failed to read file %q: %s", filename, err) + } + + return string(content) +}