Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests to increase code coverage #784

Merged
merged 4 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion async_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved.
// Copyright (c) 2021-2023 Snowflake Computing Inc. All rights reserved.

package gosnowflake

Expand Down Expand Up @@ -44,6 +44,33 @@ func TestAsyncMode(t *testing.T) {
})
}

func TestAsyncModeMultiStatement(t *testing.T) {
withMultiStmtCtx, _ := WithMultiStatement(context.Background(), 6)
ctx := WithAsyncMode(withMultiStmtCtx)
multiStmtQuery := "begin;\n" +
"delete from test_multi_statement_async;\n" +
"insert into test_multi_statement_async values (1, 'a'), (2, 'b');\n" +
"select 1;\n" +
"select 2;\n" +
"rollback;"

runTests(t, dsn, func(dbt *DBTest) {
dbt.mustExec("drop table if exists test_multi_statement_async")
dbt.mustExec(`create or replace table test_multi_statement_async(
c1 number, c2 string) as select 10, 'z'`)
defer dbt.mustExec("drop table if exists test_multi_statement_async")

res := dbt.mustExecContext(ctx, multiStmtQuery)
count, err := res.RowsAffected()
if err != nil {
t.Fatalf("res.RowsAffected() returned error: %v", err)
}
if count != 3 {
t.Fatalf("expected 3 affected rows, got %d", count)
}
})
}

func TestAsyncModeCancel(t *testing.T) {
withCancelCtx, cancel := context.WithCancel(context.Background())
ctx := WithAsyncMode(withCancelCtx)
Expand Down
121 changes: 121 additions & 0 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,45 @@ func postAuthCheckUsernamePasswordMfa(_ context.Context, _ *snowflakeRestful, _
}, nil
}

func postAuthCheckUsernamePasswordMfaToken(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) {
var ar authRequest
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}

if ar.Data.Token != "mockedMfaToken" {
return nil, fmt.Errorf("unexpected mfa token: %v", ar.Data.Token)
}
return &authResponse{
Success: true,
Data: authResponseMain{
Token: "t",
MasterToken: "m",
MfaToken: "mockedMfaToken",
SessionInfo: authResponseSessionInfo{
DatabaseName: "dbn",
},
},
}, nil
}

func postAuthCheckUsernamePasswordMfaFailed(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) {
var ar authRequest
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}

if ar.Data.Token != "mockedMfaToken" {
return nil, fmt.Errorf("unexpected mfa token: %v", ar.Data.Token)
}
return &authResponse{
Success: false,
Data: authResponseMain{},
Message: "auth failed",
Code: "260008",
}, nil
}

func postAuthCheckExternalBrowser(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) {
var ar authRequest
if err := json.Unmarshal(jsonBody, &ar); err != nil {
Expand All @@ -273,6 +312,45 @@ func postAuthCheckExternalBrowser(_ context.Context, _ *snowflakeRestful, _ *url
}, nil
}

func postAuthCheckExternalBrowserToken(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) {
var ar authRequest
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}

if ar.Data.Token != "mockedIDToken" {
return nil, fmt.Errorf("unexpected mfatoken: %v", ar.Data.Token)
}
return &authResponse{
Success: true,
Data: authResponseMain{
Token: "t",
MasterToken: "m",
IDToken: "mockedIDToken",
SessionInfo: authResponseSessionInfo{
DatabaseName: "dbn",
},
},
}, nil
}

func postAuthCheckExternalBrowserFailed(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) {
var ar authRequest
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}

if ar.Data.SessionParameters["CLIENT_STORE_TEMPORARY_CREDENTIAL"] != true {
return nil, fmt.Errorf("expected client_store_temporary_credential to be true but was %v", ar.Data.SessionParameters["CLIENT_STORE_TEMPORARY_CREDENTIAL"])
}
return &authResponse{
Success: false,
Data: authResponseMain{},
Message: "auth failed",
Code: "260008",
}, nil
}

