diff --git a/tools/cassandra/handler.go b/tools/cassandra/handler.go index a29c5ddc143..b5b2f51e871 100644 --- a/tools/cassandra/handler.go +++ b/tools/cassandra/handler.go @@ -170,6 +170,7 @@ func newCQLClientConfig(cli *cli.Context) (*CQLClientConfig, error) { CertFile: cli.GlobalString(schema.CLIFlagTLSCertFile), KeyFile: cli.GlobalString(schema.CLIFlagTLSKeyFile), CaFile: cli.GlobalString(schema.CLIFlagTLSCaFile), + ServerName: cli.GlobalString(schema.CLIFlagTLSHostName), EnableHostVerification: !cli.GlobalBool(schema.CLIFlagTLSDisableHostVerification), } } diff --git a/tools/cassandra/main.go b/tools/cassandra/main.go index 545f80de080..8b71ea2c692 100644 --- a/tools/cassandra/main.go +++ b/tools/cassandra/main.go @@ -136,6 +136,12 @@ func buildCLIOptions() *cli.App { Usage: "TLS CA file", EnvVar: "CASSANDRA_TLS_CA", }, + cli.StringFlag{ + Name: schema.CLIFlagTLSHostName, + Value: "", + Usage: "override for target server name", + EnvVar: "CASSANDRA_TLS_SERVER_NAME", + }, cli.BoolFlag{ Name: schema.CLIFlagTLSDisableHostVerification, Usage: "disable tls host name verification (tls must be enabled)", diff --git a/tools/cli/adminCommands.go b/tools/cli/adminCommands.go index 713898f6ca1..8cbeb2d9d91 100644 --- a/tools/cli/adminCommands.go +++ b/tools/cli/adminCommands.go @@ -306,6 +306,7 @@ func connectToCassandra(c *cli.Context) gocql.Session { CertFile: c.String(FlagTLSCertPath), KeyFile: c.String(FlagTLSKeyPath), CaFile: c.String(FlagTLSCaPath), + ServerName: c.String(FlagTLSServerName), EnableHostVerification: !c.Bool(FlagTLSDisableHostVerification), } } diff --git a/tools/cli/factory.go b/tools/cli/factory.go index 8226924bb4f..5559bab1843 100644 --- a/tools/cli/factory.go +++ b/tools/cli/factory.go @@ -144,7 +144,7 @@ func (b *clientFactory) createTLSConfig(c *cli.Context) (*tls.Config, error) { certPath := c.GlobalString(FlagTLSCertPath) keyPath := c.GlobalString(FlagTLSKeyPath) caPath := c.GlobalString(FlagTLSCaPath) - hostNameVerification := !c.GlobalBool(FlagTLSDisableHostVerification) + disableHostNameVerification := c.GlobalBool(FlagTLSDisableHostVerification) serverName := c.GlobalString(FlagTLSServerName) var host string @@ -171,9 +171,6 @@ func (b *clientFactory) createTLSConfig(c *cli.Context) (*tls.Config, error) { if caPool != nil || cert != nil { if serverName != "" { host = serverName - // If server name is provided, we enable host verification - // because that's the only reason for providing server name - hostNameVerification = true } else { hostPort := c.GlobalString(FlagAddress) if hostPort == "" { @@ -182,7 +179,7 @@ func (b *clientFactory) createTLSConfig(c *cli.Context) (*tls.Config, error) { // Ignoring error as we'll fail to dial anyway, and that will produce a meaningful error host, _, _ = net.SplitHostPort(hostPort) } - tlsConfig := auth.NewTLSConfigForServer(host, hostNameVerification) + tlsConfig := auth.NewTLSConfigForServer(host, !disableHostNameVerification) if caPool != nil { tlsConfig.RootCAs = caPool } @@ -195,10 +192,7 @@ func (b *clientFactory) createTLSConfig(c *cli.Context) (*tls.Config, error) { // If we are given a server name, set the TLS server name for DNS resolution if serverName != "" { host = serverName - // If server name is provided, we enable host verification - // because that's the only reason for providing server name - hostNameVerification = true - tlsConfig := auth.NewTLSConfigForServer(host, hostNameVerification) + tlsConfig := auth.NewTLSConfigForServer(host, !disableHostNameVerification) return tlsConfig, nil } diff --git a/tools/cli/persistenceUtil.go b/tools/cli/persistenceUtil.go index f83d69feb19..4713b31f409 100644 --- a/tools/cli/persistenceUtil.go +++ b/tools/cli/persistenceUtil.go @@ -85,6 +85,7 @@ func CreateDefaultDBConfig(c *cli.Context) (config.DataStore, error) { CertFile: c.String(FlagTLSCertPath), KeyFile: c.String(FlagTLSKeyPath), CaFile: c.String(FlagTLSCaPath), + ServerName: c.String(FlagTLSServerName), EnableHostVerification: !c.Bool(FlagTLSDisableHostVerification), } } diff --git a/tools/common/schema/types.go b/tools/common/schema/types.go index 28107c7a896..2c88fe4170b 100644 --- a/tools/common/schema/types.go +++ b/tools/common/schema/types.go @@ -161,6 +161,8 @@ const ( CLIFlagTLSCaFile = "tls-ca-file" // CLIFlagTLSDisableHostVerification disable tls host verification (tls must be enabled) CLIFlagTLSDisableHostVerification = "tls-disable-host-verification" + // CLIFlagTLSHostName specifies the host name for host name verification + CLIFlagTLSHostName = "tls-server-name" ) var rmspaceRegex = regexp.MustCompile(`\s+`) diff --git a/tools/sql/handler.go b/tools/sql/handler.go index 3d45d33282a..f04f7023477 100644 --- a/tools/sql/handler.go +++ b/tools/sql/handler.go @@ -167,6 +167,7 @@ func parseConnectConfig(cli *cli.Context) (*config.SQL, error) { CertFile: cli.GlobalString(schema.CLIFlagTLSCertFile), KeyFile: cli.GlobalString(schema.CLIFlagTLSKeyFile), CaFile: cli.GlobalString(schema.CLIFlagTLSCaFile), + ServerName: cli.GlobalString(schema.CLIFlagTLSHostName), EnableHostVerification: !cli.GlobalBool(schema.CLIFlagTLSDisableHostVerification), } } diff --git a/tools/sql/main.go b/tools/sql/main.go index 88facf74599..2e063b6509c 100644 --- a/tools/sql/main.go +++ b/tools/sql/main.go @@ -124,6 +124,12 @@ func BuildCLIOptions() *cli.App { Usage: "sql tls client ca file (tls must be enabled)", EnvVar: "SQL_TLS_CA_FILE", }, + cli.StringFlag{ + Name: schema.CLIFlagTLSHostName, + Value: "", + Usage: "override for target server name", + EnvVar: "SQL_TLS_SERVER_NAME", + }, cli.BoolFlag{ Name: schema.CLIFlagTLSDisableHostVerification, Usage: "disable tls host name verification (tls must be enabled)",