Skip to content

Commit

Permalink
Add tests to increase code coverage (#784)
Browse files Browse the repository at this point in the history
add tests to increase code coverage to baseline 80%
  • Loading branch information
sfc-gh-ext-simba-lb committed May 26, 2023
1 parent 90cde5d commit cd0f451
Show file tree
Hide file tree
Showing 34 changed files with 606,691 additions and 97 deletions.
29 changes: 28 additions & 1 deletion async_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved.
// Copyright (c) 2021-2023 Snowflake Computing Inc. All rights reserved.

package gosnowflake

Expand Down Expand Up @@ -44,6 +44,33 @@ func TestAsyncMode(t *testing.T) {
})
}

func TestAsyncModeMultiStatement(t *testing.T) {
withMultiStmtCtx, _ := WithMultiStatement(context.Background(), 6)
ctx := WithAsyncMode(withMultiStmtCtx)
multiStmtQuery := "begin;\n" +
"delete from test_multi_statement_async;\n" +
"insert into test_multi_statement_async values (1, 'a'), (2, 'b');\n" +
"select 1;\n" +
"select 2;\n" +
"rollback;"

runTests(t, dsn, func(dbt *DBTest) {
dbt.mustExec("drop table if exists test_multi_statement_async")
dbt.mustExec(`create or replace table test_multi_statement_async(
c1 number, c2 string) as select 10, 'z'`)
defer dbt.mustExec("drop table if exists test_multi_statement_async")

res := dbt.mustExecContext(ctx, multiStmtQuery)
count, err := res.RowsAffected()
if err != nil {
t.Fatalf("res.RowsAffected() returned error: %v", err)
}
if count != 3 {
t.Fatalf("expected 3 affected rows, got %d", count)
}
})
}

func TestAsyncModeCancel(t *testing.T) {
withCancelCtx, cancel := context.WithCancel(context.Background())
ctx := WithAsyncMode(withCancelCtx)
Expand Down
121 changes: 121 additions & 0 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,45 @@ func postAuthCheckUsernamePasswordMfa(_ context.Context, _ *snowflakeRestful, _
}, nil
}

func postAuthCheckUsernamePasswordMfaToken(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) {
var ar authRequest
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}

if ar.Data.Token != "mockedMfaToken" {
return nil, fmt.Errorf("unexpected mfa token: %v", ar.Data.Token)
}
return &authResponse{
Success: true,
Data: authResponseMain{
Token: "t",
MasterToken: "m",
MfaToken: "mockedMfaToken",
SessionInfo: authResponseSessionInfo{
DatabaseName: "dbn",
},
},
}, nil
}

func postAuthCheckUsernamePasswordMfaFailed(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) {
var ar authRequest
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}

if ar.Data.Token != "mockedMfaToken" {
return nil, fmt.Errorf("unexpected mfa token: %v", ar.Data.Token)
}
return &authResponse{
Success: false,
Data: authResponseMain{},
Message: "auth failed",
Code: "260008",
}, nil
}

func postAuthCheckExternalBrowser(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) {
var ar authRequest
if err := json.Unmarshal(jsonBody, &ar); err != nil {
Expand All @@ -273,6 +312,45 @@ func postAuthCheckExternalBrowser(_ context.Context, _ *snowflakeRestful, _ *url
}, nil
}

func postAuthCheckExternalBrowserToken(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) {
var ar authRequest
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}

if ar.Data.Token != "mockedIDToken" {
return nil, fmt.Errorf("unexpected mfatoken: %v", ar.Data.Token)
}
return &authResponse{
Success: true,
Data: authResponseMain{
Token: "t",
MasterToken: "m",
IDToken: "mockedIDToken",
SessionInfo: authResponseSessionInfo{
DatabaseName: "dbn",
},
},
}, nil
}

func postAuthCheckExternalBrowserFailed(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) {
var ar authRequest
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}

if ar.Data.SessionParameters["CLIENT_STORE_TEMPORARY_CREDENTIAL"] != true {
return nil, fmt.Errorf("expected client_store_temporary_credential to be true but was %v", ar.Data.SessionParameters["CLIENT_STORE_TEMPORARY_CREDENTIAL"])
}
return &authResponse{
Success: false,
Data: authResponseMain{},
Message: "auth failed",
Code: "260008",
}, nil
}

