From 69eedd248d925b9f37b8d67f782ee0bcc6a8e363 Mon Sep 17 00:00:00 2001 From: Mihai Claudiu Toader Date: Wed, 16 Mar 2022 14:51:57 -0700 Subject: [PATCH] [Feature] Improvements to Async handling SIG-18794: Refactor the monitoring fetcher & auto-cancel async (#57) * SIG-18794: Refactor the monitoring fetcher & auto-cancel async * better configuration * wait until the monitoring data refers to completed queries SIG-16907: Error out early when invalid state for multi-statement requests. (#43) Remove the debug-tracing added earlier. Co-authored-by: Agam Brahma SIG-18794: Make getAsync wait until the query completes or times out. (#59) * SIG-18794: Make getAsync wait until the query completes or times out. * make the getStatus for an async made query block until the ctx timeout SIG-18794: Fix a bug when looping to complete an async query (#60) * discovered a new error code the needs looping when async, enhanced the test --- async.go | 145 +++++++++++++++++++++++++++++++++----------------- async_test.go | 46 ++++++++++++++++ connection.go | 16 ++++-- dsn.go | 75 ++++++++++++++++++++++++++ dsn_test.go | 84 +++++++++++++++++++---------- errors.go | 9 ++++ monitoring.go | 31 ++++++++--- restful.go | 2 +- rows.go | 25 ++++++--- util.go | 6 +++ 10 files changed, 347 insertions(+), 92 deletions(-) diff --git a/async.go b/async.go index 5c477dc1a..7f66e62af 100644 --- a/async.go +++ b/async.go @@ -11,12 +11,21 @@ import ( "time" ) +func isAsyncModeNoFetch(ctx context.Context) bool { + if flag, ok := ctx.Value(asyncModeNoFetch).(bool); ok && flag { + return true + } + + return false +} + func (sr *snowflakeRestful) processAsync( ctx context.Context, respd *execResponse, headers map[string]string, timeout time.Duration, - cfg *Config) (*execResponse, error) { + cfg *Config, + requestID uuid) (*execResponse, error) { // placeholder object to return to user while retrieving results rows := new(snowflakeRows) res := new(snowflakeResult) @@ -34,9 +43,10 @@ func (sr *snowflakeRestful) processAsync( default: return respd, nil } - // spawn goroutine to retrieve asynchronous results - go sr.getAsync(ctx, headers, sr.getFullURL(respd.Data.GetResultURL, nil), timeout, res, rows, cfg) + go func() { + _ = sr.getAsync(ctx, headers, sr.getFullURL(respd.Data.GetResultURL, nil), timeout, res, rows, requestID, cfg) + }() return respd, nil } @@ -47,6 +57,7 @@ func (sr *snowflakeRestful) getAsync( timeout time.Duration, res *snowflakeResult, rows *snowflakeRows, + requestID uuid, cfg *Config) error { resType := getResultType(ctx) var errChannel chan error @@ -63,38 +74,39 @@ func (sr *snowflakeRestful) getAsync( defer close(errChannel) token, _, _ := sr.TokenAccessor.GetTokens() headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) - resp, err := sr.FuncGet(ctx, sr, URL, headers, timeout) - if err != nil { - logger.WithContext(ctx).Errorf("failed to get response. err: %v", err) - sfError.Message = err.Error() - errChannel <- sfError - return err - } - if resp.Body != nil { - defer resp.Body.Close() - } - respd := execResponse{} - err = json.NewDecoder(resp.Body).Decode(&respd) - resp.Body.Close() - if err != nil { - logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) - sfError.Message = err.Error() - errChannel <- sfError - return err + // the get call pulling for result status is + var response *execResponse + var err error + for response == nil || isQueryInProgress(response) { + response, err = sr.getAsyncOrStatus(ctx, URL, headers, timeout) + + if err != nil { + logger.WithContext(ctx).Errorf("failed to get response. err: %v", err) + if err == context.Canceled || err == context.DeadlineExceeded { + // use the default top level 1 sec timeout for cancellation as throughout the driver + if err := cancelQuery(context.TODO(), sr, requestID, time.Second); err != nil { + logger.WithContext(ctx).Errorf("failed to cancel async query, err: %v", err) + } + } + + sfError.Message = err.Error() + errChannel <- sfError + return err + } } sc := &snowflakeConn{rest: sr, cfg: cfg} - if respd.Success { + if response.Success { if resType == execResultType { res.insertID = -1 - if isDml(respd.Data.StatementTypeID) { - res.affectedRows, err = updateRows(respd.Data) + if isDml(response.Data.StatementTypeID) { + res.affectedRows, err = updateRows(response.Data) if err != nil { return err } - } else if isMultiStmt(&respd.Data) { - r, err := sc.handleMultiExec(ctx, respd.Data) + } else if isMultiStmt(&response.Data) { + r, err := sc.handleMultiExec(ctx, response.Data) if err != nil { res.errChannel <- err return err @@ -105,38 +117,75 @@ func (sr *snowflakeRestful) getAsync( return err } } - res.queryID = respd.Data.QueryID + res.queryID = response.Data.QueryID res.errChannel <- nil // mark exec status complete } else { rows.sc = sc - rows.queryID = respd.Data.QueryID - if isMultiStmt(&respd.Data) { - if err = sc.handleMultiQuery(ctx, respd.Data, rows); err != nil { - rows.errChannel <- err - return err + rows.queryID = response.Data.QueryID + + if !isAsyncModeNoFetch(ctx) { + if isMultiStmt(&response.Data) { + if err = sc.handleMultiQuery(ctx, response.Data, rows); err != nil { + rows.errChannel <- err + close(errChannel) + return err + } + } else { + rows.addDownloader(populateChunkDownloader(ctx, sc, response.Data)) } - } else { - rows.addDownloader(populateChunkDownloader(ctx, sc, respd.Data)) + _ = rows.ChunkDownloader.start() } - rows.ChunkDownloader.start() rows.errChannel <- nil // mark query status complete } } else { - var code int - if respd.Code != "" { - code, err = strconv.Atoi(respd.Code) - if err != nil { - code = -1 - } - } else { - code = -1 - } errChannel <- &SnowflakeError{ - Number: code, - SQLState: respd.Data.SQLState, - Message: respd.Message, - QueryID: respd.Data.QueryID, + Number: parseCode(response.Code), + SQLState: response.Data.SQLState, + Message: response.Message, + QueryID: response.Data.QueryID, } } return nil } + +func isQueryInProgress(execResponse *execResponse) bool { + if !execResponse.Success { + return false + } + + switch parseCode(execResponse.Code) { + case ErrQueryExecutionInProgress, ErrAsyncExecutionInProgress: + return true + default: + return false + } +} + +func parseCode(codeStr string) int { + if code, err := strconv.Atoi(codeStr); err == nil { + return code + } + + return -1 +} + +func (sr *snowflakeRestful) getAsyncOrStatus( + ctx context.Context, + url *url.URL, + headers map[string]string, + timeout time.Duration) (*execResponse, error) { + resp, err := sr.FuncGet(ctx, sr, url, headers, timeout) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer func() { _ = resp.Body.Close() }() + } + + response := &execResponse{} + if err = json.NewDecoder(resp.Body).Decode(&response); err != nil { + return nil, err + } + + return response, nil +} diff --git a/async_test.go b/async_test.go index a742dbdac..023f440cb 100644 --- a/async_test.go +++ b/async_test.go @@ -59,6 +59,52 @@ func TestAsyncModeCancel(t *testing.T) { }) } +const ( + //selectTimelineGenerator = "SELECT COUNT(*) FROM TABLE(GENERATOR(TIMELIMIT=>%v))" + selectTimelineGenerator = "SELECT SYSTEM$WAIT(%v, 'SECONDS')" +) + +func TestAsyncModeNoFetch(t *testing.T) { + ctx := WithAsyncMode(WithAsyncModeNoFetch(context.Background())) + // the default behavior of the async wait is to wait for 45s. We want to make sure we wait until the query actually + // completes, so we make the test take longer than 45s + secondsToRun := 50 + + runTests(t, dsn, func(dbt *DBTest) { + start := time.Now() + rows := dbt.mustQueryContext(ctx, fmt.Sprintf(selectTimelineGenerator, secondsToRun)) + defer rows.Close() + + // Next() will block and wait until results are available + if rows.Next() == true { + t.Fatalf("next should have returned no rows") + } + columns, err := rows.Columns() + if columns != nil { + t.Fatalf("there should be no column array returned") + } + if err == nil { + t.Fatalf("we should have an error thrown") + } + if rows.Scan(nil) == nil { + t.Fatalf("we should have an error thrown") + } + if (time.Second * time.Duration(secondsToRun)) > time.Since(start) { + t.Fatalf("tset should should have run for %d seconds", secondsToRun) + } + + dbt.mustExec("create or replace table test_async_exec (value boolean)") + res := dbt.mustExecContext(ctx, "insert into test_async_exec values (true)") + count, err := res.RowsAffected() + if err != nil { + t.Fatalf("res.RowsAffected() returned error: %v", err) + } + if count != 1 { + t.Fatalf("expected 1 affected row, got %d", count) + } + }) +} + func TestAsyncQueryFail(t *testing.T) { ctx := WithAsyncMode(context.Background()) runTests(t, dsn, func(dbt *DBTest) { diff --git a/connection.go b/connection.go index d79724a3a..9fd365e4f 100644 --- a/connection.go +++ b/connection.go @@ -411,6 +411,16 @@ func (sc *snowflakeConn) queryContextInternal( if err = sc.handleMultiQuery(ctx, data.Data, rows); err != nil { return nil, err } + if data.Data.ResultIDs == "" && rows.ChunkDownloader == nil { + // SIG-16907: We have no results to download here. + logger.WithContext(ctx).Errorf("Encountered empty result-ids for a multi-statement request. Query-id: %s, Query: %s", data.Data.QueryID, query) + return nil, (&SnowflakeError{ + Number: ErrQueryIDFormat, + SQLState: data.Data.SQLState, + Message: "ExecResponse for multi-statement request had no ResultIDs", + QueryID: data.Data.QueryID, + }).exceptionTelemetry(sc) + } } else { rows.addDownloader(populateChunkDownloader(ctx, sc, data.Data)) } @@ -834,17 +844,17 @@ type ResultFetcher interface { // MonitoringResultFetcher is an interface which allows to fetch monitoringResult // with snowflake connection and query-id. type MonitoringResultFetcher interface { - FetchMonitoringResult(queryID string) (*monitoringResult, error) + FetchMonitoringResult(queryID string, runtime time.Duration) (*monitoringResult, error) } // FetchMonitoringResult returns a monitoringResult object // Multiplex can call monitoringResult.Monitoring() to get the QueryMonitoringData -func (sc *snowflakeConn) FetchMonitoringResult(queryID string) (*monitoringResult, error) { +func (sc *snowflakeConn) FetchMonitoringResult(queryID string, runtime time.Duration) (*monitoringResult, error) { if sc.rest == nil { return nil, driver.ErrBadConn } // set the fake runtime just to bypass fast query - monitoringResult := mkMonitoringFetcher(sc, queryID, time.Minute*10) + monitoringResult := mkMonitoringFetcher(sc, queryID, runtime) return monitoringResult, nil } diff --git a/dsn.go b/dsn.go index 1e007816d..08be2926b 100644 --- a/dsn.go +++ b/dsn.go @@ -22,6 +22,11 @@ const ( defaultRequestTimeout = 0 * time.Second // Timeout for retry for request EXCLUDING clientTimeout defaultJWTTimeout = 60 * time.Second defaultDomain = ".snowflakecomputing.com" + + // default monitoring fetcher config values + defaultMonitoringFetcherQueryMonitoringThreshold = 45 * time.Second + defaultMonitoringFetcherMaxDuration = 10 * time.Second + defaultMonitoringFetcherRetrySleepDuration = 250 * time.Second ) // ConfigBool is a type to represent true or false in the Config @@ -85,15 +90,37 @@ type Config struct { Tracing string // sets logging level +<<<<<<< HEAD MfaToken string // Internally used to cache the MFA token IDToken string // Internally used to cache the Id Token for external browser ClientRequestMfaToken ConfigBool // When true the MFA token is cached in the credential manager. True by default in Windows/OSX. False for Linux. ClientStoreTemporaryCredential ConfigBool // When true the ID token is cached in the credential manager. True by default in Windows/OSX. False for Linux. +======= + // Monitoring fetcher config + MonitoringFetcher MonitoringFetcherConfig + +>>>>>>> 95d57be ([Feature] Improvements to Async handling) // An identifier for this Config. Used to associate multiple connection instances with // a single logical sql.DB connection. ConnectionID string } +// MonitoringFetcherConfig provides some knobs to control the behavior of the monitoring data fetcher +type MonitoringFetcherConfig struct { + // QueryRuntimeThreshold specifies the threshold, over which we'll fetch the monitoring + // data for a successful snowflake query. We use a time-based threshold, since there is + // a non-zero latency cost to fetch this data, and we want to bound the additional latency. + // By default, we bound to a 2% increase in latency - assuming worst case 100ms - when + // fetching this metadata. + QueryRuntimeThreshold time.Duration + + // max time to wait until we get a proper monitoring sample for a query + MaxDuration time.Duration + + // Wait time between monitoring retries + RetrySleepDuration time.Duration +} + // ocspMode returns the OCSP mode in string INSECURE, FAIL_OPEN, FAIL_CLOSED func (c *Config) ocspMode() string { if c.InsecureMode { @@ -218,6 +245,16 @@ func DSN(cfg *Config) (dsn string, err error) { params.Add("clientStoreTemporaryCredential", strconv.FormatBool(cfg.ClientStoreTemporaryCredential != ConfigBoolFalse)) } + if cfg.MonitoringFetcher.QueryRuntimeThreshold != defaultMonitoringFetcherQueryMonitoringThreshold { + params.Add("monitoringFetcher_queryRuntimeThresholdMs", durationAsMillis(cfg.MonitoringFetcher.QueryRuntimeThreshold)) + } + if cfg.MonitoringFetcher.MaxDuration != defaultMonitoringFetcherMaxDuration { + params.Add("monitoringFetcher_maxDurationMs", durationAsMillis(cfg.MonitoringFetcher.MaxDuration)) + } + if cfg.MonitoringFetcher.RetrySleepDuration != defaultMonitoringFetcherRetrySleepDuration { + params.Add("monitoringFetcher_retrySleepDurationMs", durationAsMillis(cfg.MonitoringFetcher.RetrySleepDuration)) + } + if cfg.ConnectionID != "" { params.Add("connectionId", cfg.ConnectionID) } @@ -446,6 +483,16 @@ func fillMissingConfigParameters(cfg *Config) error { cfg.ValidateDefaultParameters = ConfigBoolTrue } + if cfg.MonitoringFetcher.QueryRuntimeThreshold == 0 { + cfg.MonitoringFetcher.QueryRuntimeThreshold = defaultMonitoringFetcherQueryMonitoringThreshold + } + if cfg.MonitoringFetcher.MaxDuration == 0 { + cfg.MonitoringFetcher.MaxDuration = defaultMonitoringFetcherMaxDuration + } + if cfg.MonitoringFetcher.RetrySleepDuration == 0 { + cfg.MonitoringFetcher.RetrySleepDuration = defaultMonitoringFetcherRetrySleepDuration + } + if cfg.ConnectionID == "" { cfg.ConnectionID = uuid.New().String() } @@ -659,6 +706,21 @@ func parseDSNParams(cfg *Config, params string) (err error) { } case "tracing": cfg.Tracing = value + case "monitoringFetcher_queryRuntimeThresholdMs": + cfg.MonitoringFetcher.QueryRuntimeThreshold, err = parseMillisToDuration(value) + if err != nil { + return err + } + case "monitoringFetcher_maxDurationMs": + cfg.MonitoringFetcher.MaxDuration, err = parseMillisToDuration(value) + if err != nil { + return err + } + case "monitoringFetcher_retrySleepDurationMs": + cfg.MonitoringFetcher.RetrySleepDuration, err = parseMillisToDuration(value) + if err != nil { + return err + } default: if cfg.Params == nil { cfg.Params = make(map[string]*string) @@ -669,6 +731,19 @@ func parseDSNParams(cfg *Config, params string) (err error) { return } +func parseMillisToDuration(value string) (time.Duration, error) { + intValue, err := strconv.ParseInt(value, 10, 64) + if err == nil { + return time.Millisecond * time.Duration(intValue), nil + } + + return 0, err +} + +func durationAsMillis(duration time.Duration) string { + return strconv.FormatInt(duration.Milliseconds(), 10) +} + func parseTimeout(value string) (time.Duration, error) { var vv int64 var err error diff --git a/dsn_test.go b/dsn_test.go index 88cbc7305..52a0c9d08 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -612,7 +612,7 @@ func TestDSN(t *testing.T) { Account: "a-aofnadsf.somewhere.azure", ConnectionID: testConnectionID, }, - dsn: "u:p@a-aofnadsf.somewhere.azure.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true&queryMonitoringThreshold=5®ion=somewhere.azure&validateDefaultParameters=true", + dsn: "u:p@a-aofnadsf.somewhere.azure.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=somewhere.azure&validateDefaultParameters=true", }, { cfg: &Config{ @@ -621,7 +621,7 @@ func TestDSN(t *testing.T) { Account: "a-aofnadsf.global", ConnectionID: testConnectionID, }, - dsn: "u:p@a-aofnadsf.global.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true&queryMonitoringThreshold=5®ion=global&validateDefaultParameters=true", + dsn: "u:p@a-aofnadsf.global.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=global&validateDefaultParameters=true", }, { cfg: &Config{ @@ -631,7 +631,7 @@ func TestDSN(t *testing.T) { Region: "us-west-2", ConnectionID: testConnectionID, }, - dsn: "u:p@a-aofnadsf.global.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true&queryMonitoringThreshold=5®ion=global&validateDefaultParameters=true", + dsn: "u:p@a-aofnadsf.global.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=global&validateDefaultParameters=true", }, { cfg: &Config{ @@ -650,7 +650,7 @@ func TestDSN(t *testing.T) { Account: "a", ConnectionID: testConnectionID, }, - dsn: "u:p@a.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true&queryMonitoringThreshold=5&validateDefaultParameters=true", + dsn: "u:p@a.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ @@ -660,7 +660,7 @@ func TestDSN(t *testing.T) { Region: "us-west-2", ConnectionID: testConnectionID, }, - dsn: "u:p@a.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true&queryMonitoringThreshold=5&validateDefaultParameters=true", + dsn: "u:p@a.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ @@ -670,7 +670,7 @@ func TestDSN(t *testing.T) { Region: "r", ConnectionID: testConnectionID, }, - dsn: "u:p@a.r.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true&queryMonitoringThreshold=5®ion=r&validateDefaultParameters=true", + dsn: "u:p@a.r.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=r&validateDefaultParameters=true", }, { cfg: &Config{ @@ -706,7 +706,7 @@ func TestDSN(t *testing.T) { Account: "a.e", ConnectionID: testConnectionID, }, - dsn: "u:p@a.e.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true&queryMonitoringThreshold=5®ion=e&validateDefaultParameters=true", + dsn: "u:p@a.e.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=e&validateDefaultParameters=true", }, { cfg: &Config{ @@ -716,7 +716,7 @@ func TestDSN(t *testing.T) { Region: "us-west-2", ConnectionID: testConnectionID, }, - dsn: "u:p@a.e.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true&queryMonitoringThreshold=5®ion=e&validateDefaultParameters=true", + dsn: "u:p@a.e.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=e&validateDefaultParameters=true", }, { cfg: &Config{ @@ -745,7 +745,7 @@ func TestDSN(t *testing.T) { Application: "special go", ConnectionID: testConnectionID, }, - dsn: "u:p@a.b.snowflakecomputing.com:443?application=special+go&connectionId=abcd-0123-4567-1234&database=db&loginTimeout=10&ocspFailOpen=true&passcode=db&passcodeInPassword=true&queryMonitoringThreshold=5®ion=b&requestTimeout=300&role=ro&schema=sc&validateDefaultParameters=true", + dsn: "u:p@a.b.snowflakecomputing.com:443?application=special+go&connectionId=abcd-0123-4567-1234&database=db&loginTimeout=10&ocspFailOpen=true&passcode=db&passcodeInPassword=true®ion=b&requestTimeout=300&role=ro&schema=sc&validateDefaultParameters=true", }, { cfg: &Config{ @@ -755,7 +755,7 @@ func TestDSN(t *testing.T) { Authenticator: AuthTypeExternalBrowser, ConnectionID: testConnectionID, }, - dsn: "u:p@a.snowflakecomputing.com:443?authenticator=externalbrowser&connectionId=abcd-0123-4567-1234&ocspFailOpen=true&queryMonitoringThreshold=5&validateDefaultParameters=true", + dsn: "u:p@a.snowflakecomputing.com:443?authenticator=externalbrowser&connectionId=abcd-0123-4567-1234&ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ @@ -769,7 +769,7 @@ func TestDSN(t *testing.T) { }, ConnectionID: testConnectionID, }, - dsn: "u:p@a.snowflakecomputing.com:443?authenticator=https%3A%2F%2Fsc.okta.com&connectionId=abcd-0123-4567-1234&ocspFailOpen=true&queryMonitoringThreshold=5&validateDefaultParameters=true", + dsn: "u:p@a.snowflakecomputing.com:443?authenticator=https%3A%2F%2Fsc.okta.com&connectionId=abcd-0123-4567-1234&ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ @@ -781,7 +781,7 @@ func TestDSN(t *testing.T) { }, ConnectionID: testConnectionID, }, - dsn: "u:p@a.e.snowflakecomputing.com:443?TIMESTAMP_OUTPUT_FORMAT=MM-DD-YYYY&connectionId=abcd-0123-4567-1234&ocspFailOpen=true&queryMonitoringThreshold=5®ion=e&validateDefaultParameters=true", + dsn: "u:p@a.e.snowflakecomputing.com:443?TIMESTAMP_OUTPUT_FORMAT=MM-DD-YYYY&connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=e&validateDefaultParameters=true", }, { cfg: &Config{ @@ -793,7 +793,7 @@ func TestDSN(t *testing.T) { }, ConnectionID: testConnectionID, }, - dsn: "u:%3A%40abc@a.e.snowflakecomputing.com:443?TIMESTAMP_OUTPUT_FORMAT=MM-DD-YYYY&connectionId=abcd-0123-4567-1234&ocspFailOpen=true&queryMonitoringThreshold=5®ion=e&validateDefaultParameters=true", + dsn: "u:%3A%40abc@a.e.snowflakecomputing.com:443?TIMESTAMP_OUTPUT_FORMAT=MM-DD-YYYY&connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=e&validateDefaultParameters=true", }, { cfg: &Config{ @@ -803,7 +803,7 @@ func TestDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ConnectionID: testConnectionID, }, - dsn: "u:p@a.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true&queryMonitoringThreshold=5&validateDefaultParameters=true", + dsn: "u:p@a.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ @@ -813,7 +813,7 @@ func TestDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenFalse, ConnectionID: testConnectionID, }, - dsn: "u:p@a.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=false&queryMonitoringThreshold=5&validateDefaultParameters=true", + dsn: "u:p@a.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=false&validateDefaultParameters=true", }, { cfg: &Config{ @@ -823,7 +823,7 @@ func TestDSN(t *testing.T) { ValidateDefaultParameters: ConfigBoolFalse, ConnectionID: testConnectionID, }, - dsn: "u:p@a.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true&queryMonitoringThreshold=5&validateDefaultParameters=false", + dsn: "u:p@a.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true&validateDefaultParameters=false", }, { cfg: &Config{ @@ -833,7 +833,7 @@ func TestDSN(t *testing.T) { ValidateDefaultParameters: ConfigBoolTrue, ConnectionID: testConnectionID, }, - dsn: "u:p@a.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true&queryMonitoringThreshold=5&validateDefaultParameters=true", + dsn: "u:p@a.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ @@ -843,7 +843,7 @@ func TestDSN(t *testing.T) { InsecureMode: true, ConnectionID: testConnectionID, }, - dsn: "u:p@a.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&insecureMode=true&ocspFailOpen=true&queryMonitoringThreshold=5&validateDefaultParameters=true", + dsn: "u:p@a.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&insecureMode=true&ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ @@ -852,7 +852,7 @@ func TestDSN(t *testing.T) { Account: "a.b.c", ConnectionID: testConnectionID, }, - dsn: "u:p@a.b.c.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true&queryMonitoringThreshold=5®ion=b.c&validateDefaultParameters=true", + dsn: "u:p@a.b.c.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ @@ -862,7 +862,7 @@ func TestDSN(t *testing.T) { Region: "us-west-2", ConnectionID: testConnectionID, }, - dsn: "u:p@a.b.c.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true&queryMonitoringThreshold=5®ion=b.c&validateDefaultParameters=true", + dsn: "u:p@a.b.c.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ @@ -882,17 +882,47 @@ func TestDSN(t *testing.T) { ClientTimeout: 300 * time.Second, ConnectionID: testConnectionID, }, - dsn: "u:p@a.b.c.snowflakecomputing.com:443?clientTimeout=300&connectionId=abcd-0123-4567-1234&ocspFailOpen=true&queryMonitoringThreshold=5®ion=b.c&validateDefaultParameters=true", + dsn: "u:p@a.b.c.snowflakecomputing.com:443?clientTimeout=300&connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ - User: "u", - Password: "p", - Account: "a.e", - QueryMonitoringThreshold: 20 * time.Second, - ConnectionID: testConnectionID, + User: "u", + Password: "p", + Account: "a.e", + MonitoringFetcher: MonitoringFetcherConfig{ + QueryRuntimeThreshold: time.Second * 56, + }, + ConnectionID: testConnectionID, + }, + dsn: "u:p@a.e.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&monitoringFetcher_queryRuntimeThresholdMs=56000&ocspFailOpen=true®ion=e&validateDefaultParameters=true", + }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.e", + MonitoringFetcher: MonitoringFetcherConfig{ + QueryRuntimeThreshold: time.Second * 56, + MaxDuration: time.Second * 14, + RetrySleepDuration: time.Millisecond * 45, + }, + ConnectionID: testConnectionID, + }, + dsn: "u:p@a.e.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&monitoringFetcher_maxDurationMs=14000&monitoringFetcher_queryRuntimeThresholdMs=56000&monitoringFetcher_retrySleepDurationMs=45&ocspFailOpen=true®ion=e&validateDefaultParameters=true", + }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.e", + MonitoringFetcher: MonitoringFetcherConfig{ + QueryRuntimeThreshold: defaultMonitoringFetcherQueryMonitoringThreshold, + MaxDuration: defaultMonitoringFetcherMaxDuration, + RetrySleepDuration: defaultMonitoringFetcherRetrySleepDuration, + }, + ConnectionID: testConnectionID, }, - dsn: "u:p@a.e.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true&queryMonitoringThreshold=20®ion=e&validateDefaultParameters=true", + dsn: "u:p@a.e.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=e&validateDefaultParameters=true", }, } for _, test := range testcases { diff --git a/errors.go b/errors.go index c49fc921c..bdd1bdd67 100644 --- a/errors.go +++ b/errors.go @@ -247,6 +247,14 @@ const ( ErrRoleNotExist = 390189 // ErrObjectNotExistOrAuthorized is a GS error code for the case that the server-side object specified does not exist ErrObjectNotExistOrAuthorized = 390201 + + /* Extra error code */ + + // ErrQueryExecutionInProgress is returned when monitoring an async query reaches 45s + ErrQueryExecutionInProgress = 333333 + + // ErrAsyncExecutionInProgress is returned when monitoring an async query reaches 45s + ErrAsyncExecutionInProgress = 333334 ) const ( @@ -288,6 +296,7 @@ const ( errMsgNoResultIDs = "no result IDs returned with the multi-statement query" errMsgQueryStatus = "server ErrorCode=%s, ErrorMessage=%s" errMsgInvalidPadding = "invalid padding on input" + errMsgAsyncWithNoResults = "async with no results" ) var ( diff --git a/monitoring.go b/monitoring.go index 7ed6297a9..37d762d4b 100644 --- a/monitoring.go +++ b/monitoring.go @@ -311,7 +311,7 @@ func (sc *snowflakeConn) buildRowsForRunningQuery( func mkMonitoringFetcher(sc *snowflakeConn, qid string, runtime time.Duration) *monitoringResult { // Exit early if this was a "fast" query - if runtime < FetchQueryMonitoringDataThreshold { + if runtime < sc.cfg.MonitoringFetcher.QueryRuntimeThreshold { return nil } @@ -334,14 +334,31 @@ func monitoring( ) { defer close(resp) - ctx, cancel := context.WithTimeout(context.Background(), sc.rest.RequestTimeout) + ctx, cancel := context.WithTimeout(context.Background(), sc.cfg.MonitoringFetcher.MaxDuration) defer cancel() - var m monitoringResponse - err := sc.getMonitoringResult(ctx, "queries", qid, &m) - if err == nil && len(m.Data.Queries) == 1 { - resp <- &m.Data.Queries[0] + var queryMonitoringData *QueryMonitoringData + for { + var m monitoringResponse + if err := sc.getMonitoringResult(ctx, "queries", qid, &m); err != nil { + break + } + + if len(m.Data.Queries) == 1 { + queryMonitoringData = &m.Data.Queries[0] + if !strToQueryStatus(queryMonitoringData.Status).isRunning() { + break + } + } + + time.Sleep(sc.cfg.MonitoringFetcher.RetrySleepDuration) + } + + if queryMonitoringData != nil { + resp <- queryMonitoringData } + + return } func queryGraph( @@ -352,7 +369,7 @@ func queryGraph( defer close(resp) // Bound the GET request to 1 second in the absolute worst case. - ctx, cancel := context.WithTimeout(context.Background(), sc.rest.RequestTimeout) + ctx, cancel := context.WithTimeout(context.Background(), sc.cfg.MonitoringFetcher.MaxDuration) defer cancel() var qg queryGraphResponse diff --git a/restful.go b/restful.go index 04da332ce..a0d439e1b 100644 --- a/restful.go +++ b/restful.go @@ -236,7 +236,7 @@ func postRestfulQueryHelper( // if asynchronous query in progress, kick off retrieval but return object if respd.Code == queryInProgressAsyncCode && isAsyncMode(ctx) { - return sr.processAsync(ctx, &respd, headers, timeout, cfg) + return sr.processAsync(ctx, &respd, headers, timeout, cfg, requestID) } for isSessionRenewed || respd.Code == queryInProgressCode || respd.Code == queryInProgressAsyncCode { diff --git a/rows.go b/rows.go index 48f642ea5..c5dfe1a9b 100644 --- a/rows.go +++ b/rows.go @@ -4,6 +4,7 @@ package gosnowflake import ( "database/sql/driver" + "fmt" "io" "reflect" "strings" @@ -44,6 +45,7 @@ type snowflakeRows struct { err error errChannel chan error monitoring *monitoringResult + asyncRequestID uuid } type snowflakeValue interface{} @@ -77,12 +79,16 @@ func (rows *snowflakeRows) ColumnTypeDatabaseTypeName(index int) string { if err := rows.waitForAsyncQueryStatus(); err != nil { return err.Error() } + if rows.ChunkDownloader == nil { + return "" + } + return strings.ToUpper(rows.ChunkDownloader.getRowType()[index].Type) } // ColumnTypeLength returns the length of the column func (rows *snowflakeRows) ColumnTypeLength(index int) (length int64, ok bool) { - if err := rows.waitForAsyncQueryStatus(); err != nil { + if err := rows.waitForAsyncQueryStatus(); err != nil || rows.ChunkDownloader == nil { return 0, false } if index < 0 || index > len(rows.ChunkDownloader.getRowType()) { @@ -96,7 +102,7 @@ func (rows *snowflakeRows) ColumnTypeLength(index int) (length int64, ok bool) { } func (rows *snowflakeRows) ColumnTypeNullable(index int) (nullable, ok bool) { - if err := rows.waitForAsyncQueryStatus(); err != nil { + if err := rows.waitForAsyncQueryStatus(); err != nil || rows.ChunkDownloader == nil { return false, false } if index < 0 || index > len(rows.ChunkDownloader.getRowType()) { @@ -106,7 +112,7 @@ func (rows *snowflakeRows) ColumnTypeNullable(index int) (nullable, ok bool) { } func (rows *snowflakeRows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { - if err := rows.waitForAsyncQueryStatus(); err != nil { + if err := rows.waitForAsyncQueryStatus(); err != nil || rows.ChunkDownloader == nil { return 0, 0, false } rowType := rows.ChunkDownloader.getRowType() @@ -125,7 +131,7 @@ func (rows *snowflakeRows) ColumnTypePrecisionScale(index int) (precision, scale } func (rows *snowflakeRows) Columns() []string { - if err := rows.waitForAsyncQueryStatus(); err != nil { + if err := rows.waitForAsyncQueryStatus(); err != nil || rows.ChunkDownloader == nil { return make([]string, 0) } logger.Debug("Rows.Columns") @@ -137,7 +143,7 @@ func (rows *snowflakeRows) Columns() []string { } func (rows *snowflakeRows) ColumnTypeScanType(index int) reflect.Type { - if err := rows.waitForAsyncQueryStatus(); err != nil { + if err := rows.waitForAsyncQueryStatus(); err != nil || rows.ChunkDownloader == nil { return nil } return snowflakeTypeToGo( @@ -176,6 +182,9 @@ func (rows *snowflakeRows) Next(dest []driver.Value) (err error) { if err = rows.waitForAsyncQueryStatus(); err != nil { return err } + if rows.ChunkDownloader == nil { + return fmt.Errorf(errMsgAsyncWithNoResults) + } row, err := rows.ChunkDownloader.next() if err != nil { // includes io.EOF @@ -212,7 +221,7 @@ func (rows *snowflakeRows) Next(dest []driver.Value) (err error) { } func (rows *snowflakeRows) HasNextResultSet() bool { - if err := rows.waitForAsyncQueryStatus(); err != nil { + if err := rows.waitForAsyncQueryStatus(); err != nil || rows.ChunkDownloader == nil { return false } return rows.ChunkDownloader.hasNextResultSet() @@ -222,6 +231,10 @@ func (rows *snowflakeRows) NextResultSet() error { if err := rows.waitForAsyncQueryStatus(); err != nil { return err } + if rows.ChunkDownloader == nil { + return fmt.Errorf(errMsgAsyncWithNoResults) + } + if len(rows.ChunkDownloader.getChunkMetas()) == 0 { if rows.ChunkDownloader.getNextChunkDownloader() == nil { return io.EOF diff --git a/util.go b/util.go index b09c542d2..155516ca3 100644 --- a/util.go +++ b/util.go @@ -18,6 +18,7 @@ type contextKey string const ( multiStatementCount contextKey = "MULTI_STATEMENT_COUNT" asyncMode contextKey = "ASYNC_MODE_QUERY" + asyncModeNoFetch contextKey = "ASYNC_MODE_NO_FETCH_QUERY" queryIDChannel contextKey = "QUERY_ID_CHANNEL" snowflakeRequestIDKey contextKey = "SNOWFLAKE_REQUEST_ID" fetchResultByID contextKey = "SF_FETCH_RESULT_BY_ID" @@ -45,6 +46,11 @@ func WithAsyncMode(ctx context.Context) context.Context { return context.WithValue(ctx, asyncMode, true) } +// WithAsyncModeNoFetch returns a context that, when you execute a query in async mode, will not fetch results +func WithAsyncModeNoFetch(ctx context.Context) context.Context { + return context.WithValue(ctx, asyncModeNoFetch, true) +} + // WithQueryIDChan returns a context that contains the channel to receive the query ID func WithQueryIDChan(ctx context.Context, c chan<- string) context.Context { return context.WithValue(ctx, queryIDChannel, c)