From 4c04398725ebbaf591827ea725bd61a4847ecf74 Mon Sep 17 00:00:00 2001 From: Mihai Claudiu Toader Date: Fri, 11 Mar 2022 14:47:50 -0800 Subject: [PATCH] [Fix] SIG-18794: Fix getAsync() to not panic on context exceeded + test. (#56) --- async_test.go | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/async_test.go b/async_test.go index f7d685b27..a742dbdac 100644 --- a/async_test.go +++ b/async_test.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "testing" + "time" ) func TestAsyncMode(t *testing.T) { @@ -89,6 +90,7 @@ func TestMultipleAsyncQueries(t *testing.T) { go retrieveRows(rows1, ch1) go retrieveRows(rows2, ch2) + select { case res := <-ch1: t.Fatalf("value %v should not have been called earlier.", res) @@ -100,6 +102,40 @@ func TestMultipleAsyncQueries(t *testing.T) { }) } +func TestMultipleAsyncSuccessAndFailedQueries(t *testing.T) { + ctx := WithAsyncMode(context.Background()) + ctx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + + s1 := "foo" + s2 := "bar" + ch1 := make(chan string) + ch2 := make(chan string) + + runTests(t, dsn, func(dbt *DBTest) { + rows1 := dbt.mustQueryContext(ctx, fmt.Sprintf("select distinct '%s' from table (generator(timelimit=>3))", s1)) + defer rows1.Close() + + rows2 := dbt.mustQueryContext(ctx, fmt.Sprintf("select distinct '%s' from table (generator(timelimit=>7))", s2)) + defer rows2.Close() + + go retrieveRows(rows1, ch1) + go retrieveRows(rows2, ch2) + + res1 := <-ch1 + if res1 != s1 { + t.Fatalf("query failed. expected: %v, got: %v", s1, res1) + } + + // wait until rows2 is done + <-ch2 + driverErr, ok := rows2.Err().(*SnowflakeError) + if !ok || driverErr == nil || driverErr.Number != ErrAsync { + t.Fatalf("Snowflake ErrAsync expected. got: %T, %v", rows2.Err(), rows2.Err()) + } + }) +} + func retrieveRows(rows *RowsExtended, ch chan string) { var s string for rows.Next() {