diff --git a/connection.go b/connection.go index 3f8801527..ca3cb0f95 100644 --- a/connection.go +++ b/connection.go @@ -806,6 +806,14 @@ func (sc *snowflakeConn) FetchResult(ctx context.Context, qid string) (driver.Ro return sc.buildRowsForRunningQuery(ctx, qid) } +// WaitForQueryCompletion waits for the result of a previously issued query, +// given the snowflake query-id. This functionality is not used by the +// go sql library but is exported to clients who can make use of this +// capability explicitly. +func (sc *snowflakeConn) WaitForQueryCompletion(ctx context.Context, qid string) error { + return sc.blockOnQueryCompletion(ctx, qid) +} + // ResultFetcher is an interface which allows a query result to be // fetched given the corresponding snowflake query-id. // @@ -815,6 +823,7 @@ func (sc *snowflakeConn) FetchResult(ctx context.Context, qid string) (driver.Ro // function. type ResultFetcher interface { FetchResult(ctx context.Context, qid string) (driver.Rows, error) + WaitForQueryCompletion(ctx context.Context, qid string) error } // MonitoringResultFetcher is an interface which allows to fetch monitoringResult diff --git a/monitoring.go b/monitoring.go index 3be6beac4..acf7d415d 100644 --- a/monitoring.go +++ b/monitoring.go @@ -189,6 +189,7 @@ func (sc *snowflakeConn) checkQueryStatus( return &queryRet, nil } +// Waits 45 seconds for a query response; return early if query finishes func (sc *snowflakeConn) getQueryResultResp( ctx context.Context, resultPath string, @@ -228,6 +229,63 @@ func (sc *snowflakeConn) getQueryResultResp( return respd, nil } +// Waits for the query to complete, then returns the response +func (sc *snowflakeConn) waitForCompletedQueryResultResp( + ctx context.Context, + resultPath string, +) (*execResponse, error) { + // if we already have the response; return that + if cachedResponse, ok := sc.execRespCache.load(resultPath); ok { + return cachedResponse, nil + } + requestID := getOrGenerateRequestIDFromContext(ctx) + headers := getHeaders() + if serviceName, ok := sc.cfg.Params[serviceName]; ok { + headers[httpHeaderServiceName] = *serviceName + } + param := make(url.Values) + param.Add(requestIDKey, requestID.String()) + param.Add("clientStartTime", strconv.FormatInt(time.Now().Unix(), 10)) + param.Add(requestGUIDKey, NewUUID().String()) + token, _, _ := sc.rest.TokenAccessor.GetTokens() + if token != "" { + headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) + } + url := sc.rest.getFullURL(resultPath, ¶m) + + + deadline, ok := ctx.Deadline() + var timeout time.Duration + if !ok { + timeout = sc.rest.RequestTimeout + } else { + // if we have a context deadline set we want to override the default + timeout = deadline.Sub(time.Now()) + } + + // internally, pulls on FuncGet until we have a result at the result location (queryID) + var response *execResponse + var err error + for response == nil || isQueryInProgress(response) { + response, err = sc.rest.getAsyncOrStatus(ctx, url, headers, timeout) + + // if the context is canceled, we have to cancel it manually now + if err != nil { + logger.WithContext(ctx).Errorf("failed to get response. err: %v", err) + if err == context.Canceled || err == context.DeadlineExceeded { + // use the default top level 1 sec timeout for cancellation as throughout the driver + if err := cancelQuery(context.TODO(), sc.rest, requestID, time.Second); err != nil { + logger.WithContext(ctx).Errorf("failed to cancel async query, err: %v", err) + } + } + return nil, err + } + } + + sc.execRespCache.store(resultPath, response) + return response, nil +} + // Fetch query result for a query id from /queries//result endpoint. func (sc *snowflakeConn) rowsForRunningQuery( ctx context.Context, qid string, @@ -268,6 +326,44 @@ func (sc *snowflakeConn) rowsForRunningQuery( return nil } +// Wait for query to complete from a query id from /queries//result endpoint. +func (sc *snowflakeConn) blockOnRunningQuery( + ctx context.Context, qid string) error { + resultPath := fmt.Sprintf(urlQueriesResultFmt, qid) + resp, err := sc.waitForCompletedQueryResultResp(ctx, resultPath) + if err != nil { + logger.WithContext(ctx).Errorf("error: %v", err) + if resp != nil { + code, err := strconv.Atoi(resp.Code) + if err != nil { + return err + } + return (&SnowflakeError{ + Number: code, + SQLState: resp.Data.SQLState, + Message: err.Error(), + QueryID: resp.Data.QueryID, + }).exceptionTelemetry(sc) + } + return err + } + if !resp.Success { + message := resp.Message + code, err := strconv.Atoi(resp.Code) + if err != nil { + code = ErrQueryStatus + message = fmt.Sprintf("%s: (failed to parse original code: %s: %s)", message, resp.Code, err.Error()) + } + return (&SnowflakeError{ + Number: code, + SQLState: resp.Data.SQLState, + Message: message, + QueryID: resp.Data.QueryID, + }).exceptionTelemetry(sc) + } + return nil +} + // prepare a Rows object to return for query of 'qid' func (sc *snowflakeConn) buildRowsForRunningQuery( ctx context.Context, @@ -285,6 +381,16 @@ func (sc *snowflakeConn) buildRowsForRunningQuery( return rows, nil } +func (sc *snowflakeConn) blockOnQueryCompletion( + ctx context.Context, + qid string, +) error { + if err := sc.blockOnRunningQuery(ctx, qid); err != nil { + return err + } + return nil +} + func mkMonitoringFetcher(sc *snowflakeConn, qid string, runtime time.Duration) *monitoringResult { // Exit early if this was a "fast" query if runtime < sc.cfg.MonitoringFetcher.QueryRuntimeThreshold {