Skip to content

Commit

Permalink
[Feature] Snowflake driver to support WaitForQueryCompletion(); wai…
Browse files Browse the repository at this point in the history
…t for results to finish without returning result rows (#84)
  • Loading branch information
madisonchamberlain committed Apr 27, 2023
1 parent f72c414 commit e762e23
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 0 deletions.
9 changes: 9 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,14 @@ func (sc *snowflakeConn) FetchResult(ctx context.Context, qid string) (driver.Ro
return sc.buildRowsForRunningQuery(ctx, qid)
}

// WaitForQueryCompletion waits for the result of a previously issued query,
// given the snowflake query-id. This functionality is not used by the
// go sql library but is exported to clients who can make use of this
// capability explicitly.
func (sc *snowflakeConn) WaitForQueryCompletion(ctx context.Context, qid string) error {
return sc.blockOnQueryCompletion(ctx, qid)
}

// ResultFetcher is an interface which allows a query result to be
// fetched given the corresponding snowflake query-id.
//
Expand All @@ -815,6 +823,7 @@ func (sc *snowflakeConn) FetchResult(ctx context.Context, qid string) (driver.Ro
// function.
type ResultFetcher interface {
FetchResult(ctx context.Context, qid string) (driver.Rows, error)
WaitForQueryCompletion(ctx context.Context, qid string) error
}

// MonitoringResultFetcher is an interface which allows to fetch monitoringResult
Expand Down
106 changes: 106 additions & 0 deletions monitoring.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ func (sc *snowflakeConn) checkQueryStatus(
return &queryRet, nil
}

// Waits 45 seconds for a query response; return early if query finishes
func (sc *snowflakeConn) getQueryResultResp(
ctx context.Context,
resultPath string,
Expand Down Expand Up @@ -228,6 +229,63 @@ func (sc *snowflakeConn) getQueryResultResp(
return respd, nil
}

// Waits for the query to complete, then returns the response
func (sc *snowflakeConn) waitForCompletedQueryResultResp(
ctx context.Context,
resultPath string,
) (*execResponse, error) {
// if we already have the response; return that
if cachedResponse, ok := sc.execRespCache.load(resultPath); ok {
return cachedResponse, nil
}
requestID := getOrGenerateRequestIDFromContext(ctx)
headers := getHeaders()
if serviceName, ok := sc.cfg.Params[serviceName]; ok {
headers[httpHeaderServiceName] = *serviceName
}
param := make(url.Values)
param.Add(requestIDKey, requestID.String())
param.Add("clientStartTime", strconv.FormatInt(time.Now().Unix(), 10))
param.Add(requestGUIDKey, NewUUID().String())
token, _, _ := sc.rest.TokenAccessor.GetTokens()
if token != "" {
headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)
}
url := sc.rest.getFullURL(resultPath, &param)


deadline, ok := ctx.Deadline()
var timeout time.Duration
if !ok {
timeout = sc.rest.RequestTimeout
} else {
// if we have a context deadline set we want to override the default
timeout = deadline.Sub(time.Now())
}

// internally, pulls on FuncGet until we have a result at the result location (queryID)
var response *execResponse
var err error
for response == nil || isQueryInProgress(response) {
response, err = sc.rest.getAsyncOrStatus(ctx, url, headers, timeout)

// if the context is canceled, we have to cancel it manually now
if err != nil {
logger.WithContext(ctx).Errorf("failed to get response. err: %v", err)
if err == context.Canceled || err == context.DeadlineExceeded {
// use the default top level 1 sec timeout for cancellation as throughout the driver
if err := cancelQuery(context.TODO(), sc.rest, requestID, time.Second); err != nil {
logger.WithContext(ctx).Errorf("failed to cancel async query, err: %v", err)
}
}
return nil, err
}
}

sc.execRespCache.store(resultPath, response)
return response, nil
}

// Fetch query result for a query id from /queries/<qid>/result endpoint.
func (sc *snowflakeConn) rowsForRunningQuery(
ctx context.Context, qid string,
Expand Down Expand Up @@ -268,6 +326,44 @@ func (sc *snowflakeConn) rowsForRunningQuery(
return nil
}

// Wait for query to complete from a query id from /queries/<qid>/result endpoint.
func (sc *snowflakeConn) blockOnRunningQuery(
ctx context.Context, qid string) error {
resultPath := fmt.Sprintf(urlQueriesResultFmt, qid)
resp, err := sc.waitForCompletedQueryResultResp(ctx, resultPath)
if err != nil {
logger.WithContext(ctx).Errorf("error: %v", err)
if resp != nil {
code, err := strconv.Atoi(resp.Code)
if err != nil {
return err
}
return (&SnowflakeError{
Number: code,
SQLState: resp.Data.SQLState,
Message: err.Error(),
QueryID: resp.Data.QueryID,
}).exceptionTelemetry(sc)
}
return err
}
if !resp.Success {
message := resp.Message
code, err := strconv.Atoi(resp.Code)
if err != nil {
code = ErrQueryStatus
message = fmt.Sprintf("%s: (failed to parse original code: %s: %s)", message, resp.Code, err.Error())
}
return (&SnowflakeError{
Number: code,
SQLState: resp.Data.SQLState,
Message: message,
QueryID: resp.Data.QueryID,
}).exceptionTelemetry(sc)
}
return nil
}

// prepare a Rows object to return for query of 'qid'
func (sc *snowflakeConn) buildRowsForRunningQuery(
ctx context.Context,
Expand All @@ -285,6 +381,16 @@ func (sc *snowflakeConn) buildRowsForRunningQuery(
return rows, nil
}

func (sc *snowflakeConn) blockOnQueryCompletion(
ctx context.Context,
qid string,
) error {
if err := sc.blockOnRunningQuery(ctx, qid); err != nil {
return err
}
return nil
}

func mkMonitoringFetcher(sc *snowflakeConn, qid string, runtime time.Duration) *monitoringResult {
// Exit early if this was a "fast" query
if runtime < sc.cfg.MonitoringFetcher.QueryRuntimeThreshold {
Expand Down

0 comments on commit e762e23

Please sign in to comment.