diff --git a/docs/configuration.md b/docs/configuration.md index 617ca4e5..60c14b72 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -16,7 +16,12 @@ This is the database password This is the data source information for the MySQL or Oracle database. For Oracle the format can be in the form of '(DESCRIPTION=(ADDRESS_LIST=(ADDRESS=(PROTOCOL=tcp)(HOST=hostname)(PORT=port)))(CONNECT_DATA=(SERVICE_NAME=sn)))'. Or it can be a name of an entry in tnsnames.ora. Please see the Oracle documentation for more details. + We use the same environment name for MySQL. For example, the value can be tcp(127.0.0.1:3306)/myschema. +Failover uses two pipes to separate entries, +tcp(127.0.0.1:3306)/myschema?timeout=9s||tcp(127.0.0.2:3306)/myschema . +Set environment variable certdir to load all the pem files that you can +specify as certificate authorities for the mysql worker to accept. For sharding case, we need to define multiple datasources, one for each shard. The convention is to define the datasource for the first shard in TWO_TASK_0 environment variable, for the second shard in TWO_TASK_1, etc. diff --git a/lib/workerclient.go b/lib/workerclient.go index 1664f8ea..ea7b8e71 100644 --- a/lib/workerclient.go +++ b/lib/workerclient.go @@ -201,6 +201,7 @@ func (worker *WorkerClient) StartWorker() (err error) { return errors.New("Invalid module name, must be like hera- ") } + var twoTask string switch worker.Type { case wtypeStdBy: if GetConfig().EnableSharding { @@ -223,7 +224,7 @@ func (worker *WorkerClient) StartWorker() (err error) { envUpsert(&attr, envHeraName, fmt.Sprintf("%s_taf", worker.moduleName)) twoTaskEnv := fmt.Sprintf("TWO_TASK_STANDBY0_%d", worker.shardID) - twoTask := os.Getenv(twoTaskEnv) + twoTask = os.Getenv(twoTaskEnv) if twoTask == "" { if worker.shardID != 0 { logger.GetLogger().Log(logger.Alert, twoTaskEnv, "is not defined") @@ -260,7 +261,7 @@ func (worker *WorkerClient) StartWorker() (err error) { envUpsert(&attr, envHeraName, worker.moduleName) twoTaskEnv := fmt.Sprintf("TWO_TASK_READ_%d", worker.shardID) - twoTask := os.Getenv(twoTaskEnv) + twoTask = os.Getenv(twoTaskEnv) if twoTask == "" { if worker.shardID != 0 { logger.GetLogger().Log(logger.Alert, twoTaskEnv, "is not defined") @@ -299,7 +300,7 @@ func (worker *WorkerClient) StartWorker() (err error) { envUpsert(&attr, envHeraName, worker.moduleName) twoTaskEnv := fmt.Sprintf("TWO_TASK_%d", worker.shardID) - twoTask := os.Getenv(twoTaskEnv) + twoTask = os.Getenv(twoTaskEnv) if twoTask == "" { if worker.shardID != 0 { logger.GetLogger().Log(logger.Alert, twoTaskEnv, "is not defined") @@ -322,7 +323,7 @@ func (worker *WorkerClient) StartWorker() (err error) { return errors.New("TWO_TASK is not defined") } } - envUpsert(&attr, "mysql_datasource", os.Getenv(envTwoTask)) + envUpsert(&attr, "mysql_datasource", twoTask) socketPair, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_STREAM, 0) if err != nil { diff --git a/tests/unittest/coordinator_basic/main_test.go b/tests/unittest/coordinator_basic/main_test.go index 5bf391a1..b24b25ac 100644 --- a/tests/unittest/coordinator_basic/main_test.go +++ b/tests/unittest/coordinator_basic/main_test.go @@ -12,6 +12,20 @@ import ( "github.com/paypal/hera/utility/logger" ) +/* +To run the test +export DB_USER=x +export DB_PASSWORD=x +export DB_DATASOURCE=x +export username=realU +export password=realU-pwd +export TWO_TASK='tcp(mysql.example.com:3306)/someSchema?timeout=60s&tls=preferred||tcp(failover.example.com:3306)/someSchema' +export TWO_TASK_READ='tcp(mysqlr.example.com:3306)/someSchema?timeout=6s&tls=preferred||tcp(failover.example.com:3306)/someSchema' +$GOROOT/bin/go install .../worker/{mysql,oracle}worker +ln -s $GOPATH/bin/{mysql,oracle}worker . +$GOROOT/bin/go test -c .../tests/unittest/coordinator_basic && ./coordinator_basic.test +*/ + var mx testutil.Mux var tableName string diff --git a/worker/mysqlworker/adapter.go b/worker/mysqlworker/adapter.go index 6ec5a23d..89489d87 100644 --- a/worker/mysqlworker/adapter.go +++ b/worker/mysqlworker/adapter.go @@ -18,10 +18,13 @@ package main import ( + "context" "database/sql" "errors" "fmt" "os" + "strings" + "time" _ "github.com/go-sql-driver/mysql" "github.com/paypal/hera/utility/logger" @@ -47,10 +50,65 @@ func (adapter *mysqlAdapter) InitDB() (*sql.DB, error) { return nil, errors.New("Can't get 'mysql_datasource' from env") } - if logger.GetLogger().V(logger.Verbose) { - logger.GetLogger().Log(logger.Verbose, "connect string:", fmt.Sprintf("%s:%s@%s", user, pass, ds)) + var db *sql.DB + var err error + for idx, curDs := range strings.Split(ds, "||") { + db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s", user, pass, curDs)) + if err != nil { + if logger.GetLogger().V(logger.Warning) { + logger.GetLogger().Log(logger.Warning, user+" failed to connect to "+curDs+fmt.Sprintf(" %d", idx)) + } + continue + } + ctx, _ /*cancel*/ := context.WithTimeout(context.Background(), 10*time.Second) + conn, err := db.Conn(ctx) + if err != nil { + if logger.GetLogger().V(logger.Warning) { + logger.GetLogger().Log(logger.Warning, "could not get connection "+err.Error()) + } + continue + } + if strings.HasPrefix(os.Getenv("logger.LOG_PREFIX"), "WORKER ") { + stmt, err := conn.PrepareContext(ctx, "select @@global.read_only") + //stmt, err := conn.PrepareContext(ctx, "show variables where variable_name='read_only'") + if err != nil { + if logger.GetLogger().V(logger.Warning) { + logger.GetLogger().Log(logger.Warning, "query ro check err ", err.Error()) + } + } + rows, err := stmt.Query() + if err != nil { + if logger.GetLogger().V(logger.Warning) { + logger.GetLogger().Log(logger.Warning, "ro check err ", err.Error()) + } + } + writable := false + countRows := 0 + if rows.Next() { + countRows++ + var readOnly int + /*var nom string + rows.Scan(&nom, &readOnly) // */ + rows.Scan(&readOnly) + if readOnly == 0 { + writable = true + } + } + rows.Close() + stmt.Close() + conn.Close() + if !writable { + // read only connection + if logger.GetLogger().V(logger.Warning) { + logger.GetLogger().Log(logger.Warning, "recycling, got read-only conn "+curDs) + } + db.Close() + continue + } + } + return db, err } - return sql.Open("mysql", fmt.Sprintf("%s:%s@%s", user, pass, ds)) + return db, err } // UseBindNames return false because the SQL string uses ? for bind parameters diff --git a/worker/mysqlworker/main.go b/worker/mysqlworker/main.go index 2b6acedd..407b54fc 100644 --- a/worker/mysqlworker/main.go +++ b/worker/mysqlworker/main.go @@ -19,9 +19,73 @@ package main import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" + "log" + "os" + "strings" + + "github.com/go-sql-driver/mysql" workerservice "github.com/paypal/hera/worker/shared" ) func main() { + certdir := os.Getenv("certdir") + finfos, err := ioutil.ReadDir(certdir) + if err != nil { + log.Print("could not read dir " + certdir) + } + for _, finfo := range finfos { + if !strings.HasSuffix(finfo.Name(), ".pem") { + continue + } + shortName := finfo.Name()[:len(finfo.Name())-4] + certfile := certdir + "/" + finfo.Name() + data, err := ioutil.ReadFile(certfile) + if err != nil { + log.Print("could not read " + certfile) + continue + } + rootCertPool := x509.NewCertPool() + if ok := rootCertPool.AppendCertsFromPEM(data); !ok { + log.Print("could not add rt pem " + certfile) + continue + } + mysql.RegisterTLSConfig(shortName, &tls.Config{RootCAs: rootCertPool}) + } workerservice.Start(&mysqlAdapter{}) } +/* +To test DB cert validation, I put the db's cert in $certdir/certOrCa.pem +export certdir=/path/to/dir/with/certs +export TWO_TASK='tcp(db.example.com:3306)/clocschema?timeout=9s&tls=certOrCa' + +To generate a DB cert: +cd /etc/mysql + +cat << EOF > db-cert.cfg +[ req ] +prompt = no +distinguished_name = ca_dn + +[ ca_dn ] +organizationName = "Hera Test DB Cert" +commonName = "hera test db" +countryName = "US" +stateOrProvinceName = "California" +EOF +openssl req -x509 -nodes -config db-cert.cfg -newkey rsa:3072 -keyout server-key0.pem -out server-cert.pem -days 3000 + +openssl rsa -in server-key0.pem -out server-key.pem + +if ! grep -q ^ssl-key mysql.conf.c/mysqld.cnf +then + sed -e 's/^# ssl-key/ssl-key/;s/^# ssl-cert/ssl-cert/' -i mysql.conf.d/mysqld.cnf +fi + +# for some installations, you'll also have to edit the bind_address to +# 0.0.0.0 in mysqld.conf and use a mysql client to adjust grants or permissions +# to allow the user to login from other ip's. + +*/