Skip to content

Commit

Permalink
[Feature] Improvements to Async handling
Browse files Browse the repository at this point in the history
SIG-18794: Refactor the monitoring fetcher & auto-cancel async (#57)

* SIG-18794: Refactor the monitoring fetcher & auto-cancel async

  * better configuration
  * wait until the monitoring data refers to completed queries

SIG-16907: Error out early when invalid state for multi-statement requests. (#43)

Remove the debug-tracing added earlier.

Co-authored-by: Agam Brahma <agam@sigmacomputing.com>

SIG-18794: Make getAsync wait until the query completes or times out. (#59)

* SIG-18794: Make getAsync wait until the query completes or times out.

  * make the getStatus for an async made query block until the ctx
    timeout

SIG-18794: Fix a bug when looping to complete an async query (#60)

* discovered a new error code the needs looping when async,
     enhanced the test
  • Loading branch information
mtoader authored and madisonchamberlain committed Feb 8, 2023
1 parent 31161ad commit d6a90fc
Show file tree
Hide file tree
Showing 10 changed files with 347 additions and 92 deletions.
145 changes: 97 additions & 48 deletions async.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,21 @@ import (
"time"
)

func isAsyncModeNoFetch(ctx context.Context) bool {
if flag, ok := ctx.Value(asyncModeNoFetch).(bool); ok && flag {
return true
}

return false
}

func (sr *snowflakeRestful) processAsync(
ctx context.Context,
respd *execResponse,
headers map[string]string,
timeout time.Duration,
cfg *Config) (*execResponse, error) {
cfg *Config,
requestID uuid) (*execResponse, error) {
// placeholder object to return to user while retrieving results
rows := new(snowflakeRows)
res := new(snowflakeResult)
Expand All @@ -34,9 +43,10 @@ func (sr *snowflakeRestful) processAsync(
default:
return respd, nil
}

// spawn goroutine to retrieve asynchronous results
go sr.getAsync(ctx, headers, sr.getFullURL(respd.Data.GetResultURL, nil), timeout, res, rows, cfg)
go func() {
_ = sr.getAsync(ctx, headers, sr.getFullURL(respd.Data.GetResultURL, nil), timeout, res, rows, requestID, cfg)
}()
return respd, nil
}

Expand All @@ -47,6 +57,7 @@ func (sr *snowflakeRestful) getAsync(
timeout time.Duration,
res *snowflakeResult,
rows *snowflakeRows,
requestID uuid,
cfg *Config) error {
resType := getResultType(ctx)
var errChannel chan error
Expand All @@ -63,38 +74,39 @@ func (sr *snowflakeRestful) getAsync(
defer close(errChannel)
token, _, _ := sr.TokenAccessor.GetTokens()
headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)
resp, err := sr.FuncGet(ctx, sr, URL, headers, timeout)
if err != nil {
logger.WithContext(ctx).Errorf("failed to get response. err: %v", err)
sfError.Message = err.Error()
errChannel <- sfError
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}

respd := execResponse{}
err = json.NewDecoder(resp.Body).Decode(&respd)
resp.Body.Close()
if err != nil {
logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
sfError.Message = err.Error()
errChannel <- sfError
return err
// the get call pulling for result status is
var response *execResponse
var err error
for response == nil || isQueryInProgress(response) {
response, err = sr.getAsyncOrStatus(ctx, URL, headers, timeout)

if err != nil {
logger.WithContext(ctx).Errorf("failed to get response. err: %v", err)
if err == context.Canceled || err == context.DeadlineExceeded {
// use the default top level 1 sec timeout for cancellation as throughout the driver
if err := cancelQuery(context.TODO(), sr, requestID, time.Second); err != nil {
logger.WithContext(ctx).Errorf("failed to cancel async query, err: %v", err)
}
}

sfError.Message = err.Error()
errChannel <- sfError
return err
}
}

sc := &snowflakeConn{rest: sr, cfg: cfg}
if respd.Success {
if response.Success {
if resType == execResultType {
res.insertID = -1
if isDml(respd.Data.StatementTypeID) {
res.affectedRows, err = updateRows(respd.Data)
if isDml(response.Data.StatementTypeID) {
res.affectedRows, err = updateRows(response.Data)
if err != nil {
return err
}
} else if isMultiStmt(&respd.Data) {
r, err := sc.handleMultiExec(ctx, respd.Data)
} else if isMultiStmt(&response.Data) {
r, err := sc.handleMultiExec(ctx, response.Data)
if err != nil {
res.errChannel <- err
return err
Expand All @@ -105,38 +117,75 @@ func (sr *snowflakeRestful) getAsync(
return err
}
}
res.queryID = respd.Data.QueryID
res.queryID = response.Data.QueryID
res.errChannel <- nil // mark exec status complete
} else {
rows.sc = sc
rows.queryID = respd.Data.QueryID
if isMultiStmt(&respd.Data) {
if err = sc.handleMultiQuery(ctx, respd.Data, rows); err != nil {
rows.errChannel <- err
return err
rows.queryID = response.Data.QueryID

if !isAsyncModeNoFetch(ctx) {
if isMultiStmt(&response.Data) {
if err = sc.handleMultiQuery(ctx, response.Data, rows); err != nil {
rows.errChannel <- err
close(errChannel)
return err
}
} else {
rows.addDownloader(populateChunkDownloader(ctx, sc, response.Data))
}
} else {
rows.addDownloader(populateChunkDownloader(ctx, sc, respd.Data))
_ = rows.ChunkDownloader.start()
}
rows.ChunkDownloader.start()
rows.errChannel <- nil // mark query status complete
}
} else {
var code int
if respd.Code != "" {
code, err = strconv.Atoi(respd.Code)
if err != nil {
code = -1
}
} else {
code = -1
}
errChannel <- &SnowflakeError{
Number: code,
SQLState: respd.Data.SQLState,
Message: respd.Message,
QueryID: respd.Data.QueryID,
Number: parseCode(response.Code),
SQLState: response.Data.SQLState,
Message: response.Message,
QueryID: response.Data.QueryID,
}
}
return nil
}

func isQueryInProgress(execResponse *execResponse) bool {
if !execResponse.Success {
return false
}

switch parseCode(execResponse.Code) {
case ErrQueryExecutionInProgress, ErrAsyncExecutionInProgress:
return true
default:
return false
}
}

func parseCode(codeStr string) int {
if code, err := strconv.Atoi(codeStr); err == nil {
return code
}

return -1
}

func (sr *snowflakeRestful) getAsyncOrStatus(
ctx context.Context,
url *url.URL,
headers map[string]string,
timeout time.Duration) (*execResponse, error) {
resp, err := sr.FuncGet(ctx, sr, url, headers, timeout)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer func() { _ = resp.Body.Close() }()
}

response := &execResponse{}
if err = json.NewDecoder(resp.Body).Decode(&response); err != nil {
return nil, err
}

return response, nil
}
46 changes: 46 additions & 0 deletions async_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,52 @@ func TestAsyncModeCancel(t *testing.T) {
})
}

const (
//selectTimelineGenerator = "SELECT COUNT(*) FROM TABLE(GENERATOR(TIMELIMIT=>%v))"
selectTimelineGenerator = "SELECT SYSTEM$WAIT(%v, 'SECONDS')"
)

func TestAsyncModeNoFetch(t *testing.T) {
ctx := WithAsyncMode(WithAsyncModeNoFetch(context.Background()))
// the default behavior of the async wait is to wait for 45s. We want to make sure we wait until the query actually
// completes, so we make the test take longer than 45s
secondsToRun := 50

runTests(t, dsn, func(dbt *DBTest) {
start := time.Now()
rows := dbt.mustQueryContext(ctx, fmt.Sprintf(selectTimelineGenerator, secondsToRun))
defer rows.Close()

// Next() will block and wait until results are available
if rows.Next() == true {
t.Fatalf("next should have returned no rows")
}
columns, err := rows.Columns()
if columns != nil {
t.Fatalf("there should be no column array returned")
}
if err == nil {
t.Fatalf("we should have an error thrown")
}
if rows.Scan(nil) == nil {
t.Fatalf("we should have an error thrown")
}
if (time.Second * time.Duration(secondsToRun)) > time.Since(start) {
t.Fatalf("tset should should have run for %d seconds", secondsToRun)
}

dbt.mustExec("create or replace table test_async_exec (value boolean)")
res := dbt.mustExecContext(ctx, "insert into test_async_exec values (true)")
count, err := res.RowsAffected()
if err != nil {
t.Fatalf("res.RowsAffected() returned error: %v", err)
}
if count != 1 {
t.Fatalf("expected 1 affected row, got %d", count)
}
})
}

func TestAsyncQueryFail(t *testing.T) {
ctx := WithAsyncMode(context.Background())
runTests(t, dsn, func(dbt *DBTest) {
Expand Down
16 changes: 13 additions & 3 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,16 @@ func (sc *snowflakeConn) queryContextInternal(
if err = sc.handleMultiQuery(ctx, data.Data, rows); err != nil {
return nil, err
}
if data.Data.ResultIDs == "" && rows.ChunkDownloader == nil {
// SIG-16907: We have no results to download here.
logger.WithContext(ctx).Errorf("Encountered empty result-ids for a multi-statement request. Query-id: %s, Query: %s", data.Data.QueryID, query)
return nil, (&SnowflakeError{
Number: ErrQueryIDFormat,
SQLState: data.Data.SQLState,
Message: "ExecResponse for multi-statement request had no ResultIDs",
QueryID: data.Data.QueryID,
}).exceptionTelemetry(sc)
}
} else {
rows.addDownloader(populateChunkDownloader(ctx, sc, data.Data))
}
Expand Down Expand Up @@ -582,17 +592,17 @@ type ResultFetcher interface {
// MonitoringResultFetcher is an interface which allows to fetch monitoringResult
// with snowflake connection and query-id.
type MonitoringResultFetcher interface {
FetchMonitoringResult(queryID string) (*monitoringResult, error)
FetchMonitoringResult(queryID string, runtime time.Duration) (*monitoringResult, error)
}

// FetchMonitoringResult returns a monitoringResult object
// Multiplex can call monitoringResult.Monitoring() to get the QueryMonitoringData
func (sc *snowflakeConn) FetchMonitoringResult(queryID string) (*monitoringResult, error) {
func (sc *snowflakeConn) FetchMonitoringResult(queryID string, runtime time.Duration) (*monitoringResult, error) {
if sc.rest == nil {
return nil, driver.ErrBadConn
}

// set the fake runtime just to bypass fast query
monitoringResult := mkMonitoringFetcher(sc, queryID, time.Minute*10)
monitoringResult := mkMonitoringFetcher(sc, queryID, runtime)
return monitoringResult, nil
}
Loading

0 comments on commit d6a90fc

Please sign in to comment.