func getDefaultSnowflakeConn() *snowflakeConn {
cfg := Config{
Account: "a",
Expand Down Expand Up @@ -531,6 +609,36 @@ func TestUnitAuthenticateUsernamePasswordMfa(t *testing.T) {
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}

sr.FuncPostAuth = postAuthCheckUsernamePasswordMfaToken
sc.cfg.MfaToken = "mockedMfaToken"
_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}

sr.FuncPostAuth = postAuthCheckUsernamePasswordMfaFailed
_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
if err == nil {
t.Fatal("should have failed")
}
}

func TestUnitAuthenticateWithConfigMFA(t *testing.T) {
var err error
sr := &snowflakeRestful{
FuncPostAuth: postAuthCheckUsernamePasswordMfa,
TokenAccessor: getSimpleTokenAccessor(),
}
sc := getDefaultSnowflakeConn()
sc.cfg.Authenticator = AuthTypeUsernamePasswordMFA
sc.cfg.ClientRequestMfaToken = ConfigBoolTrue
sc.rest = sr
sc.ctx = context.TODO()
err = authenticateWithConfig(sc)
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}
}

func TestUnitAuthenticateExternalBrowser(t *testing.T) {
Expand All @@ -547,6 +655,19 @@ func TestUnitAuthenticateExternalBrowser(t *testing.T) {
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}

sr.FuncPostAuth = postAuthCheckExternalBrowserToken
sc.cfg.IDToken = "mockedIDToken"
_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}

sr.FuncPostAuth = postAuthCheckExternalBrowserFailed
_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
if err == nil {
t.Fatal("should have failed")
}
}

// To run this test you need to set environment variables in parameters.json to a user with MFA authentication enabled
Expand Down
14 changes: 14 additions & 0 deletions authexternalbrowser.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,20 @@ func getIdpURLProofKey(
if err != nil {
return "", "", err
}
if !respd.Success {
logger.Errorln("Authentication FAILED")
sr.TokenAccessor.SetTokens("", "", -1)
code, err := strconv.Atoi(respd.Code)
if err != nil {
code = -1
return "", "", err
}
return "", "", &SnowflakeError{
Number: code,
SQLState: SQLStateConnectionRejected,
Message: respd.Message,
}
}
return respd.Data.SSOURL, respd.Data.ProofKey, nil
}

Expand Down
68 changes: 68 additions & 0 deletions authexternalbrowser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
package gosnowflake

import (
"context"
"errors"
"strings"
"testing"
"time"
)

func TestGetTokenFromResponseFail(t *testing.T) {
Expand Down Expand Up @@ -44,3 +48,67 @@ func TestGetTokenFromResponse(t *testing.T) {
t.Errorf("Expected: %s, found: %s", expected, token)
}
}

func TestBuildResponse(t *testing.T) {
resp := buildResponse("Go")
bytes := resp.Bytes()
respStr := string(bytes[:])
if !strings.Contains(respStr, "Your identity was confirmed and propagated to Snowflake Go.\nYou can close this window now and go back where you started from.") {
t.Fatalf("failed to build response")
}
}

func postAuthExternalBrowserError(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
return &authResponse{}, errors.New("failed to get SAML response")
}

func postAuthExternalBrowserFail(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
return &authResponse{
Success: false,
Message: "external browser auth failed",
}, nil
}

func postAuthExternalBrowserFailWithCode(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
return &authResponse{
Success: false,
Message: "failed to connect to db",
Code: "260008",
}, nil
}