func getDefaultSnowflakeConn() *snowflakeConn {
cfg := Config{
Account: "a",
Expand Down Expand Up @@ -531,6 +609,36 @@ func TestUnitAuthenticateUsernamePasswordMfa(t *testing.T) {
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}

sr.FuncPostAuth = postAuthCheckUsernamePasswordMfaToken
sc.cfg.MfaToken = "mockedMfaToken"
_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}

sr.FuncPostAuth = postAuthCheckUsernamePasswordMfaFailed
_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
if err == nil {
t.Fatal("should have failed")
}
}

func TestUnitAuthenticateWithConfigMFA(t *testing.T) {
var err error
sr := &snowflakeRestful{
FuncPostAuth: postAuthCheckUsernamePasswordMfa,
TokenAccessor: getSimpleTokenAccessor(),
}
sc := getDefaultSnowflakeConn()
sc.cfg.Authenticator = AuthTypeUsernamePasswordMFA
sc.cfg.ClientRequestMfaToken = ConfigBoolTrue
sc.rest = sr
sc.ctx = context.TODO()
err = authenticateWithConfig(sc)
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}
}

func TestUnitAuthenticateExternalBrowser(t *testing.T) {
Expand All @@ -547,6 +655,19 @@ func TestUnitAuthenticateExternalBrowser(t *testing.T) {
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}

sr.FuncPostAuth = postAuthCheckExternalBrowserToken
sc.cfg.IDToken = "mockedIDToken"
_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}

sr.FuncPostAuth = postAuthCheckExternalBrowserFailed
_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
if err == nil {
t.Fatal("should have failed")
}
}

// To run this test you need to set environment variables in parameters.json to a user with MFA authentication enabled
Expand Down
14 changes: 14 additions & 0 deletions authexternalbrowser.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,20 @@ func getIdpURLProofKey(
if err != nil {
return "", "", err
}
if !respd.Success {
logger.Errorln("Authentication FAILED")
sr.TokenAccessor.SetTokens("", "", -1)
code, err := strconv.Atoi(respd.Code)
if err != nil {
code = -1
return "", "", err
}
return "", "", &SnowflakeError{
Number: code,
SQLState: SQLStateConnectionRejected,
Message: respd.Message,
}
}
return respd.Data.SSOURL, respd.Data.ProofKey, nil
}

Expand Down
68 changes: 68 additions & 0 deletions authexternalbrowser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
package gosnowflake

import (
"context"
"errors"
"strings"
"testing"
"time"
)

func TestGetTokenFromResponseFail(t *testing.T) {
Expand Down Expand Up @@ -44,3 +48,67 @@ func TestGetTokenFromResponse(t *testing.T) {
t.Errorf("Expected: %s, found: %s", expected, token)
}
}

func TestBuildResponse(t *testing.T) {
resp := buildResponse("Go")
bytes := resp.Bytes()
respStr := string(bytes[:])
if !strings.Contains(respStr, "Your identity was confirmed and propagated to Snowflake Go.\nYou can close this window now and go back where you started from.") {
t.Fatalf("failed to build response")
}
}

func postAuthExternalBrowserError(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
return &authResponse{}, errors.New("failed to get SAML response")
}

func postAuthExternalBrowserFail(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
return &authResponse{
Success: false,
Message: "external browser auth failed",
}, nil
}

func postAuthExternalBrowserFailWithCode(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
return &authResponse{
Success: false,
Message: "failed to connect to db",
Code: "260008",
}, nil
}

