Skip to content

Commit

Permalink
[Fix] SIG-18794: Fix getAsync() to not panic on context exceeded + te…
Browse files Browse the repository at this point in the history
…st. (#56)
  • Loading branch information
mtoader authored and ardenma committed Mar 12, 2024
1 parent e2d4a0b commit d2c2058
Showing 1 changed file with 39 additions and 12 deletions.
51 changes: 39 additions & 12 deletions async_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ package gosnowflake

import (
"context"
"database/sql"
"fmt"
"testing"
"time"
)

func TestAsyncMode(t *testing.T) {
Expand Down Expand Up @@ -125,22 +125,15 @@ func TestMultipleAsyncQueries(t *testing.T) {
ch1 := make(chan string)
ch2 := make(chan string)

db := openDB(t)

runDBTest(t, func(dbt *DBTest) {
rows1, err := db.QueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s1, 30))
if err != nil {
t.Fatalf("can't read rows1: %v", err)
}
rows1 := dbt.mustQueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s1, 30))
defer rows1.Close()
rows2, err := db.QueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s2, 10))
if err != nil {
t.Fatalf("can't read rows2: %v", err)
}
rows2 := dbt.mustQueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s2, 10))
defer rows2.Close()

go retrieveRows(rows1, ch1)
go retrieveRows(rows2, ch2)

select {
case res := <-ch1:
t.Fatalf("value %v should not have been called earlier.", res)
Expand All @@ -152,7 +145,41 @@ func TestMultipleAsyncQueries(t *testing.T) {
})
}

func retrieveRows(rows *sql.Rows, ch chan string) {
func TestMultipleAsyncSuccessAndFailedQueries(t *testing.T) {
ctx := WithAsyncMode(context.Background())
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()

s1 := "foo"
s2 := "bar"
ch1 := make(chan string)
ch2 := make(chan string)

runDBTest(t, func(dbt *DBTest) {
rows1 := dbt.mustQueryContext(ctx, fmt.Sprintf("select distinct '%s' from table (generator(timelimit=>3))", s1))
defer rows1.Close()

rows2 := dbt.mustQueryContext(ctx, fmt.Sprintf("select distinct '%s' from table (generator(timelimit=>7))", s2))
defer rows2.Close()

go retrieveRows(rows1, ch1)
go retrieveRows(rows2, ch2)

res1 := <-ch1
if res1 != s1 {
t.Fatalf("query failed. expected: %v, got: %v", s1, res1)
}

// wait until rows2 is done
<-ch2
driverErr, ok := rows2.Err().(*SnowflakeError)
if !ok || driverErr == nil || driverErr.Number != ErrAsync {
t.Fatalf("Snowflake ErrAsync expected. got: %T, %v", rows2.Err(), rows2.Err())
}
})
}

func retrieveRows(rows *RowsExtended, ch chan string) {
var s string
for rows.Next() {
if err := rows.Scan(&s); err != nil {
Expand Down

0 comments on commit d2c2058

Please sign in to comment.