Skip to content

Commit

Permalink
SNOW-857829 Fix username and password requiredness (#846)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus committed Jul 14, 2023
1 parent 7d6e39a commit 3623a16
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 109 deletions.
22 changes: 15 additions & 7 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,16 +390,11 @@ func fillMissingConfigParameters(cfg *Config) error {
return ErrEmptyAccount
}

if cfg.Authenticator != AuthTypeOAuth && strings.Trim(cfg.User, " ") == "" {
// oauth does not require a username
if authRequiresUser(cfg) && strings.TrimSpace(cfg.User) == "" {
return ErrEmptyUsername
}

if cfg.Authenticator != AuthTypeExternalBrowser &&
cfg.Authenticator != AuthTypeOAuth &&
cfg.Authenticator != AuthTypeJwt &&
strings.Trim(cfg.Password, " ") == "" {
// no password parameter is required for EXTERNALBROWSER, OAUTH or JWT.
if authRequiresPassword(cfg) && strings.TrimSpace(cfg.Password) == "" {
return ErrEmptyPassword
}
if strings.Trim(cfg.Protocol, " ") == "" {
Expand Down Expand Up @@ -467,6 +462,19 @@ func fillMissingConfigParameters(cfg *Config) error {
return nil
}

func authRequiresUser(cfg *Config) bool {
return cfg.Authenticator != AuthTypeOAuth &&
cfg.Authenticator != AuthTypeTokenAccessor &&
cfg.Authenticator != AuthTypeExternalBrowser
}

func authRequiresPassword(cfg *Config) bool {
return cfg.Authenticator != AuthTypeOAuth &&
cfg.Authenticator != AuthTypeTokenAccessor &&
cfg.Authenticator != AuthTypeExternalBrowser &&
cfg.Authenticator != AuthTypeJwt
}

// transformAccountToHost transforms host to account name
func transformAccountToHost(cfg *Config) (err error) {
if cfg.Port == 0 && !strings.HasSuffix(cfg.Host, defaultDomain) && cfg.Host != "" {
Expand Down
263 changes: 161 additions & 102 deletions dsn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -582,112 +582,171 @@ func TestParseDSN(t *testing.T) {
},
}

for _, at := range []AuthType{AuthTypeExternalBrowser, AuthTypeOAuth} {
testcases = append(testcases, tcParseDSN{
dsn: fmt.Sprintf("@host:777/db/schema?account=ac&protocol=http&authenticator=%v", strings.ToLower(at.String())),
config: &Config{
Account: "ac", User: "", Password: "",
Protocol: "http", Host: "host", Port: 777,
Database: "db", Schema: "schema",
OCSPFailOpen: OCSPFailOpenTrue,
ValidateDefaultParameters: ConfigBoolTrue,
ClientTimeout: defaultClientTimeout,
JWTClientTimeout: defaultJWTClientTimeout,
ExternalBrowserTimeout: defaultExternalBrowserTimeout,
Authenticator: at,
},
ocspMode: ocspModeFailOpen,
err: nil,
})
}

for _, at := range []AuthType{AuthTypeSnowflake, AuthTypeUsernamePasswordMFA, AuthTypeJwt} {
testcases = append(testcases, tcParseDSN{
dsn: fmt.Sprintf("@host:888/db/schema?account=ac&protocol=http&authenticator=%v", strings.ToLower(at.String())),
config: &Config{
Account: "ac", User: "", Password: "",
Protocol: "http", Host: "host", Port: 888,
Database: "db", Schema: "schema",
OCSPFailOpen: OCSPFailOpenTrue,
ValidateDefaultParameters: ConfigBoolTrue,
ClientTimeout: defaultClientTimeout,
JWTClientTimeout: defaultJWTClientTimeout,
ExternalBrowserTimeout: defaultExternalBrowserTimeout,
Authenticator: at,
},
ocspMode: ocspModeFailOpen,
err: ErrEmptyUsername,
})
}

for _, at := range []AuthType{AuthTypeSnowflake, AuthTypeUsernamePasswordMFA} {
testcases = append(testcases, tcParseDSN{
dsn: fmt.Sprintf("user@host:888/db/schema?account=ac&protocol=http&authenticator=%v", strings.ToLower(at.String())),
config: &Config{
Account: "ac", User: "user", Password: "",
Protocol: "http", Host: "host", Port: 888,
Database: "db", Schema: "schema",
OCSPFailOpen: OCSPFailOpenTrue,
ValidateDefaultParameters: ConfigBoolTrue,
ClientTimeout: defaultClientTimeout,
JWTClientTimeout: defaultJWTClientTimeout,
ExternalBrowserTimeout: defaultExternalBrowserTimeout,
Authenticator: at,
},
ocspMode: ocspModeFailOpen,
err: ErrEmptyPassword,
})
}

for i, test := range testcases {
// t.Logf("Parsing testcase %d, DSN: %s", i, test.dsn)
cfg, err := ParseDSN(test.dsn)
switch {
case test.err == nil:
if err != nil {
t.Fatalf("%d: Failed to parse the DSN. dsn: %v, err: %v", i, test.dsn, err)
}
if test.config.Host != cfg.Host {
t.Fatalf("%d: Failed to match host. expected: %v, got: %v",
i, test.config.Host, cfg.Host)
}
if test.config.Account != cfg.Account {
t.Fatalf("%d: Failed to match account. expected: %v, got: %v",
i, test.config.Account, cfg.Account)
}
if test.config.User != cfg.User {
t.Fatalf("%d: Failed to match user. expected: %v, got: %v",
i, test.config.User, cfg.User)
}
if test.config.Password != cfg.Password {
t.Fatalf("%d: Failed to match password. expected: %v, got: %v",
i, test.config.Password, cfg.Password)
}
if test.config.Database != cfg.Database {
t.Fatalf("%d: Failed to match database. expected: %v, got: %v",
i, test.config.Database, cfg.Database)
}
if test.config.Schema != cfg.Schema {
t.Fatalf("%d: Failed to match schema. expected: %v, got: %v",
i, test.config.Schema, cfg.Schema)
}
if test.config.Warehouse != cfg.Warehouse {
t.Fatalf("%d: Failed to match warehouse. expected: %v, got: %v",
i, test.config.Warehouse, cfg.Warehouse)
}
if test.config.Role != cfg.Role {
t.Fatalf("%d: Failed to match role. expected: %v, got: %v",
i, test.config.Role, cfg.Role)
}
if test.config.Region != cfg.Region {
t.Fatalf("%d: Failed to match region. expected: %v, got: %v",
i, test.config.Region, cfg.Region)
}
if test.config.Protocol != cfg.Protocol {
t.Fatalf("%d: Failed to match protocol. expected: %v, got: %v",
i, test.config.Protocol, cfg.Protocol)
}
if test.config.Passcode != cfg.Passcode {
t.Fatalf("%d: Failed to match passcode. expected: %v, got: %v",
i, test.config.Passcode, cfg.Passcode)
}
if test.config.PasscodeInPassword != cfg.PasscodeInPassword {
t.Fatalf("%d: Failed to match passcodeInPassword. expected: %v, got: %v",
i, test.config.PasscodeInPassword, cfg.PasscodeInPassword)
}
if test.config.Authenticator != cfg.Authenticator {
t.Fatalf("%d: Failed to match Authenticator. expected: %v, got: %v",
i, test.config.Authenticator.String(), cfg.Authenticator.String())
}
if test.config.Authenticator == AuthTypeOkta && *test.config.OktaURL != *cfg.OktaURL {
t.Fatalf("%d: Failed to match okta URL. expected: %v, got: %v",
i, test.config.OktaURL, cfg.OktaURL)
}
if test.config.OCSPFailOpen != cfg.OCSPFailOpen {
t.Fatalf("%d: Failed to match OCSPFailOpen. expected: %v, got: %v",
i, test.config.OCSPFailOpen, cfg.OCSPFailOpen)
}
if test.ocspMode != cfg.ocspMode() {
t.Fatalf("%d: Failed to match OCSPMode. expected: %v, got: %v",
i, test.ocspMode, cfg.ocspMode())
}
if test.config.ValidateDefaultParameters != cfg.ValidateDefaultParameters {
t.Fatalf("%d: Failed to match ValidateDefaultParameters. expected: %v, got: %v",
i, test.config.ValidateDefaultParameters, cfg.ValidateDefaultParameters)
}
if test.config.ClientTimeout != cfg.ClientTimeout {
t.Fatalf("%d: Failed to match ClientTimeout. expected: %v, got: %v",
i, test.config.ClientTimeout, cfg.ClientTimeout)
}
if test.config.JWTClientTimeout != cfg.JWTClientTimeout {
t.Fatalf("%d: Failed to match JWTClientTimeout. expected: %v, got: %v",
i, test.config.JWTClientTimeout, cfg.JWTClientTimeout)
}
if test.config.ExternalBrowserTimeout != cfg.ExternalBrowserTimeout {
t.Fatalf("%d: Failed to match ExternalBrowserTimeout. expected: %v, got: %v",
i, test.config.ExternalBrowserTimeout, cfg.ExternalBrowserTimeout)
}
case test.err != nil:
driverErrE, okE := test.err.(*SnowflakeError)
driverErrG, okG := err.(*SnowflakeError)
if okE && !okG || !okE && okG {
t.Fatalf("%d: Wrong error. expected: %v, got: %v", i, test.err, err)
}
if okE && okG {
if driverErrE.Number != driverErrG.Number {
t.Fatalf("%d: Wrong error number. expected: %v, got: %v", i, driverErrE.Number, driverErrG.Number)
t.Run("TestParseDSN", func(t *testing.T) {
cfg, err := ParseDSN(test.dsn)
switch {
case test.err == nil:
if err != nil {
t.Fatalf("%d: Failed to parse the DSN. dsn: %v, err: %v", i, test.dsn, err)
}
if test.config.Host != cfg.Host {
t.Fatalf("%d: Failed to match host. expected: %v, got: %v",
i, test.config.Host, cfg.Host)
}
if test.config.Account != cfg.Account {
t.Fatalf("%d: Failed to match account. expected: %v, got: %v",
i, test.config.Account, cfg.Account)
}
if test.config.User != cfg.User {
t.Fatalf("%d: Failed to match user. expected: %v, got: %v",
i, test.config.User, cfg.User)
}
} else {
t1 := reflect.TypeOf(err)
t2 := reflect.TypeOf(test.err)
if t1 != t2 {
t.Fatalf("%d: Wrong error. expected: %T:%v, got: %T:%v", i, test.err, test.err, err, err)
if test.config.Password != cfg.Password {
t.Fatalf("%d: Failed to match password. expected: %v, got: %v",
i, test.config.Password, cfg.Password)
}
if test.config.Database != cfg.Database {
t.Fatalf("%d: Failed to match database. expected: %v, got: %v",
i, test.config.Database, cfg.Database)
}
if test.config.Schema != cfg.Schema {
t.Fatalf("%d: Failed to match schema. expected: %v, got: %v",
i, test.config.Schema, cfg.Schema)
}
if test.config.Warehouse != cfg.Warehouse {
t.Fatalf("%d: Failed to match warehouse. expected: %v, got: %v",
i, test.config.Warehouse, cfg.Warehouse)
}
if test.config.Role != cfg.Role {
t.Fatalf("%d: Failed to match role. expected: %v, got: %v",
i, test.config.Role, cfg.Role)
}
if test.config.Region != cfg.Region {
t.Fatalf("%d: Failed to match region. expected: %v, got: %v",
i, test.config.Region, cfg.Region)
}
if test.config.Protocol != cfg.Protocol {
t.Fatalf("%d: Failed to match protocol. expected: %v, got: %v",
i, test.config.Protocol, cfg.Protocol)
}
if test.config.Passcode != cfg.Passcode {
t.Fatalf("%d: Failed to match passcode. expected: %v, got: %v",
i, test.config.Passcode, cfg.Passcode)
}
if test.config.PasscodeInPassword != cfg.PasscodeInPassword {
t.Fatalf("%d: Failed to match passcodeInPassword. expected: %v, got: %v",
i, test.config.PasscodeInPassword, cfg.PasscodeInPassword)
}
if test.config.Authenticator != cfg.Authenticator {
t.Fatalf("%d: Failed to match Authenticator. expected: %v, got: %v",
i, test.config.Authenticator.String(), cfg.Authenticator.String())
}
if test.config.Authenticator == AuthTypeOkta && *test.config.OktaURL != *cfg.OktaURL {
t.Fatalf("%d: Failed to match okta URL. expected: %v, got: %v",
i, test.config.OktaURL, cfg.OktaURL)
}
if test.config.OCSPFailOpen != cfg.OCSPFailOpen {
t.Fatalf("%d: Failed to match OCSPFailOpen. expected: %v, got: %v",
i, test.config.OCSPFailOpen, cfg.OCSPFailOpen)
}
if test.ocspMode != cfg.ocspMode() {
t.Fatalf("%d: Failed to match OCSPMode. expected: %v, got: %v",
i, test.ocspMode, cfg.ocspMode())
}
if test.config.ValidateDefaultParameters != cfg.ValidateDefaultParameters {
t.Fatalf("%d: Failed to match ValidateDefaultParameters. expected: %v, got: %v",
i, test.config.ValidateDefaultParameters, cfg.ValidateDefaultParameters)
}
if test.config.ClientTimeout != cfg.ClientTimeout {
t.Fatalf("%d: Failed to match ClientTimeout. expected: %v, got: %v",
i, test.config.ClientTimeout, cfg.ClientTimeout)
}
if test.config.JWTClientTimeout != cfg.JWTClientTimeout {
t.Fatalf("%d: Failed to match JWTClientTimeout. expected: %v, got: %v",
i, test.config.JWTClientTimeout, cfg.JWTClientTimeout)
}
if test.config.ExternalBrowserTimeout != cfg.ExternalBrowserTimeout {
t.Fatalf("%d: Failed to match ExternalBrowserTimeout. expected: %v, got: %v",
i, test.config.ExternalBrowserTimeout, cfg.ExternalBrowserTimeout)
}
case test.err != nil:
driverErrE, okE := test.err.(*SnowflakeError)
driverErrG, okG := err.(*SnowflakeError)
if okE && !okG || !okE && okG {
t.Fatalf("%d: Wrong error. expected: %v, got: %v", i, test.err, err)
}
if okE && okG {
if driverErrE.Number != driverErrG.Number {
t.Fatalf("%d: Wrong error number. expected: %v, got: %v", i, driverErrE.Number, driverErrG.Number)
}
} else {
t1 := reflect.TypeOf(err)
t2 := reflect.TypeOf(test.err)
if t1 != t2 {
t.Fatalf("%d: Wrong error. expected: %T:%v, got: %T:%v", i, test.err, test.err, err, err)
}
}
}
}

})
}
}

Expand Down

0 comments on commit 3623a16

Please sign in to comment.