Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SIG-18794: Make getAsync wait until the query completes or times out. #59

Merged
merged 2 commits into from
Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ wss-golang-agent.config
wss-unified-agent.jar
whitesource/
*.swp
/vendor/github.com/snowflakedb/gosnowflake
129 changes: 77 additions & 52 deletions async.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ import (
"time"
)

func isAsyncModeNoFetch(ctx context.Context) bool {
if flag, ok := ctx.Value(asyncModeNoFetch).(bool); ok && flag {
return true
}

return false
}

func (sr *snowflakeRestful) processAsync(
ctx context.Context,
respd *execResponse,
Expand Down Expand Up @@ -66,45 +74,39 @@ func (sr *snowflakeRestful) getAsync(
defer close(errChannel)
token, _, _ := sr.TokenAccessor.GetTokens()
headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)
resp, err := sr.FuncGet(ctx, sr, URL, headers, timeout)
if err != nil {
logger.WithContext(ctx).Errorf("failed to get response. err: %v", err)
sfError.Message = err.Error()
errChannel <- sfError
// if we failed here because of top level context cancellation we want to cancel the original query
if err == context.Canceled || err == context.DeadlineExceeded {
// use the default top level 1 sec timeout for cancellation as throughout the driver
if err := cancelQuery(context.TODO(), sr, requestID, time.Second); err != nil {
logger.WithContext(ctx).Errorf("failed to cancel async query, err: %v", err)

// the get call pulling for result status is
var response *execResponse
var err error
for response == nil || (!response.Success && parseCode(response.Code) == ErrQueryExecutionInProgress) {
response, err = sr.getAsyncOrStatus(ctx, URL, headers, timeout)

if err != nil {
logger.WithContext(ctx).Errorf("failed to get response. err: %v", err)
if err == context.Canceled || err == context.DeadlineExceeded {
// use the default top level 1 sec timeout for cancellation as throughout the driver
if err := cancelQuery(context.TODO(), sr, requestID, time.Second); err != nil {
logger.WithContext(ctx).Errorf("failed to cancel async query, err: %v", err)
}
}
}
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}

respd := execResponse{}
err = json.NewDecoder(resp.Body).Decode(&respd)
resp.Body.Close()
if err != nil {
logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
sfError.Message = err.Error()
errChannel <- sfError
return err
sfError.Message = err.Error()
errChannel <- sfError
return err
}
}

sc := &snowflakeConn{rest: sr, cfg: cfg}
if respd.Success {
if response.Success {
if resType == execResultType {
res.insertID = -1
if isDml(respd.Data.StatementTypeID) {
res.affectedRows, err = updateRows(respd.Data)
if isDml(response.Data.StatementTypeID) {
res.affectedRows, err = updateRows(response.Data)
if err != nil {
return err
}
} else if isMultiStmt(&respd.Data) {
r, err := sc.handleMultiExec(ctx, respd.Data)
} else if isMultiStmt(&response.Data) {
r, err := sc.handleMultiExec(ctx, response.Data)
if err != nil {
res.errChannel <- err
return err
Expand All @@ -115,39 +117,62 @@ func (sr *snowflakeRestful) getAsync(
return err
}
}
res.queryID = respd.Data.QueryID
res.queryID = response.Data.QueryID
res.errChannel <- nil // mark exec status complete
} else {
rows.sc = sc
rows.queryID = respd.Data.QueryID
if isMultiStmt(&respd.Data) {
if err = sc.handleMultiQuery(ctx, respd.Data, rows); err != nil {
rows.errChannel <- err
close(errChannel)
return err
rows.queryID = response.Data.QueryID

if !isAsyncModeNoFetch(ctx) {
if isMultiStmt(&response.Data) {
if err = sc.handleMultiQuery(ctx, response.Data, rows); err != nil {
rows.errChannel <- err
close(errChannel)
return err
}
} else {
rows.addDownloader(populateChunkDownloader(ctx, sc, response.Data))
}
} else {
rows.addDownloader(populateChunkDownloader(ctx, sc, respd.Data))
_ = rows.ChunkDownloader.start()
}
rows.ChunkDownloader.start()
rows.errChannel <- nil // mark query status complete
}
} else {
var code int
if respd.Code != "" {
code, err = strconv.Atoi(respd.Code)
if err != nil {
code = -1
}
} else {
code = -1
}
errChannel <- &SnowflakeError{
Number: code,
SQLState: respd.Data.SQLState,
Message: respd.Message,
QueryID: respd.Data.QueryID,
Number: parseCode(response.Code),
SQLState: response.Data.SQLState,
Message: response.Message,
QueryID: response.Data.QueryID,
}
}
return nil
}

func parseCode(codeStr string) int {
if code, err := strconv.Atoi(codeStr); err == nil {
return code
}

return -1
}

func (sr *snowflakeRestful) getAsyncOrStatus(
ctx context.Context,
url *url.URL,
headers map[string]string,
timeout time.Duration) (*execResponse, error) {
resp, err := sr.FuncGet(ctx, sr, url, headers, timeout)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer func() { _ = resp.Body.Close() }()
}

response := &execResponse{}
if err = json.NewDecoder(resp.Body).Decode(&response); err != nil {
return nil, err
}

return response, nil
}
39 changes: 39 additions & 0 deletions async_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,45 @@ func TestAsyncMode(t *testing.T) {
})
}

func TestAsyncModeNoFetch(t *testing.T) {
ctx := WithAsyncModeNoFetch(WithAsyncMode(context.Background()))
numrows := 100000

runTests(t, dsn, func(dbt *DBTest) {
rows := dbt.mustQueryContext(ctx, fmt.Sprintf(selectRandomGenerator, numrows))
defer rows.Close()

// Next() will block and wait until results are available
if rows.Next() == true {
t.Fatalf("next should have returned no rows")
}
if err := rows.Err(); err == nil {
t.Fatalf("we should have an error thrown")
}
columns, err := rows.Columns()
if columns != nil {
t.Fatalf("there should be no column array returned")
}
if err == nil {
t.Fatalf("we should have an error thrown")
}

if rows.Scan(nil) == nil {
t.Fatalf("we should have an error thrown")
}

dbt.mustExec("create or replace table test_async_exec (value boolean)")
res := dbt.mustExecContext(ctx, "insert into test_async_exec values (true)")
count, err := res.RowsAffected()
if err != nil {
t.Fatalf("res.RowsAffected() returned error: %v", err)
}
if count != 1 {
t.Fatalf("expected 1 affected row, got %d", count)
}
})
}

func TestAsyncQueryFail(t *testing.T) {
ctx := WithAsyncMode(context.Background())
runTests(t, dsn, func(dbt *DBTest) {
Expand Down
6 changes: 6 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,11 @@ const (
ErrRoleNotExist = 390189
// ErrObjectNotExistOrAuthorized is a GS error code for the case that the server-side object specified does not exist
ErrObjectNotExistOrAuthorized = 390201

/* Extra error code */

// ErrQueryExecutionInProgress is returned when monitoring an async query reaches 45s
ErrQueryExecutionInProgress = 333333
)

const (
Expand Down Expand Up @@ -267,6 +272,7 @@ const (
errMsgFailedToConvertToS3Client = "failed to convert interface to s3 client"
errMsgNoResultIDs = "no result IDs returned with the multi-statement query"
errMsgQueryStatus = "server ErrorCode=%s, ErrorMessage=%s"
errMsgAsyncWithNoResults = "async with no results"
)

var (
Expand Down
24 changes: 18 additions & 6 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package gosnowflake

import (
"database/sql/driver"
"fmt"
"io"
"reflect"
"strings"
Expand Down Expand Up @@ -71,12 +72,16 @@ func (rows *snowflakeRows) ColumnTypeDatabaseTypeName(index int) string {
if err := rows.waitForAsyncQueryStatus(); err != nil {
return err.Error()
}
if rows.ChunkDownloader == nil {
return ""
}

return strings.ToUpper(rows.ChunkDownloader.getRowType()[index].Type)
}

// ColumnTypeLength returns the length of the column
func (rows *snowflakeRows) ColumnTypeLength(index int) (length int64, ok bool) {
if err := rows.waitForAsyncQueryStatus(); err != nil {
if err := rows.waitForAsyncQueryStatus(); err != nil || rows.ChunkDownloader == nil {
return 0, false
}
if index < 0 || index > len(rows.ChunkDownloader.getRowType()) {
Expand All @@ -90,7 +95,7 @@ func (rows *snowflakeRows) ColumnTypeLength(index int) (length int64, ok bool) {
}

func (rows *snowflakeRows) ColumnTypeNullable(index int) (nullable, ok bool) {
if err := rows.waitForAsyncQueryStatus(); err != nil {
if err := rows.waitForAsyncQueryStatus(); err != nil || rows.ChunkDownloader == nil {
return false, false
}
if index < 0 || index > len(rows.ChunkDownloader.getRowType()) {
Expand All @@ -100,7 +105,7 @@ func (rows *snowflakeRows) ColumnTypeNullable(index int) (nullable, ok bool) {
}

func (rows *snowflakeRows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
if err := rows.waitForAsyncQueryStatus(); err != nil {
if err := rows.waitForAsyncQueryStatus(); err != nil || rows.ChunkDownloader == nil {
return 0, 0, false
}
rowType := rows.ChunkDownloader.getRowType()
Expand All @@ -119,7 +124,7 @@ func (rows *snowflakeRows) ColumnTypePrecisionScale(index int) (precision, scale
}

func (rows *snowflakeRows) Columns() []string {
if err := rows.waitForAsyncQueryStatus(); err != nil {
if err := rows.waitForAsyncQueryStatus(); err != nil || rows.ChunkDownloader == nil {
return make([]string, 0)
}
logger.Debug("Rows.Columns")
Expand All @@ -131,7 +136,7 @@ func (rows *snowflakeRows) Columns() []string {
}

func (rows *snowflakeRows) ColumnTypeScanType(index int) reflect.Type {
if err := rows.waitForAsyncQueryStatus(); err != nil {
if err := rows.waitForAsyncQueryStatus(); err != nil || rows.ChunkDownloader == nil {
return nil
}
return snowflakeTypeToGo(
Expand Down Expand Up @@ -164,6 +169,9 @@ func (rows *snowflakeRows) Next(dest []driver.Value) (err error) {
if err = rows.waitForAsyncQueryStatus(); err != nil {
return err
}
if rows.ChunkDownloader == nil {
return fmt.Errorf(errMsgAsyncWithNoResults)
}
row, err := rows.ChunkDownloader.next()
if err != nil {
// includes io.EOF
Expand Down Expand Up @@ -196,7 +204,7 @@ func (rows *snowflakeRows) Next(dest []driver.Value) (err error) {
}

func (rows *snowflakeRows) HasNextResultSet() bool {
if err := rows.waitForAsyncQueryStatus(); err != nil {
if err := rows.waitForAsyncQueryStatus(); err != nil || rows.ChunkDownloader == nil {
return false
}
return rows.ChunkDownloader.hasNextResultSet()
Expand All @@ -206,6 +214,10 @@ func (rows *snowflakeRows) NextResultSet() error {
if err := rows.waitForAsyncQueryStatus(); err != nil {
return err
}
if rows.ChunkDownloader == nil {
return fmt.Errorf(errMsgAsyncWithNoResults)
}

if len(rows.ChunkDownloader.getChunkMetas()) == 0 {
if rows.ChunkDownloader.getNextChunkDownloader() == nil {
return io.EOF
Expand Down
6 changes: 6 additions & 0 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type contextKey string
const (
multiStatementCount contextKey = "MULTI_STATEMENT_COUNT"
asyncMode contextKey = "ASYNC_MODE_QUERY"
asyncModeNoFetch contextKey = "ASYNC_MODE_NO_FETCH_QUERY"
queryIDChannel contextKey = "QUERY_ID_CHANNEL"
snowflakeRequestIDKey contextKey = "SNOWFLAKE_REQUEST_ID"
fetchResultByID contextKey = "SF_FETCH_RESULT_BY_ID"
Expand Down Expand Up @@ -44,6 +45,11 @@ func WithAsyncMode(ctx context.Context) context.Context {
return context.WithValue(ctx, asyncMode, true)
}

// WithAsyncModeNoFetch returns a context that, when you execute a query in async mode, will not fetch results
func WithAsyncModeNoFetch(ctx context.Context) context.Context {
return context.WithValue(ctx, asyncModeNoFetch, true)
}

// WithQueryIDChan returns a context that contains the channel to receive the query ID
func WithQueryIDChan(ctx context.Context, c chan<- string) context.Context {
return context.WithValue(ctx, queryIDChannel, c)
Expand Down