Skip to content

Commit

Permalink
SIG-18794: Make getAsync wait until the query completes or times out. (
Browse files Browse the repository at this point in the history
…#59)

* SIG-18794: Make getAsync wait until the query completes or times out.

  * make the getStatus for an async made query block until the ctx
    timeout
  • Loading branch information
mtoader committed Mar 18, 2022
1 parent 48707af commit 2d7e6ad
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 58 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ wss-golang-agent.config
wss-unified-agent.jar
whitesource/
*.swp
/vendor/github.com/snowflakedb/gosnowflake
129 changes: 77 additions & 52 deletions async.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ import (
"time"
)

func isAsyncModeNoFetch(ctx context.Context) bool {
if flag, ok := ctx.Value(asyncModeNoFetch).(bool); ok && flag {
return true
}

return false
}

func (sr *snowflakeRestful) processAsync(
ctx context.Context,
respd *execResponse,
Expand Down Expand Up @@ -66,45 +74,39 @@ func (sr *snowflakeRestful) getAsync(
defer close(errChannel)
token, _, _ := sr.TokenAccessor.GetTokens()
headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)
resp, err := sr.FuncGet(ctx, sr, URL, headers, timeout)
if err != nil {
logger.WithContext(ctx).Errorf("failed to get response. err: %v", err)
sfError.Message = err.Error()
errChannel <- sfError
// if we failed here because of top level context cancellation we want to cancel the original query
if err == context.Canceled || err == context.DeadlineExceeded {
// use the default top level 1 sec timeout for cancellation as throughout the driver
if err := cancelQuery(context.TODO(), sr, requestID, time.Second); err != nil {
logger.WithContext(ctx).Errorf("failed to cancel async query, err: %v", err)

// the get call pulling for result status is
var response *execResponse
var err error
for response == nil || (!response.Success && parseCode(response.Code) == ErrQueryExecutionInProgress) {
response, err = sr.getAsyncOrStatus(ctx, URL, headers, timeout)

if err != nil {
logger.WithContext(ctx).Errorf("failed to get response. err: %v", err)
if err == context.Canceled || err == context.DeadlineExceeded {
// use the default top level 1 sec timeout for cancellation as throughout the driver
if err := cancelQuery(context.TODO(), sr, requestID, time.Second); err != nil {
logger.WithContext(ctx).Errorf("failed to cancel async query, err: %v", err)
}
}
}
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}

respd := execResponse{}
err = json.NewDecoder(resp.Body).Decode(&respd)
resp.Body.Close()
if err != nil {
logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
sfError.Message = err.Error()
errChannel <- sfError
return err
sfError.Message = err.Error()
errChannel <- sfError
return err
}
}

sc := &snowflakeConn{rest: sr, cfg: cfg}
if respd.Success {
if response.Success {
if resType == execResultType {
res.insertID = -1
if isDml(respd.Data.StatementTypeID) {
res.affectedRows, err = updateRows(respd.Data)
if isDml(response.Data.StatementTypeID) {
res.affectedRows, err = updateRows(response.Data)
if err != nil {
return err
}
} else if isMultiStmt(&respd.Data) {
r, err := sc.handleMultiExec(ctx, respd.Data)
} else if isMultiStmt(&response.Data) {
r, err := sc.handleMultiExec(ctx, response.Data)
if err != nil {
res.errChannel <- err
return err
Expand All @@ -115,39 +117,62 @@ func (sr *snowflakeRestful) getAsync(
return err
}
}
res.queryID = respd.Data.QueryID
res.queryID = response.Data.QueryID
res.errChannel <- nil // mark exec status complete
} else {
rows.sc = sc
rows.queryID = respd.Data.QueryID
if isMultiStmt(&respd.Data) {
if err = sc.handleMultiQuery(ctx, respd.Data, rows); err != nil {
rows.errChannel <- err
close(errChannel)
return err
rows.queryID = response.Data.QueryID

if !isAsyncModeNoFetch(ctx) {
if isMultiStmt(&response.Data) {
if err = sc.handleMultiQuery(ctx, response.Data, rows); err != nil {
rows.errChannel <- err
close(errChannel)
return err
}
} else {
rows.addDownloader(populateChunkDownloader(ctx, sc, response.Data))
}
} else {
rows.addDownloader(populateChunkDownloader(ctx, sc, respd.Data))
_ = rows.ChunkDownloader.start()
}
rows.ChunkDownloader.start()
rows.errChannel <- nil // mark query status complete
}
} else {
var code int
if respd.Code != "" {
code, err = strconv.Atoi(respd.Code)
if err != nil {
code = -1
}
} else {
code = -1
}
errChannel <- &SnowflakeError{
Number: code,
SQLState: respd.Data.SQLState,
Message: respd.Message,
QueryID: respd.Data.QueryID,
Number: parseCode(response.Code),
SQLState: response.Data.SQLState,
Message: response.Message,
QueryID: response.Data.QueryID,
}
}
return nil
}

func parseCode(codeStr string) int {
if code, err := strconv.Atoi(codeStr); err == nil {
return code
}

return -1
}

func (sr *snowflakeRestful) getAsyncOrStatus(
ctx context.Context,
url *url.URL,
headers map[string]string,
timeout time.Duration) (*execResponse, error) {
resp, err := sr.FuncGet(ctx, sr, url, headers, timeout)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer func() { _ = resp.Body.Close() }()
}

response := &execResponse{}
if err = json.NewDecoder(resp.Body).Decode(&response); err != nil {
return nil, err
}

return response, nil
}
39 changes: 39 additions & 0 deletions async_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,45 @@ func TestAsyncMode(t *testing.T) {
})
}

func TestAsyncModeNoFetch(t *testing.T) {
ctx := WithAsyncModeNoFetch(WithAsyncMode(context.Background()))
numrows := 100000

runTests(t, dsn, func(dbt *DBTest) {
rows := dbt.mustQueryContext(ctx, fmt.Sprintf(selectRandomGenerator, numrows))
defer rows.Close()

// Next() will block and wait until results are available
if rows.Next() == true {
t.Fatalf("next should have returned no rows")
}
if err := rows.Err(); err == nil {
t.Fatalf("we should have an error thrown")
}
columns, err := rows.Columns()
if columns != nil {
t.Fatalf("there should be no column array returned")
}
if err == nil {
t.Fatalf("we should have an error thrown")
}

if rows.Scan(nil) == nil {
t.Fatalf("we should have an error thrown")
}

dbt.mustExec("create or replace table test_async_exec (value boolean)")
res := dbt.mustExecContext(ctx, "insert into test_async_exec values (true)")
count, err := res.RowsAffected()
if err != nil {
t.Fatalf("res.RowsAffected() returned error: %v", err)
}
if count != 1 {
t.Fatalf("expected 1 affected row, got %d", count)
}
})
}

func TestAsyncQueryFail(t *testing.T) {
ctx := WithAsyncMode(context.Background())
runTests(t, dsn, func(dbt *DBTest) {
Expand Down
6 changes: 6 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,11 @@ const (
ErrRoleNotExist = 390189
// ErrObjectNotExistOrAuthorized is a GS error code for the case that the server-side object specified does not exist
ErrObjectNotExistOrAuthorized = 390201

/* Extra error code */

// ErrQueryExecutionInProgress is returned when monitoring an async query reaches 45s
ErrQueryExecutionInProgress = 333333
)

const (
Expand Down Expand Up @@ -267,6 +272,7 @@ const (
errMsgFailedToConvertToS3Client = "failed to convert interface to s3 client"
errMsgNoResultIDs = "no result IDs returned with the multi-statement query"
errMsgQueryStatus = "server ErrorCode=%s, ErrorMessage=%s"
errMsgAsyncWithNoResults = "async with no results"
)

var (
Expand Down
24 changes: 18 additions & 6 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package gosnowflake

import (
"database/sql/driver"
"fmt"
"io"
"reflect"
"strings"
Expand Down Expand Up @@ -71,12 +72,16 @@ func (rows *snowflakeRows) ColumnTypeDatabaseTypeName(index int) string {
if err := rows.waitForAsyncQueryStatus(); err != nil {
return err.Error()
}
if rows.ChunkDownloader == nil {
return ""
}

return strings.ToUpper(rows.ChunkDownloader.getRowType()[index].Type)
}

// ColumnTypeLength returns the length of the column
func (rows *snowflakeRows) ColumnTypeLength(index int) (length int64, ok bool) {
if err := rows.waitForAsyncQueryStatus(); err != nil {
if err := rows.waitForAsyncQueryStatus(); err != nil || rows.ChunkDownloader == nil {
return 0, false
}
if index < 0 || index > len(rows.ChunkDownloader.getRowType()) {
Expand All @@ -90,7 +95,7 @@ func (rows *snowflakeRows) ColumnTypeLength(index int) (length int64, ok bool) {
}

func (rows *snowflakeRows) ColumnTypeNullable(index int) (nullable, ok bool) {
if err := rows.waitForAsyncQueryStatus(); err != nil {
if err := rows.waitForAsyncQueryStatus(); err != nil || rows.ChunkDownloader == nil {
return false, false
}
if index < 0 || index > len(rows.ChunkDownloader.getRowType()) {
Expand All @@ -100,7 +105,7 @@ func (rows *snowflakeRows) ColumnTypeNullable(index int) (nullable, ok bool) {
}

func (rows *snowflakeRows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
if err := rows.waitForAsyncQueryStatus(); err != nil {
if err := rows.waitForAsyncQueryStatus(); err != nil || rows.ChunkDownloader == nil {
return 0, 0, false
}
rowType := rows.ChunkDownloader.getRowType()
Expand All @@ -119,7 +124,7 @@ func (rows *snowflakeRows) ColumnTypePrecisionScale(index int) (precision, scale
}

func (rows *snowflakeRows) Columns() []string {
if err := rows.waitForAsyncQueryStatus(); err != nil {
if err := rows.waitForAsyncQueryStatus(); err != nil || rows.ChunkDownloader == nil {
return make([]string, 0)
}
logger.Debug("Rows.Columns")
Expand All @@ -131,7 +136,7 @@ func (rows *snowflakeRows) Columns() []string {
}

func (rows *snowflakeRows) ColumnTypeScanType(index int) reflect.Type {
if err := rows.waitForAsyncQueryStatus(); err != nil {
if err := rows.waitForAsyncQueryStatus(); err != nil || rows.ChunkDownloader == nil {
return nil
}
return snowflakeTypeToGo(
Expand Down Expand Up @@ -164,6 +169,9 @@ func (rows *snowflakeRows) Next(dest []driver.Value) (err error) {
if err = rows.waitForAsyncQueryStatus(); err != nil {
return err
}
if rows.ChunkDownloader == nil {
return fmt.Errorf(errMsgAsyncWithNoResults)
}
row, err := rows.ChunkDownloader.next()
if err != nil {
// includes io.EOF
Expand Down Expand Up @@ -196,7 +204,7 @@ func (rows *snowflakeRows) Next(dest []driver.Value) (err error) {
}

func (rows *snowflakeRows) HasNextResultSet() bool {
if err := rows.waitForAsyncQueryStatus(); err != nil {
if err := rows.waitForAsyncQueryStatus(); err != nil || rows.ChunkDownloader == nil {
return false
}
return rows.ChunkDownloader.hasNextResultSet()
Expand All @@ -206,6 +214,10 @@ func (rows *snowflakeRows) NextResultSet() error {
if err := rows.waitForAsyncQueryStatus(); err != nil {
return err
}
if rows.ChunkDownloader == nil {
return fmt.Errorf(errMsgAsyncWithNoResults)
}

if len(rows.ChunkDownloader.getChunkMetas()) == 0 {
if rows.ChunkDownloader.getNextChunkDownloader() == nil {
return io.EOF
Expand Down
6 changes: 6 additions & 0 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type contextKey string
const (
multiStatementCount contextKey = "MULTI_STATEMENT_COUNT"
asyncMode contextKey = "ASYNC_MODE_QUERY"
asyncModeNoFetch contextKey = "ASYNC_MODE_NO_FETCH_QUERY"
queryIDChannel contextKey = "QUERY_ID_CHANNEL"
snowflakeRequestIDKey contextKey = "SNOWFLAKE_REQUEST_ID"
fetchResultByID contextKey = "SF_FETCH_RESULT_BY_ID"
Expand Down Expand Up @@ -44,6 +45,11 @@ func WithAsyncMode(ctx context.Context) context.Context {
return context.WithValue(ctx, asyncMode, true)
}

// WithAsyncModeNoFetch returns a context that, when you execute a query in async mode, will not fetch results
func WithAsyncModeNoFetch(ctx context.Context) context.Context {
return context.WithValue(ctx, asyncModeNoFetch, true)
}

// WithQueryIDChan returns a context that contains the channel to receive the query ID
func WithQueryIDChan(ctx context.Context, c chan<- string) context.Context {
return context.WithValue(ctx, queryIDChannel, c)
Expand Down

0 comments on commit 2d7e6ad

Please sign in to comment.