Skip to content

Commit

Permalink
Add TLS configuration for PostgreSQL (#849)
Browse files Browse the repository at this point in the history
  • Loading branch information
wxing1292 committed Oct 14, 2020
1 parent 726047b commit 07a3a63
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 106 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ install-schema-mysql: temporal-sql-tool
./temporal-sql-tool --ep 127.0.0.1 -u root --pw root --db temporal_visibility setup-schema -v 0.0
./temporal-sql-tool --ep 127.0.0.1 -u root --pw root --db temporal_visibility update-schema -d ./schema/mysql/v57/visibility/versioned

install-schema-postgres: temporal-sql-tool
install-schema-postgresql: temporal-sql-tool
@printf $(COLOR) "Install Postgres schema..."
./temporal-sql-tool --ep 127.0.0.1 -p 5432 -u temporal -pw temporal --pl postgres create --db temporal
./temporal-sql-tool --ep 127.0.0.1 -p 5432 -u temporal -pw temporal --pl postgres --db temporal setup -v 0.0
Expand Down
66 changes: 1 addition & 65 deletions common/persistence/sql/sqlplugin/mysql/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,13 @@ package mysql

import (
"bytes"
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"net/url"
"strings"

"github.com/go-sql-driver/mysql"
"github.com/iancoleman/strcase"
"github.com/jmoiron/sqlx"

"go.temporal.io/server/common/auth"
"go.temporal.io/server/common/persistence/sql"
"go.temporal.io/server/common/persistence/sql/sqlplugin"
"go.temporal.io/server/common/service/config"
Expand All @@ -47,7 +41,7 @@ import (
const (
// PluginName is the name of the plugin
PluginName = "mysql"
dsnFmt = "%s:%s@%v(%v)/%s"
dsnFmt = "%v:%v@%v(%v)/%v"
isolationLevelAttrName = "transaction_isolation"
isolationLevelAttrNameLegacy = "tx_isolation"
defaultIsolationLevel = "'READ-COMMITTED'"
Expand Down Expand Up @@ -118,64 +112,6 @@ func (p *plugin) createDBConnection(cfg *config.SQL) (*sqlx.DB, error) {
return db, nil
}

func registerTLSConfig(cfg *config.SQL) error {
if cfg.TLS == nil || !cfg.TLS.Enabled {
return nil
}

host, _, err := net.SplitHostPort(cfg.ConnectAddr)
if err != nil {
return fmt.Errorf("error in host port from ConnectAddr: %v", err)
}

// TODO: create a way to set MinVersion and CipherSuites via cfg.
tlsConfig := auth.NewTLSConfigForServer(host)

if cfg.TLS.CaFile != "" {
rootCertPool := x509.NewCertPool()
pem, err := ioutil.ReadFile(cfg.TLS.CaFile)
if err != nil {
return fmt.Errorf("failed to load CA files: %v", err)
}
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
return fmt.Errorf("failed to append CA file")
}
tlsConfig.RootCAs = rootCertPool
}

if cfg.TLS.CertFile != "" && cfg.TLS.KeyFile != "" {
clientCert := make([]tls.Certificate, 0, 1)
certs, err := tls.LoadX509KeyPair(
cfg.TLS.CertFile,
cfg.TLS.KeyFile,
)
if err != nil {
return fmt.Errorf("failed to load tls x509 key pair: %v", err)
}
clientCert = append(clientCert, certs)
tlsConfig.Certificates = clientCert
}

// In order to use the TLS configuration you need to register it. Once registered you use it by specifying
// `tls` in the connect attributes.
err = mysql.RegisterTLSConfig(customTLSName, tlsConfig)
if err != nil {
return fmt.Errorf("failed to register tls config: %v", err)
}

if cfg.ConnectAttributes == nil {
cfg.ConnectAttributes = map[string]string{}
}

// If no `tls` connect attribute is provided then we override it to our newly registered tls config automatically.
// This allows users to simply provide a tls config without needing to remember to also set the connect attribute
if cfg.ConnectAttributes["tls"] == "" {
cfg.ConnectAttributes["tls"] = customTLSName
}

return nil
}

func buildDSN(cfg *config.SQL) string {
attrs := buildDSNAttrs(cfg)
dsn := fmt.Sprintf(dsnFmt, cfg.User, cfg.Password, cfg.ConnectProtocol, cfg.ConnectAddr, cfg.DatabaseName)
Expand Down
96 changes: 96 additions & 0 deletions common/persistence/sql/sqlplugin/mysql/tls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// The MIT License
//
// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved.
//
// Copyright (c) 2020 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package mysql

import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"

"github.com/go-sql-driver/mysql"

"go.temporal.io/server/common/auth"
"go.temporal.io/server/common/service/config"
)

func registerTLSConfig(cfg *config.SQL) error {
if cfg.TLS == nil || !cfg.TLS.Enabled {
return nil
}

host, _, err := net.SplitHostPort(cfg.ConnectAddr)
if err != nil {
return fmt.Errorf("error in host port from ConnectAddr: %v", err)
}

// TODO: create a way to set MinVersion and CipherSuites via cfg.
tlsConfig := auth.NewTLSConfigForServer(host)

if cfg.TLS.CaFile != "" {
rootCertPool := x509.NewCertPool()
pem, err := ioutil.ReadFile(cfg.TLS.CaFile)
if err != nil {
return fmt.Errorf("failed to load CA files: %v", err)
}
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
return fmt.Errorf("failed to append CA file")
}
tlsConfig.RootCAs = rootCertPool
}

if cfg.TLS.CertFile != "" && cfg.TLS.KeyFile != "" {
clientCert := make([]tls.Certificate, 0, 1)
certs, err := tls.LoadX509KeyPair(
cfg.TLS.CertFile,
cfg.TLS.KeyFile,
)
if err != nil {
return fmt.Errorf("failed to load tls x509 key pair: %v", err)
}
clientCert = append(clientCert, certs)
tlsConfig.Certificates = clientCert
}

// In order to use the TLS configuration you need to register it. Once registered you use it by specifying
// `tls` in the connect attributes.
err = mysql.RegisterTLSConfig(customTLSName, tlsConfig)
if err != nil {
return fmt.Errorf("failed to register tls config: %v", err)
}

if cfg.ConnectAttributes == nil {
cfg.ConnectAttributes = map[string]string{}
}

// If no `tls` connect attribute is provided then we override it to our newly registered tls config automatically.
// This allows users to simply provide a tls config without needing to remember to also set the connect attribute
if cfg.ConnectAttributes["tls"] == "" {
cfg.ConnectAttributes["tls"] = customTLSName
}

return nil
}
61 changes: 21 additions & 40 deletions common/persistence/sql/sqlplugin/postgresql/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ package postgresql
import (
"errors"
"fmt"
"net"
"strings"

"github.com/iancoleman/strcase"
"github.com/jmoiron/sqlx"
Expand All @@ -41,6 +39,7 @@ import (
const (
// PluginName is the name of the plugin
PluginName = "postgres"
dsnFmt = "postgres://%v:%v@%v/%v?%v"
)

var errTLSNotImplemented = errors.New("tls for postgresql has not been implemented")
Expand Down Expand Up @@ -73,45 +72,13 @@ func (d *plugin) CreateAdminDB(cfg *config.SQL) (sqlplugin.AdminDB, error) {
return db, nil
}

func composeConnectionString(user, password, host, port, dbName string) string {
composeSegment := func(paramName string, paramValue string) string {
paramValue = strings.TrimSpace(paramValue)
if paramValue != "" {
return fmt.Sprintf("%s=%s ", paramName, paramValue)
}
return ""
}

return composeSegment("user", user) +
composeSegment("password", password) +
composeSegment("host", host) +
composeSegment("port", port) +
composeSegment("dbname", dbName) +
composeSegment("sslmode", "disable")
}

// CreateDBConnection creates a returns a reference to a logical connection to the
// underlying SQL database. The returned object is to tied to a single
// SQL database and the object can be used to perform CRUD operations on
// the tables in the database
func (d *plugin) createDBConnection(cfg *config.SQL) (*sqlx.DB, error) {
err := registerTLSConfig(cfg)
if err != nil {
return nil, err
}

host, port, err := net.SplitHostPort(cfg.ConnectAddr)
if err != nil {
return nil, fmt.Errorf("invalid connect address, it must be in host:port format, %v, err: %v", cfg.ConnectAddr, err)
}

dbName := cfg.DatabaseName
//NOTE: postgresql doesn't allow to connect with empty dbName, the admin dbName is "postgres"
if dbName == "" {
dbName = "postgres"
}

db, err := sqlx.Connect(PluginName, composeConnectionString(cfg.User, cfg.Password, host, port, dbName))
db, err := sqlx.Connect(PluginName, buildDSN(cfg))
if err != nil {
return nil, err
}
Expand All @@ -130,10 +97,24 @@ func (d *plugin) createDBConnection(cfg *config.SQL) (*sqlx.DB, error) {
return db, nil
}

// TODO: implement postgresql specific support for TLS
func registerTLSConfig(cfg *config.SQL) error {
if cfg.TLS == nil || !cfg.TLS.Enabled {
return nil
func buildDSN(cfg *config.SQL) string {
tlsAttrs := dsnTSL(cfg).Encode()
dsn := fmt.Sprintf(
dsnFmt,
cfg.User,
cfg.Password,
cfg.ConnectAddr,
databaseName(cfg.DatabaseName),
tlsAttrs,
)
fmt.Println(dsn)
return dsn
}

func databaseName(dbName string) string {
//NOTE: postgres doesn't allow to connect with empty dbName, the admin dbName is "postgres"
if dbName == "" {
return "postgres"
}
return errTLSNotImplemented
return dbName
}
44 changes: 44 additions & 0 deletions common/persistence/sql/sqlplugin/postgresql/tls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// The MIT License
//
// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved.
//
// Copyright (c) 2020 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package postgresql

import (
"net/url"

"go.temporal.io/server/common/service/config"
)

func dsnTSL(cfg *config.SQL) url.Values {
sslParams := url.Values{}
if cfg.TLS != nil && cfg.TLS.Enabled {
sslParams.Set("sslmode", "verify-ca")
sslParams.Set("sslrootcert", cfg.TLS.CaFile)
sslParams.Set("sslkey", cfg.TLS.KeyFile)
sslParams.Set("sslcert", cfg.TLS.CertFile)
} else {
sslParams.Set("sslmode", "disable")
}
return sslParams
}

0 comments on commit 07a3a63

Please sign in to comment.