Skip to content

Commit

Permalink
Unify URI versus values handling
Browse files Browse the repository at this point in the history
Previously, some places assumed there was a URI, other places only looked at individual `valueOrValueFrom` combinations, and `schemahero shell db-that-uses-values` would die with a mysterious error. This harmonizes those pathways, which should also allow shell for AWS SSM to start to work

It also fixes the bug where the code was blindly dereferencing the `map[string]string` instead of returning an `error` about that circumstance, leading to opaque errors such as

```
failed to plan sync: failed to connect to postgres: failed to connect to postgres: failed to connect to `host=172.17.0.6 user=postgres database=postgres`: server error (FATAL: password authentication failed for user "postgres" (SQLSTATE 28P01))
```

as it just uses the empty string for a missing key

Signed-off-by: Matthew L Daniel <md@stoi.cc>
  • Loading branch information
mdaniel authored and marccampbell committed Mar 19, 2021
1 parent 021110d commit 0e6756b
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 72 deletions.
81 changes: 42 additions & 39 deletions pkg/apis/databases/v1alpha4/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"k8s.io/client-go/kubernetes"
)

// GetConnection returns driver name, uri, and any error
func (d Database) GetConnection(ctx context.Context) (string, string, error) {
isParamBased := false

Expand Down Expand Up @@ -58,80 +59,70 @@ func (d Database) getConnectionFromParams(ctx context.Context) (string, string,
return "", "", errors.Wrap(err, "failed to get database type")
}

cfg, err := config.GetRESTConfig()
if err != nil {
return "", "", errors.Wrap(err, "failed to get config")
}

clientset, err := kubernetes.NewForConfig(cfg)
if err != nil {
return "", "", errors.Wrap(err, "failed to get clientset")
}

uri := ""
if driver == "postgres" {
hostname, err := d.Spec.Connection.Postgres.Host.Read(clientset, d.Namespace)
hostname, err := d.getValueFromValueOrValueFrom(ctx, driver, d.Spec.Connection.Postgres.Host)
if err != nil {
return "", "", errors.Wrap(err, "failed to read postgres hostname")
}

port, err := d.Spec.Connection.Postgres.Port.Read(clientset, d.Namespace)
port, err := d.getValueFromValueOrValueFrom(ctx, driver, d.Spec.Connection.Postgres.Port)
if err != nil {
return "", "", errors.Wrap(err, "failed to read postgres port")
}

user, err := d.Spec.Connection.Postgres.User.Read(clientset, d.Namespace)
user, err := d.getValueFromValueOrValueFrom(ctx, driver, d.Spec.Connection.Postgres.User)
if err != nil {
return "", "", errors.Wrap(err, "failed to read postgres user")
}

password, err := d.Spec.Connection.Postgres.Password.Read(clientset, d.Namespace)
password, err := d.getValueFromValueOrValueFrom(ctx, driver, d.Spec.Connection.Postgres.Password)
if err != nil {
return "", "", errors.Wrap(err, "failed to read postgres password")
}

dbname, err := d.Spec.Connection.Postgres.DBName.Read(clientset, d.Namespace)
dbname, err := d.getValueFromValueOrValueFrom(ctx, driver, d.Spec.Connection.Postgres.DBName)
if err != nil {
return "", "", errors.Wrap(err, "failed to read postgres dbname")
}

uri = fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, password, hostname, port, dbname)
if !d.Spec.Connection.Postgres.SSLMode.IsEmpty() {
sslMode, err := d.Spec.Connection.Postgres.SSLMode.Read(clientset, d.Namespace)
sslMode, err := d.getValueFromValueOrValueFrom(ctx, driver, d.Spec.Connection.Postgres.SSLMode)
if err != nil {
return "", "", errors.Wrap(err, "failed to read postgres ssl mode")
}
uri = fmt.Sprintf("%s?sslmode=%s", uri, sslMode)
}
} else if driver == "cockroachdb" {
hostname, err := d.Spec.Connection.CockroachDB.Host.Read(clientset, d.Namespace)
hostname, err := d.getValueFromValueOrValueFrom(ctx, driver, d.Spec.Connection.CockroachDB.Host)
if err != nil {
return "", "", errors.Wrap(err, "failed to read cockroachdb hostname")
}

port, err := d.Spec.Connection.CockroachDB.Port.Read(clientset, d.Namespace)
port, err := d.getValueFromValueOrValueFrom(ctx, driver, d.Spec.Connection.CockroachDB.Port)
if err != nil {
return "", "", errors.Wrap(err, "failed to read cockroachdb port")
}

user, err := d.Spec.Connection.CockroachDB.User.Read(clientset, d.Namespace)
user, err := d.getValueFromValueOrValueFrom(ctx, driver, d.Spec.Connection.CockroachDB.User)
if err != nil {
return "", "", errors.Wrap(err, "failed to read cockroachdb user")
}

password, err := d.Spec.Connection.CockroachDB.Password.Read(clientset, d.Namespace)
password, err := d.getValueFromValueOrValueFrom(ctx, driver, d.Spec.Connection.CockroachDB.Password)
if err != nil {
return "", "", errors.Wrap(err, "failed to read cockroachdb password")
}

dbname, err := d.Spec.Connection.CockroachDB.DBName.Read(clientset, d.Namespace)
dbname, err := d.getValueFromValueOrValueFrom(ctx, driver, d.Spec.Connection.CockroachDB.DBName)
if err != nil {
return "", "", errors.Wrap(err, "failed to read cockroachdb dbname")
}

uri = fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, password, hostname, port, dbname)
if !d.Spec.Connection.CockroachDB.SSLMode.IsEmpty() {
sslMode, err := d.Spec.Connection.CockroachDB.SSLMode.Read(clientset, d.Namespace)
sslMode, err := d.getValueFromValueOrValueFrom(ctx, driver, d.Spec.Connection.CockroachDB.SSLMode)
if err != nil {
return "", "", errors.Wrap(err, "failed to read cockroachdb ssl mode")
}
Expand All @@ -140,27 +131,27 @@ func (d Database) getConnectionFromParams(ctx context.Context) (string, string,
} else if driver == "cassandra" {
return "", "", errors.New("not implemented")
} else if driver == "mysql" {
hostname, err := d.Spec.Connection.Mysql.Host.Read(clientset, d.Namespace)
hostname, err := d.getValueFromValueOrValueFrom(ctx, driver, d.Spec.Connection.Mysql.Host)
if err != nil {
return "", "", errors.Wrap(err, "failed to read mysql hostname")
}

port, err := d.Spec.Connection.Mysql.Port.Read(clientset, d.Namespace)
port, err := d.getValueFromValueOrValueFrom(ctx, driver, d.Spec.Connection.Mysql.Port)
if err != nil {
return "", "", errors.Wrap(err, "failed to read mysql port")
}

user, err := d.Spec.Connection.Mysql.User.Read(clientset, d.Namespace)
user, err := d.getValueFromValueOrValueFrom(ctx, driver, d.Spec.Connection.Mysql.User)
if err != nil {
return "", "", errors.Wrap(err, "failed to read mysql user")
}

password, err := d.Spec.Connection.Mysql.Password.Read(clientset, d.Namespace)
password, err := d.getValueFromValueOrValueFrom(ctx, driver, d.Spec.Connection.Mysql.Password)
if err != nil {
return "", "", errors.Wrap(err, "failed to read mysql password")
}

dbname, err := d.Spec.Connection.Mysql.DBName.Read(clientset, d.Namespace)
dbname, err := d.getValueFromValueOrValueFrom(ctx, driver, d.Spec.Connection.Mysql.DBName)
if err != nil {
return "", "", errors.Wrap(err, "failed to read mysql dbname")
}
Expand All @@ -174,15 +165,14 @@ func (d Database) getConnectionFromParams(ctx context.Context) (string, string,
return driver, uri, nil
}

// getConnectionFromURI will return a valid connection string for the database. This
// getConnectionFromURI will return the driver, and a valid connection string for the database. This
// is compatible with any way that the uri was set.
// TODO refactor this to be shorter, simpler and more testable
func (d Database) getConnectionFromURI(ctx context.Context) (string, string, error) {
driver, err := d.getDbType()
if err != nil {
return "", "", errors.Wrap(err, "failed to get database type")
}

var valueOrValueFrom ValueOrValueFrom
if driver == "postgres" {
valueOrValueFrom = d.Spec.Connection.Postgres.URI
Expand All @@ -193,40 +183,53 @@ func (d Database) getConnectionFromURI(ctx context.Context) (string, string, err
} else if driver == "mysql" {
valueOrValueFrom = d.Spec.Connection.Mysql.URI
}
value, err := d.getValueFromValueOrValueFrom(ctx, driver, valueOrValueFrom)
return driver, value, err
}

// getValueFromValueOrValueFrom returns the resolved value, or an error
func (d Database) getValueFromValueOrValueFrom(ctx context.Context, driver string, valueOrValueFrom ValueOrValueFrom) (string, error) {

// if the value is static, return it
if valueOrValueFrom.Value != "" {
return driver, valueOrValueFrom.Value, nil
return valueOrValueFrom.Value, nil
}

// for other types, we need to talk to the kubernetes api
cfg, err := config.GetRESTConfig()
if err != nil {
return "", "", errors.Wrap(err, "failed to get config")
return "", errors.Wrap(err, "failed to get config")
}

clientset, err := kubernetes.NewForConfig(cfg)
if err != nil {
return "", "", errors.Wrap(err, "failed to get clientset")
return "", errors.Wrap(err, "failed to get clientset")
}

// if the value is in a secret, look it up and return it
if valueOrValueFrom.ValueFrom.SecretKeyRef != nil {
secret, err := clientset.CoreV1().Secrets(d.Namespace).Get(ctx, valueOrValueFrom.ValueFrom.SecretKeyRef.Name, metav1.GetOptions{})
secretKeyRefName := valueOrValueFrom.ValueFrom.SecretKeyRef.Name
secret, err := clientset.CoreV1().Secrets(d.Namespace).Get(ctx, secretKeyRefName, metav1.GetOptions{})
if err != nil {
return "", "", errors.Wrap(err, "failed to get secret")
return "", errors.Wrap(err, "failed to get secret")
}

return driver, string(secret.Data[valueOrValueFrom.ValueFrom.SecretKeyRef.Key]), nil
keyName := valueOrValueFrom.ValueFrom.SecretKeyRef.Key
keyData, ok := secret.Data[keyName]
if !ok {
return "", fmt.Errorf("expected Secret \"%s\" to contain key \"%s\"", secretKeyRefName, keyName)
}
return string(keyData), nil
}

if valueOrValueFrom.ValueFrom.Vault != nil {
return d.getVaultConnection(ctx, clientset, driver, valueOrValueFrom)
_, value, err := d.getVaultConnection(ctx, clientset, driver, valueOrValueFrom)
return value, err
}

if valueOrValueFrom.ValueFrom.SSM != nil {
return d.getSSMConnection(ctx, clientset, driver, valueOrValueFrom)
_, value, err := d.getSSMConnection(ctx, clientset, driver, valueOrValueFrom)
return value, err
}

return "", "", errors.New("unable to get connection")
return "", errors.New("unable to get value for driver")
}
1 change: 1 addition & 0 deletions pkg/apis/databases/v1alpha4/ssm_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"k8s.io/client-go/kubernetes"
)

// getSSMConnection returns the driver, the resolved value, and any error
func (d *Database) getSSMConnection(ctx context.Context, clientset *kubernetes.Clientset, driver string, valueOrValueFrom ValueOrValueFrom) (string, string, error) {
region := valueOrValueFrom.ValueFrom.SSM.Region
if region == "" {
Expand Down
31 changes: 0 additions & 31 deletions pkg/apis/databases/v1alpha4/value_or_value_from.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
package v1alpha4

import (
"context"

"github.com/pkg/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
)

type ValueOrValueFrom struct {
Expand Down Expand Up @@ -44,30 +40,3 @@ func (v *ValueOrValueFrom) GetVaultDetails() (*Vault, error) {

return nil, errors.New("No Vault secret configured")
}

func (v *ValueOrValueFrom) Read(clientset *kubernetes.Clientset, namespace string) (string, error) {
if v.Value != "" {
return v.Value, nil
}

if v.ValueFrom == nil {
return "", errors.New("value and valueFrom cannot both be nil/empty")
}

if v.ValueFrom.SecretKeyRef != nil {
secret, err := clientset.CoreV1().Secrets(namespace).Get(context.Background(), v.ValueFrom.SecretKeyRef.Name, metav1.GetOptions{})
if err != nil {
return "", errors.Wrap(err, "failed to get secret")
}

return string(secret.Data[v.ValueFrom.SecretKeyRef.Key]), nil
}

if v.ValueFrom.Vault != nil {
// this feels wrong, but also doesn't make sense to return a
// a URI ref as a connection URI?
return "", nil
}

return "", errors.New("unable to find supported valueFrom")
}
1 change: 1 addition & 0 deletions pkg/apis/databases/v1alpha4/vault_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"k8s.io/client-go/kubernetes"
)

// getVaultConnection returns the driver, the resolved URI, or an error
func (d *Database) getVaultConnection(ctx context.Context, clientset kubernetes.Interface, driver string, valueOrValueFrom ValueOrValueFrom) (string, string, error) {
// if the value is in vault and we are using the vault injector, just read the file
if valueOrValueFrom.ValueFrom.Vault.AgentInject {
Expand Down
4 changes: 2 additions & 2 deletions pkg/cli/schemaherokubectlcli/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func ShellCmd() *cobra.Command {
podImage = "postgres:11"
}

connectionURI, err := database.Spec.Connection.Postgres.URI.Read(clientset, namespace)
_, connectionURI, err := database.GetConnection(ctx)
if err != nil {
return err
}
Expand All @@ -98,7 +98,7 @@ func ShellCmd() *cobra.Command {
podImage = "mysql:latest"
}

connectionURI, err := database.Spec.Connection.Mysql.URI.Read(clientset, namespace)
_, connectionURI, err := database.GetConnection(ctx)
if err != nil {
return err
}
Expand Down

0 comments on commit 0e6756b

Please sign in to comment.