func TestUnitAuthenticateByExternalBrowser(t *testing.T) {
authenticator := "externalbrowser"
application := "testapp"
account := "testaccount"
user := "u"
password := "p"
sr := &snowflakeRestful{
Protocol: "https",
Host: "abc.com",
Port: 443,
FuncPostAuthSAML: postAuthExternalBrowserError,
TokenAccessor: getSimpleTokenAccessor(),
}
_, _, err := authenticateByExternalBrowser(context.TODO(), sr, authenticator, application, account, user, password)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncPostAuthSAML = postAuthExternalBrowserFail
_, _, err = authenticateByExternalBrowser(context.TODO(), sr, authenticator, application, account, user, password)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncPostAuthSAML = postAuthExternalBrowserFailWithCode
_, _, err = authenticateByExternalBrowser(context.TODO(), sr, authenticator, application, account, user, password)
if err == nil {
t.Fatal("should have failed.")
}
driverErr, ok := err.(*SnowflakeError)
if !ok {
t.Fatalf("should be snowflake error. err: %v", err)
}
if driverErr.Number != ErrCodeFailedToConnect {
t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeFailedToConnect, driverErr.Number)
}
}
38 changes: 32 additions & 6 deletions azure_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ type azureLocation struct {
path string
}

type azureAPI interface {
UploadStream(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error)
UploadFile(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error)
DownloadFile(ctx context.Context, file *os.File, o *blob.DownloadFileOptions) (int64, error)
GetProperties(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error)
}

func (util *snowflakeAzureClient) createClient(info *execResponseStageInfo, _ bool) (cloudClient, error) {
sasToken := info.Creds.AzureSasToken
u, err := url.Parse(fmt.Sprintf("https://%s.%s/%s%s", info.StorageAccount, info.EndPoint, info.Path, sasToken))
Expand Down Expand Up @@ -68,7 +75,12 @@ func (util *snowflakeAzureClient) getFileHeader(meta *fileMetadata, filename str
Message: "failed to create container client",
}
}
blobClient := containerClient.NewBlockBlobClient(path)
var blobClient azureAPI
blobClient = containerClient.NewBlockBlobClient(path)
// for testing only
if meta.mockAzureClient != nil {
blobClient = meta.mockAzureClient
}
resp, err := blobClient.GetProperties(context.Background(), &blob.GetPropertiesOptions{
AccessConditions: &blob.AccessConditions{},
CPKInfo: &blob.CPKInfo{},
Expand All @@ -91,8 +103,12 @@ func (util *snowflakeAzureClient) getFileHeader(meta *fileMetadata, filename str
meta.resStatus = uploaded
metadata := resp.Metadata
var encData encryptionData
if err = json.Unmarshal([]byte(*metadata["Encryptiondata"]), &encData); err != nil {
return nil, err

_, ok = metadata["Encryptiondata"]
if ok {
if err = json.Unmarshal([]byte(*metadata["Encryptiondata"]), &encData); err != nil {
return nil, err
}
}

matdesc, ok := metadata["Matdesc"]
Expand Down Expand Up @@ -171,7 +187,12 @@ func (util *snowflakeAzureClient) uploadFile(
Message: "failed to create container client",
}
}
blobClient := containerClient.NewBlockBlobClient(path)
var blobClient azureAPI
blobClient = containerClient.NewBlockBlobClient(path)
// for testing only
if meta.mockAzureClient != nil {
blobClient = meta.mockAzureClient
}
if meta.srcStream != nil {
uploadSrc := meta.srcStream
if meta.realSrcStream != nil {
Expand Down Expand Up @@ -207,7 +228,7 @@ func (util *snowflakeAzureClient) uploadFile(
if err != nil {
var se *azcore.ResponseError
if errors.As(err, &se) {
if se.StatusCode == 403 && util.detectAzureTokenExpireError(se.RawResponse.Request.Response) {
if se.StatusCode == 403 && util.detectAzureTokenExpireError(se.RawResponse) {
meta.resStatus = renewToken
} else {
meta.resStatus = needRetry
Expand Down Expand Up @@ -246,7 +267,12 @@ func (util *snowflakeAzureClient) nativeDownloadFile(
Message: "failed to create container client",
}
}
blobClient := containerClient.NewBlockBlobClient(path)
var blobClient azureAPI
blobClient = containerClient.NewBlockBlobClient(path)
// for testing only
if meta.mockAzureClient != nil {
blobClient = meta.mockAzureClient
}
f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, os.ModePerm)
if err != nil {
return err
Expand Down
Loading

0 comments on commit cd0f451

Please sign in to comment.