From 3b28635ceb27773213b2ad22c84a1a36d7697316 Mon Sep 17 00:00:00 2001 From: Fei Huang Date: Thu, 9 Mar 2023 12:19:41 -0800 Subject: [PATCH] Extract cluster config logic into ConfigureCassandraCluster() (#4034) --- .../nosqlplugin/cassandra/gocql/client.go | 38 ++++++++++++------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/common/persistence/nosql/nosqlplugin/cassandra/gocql/client.go b/common/persistence/nosql/nosqlplugin/cassandra/gocql/client.go index 33897639518..7e435cd00af 100644 --- a/common/persistence/nosql/nosqlplugin/cassandra/gocql/client.go +++ b/common/persistence/nosql/nosqlplugin/cassandra/gocql/client.go @@ -51,7 +51,19 @@ func NewCassandraCluster( for _, host := range parseHosts(cfg.Hosts) { resolvedHosts = append(resolvedHosts, resolver.Resolve(host)...) } + cluster := gocql.NewCluster(resolvedHosts...) + if err := ConfigureCassandraCluster(cfg, cluster); err != nil { + return nil, err + } + + return cluster, nil +} + +// Modifies the input cluster config in place. +// +//nolint:revive // cognitive complexity 61 (> max enabled 25) +func ConfigureCassandraCluster(cfg config.Cassandra, cluster *gocql.ClusterConfig) error { cluster.ProtoVersion = 4 if cfg.Port > 0 { cluster.Port = cfg.Port @@ -70,15 +82,15 @@ func NewCassandraCluster( } if cfg.TLS != nil && cfg.TLS.Enabled { if cfg.TLS.CertData != "" && cfg.TLS.CertFile != "" { - return nil, errors.New("only one of certData or certFile properties should be specified") + return errors.New("only one of certData or certFile properties should be specified") } if cfg.TLS.KeyData != "" && cfg.TLS.KeyFile != "" { - return nil, errors.New("only one of keyData or keyFile properties should be specified") + return errors.New("only one of keyData or keyFile properties should be specified") } if cfg.TLS.CaData != "" && cfg.TLS.CaFile != "" { - return nil, errors.New("only one of caData or caFile properties should be specified") + return errors.New("only one of caData or caFile properties should be specified") } cluster.SslOpts = &gocql.SslOptions{ @@ -94,31 +106,31 @@ func NewCassandraCluster( if cfg.TLS.CertFile != "" { certBytes, err = os.ReadFile(cfg.TLS.CertFile) if err != nil { - return nil, fmt.Errorf("error reading client certificate file: %w", err) + return fmt.Errorf("error reading client certificate file: %w", err) } } else if cfg.TLS.CertData != "" { certBytes, err = base64.StdEncoding.DecodeString(cfg.TLS.CertData) if err != nil { - return nil, fmt.Errorf("client certificate could not be decoded: %w", err) + return fmt.Errorf("client certificate could not be decoded: %w", err) } } if cfg.TLS.KeyFile != "" { keyBytes, err = os.ReadFile(cfg.TLS.KeyFile) if err != nil { - return nil, fmt.Errorf("error reading client certificate private key file: %w", err) + return fmt.Errorf("error reading client certificate private key file: %w", err) } } else if cfg.TLS.KeyData != "" { keyBytes, err = base64.StdEncoding.DecodeString(cfg.TLS.KeyData) if err != nil { - return nil, fmt.Errorf("client certificate private key could not be decoded: %w", err) + return fmt.Errorf("client certificate private key could not be decoded: %w", err) } } if len(certBytes) > 0 { clientCert, err := tls.X509KeyPair(certBytes, keyBytes) if err != nil { - return nil, fmt.Errorf("unable to generate x509 key pair: %w", err) + return fmt.Errorf("unable to generate x509 key pair: %w", err) } cluster.SslOpts.Certificates = []tls.Certificate{clientCert} @@ -128,10 +140,10 @@ func NewCassandraCluster( cluster.SslOpts.RootCAs = x509.NewCertPool() pem, err := base64.StdEncoding.DecodeString(cfg.TLS.CaData) if err != nil { - return nil, fmt.Errorf("caData could not be decoded: %w", err) + return fmt.Errorf("caData could not be decoded: %w", err) } if !cluster.SslOpts.RootCAs.AppendCertsFromPEM(pem) { - return nil, errors.New("failed to load decoded CA Cert as PEM") + return errors.New("failed to load decoded CA Cert as PEM") } } } @@ -164,15 +176,15 @@ func NewCassandraCluster( if cfg.AddressTranslator != nil && cfg.AddressTranslator.Translator != "" { addressTranslator, err := translator.LookupTranslator(cfg.AddressTranslator.Translator) if err != nil { - return nil, err + return err } cluster.AddressTranslator, err = addressTranslator.GetTranslator(&cfg) if err != nil { - return nil, err + return err } } - return cluster, nil + return nil } // parseHosts returns parses a list of hosts separated by comma