Skip to content

Commit

Permalink
SIG-18794: Fix getAsync() to not panic on context exceeded + test. (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtoader committed Mar 11, 2022
1 parent bf77a7e commit 3103606
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
4 changes: 0 additions & 4 deletions async.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ func (sr *snowflakeRestful) getAsync(
logger.WithContext(ctx).Errorf("failed to get response. err: %v", err)
sfError.Message = err.Error()
errChannel <- sfError
close(errChannel)
return err
}
if resp.Body != nil {
Expand All @@ -82,7 +81,6 @@ func (sr *snowflakeRestful) getAsync(
logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
sfError.Message = err.Error()
errChannel <- sfError
close(errChannel)
return err
}

Expand All @@ -99,13 +97,11 @@ func (sr *snowflakeRestful) getAsync(
r, err := sc.handleMultiExec(ctx, respd.Data)
if err != nil {
res.errChannel <- err
close(errChannel)
return err
}
res.affectedRows, err = r.RowsAffected()
if err != nil {
res.errChannel <- err
close(errChannel)
return err
}
}
Expand Down
36 changes: 36 additions & 0 deletions async_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"fmt"
"testing"
"time"
)

func TestAsyncMode(t *testing.T) {
Expand Down Expand Up @@ -75,6 +76,7 @@ func TestMultipleAsyncQueries(t *testing.T) {

go retrieveRows(rows1, ch1)
go retrieveRows(rows2, ch2)

select {
case res := <-ch1:
t.Fatalf("value %v should not have been called earlier.", res)
Expand All @@ -86,6 +88,40 @@ func TestMultipleAsyncQueries(t *testing.T) {
})
}

func TestMultipleAsyncSuccessAndFailedQueries(t *testing.T) {
ctx := WithAsyncMode(context.Background())
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()

s1 := "foo"
s2 := "bar"
ch1 := make(chan string)
ch2 := make(chan string)

runTests(t, dsn, func(dbt *DBTest) {
rows1 := dbt.mustQueryContext(ctx, fmt.Sprintf("select distinct '%s' from table (generator(timelimit=>3))", s1))
defer rows1.Close()

rows2 := dbt.mustQueryContext(ctx, fmt.Sprintf("select distinct '%s' from table (generator(timelimit=>7))", s2))
defer rows2.Close()

go retrieveRows(rows1, ch1)
go retrieveRows(rows2, ch2)

res1 := <-ch1
if res1 != s1 {
t.Fatalf("query failed. expected: %v, got: %v", s1, res1)
}

// wait until rows2 is done
<-ch2
driverErr, ok := rows2.Err().(*SnowflakeError)
if !ok || driverErr == nil || driverErr.Number != ErrAsync {
t.Fatalf("Snowflake ErrAsync expected. got: %T, %v", rows2.Err(), rows2.Err())
}
})
}

func retrieveRows(rows *RowsExtended, ch chan string) {
var s string
for rows.Next() {
Expand Down

0 comments on commit 3103606

Please sign in to comment.