Skip to content

Commit

Permalink
[Feature] Improvements to Async handling
Browse files Browse the repository at this point in the history
SIG-18794: Refactor the monitoring fetcher & auto-cancel async (#57)

* SIG-18794: Refactor the monitoring fetcher & auto-cancel async

  * better configuration
  * wait until the monitoring data refers to completed queries

SIG-16907: Error out early when invalid state for multi-statement requests. (#43)

Remove the debug-tracing added earlier.

Co-authored-by: Agam Brahma <agam@sigmacomputing.com>

SIG-18794: Make getAsync wait until the query completes or times out. (#59)

* SIG-18794: Make getAsync wait until the query completes or times out.

  * make the getStatus for an async made query block until the ctx
    timeout

SIG-18794: Fix a bug when looping to complete an async query (#60)

* discovered a new error code the needs looping when async,
     enhanced the test
  • Loading branch information
mtoader authored and rajsigma committed Sep 8, 2023
1 parent 56287d0 commit f384bc5
Show file tree
Hide file tree
Showing 9 changed files with 260 additions and 94 deletions.
24 changes: 4 additions & 20 deletions async.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strconv"
"time"
Expand All @@ -26,7 +25,7 @@ func (sr *snowflakeRestful) processAsync(
headers map[string]string,
timeout time.Duration,
cfg *Config,
requestID UUID) (*execResponse, error) {
requestID uuid) (*execResponse, error) {
// placeholder object to return to user while retrieving results
rows := new(snowflakeRows)
res := new(snowflakeResult)
Expand Down Expand Up @@ -58,7 +57,7 @@ func (sr *snowflakeRestful) getAsync(
timeout time.Duration,
res *snowflakeResult,
rows *snowflakeRows,
requestID UUID,
requestID uuid,
cfg *Config) error {
resType := getResultType(ctx)
var errChannel chan error
Expand Down Expand Up @@ -98,9 +97,7 @@ func (sr *snowflakeRestful) getAsync(
}

sc := &snowflakeConn{rest: sr, cfg: cfg}
// the result response sometimes contains only Data and not anything else.
// if code is not set we treat as success
if response.Success || response.Code == "" {
if response.Success {
if resType == execResultType {
res.insertID = -1
if isDml(response.Data.StatementTypeID) {
Expand Down Expand Up @@ -136,11 +133,7 @@ func (sr *snowflakeRestful) getAsync(
} else {
rows.addDownloader(populateChunkDownloader(ctx, sc, response.Data))
}
if err := rows.ChunkDownloader.start(); err != nil {
rows.errChannel <- err
close(errChannel)
return err
}
_ = rows.ChunkDownloader.start()
}
rows.errChannel <- nil // mark query status complete
}
Expand Down Expand Up @@ -181,19 +174,10 @@ func (sr *snowflakeRestful) getAsyncOrStatus(
url *url.URL,
headers map[string]string,
timeout time.Duration) (*execResponse, error) {
startTime := time.Now()
resp, err := sr.FuncGet(ctx, sr, url, headers, timeout)
if err != nil {
return nil, err
}
if reportAsyncErrorFromContext(ctx) {
// if we dont get a response, or we get a bad response, this is not expected, so derive the information to know
// why this happened and panic with that message
if resp == nil || resp.StatusCode != http.StatusOK {
panicMessage := newPanicMessage(ctx, resp, startTime, timeout)
panic(panicMessage)
}
}
if resp.Body != nil {
defer func() { _ = resp.Body.Close() }()
}
Expand Down
46 changes: 46 additions & 0 deletions async_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,52 @@ func TestAsyncModeCancel(t *testing.T) {
})
}

const (
//selectTimelineGenerator = "SELECT COUNT(*) FROM TABLE(GENERATOR(TIMELIMIT=>%v))"
selectTimelineGenerator = "SELECT SYSTEM$WAIT(%v, 'SECONDS')"
)

func TestAsyncModeNoFetch(t *testing.T) {
ctx := WithAsyncMode(WithAsyncModeNoFetch(context.Background()))
// the default behavior of the async wait is to wait for 45s. We want to make sure we wait until the query actually
// completes, so we make the test take longer than 45s
secondsToRun := 50

runTests(t, dsn, func(dbt *DBTest) {
start := time.Now()
rows := dbt.mustQueryContext(ctx, fmt.Sprintf(selectTimelineGenerator, secondsToRun))
defer rows.Close()

// Next() will block and wait until results are available
if rows.Next() == true {
t.Fatalf("next should have returned no rows")
}
columns, err := rows.Columns()
if columns != nil {
t.Fatalf("there should be no column array returned")
}
if err == nil {
t.Fatalf("we should have an error thrown")
}
if rows.Scan(nil) == nil {
t.Fatalf("we should have an error thrown")
}
if (time.Second * time.Duration(secondsToRun)) > time.Since(start) {
t.Fatalf("tset should should have run for %d seconds", secondsToRun)
}

dbt.mustExec("create or replace table test_async_exec (value boolean)")
res := dbt.mustExecContext(ctx, "insert into test_async_exec values (true)")
count, err := res.RowsAffected()
if err != nil {
t.Fatalf("res.RowsAffected() returned error: %v", err)
}
if count != 1 {
t.Fatalf("expected 1 affected row, got %d", count)
}
})
}

func TestAsyncQueryFail(t *testing.T) {
ctx := WithAsyncMode(context.Background())
runTests(t, dsn, func(dbt *DBTest) {
Expand Down
16 changes: 13 additions & 3 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,16 @@ func (sc *snowflakeConn) queryContextInternal(
if err = sc.handleMultiQuery(ctx, data.Data, rows); err != nil {
return nil, err
}
if data.Data.ResultIDs == "" && rows.ChunkDownloader == nil {
// SIG-16907: We have no results to download here.
logger.WithContext(ctx).Errorf("Encountered empty result-ids for a multi-statement request. Query-id: %s, Query: %s", data.Data.QueryID, query)
return nil, (&SnowflakeError{
Number: ErrQueryIDFormat,
SQLState: data.Data.SQLState,
Message: "ExecResponse for multi-statement request had no ResultIDs",
QueryID: data.Data.QueryID,
}).exceptionTelemetry(sc)
}
} else {
rows.addDownloader(populateChunkDownloader(ctx, sc, data.Data))
}
Expand Down Expand Up @@ -834,17 +844,17 @@ type ResultFetcher interface {
// MonitoringResultFetcher is an interface which allows to fetch monitoringResult
// with snowflake connection and query-id.
type MonitoringResultFetcher interface {
FetchMonitoringResult(queryID string) (*monitoringResult, error)
FetchMonitoringResult(queryID string, runtime time.Duration) (*monitoringResult, error)
}

// FetchMonitoringResult returns a monitoringResult object
// Multiplex can call monitoringResult.Monitoring() to get the QueryMonitoringData
func (sc *snowflakeConn) FetchMonitoringResult(queryID string) (*monitoringResult, error) {
func (sc *snowflakeConn) FetchMonitoringResult(queryID string, runtime time.Duration) (*monitoringResult, error) {
if sc.rest == nil {
return nil, driver.ErrBadConn
}

// set the fake runtime just to bypass fast query
monitoringResult := mkMonitoringFetcher(sc, queryID, time.Minute*10)
monitoringResult := mkMonitoringFetcher(sc, queryID, runtime)
return monitoringResult, nil
}
75 changes: 75 additions & 0 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ const (
defaultRequestTimeout = 0 * time.Second // Timeout for retry for request EXCLUDING clientTimeout
defaultJWTTimeout = 60 * time.Second
defaultDomain = ".snowflakecomputing.com"

// default monitoring fetcher config values
defaultMonitoringFetcherQueryMonitoringThreshold = 45 * time.Second
defaultMonitoringFetcherMaxDuration = 10 * time.Second
defaultMonitoringFetcherRetrySleepDuration = 250 * time.Second
)

// ConfigBool is a type to represent true or false in the Config
Expand Down Expand Up @@ -85,15 +90,37 @@ type Config struct {

Tracing string // sets logging level

<<<<<<< HEAD
MfaToken string // Internally used to cache the MFA token
IDToken string // Internally used to cache the Id Token for external browser
ClientRequestMfaToken ConfigBool // When true the MFA token is cached in the credential manager. True by default in Windows/OSX. False for Linux.
ClientStoreTemporaryCredential ConfigBool // When true the ID token is cached in the credential manager. True by default in Windows/OSX. False for Linux.
=======
// Monitoring fetcher config
MonitoringFetcher MonitoringFetcherConfig

>>>>>>> 95d57be ([Feature] Improvements to Async handling)
// An identifier for this Config. Used to associate multiple connection instances with
// a single logical sql.DB connection.
ConnectionID string
}

// MonitoringFetcherConfig provides some knobs to control the behavior of the monitoring data fetcher
type MonitoringFetcherConfig struct {
// QueryRuntimeThreshold specifies the threshold, over which we'll fetch the monitoring
// data for a successful snowflake query. We use a time-based threshold, since there is
// a non-zero latency cost to fetch this data, and we want to bound the additional latency.
// By default, we bound to a 2% increase in latency - assuming worst case 100ms - when
// fetching this metadata.
QueryRuntimeThreshold time.Duration

// max time to wait until we get a proper monitoring sample for a query
MaxDuration time.Duration

// Wait time between monitoring retries
RetrySleepDuration time.Duration
}

// ocspMode returns the OCSP mode in string INSECURE, FAIL_OPEN, FAIL_CLOSED
func (c *Config) ocspMode() string {
if c.InsecureMode {
Expand Down Expand Up @@ -218,6 +245,16 @@ func DSN(cfg *Config) (dsn string, err error) {
params.Add("clientStoreTemporaryCredential", strconv.FormatBool(cfg.ClientStoreTemporaryCredential != ConfigBoolFalse))
}

if cfg.MonitoringFetcher.QueryRuntimeThreshold != defaultMonitoringFetcherQueryMonitoringThreshold {
params.Add("monitoringFetcher_queryRuntimeThresholdMs", durationAsMillis(cfg.MonitoringFetcher.QueryRuntimeThreshold))
}
if cfg.MonitoringFetcher.MaxDuration != defaultMonitoringFetcherMaxDuration {
params.Add("monitoringFetcher_maxDurationMs", durationAsMillis(cfg.MonitoringFetcher.MaxDuration))
}
if cfg.MonitoringFetcher.RetrySleepDuration != defaultMonitoringFetcherRetrySleepDuration {
params.Add("monitoringFetcher_retrySleepDurationMs", durationAsMillis(cfg.MonitoringFetcher.RetrySleepDuration))
}

if cfg.ConnectionID != "" {
params.Add("connectionId", cfg.ConnectionID)
}
Expand Down Expand Up @@ -446,6 +483,16 @@ func fillMissingConfigParameters(cfg *Config) error {
cfg.ValidateDefaultParameters = ConfigBoolTrue
}

if cfg.MonitoringFetcher.QueryRuntimeThreshold == 0 {
cfg.MonitoringFetcher.QueryRuntimeThreshold = defaultMonitoringFetcherQueryMonitoringThreshold
}
if cfg.MonitoringFetcher.MaxDuration == 0 {
cfg.MonitoringFetcher.MaxDuration = defaultMonitoringFetcherMaxDuration
}
if cfg.MonitoringFetcher.RetrySleepDuration == 0 {
cfg.MonitoringFetcher.RetrySleepDuration = defaultMonitoringFetcherRetrySleepDuration
}

if cfg.ConnectionID == "" {
cfg.ConnectionID = uuid.New().String()
}
Expand Down Expand Up @@ -659,6 +706,21 @@ func parseDSNParams(cfg *Config, params string) (err error) {
}
case "tracing":
cfg.Tracing = value
case "monitoringFetcher_queryRuntimeThresholdMs":
cfg.MonitoringFetcher.QueryRuntimeThreshold, err = parseMillisToDuration(value)
if err != nil {
return err
}
case "monitoringFetcher_maxDurationMs":
cfg.MonitoringFetcher.MaxDuration, err = parseMillisToDuration(value)
if err != nil {
return err
}
case "monitoringFetcher_retrySleepDurationMs":
cfg.MonitoringFetcher.RetrySleepDuration, err = parseMillisToDuration(value)
if err != nil {
return err
}
default:
if cfg.Params == nil {
cfg.Params = make(map[string]*string)
Expand All @@ -669,6 +731,19 @@ func parseDSNParams(cfg *Config, params string) (err error) {
return
}

func parseMillisToDuration(value string) (time.Duration, error) {
intValue, err := strconv.ParseInt(value, 10, 64)
if err == nil {
return time.Millisecond * time.Duration(intValue), nil
}

return 0, err
}

func durationAsMillis(duration time.Duration) string {
return strconv.FormatInt(duration.Milliseconds(), 10)
}

func parseTimeout(value string) (time.Duration, error) {
var vv int64
var err error
Expand Down
Loading

0 comments on commit f384bc5

Please sign in to comment.