Skip to content

Commit

Permalink
Add TLS host name override config for Cassandra / SQL tool (#1480)
Browse files Browse the repository at this point in the history
* Allow overriding host name for TLS host name verification in schema tools
* Disable TLS host name verification when `tls_disable_host_verification` or `tls-disable-host-verification` is set
  • Loading branch information
wxing1292 committed Apr 20, 2021
1 parent 4bd74c5 commit 4aa0c51
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 9 deletions.
1 change: 1 addition & 0 deletions tools/cassandra/handler.go
Expand Up @@ -170,6 +170,7 @@ func newCQLClientConfig(cli *cli.Context) (*CQLClientConfig, error) {
CertFile: cli.GlobalString(schema.CLIFlagTLSCertFile),
KeyFile: cli.GlobalString(schema.CLIFlagTLSKeyFile),
CaFile: cli.GlobalString(schema.CLIFlagTLSCaFile),
ServerName: cli.GlobalString(schema.CLIFlagTLSHostName),
EnableHostVerification: !cli.GlobalBool(schema.CLIFlagTLSDisableHostVerification),
}
}
Expand Down
6 changes: 6 additions & 0 deletions tools/cassandra/main.go
Expand Up @@ -136,6 +136,12 @@ func buildCLIOptions() *cli.App {
Usage: "TLS CA file",
EnvVar: "CASSANDRA_TLS_CA",
},
cli.StringFlag{
Name: schema.CLIFlagTLSHostName,
Value: "",
Usage: "override for target server name",
EnvVar: "CASSANDRA_TLS_SERVER_NAME",
},
cli.BoolFlag{
Name: schema.CLIFlagTLSDisableHostVerification,
Usage: "disable tls host name verification (tls must be enabled)",
Expand Down
1 change: 1 addition & 0 deletions tools/cli/adminCommands.go
Expand Up @@ -306,6 +306,7 @@ func connectToCassandra(c *cli.Context) gocql.Session {
CertFile: c.String(FlagTLSCertPath),
KeyFile: c.String(FlagTLSKeyPath),
CaFile: c.String(FlagTLSCaPath),
ServerName: c.String(FlagTLSServerName),
EnableHostVerification: !c.Bool(FlagTLSDisableHostVerification),
}
}
Expand Down
12 changes: 3 additions & 9 deletions tools/cli/factory.go
Expand Up @@ -144,7 +144,7 @@ func (b *clientFactory) createTLSConfig(c *cli.Context) (*tls.Config, error) {
certPath := c.GlobalString(FlagTLSCertPath)
keyPath := c.GlobalString(FlagTLSKeyPath)
caPath := c.GlobalString(FlagTLSCaPath)
hostNameVerification := !c.GlobalBool(FlagTLSDisableHostVerification)
disableHostNameVerification := c.GlobalBool(FlagTLSDisableHostVerification)
serverName := c.GlobalString(FlagTLSServerName)

var host string
Expand All @@ -171,9 +171,6 @@ func (b *clientFactory) createTLSConfig(c *cli.Context) (*tls.Config, error) {
if caPool != nil || cert != nil {
if serverName != "" {
host = serverName
// If server name is provided, we enable host verification
// because that's the only reason for providing server name
hostNameVerification = true
} else {
hostPort := c.GlobalString(FlagAddress)
if hostPort == "" {
Expand All @@ -182,7 +179,7 @@ func (b *clientFactory) createTLSConfig(c *cli.Context) (*tls.Config, error) {
// Ignoring error as we'll fail to dial anyway, and that will produce a meaningful error
host, _, _ = net.SplitHostPort(hostPort)
}
tlsConfig := auth.NewTLSConfigForServer(host, hostNameVerification)
tlsConfig := auth.NewTLSConfigForServer(host, !disableHostNameVerification)
if caPool != nil {
tlsConfig.RootCAs = caPool
}
Expand All @@ -195,10 +192,7 @@ func (b *clientFactory) createTLSConfig(c *cli.Context) (*tls.Config, error) {
// If we are given a server name, set the TLS server name for DNS resolution
if serverName != "" {
host = serverName
// If server name is provided, we enable host verification
// because that's the only reason for providing server name
hostNameVerification = true
tlsConfig := auth.NewTLSConfigForServer(host, hostNameVerification)
tlsConfig := auth.NewTLSConfigForServer(host, !disableHostNameVerification)
return tlsConfig, nil
}

Expand Down
1 change: 1 addition & 0 deletions tools/cli/persistenceUtil.go
Expand Up @@ -85,6 +85,7 @@ func CreateDefaultDBConfig(c *cli.Context) (config.DataStore, error) {
CertFile: c.String(FlagTLSCertPath),
KeyFile: c.String(FlagTLSKeyPath),
CaFile: c.String(FlagTLSCaPath),
ServerName: c.String(FlagTLSServerName),
EnableHostVerification: !c.Bool(FlagTLSDisableHostVerification),
}
}
Expand Down
2 changes: 2 additions & 0 deletions tools/common/schema/types.go
Expand Up @@ -161,6 +161,8 @@ const (
CLIFlagTLSCaFile = "tls-ca-file"
// CLIFlagTLSDisableHostVerification disable tls host verification (tls must be enabled)
CLIFlagTLSDisableHostVerification = "tls-disable-host-verification"
// CLIFlagTLSHostName specifies the host name for host name verification
CLIFlagTLSHostName = "tls-server-name"
)

var rmspaceRegex = regexp.MustCompile(`\s+`)
Expand Down
1 change: 1 addition & 0 deletions tools/sql/handler.go
Expand Up @@ -167,6 +167,7 @@ func parseConnectConfig(cli *cli.Context) (*config.SQL, error) {
CertFile: cli.GlobalString(schema.CLIFlagTLSCertFile),
KeyFile: cli.GlobalString(schema.CLIFlagTLSKeyFile),
CaFile: cli.GlobalString(schema.CLIFlagTLSCaFile),
ServerName: cli.GlobalString(schema.CLIFlagTLSHostName),
EnableHostVerification: !cli.GlobalBool(schema.CLIFlagTLSDisableHostVerification),
}
}
Expand Down
6 changes: 6 additions & 0 deletions tools/sql/main.go
Expand Up @@ -124,6 +124,12 @@ func BuildCLIOptions() *cli.App {
Usage: "sql tls client ca file (tls must be enabled)",
EnvVar: "SQL_TLS_CA_FILE",
},
cli.StringFlag{
Name: schema.CLIFlagTLSHostName,
Value: "",
Usage: "override for target server name",
EnvVar: "SQL_TLS_SERVER_NAME",
},
cli.BoolFlag{
Name: schema.CLIFlagTLSDisableHostVerification,
Usage: "disable tls host name verification (tls must be enabled)",
Expand Down

0 comments on commit 4aa0c51

Please sign in to comment.