func TestUnitAuthenticateByExternalBrowser(t *testing.T) {
authenticator := "externalbrowser"
application := "testapp"
account := "testaccount"
user := "u"
password := "p"
sr := &snowflakeRestful{
Protocol: "https",
Host: "abc.com",
Port: 443,
FuncPostAuthSAML: postAuthExternalBrowserError,
TokenAccessor: getSimpleTokenAccessor(),
}
_, _, err := authenticateByExternalBrowser(context.TODO(), sr, authenticator, application, account, user, password)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncPostAuthSAML = postAuthExternalBrowserFail
_, _, err = authenticateByExternalBrowser(context.TODO(), sr, authenticator, application, account, user, password)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncPostAuthSAML = postAuthExternalBrowserFailWithCode
_, _, err = authenticateByExternalBrowser(context.TODO(), sr, authenticator, application, account, user, password)
if err == nil {
t.Fatal("should have failed.")
}
driverErr, ok := err.(*SnowflakeError)
if !ok {
t.Fatalf("should be snowflake error. err: %v", err)
}
if driverErr.Number != ErrCodeFailedToConnect {
t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeFailedToConnect, driverErr.Number)
}
}
38 changes: 32 additions & 6 deletions azure_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ type azureLocation struct {
path string
}

type azureAPI interface {
UploadStream(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error)
UploadFile(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error)
DownloadFile(ctx context.Context, file *os.File, o *blob.DownloadFileOptions) (int64, error)
GetProperties(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error)
}

func (util *snowflakeAzureClient) createClient(info *execResponseStageInfo, _ bool) (cloudClient, error) {
sasToken := info.Creds.AzureSasToken
u, err := url.Parse(fmt.Sprintf("https://%s.%s/%s%s", info.StorageAccount, info.EndPoint, info.Path, sasToken))
Expand Down Expand Up @@ -68,7 +75,12 @@ func (util *snowflakeAzureClient) getFileHeader(meta *fileMetadata, filename str
Message: "failed to create container client",
}
}
blobClient := containerClient.NewBlockBlobClient(path)
var blobClient azureAPI
blobClient = containerClient.NewBlockBlobClient(path)
// for testing only
if meta.mockAzureClient != nil {
blobClient = meta.mockAzureClient
}
resp, err := blobClient.GetProperties(context.Background(), &blob.GetPropertiesOptions{
AccessConditions: &blob.AccessConditions{},
CPKInfo: &blob.CPKInfo{},
Expand All @@ -91,8 +103,12 @@ func (util *snowflakeAzureClient) getFileHeader(meta *fileMetadata, filename str
meta.resStatus = uploaded
metadata := resp.Metadata
var encData encryptionData
if err = json.Unmarshal([]byte(*metadata["Encryptiondata"]), &encData); err != nil {
return nil, err

_, ok = metadata["Encryptiondata"]
if ok {
if err = json.Unmarshal([]byte(*metadata["Encryptiondata"]), &encData); err != nil {
return nil, err
}
}

matdesc, ok := metadata["Matdesc"]
Expand Down Expand Up @@ -171,7 +187,12 @@ func (util *snowflakeAzureClient) uploadFile(
Message: "failed to create container client",
}
}
blobClient := containerClient.NewBlockBlobClient(path)
var blobClient azureAPI
blobClient = containerClient.NewBlockBlobClient(path)
// for testing only
if meta.mockAzureClient != nil {
blobClient = meta.mockAzureClient
}
if meta.srcStream != nil {
uploadSrc := meta.srcStream
if meta.realSrcStream != nil {
Expand Down Expand Up @@ -207,7 +228,7 @@ func (util *snowflakeAzureClient) uploadFile(
if err != nil {
var se *azcore.ResponseError
if errors.As(err, &se) {
if se.StatusCode == 403 && util.detectAzureTokenExpireError(se.RawResponse.Request.Response) {
if se.StatusCode == 403 && util.detectAzureTokenExpireError(se.RawResponse) {
meta.resStatus = renewToken
} else {
meta.resStatus = needRetry
Expand Down Expand Up @@ -246,7 +267,12 @@ func (util *snowflakeAzureClient) nativeDownloadFile(
Message: "failed to create container client",
}
}
blobClient := containerClient.NewBlockBlobClient(path)
var blobClient azureAPI
blobClient = containerClient.NewBlockBlobClient(path)
// for testing only
if meta.mockAzureClient != nil {
blobClient = meta.mockAzureClient
}
f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, os.ModePerm)
if err != nil {
return err
Expand Down
Loading
Loading