diff --git a/async_test.go b/async_test.go index 447337cc6..e0190ecdb 100644 --- a/async_test.go +++ b/async_test.go @@ -1,9 +1,10 @@ -// Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved. +// Copyright (c) 2021-2023 Snowflake Computing Inc. All rights reserved. package gosnowflake import ( "context" + "database/sql" "fmt" "testing" "time" @@ -16,7 +17,7 @@ func TestAsyncMode(t *testing.T) { var idx int var v string - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQueryContext(ctx, fmt.Sprintf(selectRandomGenerator, numrows)) defer rows.Close() @@ -45,12 +46,39 @@ func TestAsyncMode(t *testing.T) { }) } +func TestAsyncModeMultiStatement(t *testing.T) { + withMultiStmtCtx, _ := WithMultiStatement(context.Background(), 6) + ctx := WithAsyncMode(withMultiStmtCtx) + multiStmtQuery := "begin;\n" + + "delete from test_multi_statement_async;\n" + + "insert into test_multi_statement_async values (1, 'a'), (2, 'b');\n" + + "select 1;\n" + + "select 2;\n" + + "rollback;" + + runDBTest(t, func(dbt *DBTest) { + dbt.mustExec("drop table if exists test_multi_statement_async") + dbt.mustExec(`create or replace table test_multi_statement_async( + c1 number, c2 string) as select 10, 'z'`) + defer dbt.mustExec("drop table if exists test_multi_statement_async") + + res := dbt.mustExecContext(ctx, multiStmtQuery) + count, err := res.RowsAffected() + if err != nil { + t.Fatalf("res.RowsAffected() returned error: %v", err) + } + if count != 3 { + t.Fatalf("expected 3 affected rows, got %d", count) + } + }) +} + func TestAsyncModeCancel(t *testing.T) { withCancelCtx, cancel := context.WithCancel(context.Background()) ctx := WithAsyncMode(withCancelCtx) numrows := 100000 - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { dbt.mustQueryContext(ctx, fmt.Sprintf(selectRandomGenerator, numrows)) cancel() }) @@ -67,7 +95,7 @@ func TestAsyncModeNoFetch(t *testing.T) { // completes, so we make the test take longer than 45s secondsToRun := 50 - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { start := time.Now() rows := dbt.mustQueryContext(ctx, fmt.Sprintf(selectTimelineGenerator, secondsToRun)) defer rows.Close() @@ -104,7 +132,7 @@ func TestAsyncModeNoFetch(t *testing.T) { func TestAsyncQueryFail(t *testing.T) { ctx := WithAsyncMode(context.Background()) - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQueryContext(ctx, "selectt 1") defer rows.Close() @@ -125,10 +153,18 @@ func TestMultipleAsyncQueries(t *testing.T) { ch1 := make(chan string) ch2 := make(chan string) - runTests(t, dsn, func(dbt *DBTest) { - rows1 := dbt.mustQueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s1, 30)) + db := openDB(t) + + runDBTest(t, func(dbt *DBTest) { + rows1, err := db.QueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s1, 30)) + if err != nil { + t.Fatalf("can't read rows1: %v", err) + } defer rows1.Close() - rows2 := dbt.mustQueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s2, 10)) + rows2, err := db.QueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s2, 10)) + if err != nil { + t.Fatalf("can't read rows2: %v", err) + } defer rows2.Close() go retrieveRows(rows1, ch1) @@ -155,11 +191,19 @@ func TestMultipleAsyncSuccessAndFailedQueries(t *testing.T) { ch1 := make(chan string) ch2 := make(chan string) - runTests(t, dsn, func(dbt *DBTest) { - rows1 := dbt.mustQueryContext(ctx, fmt.Sprintf("select distinct '%s' from table (generator(timelimit=>3))", s1)) + db := openDB(t) + + runDBTest(t, func(dbt *DBTest) { + rows1, err := db.QueryContext(ctx, fmt.Sprintf("select distinct '%s' from table (generator(timelimit=>3))", s1)) + if err != nil { + t.Fatalf("can't read rows1: %v", err) + } defer rows1.Close() - rows2 := dbt.mustQueryContext(ctx, fmt.Sprintf("select distinct '%s' from table (generator(timelimit=>7))", s2)) + rows2, err := db.QueryContext(ctx, fmt.Sprintf("select distinct '%s' from table (generator(timelimit=>7))", s2)) + if err != nil { + t.Fatalf("can't read rows2: %v", err) + } defer rows2.Close() go retrieveRows(rows1, ch1) @@ -179,7 +223,7 @@ func TestMultipleAsyncSuccessAndFailedQueries(t *testing.T) { }) } -func retrieveRows(rows *RowsExtended, ch chan string) { +func retrieveRows(rows *sql.Rows, ch chan string) { var s string for rows.Next() { if err := rows.Scan(&s); err != nil { @@ -191,3 +235,38 @@ func retrieveRows(rows *RowsExtended, ch chan string) { ch <- s close(ch) } + +func TestLongRunningAsyncQuery(t *testing.T) { + conn := openConn(t) + defer conn.Close() + + ctx, _ := WithMultiStatement(context.Background(), 0) + query := "CALL SYSTEM$WAIT(50, 'SECONDS');use snowflake_sample_data" + + rows, err := conn.QueryContext(WithAsyncMode(ctx), query) + if err != nil { + t.Fatalf("failed to run a query. %v, err: %v", query, err) + } + defer rows.Close() + var v string + i := 0 + for { + for rows.Next() { + err := rows.Scan(&v) + if err != nil { + t.Fatalf("failed to get result. err: %v", err) + } + if v == "" { + t.Fatal("should have returned a result") + } + results := []string{"waited 50 seconds", "Statement executed successfully."} + if v != results[i] { + t.Fatalf("unexpected result returned. expected: %v, but got: %v", results[i], v) + } + i++ + } + if !rows.NextResultSet() { + break + } + } +} diff --git a/bind_uploader.go b/bind_uploader.go index da817f159..5c10ff5e3 100644 --- a/bind_uploader.go +++ b/bind_uploader.go @@ -5,6 +5,7 @@ package gosnowflake import ( "bytes" "context" + "database/sql" "database/sql/driver" "fmt" "reflect" @@ -221,6 +222,10 @@ func getBindValues(bindings []driver.NamedValue) (map[string]execBindParameter, dataType = binding.Value.(SnowflakeDataType) default: // This binding is an actual parameter for the query + if tnt, ok := binding.Value.(TypedNullTime); ok { + dataType = convertTzTypeToSnowflakeType(tnt.TzType) + binding.Value = tnt.Time + } t := goTypeToSnowflake(binding.Value, dataType) var val interface{} if t == sliceType { @@ -235,7 +240,7 @@ func getBindValues(bindings []driver.NamedValue) (map[string]execBindParameter, if t == nullType || t == unSupportedType { t = textType // if null or not supported, pass to GS as text } - bindValues[strconv.Itoa(idx)] = execBindParameter{ + bindValues[bindingName(binding, idx)] = execBindParameter{ Type: t.String(), Value: val, } @@ -245,6 +250,13 @@ func getBindValues(bindings []driver.NamedValue) (map[string]execBindParameter, return bindValues, nil } +func bindingName(nv driver.NamedValue, idx int) string { + if nv.Name != "" { + return nv.Name + } + return strconv.Itoa(idx) +} + func arrayBindValueCount(bindValues []driver.NamedValue) int { if !isArrayBind(bindValues) { return 0 @@ -298,3 +310,12 @@ func supportedArrayBind(nv *driver.NamedValue) bool { return false } } + +func supportedNullBind(nv *driver.NamedValue) bool { + switch reflect.TypeOf(nv.Value) { + case reflect.TypeOf(sql.NullString{}), reflect.TypeOf(sql.NullInt64{}), + reflect.TypeOf(sql.NullBool{}), reflect.TypeOf(sql.NullFloat64{}), reflect.TypeOf(TypedNullTime{}): + return true + } + return false +} diff --git a/bindings_test.go b/bindings_test.go index 490729d14..44561545e 100644 --- a/bindings_test.go +++ b/bindings_test.go @@ -9,6 +9,7 @@ import ( "fmt" "math/big" "math/rand" + "reflect" "strconv" "testing" "time" @@ -23,9 +24,9 @@ const ( selectAllSQL = "select * from TEST_PREP_STATEMENT ORDER BY 1" createTableSQLBulkArray = `create or replace table test_bulk_array(c1 INTEGER, - c2 FLOAT, c3 BOOLEAN, c4 STRING, C5 BINARY)` + c2 FLOAT, c3 BOOLEAN, c4 STRING, C5 BINARY, C6 INTEGER)` deleteTableSQLBulkArray = "drop table if exists test_bulk_array" - insertSQLBulkArray = "insert into test_bulk_array values(?, ?, ?, ?, ?)" + insertSQLBulkArray = "insert into test_bulk_array values(?, ?, ?, ?, ?, ?)" selectAllSQLBulkArray = "select * from test_bulk_array ORDER BY 1" createTableSQLBulkArrayDateTimeTimestamp = `create or replace table test_bulk_array_DateTimeTimestamp( @@ -36,9 +37,9 @@ const ( ) func TestBindingNull(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (id int, c1 STRING, c2 BOOLEAN)") - _, err := dbt.db.Exec("INSERT INTO test VALUES (1, ?, ?)", + _, err := dbt.exec("INSERT INTO test VALUES (1, ?, ?)", DataTypeText, "hello", DataTypeNull, nil, ) @@ -50,54 +51,53 @@ func TestBindingNull(t *testing.T) { } func TestBindingFloat64(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { types := [2]string{"FLOAT", "DOUBLE"} expected := 42.23 var out float64 var rows *RowsExtended for _, v := range types { - dbt.mustExec(fmt.Sprintf("CREATE TABLE test (id int, value %v)", v)) - dbt.mustExec("INSERT INTO test VALUES (1, ?)", expected) - rows = dbt.mustQuery("SELECT value FROM test WHERE id = ?", 1) - defer rows.Close() - if rows.Next() { - rows.Scan(&out) - if expected != out { - dbt.Errorf("%s: %g != %g", v, expected, out) + t.Run(v, func(t *testing.T) { + dbt.mustExec(fmt.Sprintf("CREATE OR REPLACE TABLE test (id int, value %v)", v)) + dbt.mustExec("INSERT INTO test VALUES (1, ?)", expected) + rows = dbt.mustQuery("SELECT value FROM test WHERE id = ?", 1) + defer rows.Close() + if rows.Next() { + rows.Scan(&out) + if expected != out { + dbt.Errorf("%s: %g != %g", v, expected, out) + } + } else { + dbt.Errorf("%s: no data", v) } - } else { - dbt.Errorf("%s: no data", v) - } - dbt.mustExec("DROP TABLE IF EXISTS test") + }) } + dbt.mustExec("DROP TABLE IF EXISTS test") }) } // TestBindingUint64 tests uint64 binding. Should fail as unit64 is not a // supported binding value by Go's sql package. func TestBindingUint64(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - types := []string{"INTEGER"} + runDBTest(t, func(dbt *DBTest) { expected := uint64(18446744073709551615) - for _, v := range types { - dbt.mustExec(fmt.Sprintf("CREATE TABLE test (id int, value %v)", v)) - if _, err := dbt.db.Exec("INSERT INTO test VALUES (1, ?)", expected); err == nil { - dbt.Fatal("should fail as uint64 values with high bit set are not supported.") - } else { - logger.Infof("expected err: %v", err) - } - dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.mustExec("CREATE OR REPLACE TABLE test (id int, value INTEGER)") + if _, err := dbt.exec("INSERT INTO test VALUES (1, ?)", expected); err == nil { + dbt.Fatal("should fail as uint64 values with high bit set are not supported.") + } else { + logger.Infof("expected err: %v", err) } + dbt.mustExec("DROP TABLE IF EXISTS test") }) } func TestBindingDateTimeTimestamp(t *testing.T) { createDSN(PSTLocation) - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { expected := time.Now() dbt.mustExec( "CREATE OR REPLACE TABLE tztest (id int, ntz timestamp_ntz, ltz timestamp_ltz, dt date, tm time)") - stmt, err := dbt.db.Prepare("INSERT INTO tztest(id,ntz,ltz,dt,tm) VALUES(1,?,?,?,?)") + stmt, err := dbt.prepare("INSERT INTO tztest(id,ntz,ltz,dt,tm) VALUES(1,?,?,?,?)") if err != nil { dbt.Fatal(err.Error()) } @@ -163,36 +163,33 @@ func TestBindingDateTimeTimestamp(t *testing.T) { } func TestBindingBinary(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE OR REPLACE TABLE bintest (id int, b binary, c binary)") + runDBTest(t, func(dbt *DBTest) { + dbt.mustExec("CREATE OR REPLACE TABLE bintest (id int, b binary)") var b = []byte{0x01, 0x02, 0x03} - dbt.mustExec("INSERT INTO bintest(id,b,c) VALUES(1, ?, ?)", DataTypeBinary, b, DataTypeBinary, b) - rows := dbt.mustQuery("SELECT b, c FROM bintest WHERE id=?", 1) + dbt.mustExec("INSERT INTO bintest(id,b) VALUES(1, ?)", DataTypeBinary, b) + rows := dbt.mustQuery("SELECT b FROM bintest WHERE id=?", 1) defer rows.Close() if rows.Next() { var rb []byte - var rc []byte - if err := rows.Scan(&rb, &rc); err != nil { + if err := rows.Scan(&rb); err != nil { dbt.Errorf("failed to scan data. err: %v", err) } if !bytes.Equal(b, rb) { dbt.Errorf("failed to match data. expected: %v, got: %v", b, rb) } - if !bytes.Equal(b, rc) { - dbt.Errorf("failed to match data. expected: %v, got: %v", b, rc) - } } else { dbt.Errorf("no data") } dbt.mustExec("DROP TABLE bintest") }) + } func TestBindingTimestampTZ(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { expected := time.Now() dbt.mustExec("CREATE OR REPLACE TABLE tztest (id int, tz timestamp_tz)") - stmt, err := dbt.db.Prepare("INSERT INTO tztest(id,tz) VALUES(1, ?)") + stmt, err := dbt.prepare("INSERT INTO tztest(id,tz) VALUES(1, ?)") if err != nil { dbt.Fatal(err.Error()) } @@ -218,7 +215,7 @@ func TestBindingTimestampTZ(t *testing.T) { // SNOW-755844: Test the use of a pointer *time.Time type in user-defined structures to perform updates/inserts func TestBindingTimePtrInStruct(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { type timePtrStruct struct { id *int timeVal *time.Time @@ -231,7 +228,7 @@ func TestBindingTimePtrInStruct(t *testing.T) { runInsertQuery := false for i := 0; i < 2; i++ { if !runInsertQuery { - _, err := dbt.db.Exec("INSERT INTO timeStructTest(id,tz) VALUES(?, ?)", testStruct.id, testStruct.timeVal) + _, err := dbt.exec("INSERT INTO timeStructTest(id,tz) VALUES(?, ?)", testStruct.id, testStruct.timeVal) if err != nil { dbt.Fatal(err.Error()) } @@ -240,7 +237,7 @@ func TestBindingTimePtrInStruct(t *testing.T) { // Update row with a new time value expectedTime = time.Now().Add(1) testStruct.timeVal = &expectedTime - _, err := dbt.db.Exec("UPDATE timeStructTest SET tz = ? where id = ?", testStruct.timeVal, testStruct.id) + _, err := dbt.exec("UPDATE timeStructTest SET tz = ? where id = ?", testStruct.timeVal, testStruct.id) if err != nil { dbt.Fatal(err.Error()) } @@ -265,7 +262,7 @@ func TestBindingTimePtrInStruct(t *testing.T) { // SNOW-755844: Test the use of a time.Time type in user-defined structures to perform updates/inserts func TestBindingTimeInStruct(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { type timeStruct struct { id int timeVal time.Time @@ -278,7 +275,7 @@ func TestBindingTimeInStruct(t *testing.T) { runInsertQuery := false for i := 0; i < 2; i++ { if !runInsertQuery { - _, err := dbt.db.Exec("INSERT INTO timeStructTest(id,tz) VALUES(?, ?)", testStruct.id, testStruct.timeVal) + _, err := dbt.exec("INSERT INTO timeStructTest(id,tz) VALUES(?, ?)", testStruct.id, testStruct.timeVal) if err != nil { dbt.Fatal(err.Error()) } @@ -287,7 +284,7 @@ func TestBindingTimeInStruct(t *testing.T) { // Update row with a new time value expectedTime = time.Now().Add(1) testStruct.timeVal = expectedTime - _, err := dbt.db.Exec("UPDATE timeStructTest SET tz = ? where id = ?", testStruct.timeVal, testStruct.id) + _, err := dbt.exec("UPDATE timeStructTest SET tz = ? where id = ?", testStruct.timeVal, testStruct.id) if err != nil { dbt.Fatal(err.Error()) } @@ -311,14 +308,14 @@ func TestBindingTimeInStruct(t *testing.T) { } func TestBindingInterface(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQueryContext( WithHigherPrecision(context.Background()), selectVariousTypes) defer rows.Close() if !rows.Next() { dbt.Error("failed to query") } - var v1, v2, v3, v4, v5, v6 interface{} + var v1, v2, v3, v4, v5, v6 any if err := rows.Scan(&v1, &v2, &v3, &v4, &v5, &v6); err != nil { dbt.Errorf("failed to scan: %#v", err) } @@ -338,13 +335,13 @@ func TestBindingInterface(t *testing.T) { } func TestBindingInterfaceString(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQuery(selectVariousTypes) defer rows.Close() if !rows.Next() { dbt.Error("failed to query") } - var v1, v2, v3, v4, v5, v6 interface{} + var v1, v2, v3, v4, v5, v6 any if err := rows.Scan(&v1, &v2, &v3, &v4, &v5, &v6); err != nil { dbt.Errorf("failed to scan: %#v", err) } @@ -365,9 +362,9 @@ func TestBindingInterfaceString(t *testing.T) { } func TestBulkArrayBindingInterfaceNil(t *testing.T) { - nilArray := make([]interface{}, 1) + nilArray := make([]any, 1) - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { dbt.mustExec(createTableSQL) defer dbt.mustExec(deleteTableSQL) @@ -430,31 +427,35 @@ func TestBulkArrayBindingInterfaceNil(t *testing.T) { } func TestBulkArrayBindingInterface(t *testing.T) { - intArray := make([]interface{}, 3) + intArray := make([]any, 3) intArray[0] = int32(100) intArray[1] = int32(200) - fltArray := make([]interface{}, 3) + fltArray := make([]any, 3) fltArray[0] = float64(0.1) fltArray[2] = float64(5.678) - boolArray := make([]interface{}, 3) + boolArray := make([]any, 3) boolArray[1] = false boolArray[2] = true - strArray := make([]interface{}, 3) + strArray := make([]any, 3) strArray[2] = "test3" - byteArray := make([]interface{}, 3) + byteArray := make([]any, 3) byteArray[0] = []byte{0x01, 0x02, 0x03} byteArray[2] = []byte{0x07, 0x08, 0x09} - runTests(t, dsn, func(dbt *DBTest) { + int64Array := make([]any, 3) + int64Array[0] = int64(100) + int64Array[1] = int64(200) + + runDBTest(t, func(dbt *DBTest) { dbt.mustExec(createTableSQLBulkArray) defer dbt.mustExec(deleteTableSQLBulkArray) dbt.mustExec(insertSQLBulkArray, Array(&intArray), Array(&fltArray), - Array(&boolArray), Array(&strArray), Array(&byteArray)) + Array(&boolArray), Array(&strArray), Array(&byteArray), Array(&int64Array)) rows := dbt.mustQuery(selectAllSQLBulkArray) defer rows.Close() @@ -463,10 +464,11 @@ func TestBulkArrayBindingInterface(t *testing.T) { var v2 sql.NullBool var v3 sql.NullString var v4 []byte + var v5 sql.NullInt64 cnt := 0 for i := 0; rows.Next(); i++ { - if err := rows.Scan(&v0, &v1, &v2, &v3, &v4); err != nil { + if err := rows.Scan(&v0, &v1, &v2, &v3, &v4, &v5); err != nil { t.Fatal(err) } if v0.Valid { @@ -504,6 +506,13 @@ func TestBulkArrayBindingInterface(t *testing.T) { } else if v4 != nil { t.Fatalf("failed to fetch the []byte column v4. expected %v, got: %v", byteArray[i], v4) } + if v5.Valid { + if v5.Int64 != int64Array[i] { + t.Fatalf("failed to fetch the sql.NullInt64 column v5. expected %v, got: %v", int64Array[i], v5.Int64) + } + } else if int64Array[i] != nil { + t.Fatalf("failed to fetch the sql.NullInt64 column v5. expected %v, got: %v", int64Array[i], v5) + } cnt++ } if cnt != len(intArray) { @@ -521,27 +530,27 @@ func TestBulkArrayBindingInterfaceDateTimeTimestamp(t *testing.T) { if err != nil { t.Error(err) } - ntzArray := make([]interface{}, 3) + ntzArray := make([]any, 3) ntzArray[0] = now ntzArray[1] = now.Add(1) - ltzArray := make([]interface{}, 3) + ltzArray := make([]any, 3) ltzArray[1] = now.Add(2).In(loc) ltzArray[2] = now.Add(3).In(loc) - tzArray := make([]interface{}, 3) + tzArray := make([]any, 3) tzArray[0] = tz.Add(4).In(loc) tzArray[2] = tz.Add(5).In(loc) - dtArray := make([]interface{}, 3) + dtArray := make([]any, 3) dtArray[0] = tz.Add(6).In(loc) dtArray[1] = now.Add(7).In(loc) - tmArray := make([]interface{}, 3) + tmArray := make([]any, 3) tmArray[1] = now.Add(8).In(loc) tmArray[2] = now.Add(9).In(loc) - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { dbt.mustExec(createTableSQLBulkArrayDateTimeTimestamp) defer dbt.mustExec(deleteTableSQLBulkArrayDateTimeTimestamp) @@ -643,11 +652,11 @@ func testBindingArray(t *testing.T, bulk bool) { dtArray := []time.Time{now.Add(9), now.Add(10), now.Add(11)} tmArray := []time.Time{now.Add(12), now.Add(13), now.Add(14)} - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { dbt.mustExec(createTableSQL) defer dbt.mustExec(deleteTableSQL) if bulk { - if _, err := dbt.db.Exec("ALTER SESSION SET CLIENT_STAGE_ARRAY_BINDING_THRESHOLD = 1"); err != nil { + if _, err := dbt.exec("ALTER SESSION SET CLIENT_STAGE_ARRAY_BINDING_THRESHOLD = 1"); err != nil { t.Error(err) } } @@ -711,7 +720,7 @@ func testBindingArray(t *testing.T, bulk bool) { } func TestBulkArrayBinding(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { dbt.mustExec(fmt.Sprintf("create or replace table %v (c1 integer, c2 string)", dbname)) numRows := 100000 intArr := make([]int, numRows) @@ -722,6 +731,7 @@ func TestBulkArrayBinding(t *testing.T) { } dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?)", dbname), Array(&intArr), Array(&strArr)) rows := dbt.mustQuery("select * from " + dbname) + defer rows.Close() cnt := 0 var i int var s string @@ -755,7 +765,7 @@ func TestBulkArrayMultiPartBinding(t *testing.T) { tempTableName := fmt.Sprintf("test_table_%v", randomString(5)) ctx := context.Background() - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { dbt.mustExec(fmt.Sprintf("CREATE TABLE %s (C VARCHAR(64) NOT NULL)", tempTableName)) defer dbt.mustExec("drop table " + tempTableName) @@ -764,6 +774,7 @@ func TestBulkArrayMultiPartBinding(t *testing.T) { fmt.Sprintf("INSERT INTO %s VALUES (?)", tempTableName), Array(&randomStrings)) rows := dbt.mustQuery("select count(*) from " + tempTableName) + defer rows.Close() if rows.Next() { var count int if err := rows.Scan(&count); err != nil { @@ -773,6 +784,7 @@ func TestBulkArrayMultiPartBinding(t *testing.T) { } rows := dbt.mustQuery("select count(*) from " + tempTableName) + defer rows.Close() if rows.Next() { var count int if err := rows.Scan(&count); err != nil { @@ -786,7 +798,7 @@ func TestBulkArrayMultiPartBinding(t *testing.T) { } func TestBulkArrayMultiPartBindingInt(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { dbt.mustExec("create or replace table binding_test (c1 integer)") startNum := 1000000 endNum := 3000000 @@ -795,12 +807,13 @@ func TestBulkArrayMultiPartBindingInt(t *testing.T) { for i := startNum; i < endNum; i++ { intArr[i-startNum] = i } - _, err := dbt.db.Exec("insert into binding_test values (?)", Array(&intArr)) + _, err := dbt.exec("insert into binding_test values (?)", Array(&intArr)) if err != nil { t.Errorf("Should have succeeded to insert. err: %v", err) } rows := dbt.mustQuery("select * from binding_test order by c1") + defer rows.Close() cnt := startNum var i int for rows.Next() { @@ -820,15 +833,15 @@ func TestBulkArrayMultiPartBindingInt(t *testing.T) { } func TestBulkArrayMultiPartBindingWithNull(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { dbt.mustExec("create or replace table binding_test (c1 integer, c2 string)") startNum := 1000000 endNum := 2000000 numRows := endNum - startNum // Define the integer and string arrays - intArr := make([]interface{}, numRows) - stringArr := make([]interface{}, numRows) + intArr := make([]any, numRows) + stringArr := make([]any, numRows) for i := startNum; i < endNum; i++ { intArr[i-startNum] = i stringArr[i-startNum] = fmt.Sprint(i) @@ -842,12 +855,13 @@ func TestBulkArrayMultiPartBindingWithNull(t *testing.T) { stringArr[2] = nil stringArr[3] = nil - _, err := dbt.db.Exec("insert into binding_test values (?, ?)", Array(&intArr), Array(&stringArr)) + _, err := dbt.exec("insert into binding_test values (?, ?)", Array(&intArr), Array(&stringArr)) if err != nil { t.Errorf("Should have succeeded to insert. err: %v", err) } rows := dbt.mustQuery("select * from binding_test order by c1,c2") + defer rows.Close() cnt := startNum var i sql.NullInt32 var s sql.NullString @@ -879,3 +893,148 @@ func TestBulkArrayMultiPartBindingWithNull(t *testing.T) { dbt.mustExec("DROP TABLE binding_test") }) } + +func TestFunctionParameters(t *testing.T) { + testcases := []struct { + testDesc string + paramType string + input any + nullResult bool + }{ + {"textAndNullStringResultInNull", "text", sql.NullString{}, true}, + {"numberAndNullInt64ResultInNull", "number", sql.NullInt64{}, true}, + {"floatAndNullFloat64ResultInNull", "float", sql.NullFloat64{}, true}, + {"booleanAndAndNullBoolResultInNull", "boolean", sql.NullBool{}, true}, + {"dateAndTypedNullTimeResultInNull", "date", TypedNullTime{sql.NullTime{}, DateType}, true}, + {"datetimeAndTypedNullTimeResultInNull", "datetime", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true}, + {"timeAndTypedNullTimeResultInNull", "time", TypedNullTime{sql.NullTime{}, TimeType}, true}, + {"timestampAndTypedNullTimeResultInNull", "timestamp", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true}, + {"timestamp_ntzAndTypedNullTimeResultInNull", "timestamp_ntz", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true}, + {"timestamp_ltzAndTypedNullTimeResultInNull", "timestamp_ltz", TypedNullTime{sql.NullTime{}, TimestampLTZType}, true}, + {"timestamp_tzAndTypedNullTimeResultInNull", "timestamp_tz", TypedNullTime{sql.NullTime{}, TimestampTZType}, true}, + {"textAndStringResultInNotNull", "text", "string", false}, + {"numberAndIntegerResultInNotNull", "number", 123, false}, + {"floatAndFloatResultInNotNull", "float", 123.01, false}, + {"booleanAndBooleanResultInNotNull", "boolean", true, false}, + {"dateAndTimeResultInNotNull", "date", time.Now(), false}, + {"datetimeAndTimeResultInNotNull", "datetime", time.Now(), false}, + {"timeAndTimeResultInNotNull", "time", time.Now(), false}, + {"timestampAndTimeResultInNotNull", "timestamp", time.Now(), false}, + {"timestamp_ntzAndTimeResultInNotNull", "timestamp_ntz", time.Now(), false}, + {"timestamp_ltzAndTimeResultInNotNull", "timestamp_ltz", time.Now(), false}, + {"timestamp_tzAndTimeResultInNotNull", "timestamp_tz", time.Now(), false}, + } + + runDBTest(t, func(dbt *DBTest) { + for _, tc := range testcases { + t.Run(tc.testDesc, func(t *testing.T) { + query := fmt.Sprintf(` + CREATE OR REPLACE FUNCTION NULLPARAMFUNCTION("param1" %v) + RETURNS TABLE("r1" %v) + LANGUAGE SQL + AS 'select param1';`, tc.paramType, tc.paramType) + dbt.mustExec(query) + rows, err := dbt.query("select * from table(NULLPARAMFUNCTION(?))", tc.input) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + if rows.Err() != nil { + t.Fatal(err) + } + if !rows.Next() { + t.Fatal("no rows fetched") + } + var r1 any + err = rows.Scan(&r1) + if err != nil { + t.Fatal(err) + } + if tc.nullResult && r1 != nil { + t.Fatalf("the result for %v is of type %v but should be null", tc.paramType, reflect.TypeOf(r1)) + } + if !tc.nullResult && r1 == nil { + t.Fatalf("the result for %v should not be null", tc.paramType) + } + }) + } + }) +} + +func TestVariousBindingModes(t *testing.T) { + testcases := []struct { + testDesc string + paramType string + input any + isNil bool + }{ + {"textAndString", "text", "string", false}, + {"numberAndInteger", "number", 123, false}, + {"floatAndFloat", "float", 123.01, false}, + {"booleanAndBoolean", "boolean", true, false}, + {"dateAndTime", "date", time.Now().Truncate(24 * time.Hour), false}, + {"datetimeAndTime", "datetime", time.Now(), false}, + {"timeAndTime", "time", "12:34:56", false}, + {"timestampAndTime", "timestamp", time.Now(), false}, + {"timestamp_ntzAndTime", "timestamp_ntz", time.Now(), false}, + {"timestamp_ltzAndTime", "timestamp_ltz", time.Now(), false}, + {"timestamp_tzAndTime", "timestamp_tz", time.Now(), false}, + {"textAndNullString", "text", sql.NullString{}, true}, + {"numberAndNullInt64", "number", sql.NullInt64{}, true}, + {"floatAndNullFloat64", "float", sql.NullFloat64{}, true}, + {"booleanAndAndNullBool", "boolean", sql.NullBool{}, true}, + {"dateAndTypedNullTime", "date", TypedNullTime{sql.NullTime{}, DateType}, true}, + {"datetimeAndTypedNullTime", "datetime", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true}, + {"timeAndTypedNullTime", "time", TypedNullTime{sql.NullTime{}, TimeType}, true}, + {"timestampAndTypedNullTime", "timestamp", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true}, + {"timestamp_ntzAndTypedNullTime", "timestamp_ntz", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true}, + {"timestamp_ltzAndTypedNullTime", "timestamp_ltz", TypedNullTime{sql.NullTime{}, TimestampLTZType}, true}, + {"timestamp_tzAndTypedNullTime", "timestamp_tz", TypedNullTime{sql.NullTime{}, TimestampTZType}, true}, + } + + bindingModes := []struct { + param string + query string + transform func(any) any + }{ + { + param: "?", + transform: func(v any) any { return v }, + }, + { + param: ":1", + transform: func(v any) any { return v }, + }, + { + param: ":param", + transform: func(v any) any { return sql.Named("param", v) }, + }, + } + + runDBTest(t, func(dbt *DBTest) { + for _, tc := range testcases { + for _, bindingMode := range bindingModes { + t.Run(tc.testDesc+" "+bindingMode.param, func(t *testing.T) { + query := fmt.Sprintf(`CREATE OR REPLACE TABLE BINDING_MODES(param1 %v)`, tc.paramType) + dbt.mustExec(query) + if _, err := dbt.exec(fmt.Sprintf("INSERT INTO BINDING_MODES VALUES (%v)", bindingMode.param), bindingMode.transform(tc.input)); err != nil { + t.Fatal(err) + } + if tc.isNil { + query = "SELECT * FROM BINDING_MODES WHERE param1 IS NULL" + } else { + query = fmt.Sprintf("SELECT * FROM BINDING_MODES WHERE param1 = %v", bindingMode.param) + } + rows, err := dbt.query(query, bindingMode.transform(tc.input)) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + if !rows.Next() { + t.Fatal("Expected to return a row") + } + }) + } + } + }) +} diff --git a/ci/scripts/test_component.sh b/ci/scripts/test_component.sh index ec3700569..4829c40f3 100755 --- a/ci/scripts/test_component.sh +++ b/ci/scripts/test_component.sh @@ -12,5 +12,4 @@ if [[ -n "$GITHUB_WORKFLOW" ]]; then fi env | grep SNOWFLAKE | grep -v PASS | sort cd $TOPDIR -go test -timeout 30m -race $COVFLAGS -v . - +go test -timeout 30m -race -coverprofile=coverage.txt -covermode=atomic -v . diff --git a/cmd/showparam/showparam.go b/cmd/showparam/showparam.go index ce77261cd..e29713a08 100644 --- a/cmd/showparam/showparam.go +++ b/cmd/showparam/showparam.go @@ -6,62 +6,31 @@ import ( "flag" "fmt" "log" - "os" - "strconv" sf "github.com/snowflakedb/gosnowflake" ) -// getDSN constructs a DSN based on the test connection parameters -func getDSN(params map[string]*string) (string, *sf.Config, error) { - env := func(k string, failOnMissing bool) string { - if value := os.Getenv(k); value != "" { - return value - } - if failOnMissing { - log.Fatalf("%v environment variable is not set.", k) - } - return "" - } - - account := env("SNOWFLAKE_TEST_ACCOUNT", true) - user := env("SNOWFLAKE_TEST_USER", true) - password := env("SNOWFLAKE_TEST_PASSWORD", true) - host := env("SNOWFLAKE_TEST_HOST", false) - port := env("SNOWFLAKE_TEST_PORT", false) - protocol := env("SNOWFLAKE_TEST_PROTOCOL", false) - - portStr, err := strconv.Atoi(port) - if err != nil { - return "", nil, err - } - cfg := &sf.Config{ - Account: account, - User: user, - Password: password, - Host: host, - Port: portStr, - Protocol: protocol, - Params: params, - } - - dsn, err := sf.DSN(cfg) - return dsn, cfg, err -} - func main() { if !flag.Parsed() { flag.Parse() } - tmfmt := "MM-DD-YYYY" - dsn, cfg, err := getDSN( - map[string]*string{ - "TIMESTAMP_OUTPUT_FORMAT": &tmfmt, // session parameter - }) + cfg, err := sf.GetConfigFromEnv([]*sf.ConfigParam{ + {Name: "Account", EnvName: "SNOWFLAKE_TEST_ACCOUNT", FailOnMissing: true}, + {Name: "User", EnvName: "SNOWFLAKE_TEST_USER", FailOnMissing: true}, + {Name: "Password", EnvName: "SNOWFLAKE_TEST_PASSWORD", FailOnMissing: true}, + {Name: "Host", EnvName: "SNOWFLAKE_TEST_HOST", FailOnMissing: false}, + {Name: "Port", EnvName: "SNOWFLAKE_TEST_PORT", FailOnMissing: false}, + {Name: "Protocol", EnvName: "SNOWFLAKE_TEST_PROTOCOL", FailOnMissing: false}, + }) if err != nil { - log.Fatalf("failed to create DSN from Config: %v, err: %v", cfg, err) + log.Fatalf("failed to create Config, err: %v", err) + } + tmfmt := "MM-DD-YYYY" + cfg.Params = map[string]*string{ + "TIMESTAMP_OUTPUT_FORMAT": &tmfmt, // session parameter } + dsn, err := sf.DSN(cfg) if err != nil { log.Fatalf("failed to create DSN from Config: %v, err: %v", cfg, err) } diff --git a/connection.go b/connection.go index 7a172f812..a81747d9b 100644 --- a/connection.go +++ b/connection.go @@ -60,16 +60,17 @@ const ( const privateLinkSuffix = "privatelink.snowflakecomputing.com" type snowflakeConn struct { - ctx context.Context - cfg *Config - rest *snowflakeRestful - restMu sync.RWMutex // guard shutdown race - SequenceCounter uint64 - QueryID string - SQLState string - telemetry *snowflakeTelemetry - internal InternalClient - execRespCache *execRespCache + ctx context.Context + cfg *Config + rest *snowflakeRestful + restMu sync.RWMutex + SequenceCounter uint64 + QueryID string + SQLState string + telemetry *snowflakeTelemetry + internal InternalClient + execRespCache *execRespCache + queryContextCache *queryContextCache } var ( @@ -143,9 +144,12 @@ func (sc *snowflakeConn) exec( } logger.WithContext(ctx).Infof("Success: %v, Code: %v", data.Success, code) if !data.Success { - return nil, (populateErrorFields(code, data)).exceptionTelemetry(sc) + err = (populateErrorFields(code, data)).exceptionTelemetry(sc) + return nil, err } + sc.queryContextCache.add(data.Data.QueryContext.Entries...) + // handle PUT/GET commands if isFileTransfer(query) { data, err = sc.processFileTransfer(ctx, data, query, isInternal) @@ -159,8 +163,6 @@ func (sc *snowflakeConn) exec( sc.cfg.Schema = data.Data.FinalSchemaName sc.cfg.Role = data.Data.FinalRoleName sc.cfg.Warehouse = data.Data.FinalWarehouseName - sc.QueryID = data.Data.QueryID - sc.SQLState = data.Data.SQLState sc.populateSessionParameters(data.Data.Parameters) return data, err } @@ -196,7 +198,7 @@ func (sc *snowflakeConn) BeginTx( false /* isInternal */, isDesc, nil); err != nil { return nil, err } - return &snowflakeTx{sc}, nil + return &snowflakeTx{sc, ctx}, nil } func (sc *snowflakeConn) cleanup() { @@ -219,7 +221,7 @@ func (sc *snowflakeConn) Close() (err error) { sc.stopHeartBeat() defer sc.cleanup() - if !sc.cfg.KeepSessionAlive { + if sc.cfg != nil && !sc.cfg.KeepSessionAlive { if err = sc.rest.FuncCloseSession(sc.ctx, sc.rest, sc.rest.RequestTimeout); err != nil { logger.Error(err) } @@ -259,9 +261,9 @@ func (sc *snowflakeConn) ExecContext( if err != nil { logger.WithContext(ctx).Infof("error: %v", err) if data != nil { - code, err := strconv.Atoi(data.Code) - if err != nil { - return nil, err + code, e := strconv.Atoi(data.Code) + if e != nil { + return nil, e } return nil, (&SnowflakeError{ Number: code, @@ -288,11 +290,10 @@ func (sc *snowflakeConn) ExecContext( rows := &snowflakeResult{ affectedRows: updatedRows, insertID: -1, - queryID: sc.QueryID, + queryID: data.Data.QueryID, } // last insert id is not supported by Snowflake - rows.monitoring = mkMonitoringFetcher(sc, sc.QueryID, time.Since(qStart)) - + rows.monitoring = mkMonitoringFetcher(sc, data.Data.QueryID, time.Since(qStart)) return rows, nil } else if isMultiStmt(&data.Data) { rows, err := sc.handleMultiExec(ctx, data.Data) @@ -348,9 +349,9 @@ func (sc *snowflakeConn) queryContextInternal( if err != nil { logger.WithContext(ctx).Errorf("error: %v", err) if data != nil { - code, err := strconv.Atoi(data.Code) - if err != nil { - return nil, err + code, e := strconv.Atoi(data.Code) + if e != nil { + return nil, e } return nil, (&SnowflakeError{ Number: code, @@ -369,8 +370,8 @@ func (sc *snowflakeConn) queryContextInternal( rows := new(snowflakeRows) rows.sc = sc - rows.queryID = sc.QueryID - rows.monitoring = mkMonitoringFetcher(sc, sc.QueryID, time.Since(qStart)) + rows.queryID = data.Data.QueryID + rows.monitoring = mkMonitoringFetcher(sc, data.Data.QueryID, time.Since(qStart)) if isSubmitSync(ctx) && data.Code == queryInProgressCode { rows.status = QueryStatusInProgress @@ -442,10 +443,10 @@ func (sc *snowflakeConn) CheckNamedValue(nv *driver.NamedValue) error { // distinguish them from arguments of type []byte return nil } - if supported := supportedArrayBind(nv); !supported { - return driver.ErrSkip + if supportedNullBind(nv) || supportedArrayBind(nv) { + return nil } - return nil + return driver.ErrSkip } func (sc *snowflakeConn) GetQueryStatus( @@ -479,9 +480,9 @@ func (sc *snowflakeConn) QueryArrowStream(ctx context.Context, query string, bin if err != nil { logger.WithContext(ctx).Errorf("error: %v", err) if data != nil { - code, err := strconv.Atoi(data.Code) - if err != nil { - return nil, err + code, e := strconv.Atoi(data.Code) + if e != nil { + return nil, e } return nil, (&SnowflakeError{ Number: code, @@ -619,11 +620,18 @@ func (asb *ArrowStreamBatch) GetStream(ctx context.Context) (io.ReadCloser, erro // ArrowStreamLoader is a convenience interface for downloading // Snowflake results via multiple Arrow Record Batch streams. +// +// Some queries from Snowflake do not return Arrow data regardless +// of the settings, such as "SHOW WAREHOUSES". In these cases, +// you'll find TotalRows() > 0 but GetBatches returns no batches +// and no errors. In this case, the data is accessible via JSONData +// with the actual types matching up to the metadata in RowTypes. type ArrowStreamLoader interface { GetBatches() ([]ArrowStreamBatch, error) TotalRows() int64 RowTypes() []execResponseRowType Location() *time.Location + JSONData() [][]*string } type snowflakeArrowStreamChunkDownloader struct { @@ -646,6 +654,9 @@ func (scd *snowflakeArrowStreamChunkDownloader) TotalRows() int64 { return scd.T func (scd *snowflakeArrowStreamChunkDownloader) RowTypes() []execResponseRowType { return scd.RowSet.RowType } +func (scd *snowflakeArrowStreamChunkDownloader) JSONData() [][]*string { + return scd.RowSet.JSON +} // the server might have had an empty first batch, check if we can decode // that first batch, if not we skip it. @@ -714,9 +725,10 @@ func (scd *snowflakeArrowStreamChunkDownloader) GetBatches() (out []ArrowStreamB func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, error) { sc := &snowflakeConn{ - SequenceCounter: 0, - ctx: ctx, - cfg: &config, + SequenceCounter: 0, + ctx: ctx, + cfg: &config, + queryContextCache: (&queryContextCache{}).init(), } var st http.RoundTripper = SnowflakeTransport if sc.cfg.Transporter == nil { @@ -765,11 +777,16 @@ func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, err Timeout: sc.cfg.ClientTimeout, Transport: st, }, + JWTClient: &http.Client{ + Timeout: sc.cfg.JWTClientTimeout, + Transport: st, + }, TokenAccessor: tokenAccessor, LoginTimeout: sc.cfg.LoginTimeout, RequestTimeout: sc.cfg.RequestTimeout, FuncPost: postRestful, FuncGet: getRestful, + FuncAuthPost: postAuthRestful, FuncPostQuery: postRestfulQuery, FuncPostQueryHelper: postRestfulQueryHelper, FuncRenewSession: renewRestfulSession, diff --git a/connection_util.go b/connection_util.go index a85ffea22..f9dad37e2 100644 --- a/connection_util.go +++ b/connection_util.go @@ -24,20 +24,22 @@ func (sc *snowflakeConn) isClientSessionKeepAliveEnabled() bool { } func (sc *snowflakeConn) startHeartBeat() { - if !sc.isClientSessionKeepAliveEnabled() { + if sc.cfg != nil && !sc.isClientSessionKeepAliveEnabled() { return } - sc.rest.HeartBeat = &heartbeat{ - restful: sc.rest, + if sc.rest != nil { + sc.rest.HeartBeat = &heartbeat{ + restful: sc.rest, + } + sc.rest.HeartBeat.start() } - sc.rest.HeartBeat.start() } func (sc *snowflakeConn) stopHeartBeat() { - if !sc.isClientSessionKeepAliveEnabled() { + if sc.cfg != nil && !sc.isClientSessionKeepAliveEnabled() { return } - if sc.rest.HeartBeat != nil { + if sc.rest != nil && sc.rest.HeartBeat != nil { sc.rest.HeartBeat.stop() } } diff --git a/converter.go b/converter.go index 19ac55499..4070d5fc5 100644 --- a/converter.go +++ b/converter.go @@ -4,6 +4,7 @@ package gosnowflake import ( "context" + "database/sql" "database/sql/driver" "encoding/hex" "fmt" @@ -57,17 +58,18 @@ func isInterfaceArrayBinding(t interface{}) bool { // goTypeToSnowflake translates Go data type to Snowflake data type. func goTypeToSnowflake(v driver.Value, dataType SnowflakeDataType) snowflakeType { + // (raj) This will fail build. Reconcile after merge if dataType == nil { switch t := v.(type) { case SnowflakeDataType: return changeType - case int64: + case int64, sql.NullInt64: return fixedType - case float64: + case float64, sql.NullFloat64: return realType - case bool: + case bool, sql.NullBool: return booleanType - case string: + case string, sql.NullString: return textType case []byte: if t == nil { @@ -75,7 +77,7 @@ func goTypeToSnowflake(v driver.Value, dataType SnowflakeDataType) snowflakeType } // If we don't have an explicit data type, binary blobs are unsupported return unSupportedType - case time.Time: + case time.Time, sql.NullTime: // Default timestamp type return timestampNtzType } @@ -151,37 +153,71 @@ func valueToString(v driver.Value, dataType SnowflakeDataType) (*string, error) s := v1.String() return &s, nil case reflect.Struct: - if tm, ok := v.(time.Time); ok { - switch { - case dataType.Equals(DataTypeDate): - _, offset := tm.Zone() - tm = tm.Add(time.Second * time.Duration(offset)) - s := strconv.FormatInt(tm.Unix()*1000, 10) - return &s, nil - case dataType.Equals(DataTypeTime): - s := fmt.Sprintf("%d", - (tm.Hour()*3600+tm.Minute()*60+tm.Second())*1e9+tm.Nanosecond()) - return &s, nil - case dataType.Equals(DataTypeTimestampNtz) || dataType.Equals(DataTypeTimestampLtz) || dataType == nil: - // NOTE(greg): when the client has not given us an explicit dataType - // (dataType == nil), we assume DataTypeTimestampNtz for compatibility - // with the upstream driver - unixTime, _ := new(big.Int).SetString(fmt.Sprintf("%d", tm.Unix()), 10) - m, _ := new(big.Int).SetString(strconv.FormatInt(1e9, 10), 10) - unixTime.Mul(unixTime, m) - tmNanos, _ := new(big.Int).SetString(fmt.Sprintf("%d", tm.Nanosecond()), 10) - s := unixTime.Add(unixTime, tmNanos).String() - return &s, nil - case dataType.Equals(DataTypeTimestampTz): - _, offset := tm.Zone() - s := fmt.Sprintf("%v %v", tm.UnixNano(), offset/60+1440) - return &s, nil + switch typedVal := v.(type) { + case time.Time: + return timeTypeValueToString(typedVal, dataType) + case sql.NullTime: + if !typedVal.Valid { + return nil, nil } + return timeTypeValueToString(typedVal.Time, dataType) + case sql.NullBool: + if !typedVal.Valid { + return nil, nil + } + s := strconv.FormatBool(typedVal.Bool) + return &s, nil + case sql.NullInt64: + if !typedVal.Valid { + return nil, nil + } + s := strconv.FormatInt(typedVal.Int64, 10) + return &s, nil + case sql.NullFloat64: + if !typedVal.Valid { + return nil, nil + } + s := strconv.FormatFloat(typedVal.Float64, 'g', -1, 32) + return &s, nil + case sql.NullString: + if !typedVal.Valid { + return nil, nil + } + return &typedVal.String, nil } } return nil, fmt.Errorf("unsupported type: %v", v1.Kind()) } +func timeTypeValueToString(tm time.Time, dataType SnowflakeDataType) (*string, error) { + switch { + case dataType.Equals(DataTypeDate): + _, offset := tm.Zone() + tm = tm.Add(time.Second * time.Duration(offset)) + s := strconv.FormatInt(tm.Unix()*1000, 10) + return &s, nil + case dataType.Equals(DataTypeTime): + s := fmt.Sprintf("%d", + (tm.Hour()*3600+tm.Minute()*60+tm.Second())*1e9+tm.Nanosecond()) + return &s, nil + case dataType.Equals(DataTypeTimestampNtz) || dataType.Equals(DataTypeTimestampLtz) || dataType == nil: + // NOTE(greg): when the client has not given us an explicit dataType + // (dataType == nil), we assume DataTypeTimestampNtz for compatibility + // with the upstream driver + unixTime, _ := new(big.Int).SetString(fmt.Sprintf("%d", tm.Unix()), 10) + m, _ := new(big.Int).SetString(strconv.FormatInt(1e9, 10), 10) + unixTime.Mul(unixTime, m) + tmNanos, _ := new(big.Int).SetString(fmt.Sprintf("%d", tm.Nanosecond()), 10) + s := unixTime.Add(unixTime, tmNanos).String() + return &s, nil + case dataType.Equals(DataTypeTimestampTz): + _, offset := tm.Zone() + s := fmt.Sprintf("%v %v", tm.UnixNano(), offset/60+1440) + return &s, nil + } + return nil, fmt.Errorf("unsupported time type: %v", dataType) +} + // extractTimestamp extracts the internal timestamp data to epoch time in seconds and milliseconds func extractTimestamp(srcValue *string) (sec int64, nsec int64, err error) { logger.Debugf("SRC: %v", srcValue) @@ -1207,3 +1243,26 @@ func recordToSchema(sc *arrow.Schema, rowType []execResponseRowType, loc *time.L meta := sc.Metadata() return arrow.NewSchema(fields, &meta), nil } + +// TypedNullTime is required to properly bind the null value with the snowflakeType as the Snowflake functions +// require the type of the field to be provided explicitly for the null values +type TypedNullTime struct { + Time sql.NullTime + TzType timezoneType +} + +func convertTzTypeToSnowflakeType(tzType timezoneType) SnowflakeDataType { + switch tzType { + case TimestampNTZType: + return DataTypeTimestampNtz + case TimestampLTZType: + return DataTypeTimestampLtz + case TimestampTZType: + return DataTypeTimestampTz + case DateType: + return DataTypeDate + case TimeType: + return DataTypeTime + } + return nil +} diff --git a/converter_test.go b/converter_test.go index 9d3b2c3ad..21167c818 100644 --- a/converter_test.go +++ b/converter_test.go @@ -4,12 +4,15 @@ package gosnowflake import ( "context" + "database/sql" "database/sql/driver" "fmt" "io" + "math" "math/big" "math/cmplx" "reflect" + "strings" "testing" "time" @@ -45,6 +48,21 @@ func stringFloatToDecimal(src string, scale int64) (decimal128.Num, bool) { return decimal128.New(high.Int64(), low.Uint64()), ok } +func stringFloatToInt(src string, scale int64) (int64, bool) { + b, ok := new(big.Float).SetString(src) + if !ok { + return 0, ok + } + s := new(big.Float).SetInt(new(big.Int).Exp(big.NewInt(10), big.NewInt(scale), nil)) + n := new(big.Float).Mul(b, s) + var z big.Int + n.Int(&z) + if !z.IsInt64() { + return 0, false + } + return z.Int64(), true +} + type tcGoTypeToSnowflake struct { in interface{} tmode SnowflakeDataType @@ -90,10 +108,12 @@ func TestGoTypeToSnowflake(t *testing.T) { {in: []int{1}, tmode: nil, out: unSupportedType}, } for _, test := range testcases { - a := goTypeToSnowflake(test.in, test.tmode) - if a != test.out { - t.Errorf("failed. in: %v, tmode: %v, expected: %v, got: %v", test.in, test.tmode, test.out, a) - } + t.Run(fmt.Sprintf("%v_%v_%v", test.in, test.out, test.tmode), func(t *testing.T) { + a := goTypeToSnowflake(test.in, test.tmode) + if a != test.out { + t.Errorf("failed. in: %v, tmode: %v, expected: %v, got: %v", test.in, test.tmode, test.out, a) + } + }) } } @@ -119,13 +139,16 @@ func TestSnowflakeTypeToGo(t *testing.T) { {in: arrayType, scale: 0, out: reflect.TypeOf("")}, {in: binaryType, scale: 0, out: reflect.TypeOf([]byte{})}, {in: booleanType, scale: 0, out: reflect.TypeOf(true)}, + {in: sliceType, scale: 0, out: reflect.TypeOf("")}, } for _, test := range testcases { - a := snowflakeTypeToGo(test.in, test.scale) - if a != test.out { - t.Errorf("failed. in: %v, scale: %v, expected: %v, got: %v", - test.in, test.scale, test.out, a) - } + t.Run(fmt.Sprintf("%v_%v", test.in, test.out), func(t *testing.T) { + a := snowflakeTypeToGo(test.in, test.scale) + if a != test.out { + t.Errorf("failed. in: %v, scale: %v, expected: %v, got: %v", + test.in, test.scale, test.out, a) + } + }) } } @@ -140,6 +163,10 @@ func TestValueToString(t *testing.T) { localTime := time.Date(2019, 2, 6, 14, 17, 31, 123456789, time.FixedZone("-08:00", -8*3600)) utcTime := time.Date(2019, 2, 6, 22, 17, 31, 123456789, time.UTC) expectedUnixTime := "1549491451123456789" // time.Unix(1549491451, 123456789).Format(time.RFC3339) == "2019-02-06T14:17:31-08:00" + expectedBool := "true" + expectedInt64 := "1" + expectedFloat64 := "1.1" + expectedString := "teststring" if s, err := valueToString(localTime, DataTypeTimestampLtz); err != nil { t.Error("unexpected error") @@ -156,10 +183,42 @@ func TestValueToString(t *testing.T) { } else if *s != expectedUnixTime { t.Errorf("expected '%v', got '%v'", expectedUnixTime, *s) } + + if s, err := valueToString(sql.NullBool{Bool: true, Valid: true}, DataTypeTimestampLtz); err != nil { + t.Error("unexpected error") + } else if s == nil { + t.Errorf("expected '%v', got %v", expectedBool, s) + } else if *s != expectedBool { + t.Errorf("expected '%v', got '%v'", expectedBool, *s) + } + + if s, err := valueToString(sql.NullInt64{Int64: 1, Valid: true}, DataTypeTimestampLtz); err != nil { + t.Error("unexpected error") + } else if s == nil { + t.Errorf("expected '%v', got %v", expectedInt64, s) + } else if *s != expectedInt64 { + t.Errorf("expected '%v', got '%v'", expectedInt64, *s) + } + + if s, err := valueToString(sql.NullFloat64{Float64: 1.1, Valid: true}, DataTypeTimestampLtz); err != nil { + t.Error("unexpected error") + } else if s == nil { + t.Errorf("expected '%v', got %v", expectedFloat64, s) + } else if *s != expectedFloat64 { + t.Errorf("expected '%v', got '%v'", expectedFloat64, *s) + } + + if s, err := valueToString(sql.NullString{String: "teststring", Valid: true}, DataTypeTimestampLtz); err != nil { + t.Error("unexpected error") + } else if s == nil { + t.Errorf("expected '%v', got %v", expectedString, s) + } else if *s != expectedString { + t.Errorf("expected '%v', got '%v'", expectedString, *s) + } } func TestExtractTimestamp(t *testing.T) { - s := "1234abcdef" + s := "1234abcdef" // pragma: allowlist secret _, _, err := extractTimestamp(&s) if err == nil { t.Errorf("should raise error: %v", s) @@ -188,12 +247,14 @@ func TestStringToValue(t *testing.T) { } for _, tt := range types { - rowType = &execResponseRowType{ - Type: tt, - } - if err = stringToValue(&dest, *rowType, &source, nil); err == nil { - t.Errorf("should raise error. type: %v, value:%v", tt, source) - } + t.Run(tt, func(t *testing.T) { + rowType = &execResponseRowType{ + Type: tt, + } + if err = stringToValue(&dest, *rowType, &source, nil); err == nil { + t.Errorf("should raise error. type: %v, value:%v", tt, source) + } + }) } sources := []string{ @@ -207,12 +268,14 @@ func TestStringToValue(t *testing.T) { for _, ss := range sources { for _, tt := range types { - rowType = &execResponseRowType{ - Type: tt, - } - if err = stringToValue(&dest, *rowType, &ss, nil); err == nil { - t.Errorf("should raise error. type: %v, value:%v", tt, source) - } + t.Run(ss+tt, func(t *testing.T) { + rowType = &execResponseRowType{ + Type: tt, + } + if err = stringToValue(&dest, *rowType, &ss, nil); err == nil { + t.Errorf("should raise error. type: %v, value:%v", tt, source) + } + }) } } @@ -235,21 +298,25 @@ type tcArrayToString struct { func TestArrayToString(t *testing.T) { testcases := []tcArrayToString{ {in: driver.NamedValue{Value: &intArray{1, 2}}, typ: fixedType, out: []string{"1", "2"}}, + {in: driver.NamedValue{Value: &int32Array{1, 2}}, typ: fixedType, out: []string{"1", "2"}}, {in: driver.NamedValue{Value: &int64Array{3, 4, 5}}, typ: fixedType, out: []string{"3", "4", "5"}}, {in: driver.NamedValue{Value: &float64Array{6.7}}, typ: realType, out: []string{"6.7"}}, + {in: driver.NamedValue{Value: &float32Array{1.5}}, typ: realType, out: []string{"1.5"}}, {in: driver.NamedValue{Value: &boolArray{true, false}}, typ: booleanType, out: []string{"true", "false"}}, {in: driver.NamedValue{Value: &stringArray{"foo", "bar", "baz"}}, typ: textType, out: []string{"foo", "bar", "baz"}}, } for _, test := range testcases { - s, a := snowflakeArrayToString(&test.in, false) - if s != test.typ { - t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.typ, s) - } - for i, v := range a { - if *v != test.out[i] { - t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.out[i], a) + t.Run(strings.Join(test.out, "_"), func(t *testing.T) { + s, a := snowflakeArrayToString(&test.in, false) + if s != test.typ { + t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.typ, s) } - } + for i, v := range a { + if *v != test.out[i] { + t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.out[i], a) + } + } + }) } } @@ -272,20 +339,86 @@ func TestArrowToValue(t *testing.T) { } for _, tc := range []struct { - logical string - physical string - rowType execResponseRowType - values interface{} - builder array.Builder - append func(b array.Builder, vs interface{}) - compare func(src interface{}, dst []snowflakeValue) int + logical string + physical string + rowType execResponseRowType + values interface{} + builder array.Builder + append func(b array.Builder, vs interface{}) + compare func(src interface{}, dst []snowflakeValue) int + higherPrecision bool }{ + { + logical: "fixed", + physical: "number", // default: number(38, 0) + values: []int64{1, 2}, + builder: array.NewInt64Builder(pool), + append: func(b array.Builder, vs interface{}) { b.(*array.Int64Builder).AppendValues(vs.([]int64), valids) }, + higherPrecision: true, + }, + { + logical: "fixed", + physical: "number(38,5)", + rowType: execResponseRowType{Scale: 5}, + values: []string{"1.05430", "2.08983"}, + builder: array.NewInt64Builder(pool), + append: func(b array.Builder, vs interface{}) { + for _, s := range vs.([]string) { + num, ok := stringFloatToInt(s, 5) + if !ok { + t.Fatalf("failed to convert to int") + } + b.(*array.Int64Builder).Append(num) + } + }, + compare: func(src interface{}, dst []snowflakeValue) int { + srcvs := src.([]string) + for i := range srcvs { + num, ok := stringFloatToInt(srcvs[i], 5) + if !ok { + return i + } + srcDec := intToBigFloat(num, 5) + dstDec := dst[i].(*big.Float) + if srcDec.Cmp(dstDec) != 0 { + return i + } + } + return -1 + }, + higherPrecision: true, + }, { logical: "fixed", - physical: "number", // default: number(38, 0) - values: []int64{1, 2}, + physical: "number(38,5)", + rowType: execResponseRowType{Scale: 5}, + values: []string{"1.05430", "2.08983"}, builder: array.NewInt64Builder(pool), - append: func(b array.Builder, vs interface{}) { b.(*array.Int64Builder).AppendValues(vs.([]int64), valids) }, + append: func(b array.Builder, vs interface{}) { + for _, s := range vs.([]string) { + num, ok := stringFloatToInt(s, 5) + if !ok { + t.Fatalf("failed to convert to int") + } + b.(*array.Int64Builder).Append(num) + } + }, + compare: func(src interface{}, dst []snowflakeValue) int { + srcvs := src.([]string) + for i := range srcvs { + num, ok := stringFloatToInt(srcvs[i], 5) + if !ok { + return i + } + srcDec := fmt.Sprintf("%.*f", 5, float64(num)/math.Pow10(int(5))) + dstDec := dst[i] + if srcDec != dstDec { + return i + } + } + return -1 + }, + higherPrecision: false, }, { logical: "fixed", @@ -316,6 +449,7 @@ func TestArrowToValue(t *testing.T) { } return -1 }, + higherPrecision: true, }, { logical: "fixed", @@ -347,6 +481,7 @@ func TestArrowToValue(t *testing.T) { } return -1 }, + higherPrecision: true, }, { logical: "fixed", @@ -363,6 +498,7 @@ func TestArrowToValue(t *testing.T) { } return -1 }, + higherPrecision: true, }, { logical: "fixed", @@ -379,6 +515,7 @@ func TestArrowToValue(t *testing.T) { } return -1 }, + higherPrecision: true, }, { logical: "fixed", @@ -395,13 +532,79 @@ func TestArrowToValue(t *testing.T) { } return -1 }, + higherPrecision: true, }, { logical: "fixed", - physical: "int64", - values: []int64{1, 2}, - builder: array.NewInt64Builder(pool), - append: func(b array.Builder, vs interface{}) { b.(*array.Int64Builder).AppendValues(vs.([]int64), valids) }, + physical: "int32", + values: []string{"1.23456", "2.34567"}, + rowType: execResponseRowType{Scale: 5}, + builder: array.NewInt32Builder(pool), + append: func(b array.Builder, vs interface{}) { + for _, s := range vs.([]string) { + num, ok := stringFloatToInt(s, 5) + if !ok { + t.Fatalf("failed to convert to int") + } + b.(*array.Int32Builder).Append(int32(num)) + } + }, + compare: func(src interface{}, dst []snowflakeValue) int { + srcvs := src.([]string) + for i := range srcvs { + num, ok := stringFloatToInt(srcvs[i], 5) + if !ok { + return i + } + srcDec := intToBigFloat(num, 5) + dstDec := dst[i].(*big.Float) + if srcDec.Cmp(dstDec) != 0 { + return i + } + } + return -1 + }, + higherPrecision: true, + }, + { + logical: "fixed", + physical: "int32", + values: []string{"1.23456", "2.34567"}, + rowType: execResponseRowType{Scale: 5}, + builder: array.NewInt32Builder(pool), + append: func(b array.Builder, vs interface{}) { + for _, s := range vs.([]string) { + num, ok := stringFloatToInt(s, 5) + if !ok { + t.Fatalf("failed to convert to int") + } + b.(*array.Int32Builder).Append(int32(num)) + } + }, + compare: func(src interface{}, dst []snowflakeValue) int { + srcvs := src.([]string) + for i := range srcvs { + num, ok := stringFloatToInt(srcvs[i], 5) + if !ok { + return i + } + srcDec := fmt.Sprintf("%.*f", 5, float64(num)/math.Pow10(int(5))) + dstDec := dst[i] + if srcDec != dstDec { + return i + } + } + return -1 + }, + higherPrecision: false, + }, + { + logical: "fixed", + physical: "int64", + values: []int64{1, 2}, + builder: array.NewInt64Builder(pool), + append: func(b array.Builder, vs interface{}) { b.(*array.Int64Builder).AppendValues(vs.([]int64), valids) }, + higherPrecision: true, }, { logical: "boolean", @@ -458,6 +661,7 @@ func TestArrowToValue(t *testing.T) { } return -1 }, + higherPrecision: true, }, { logical: "timestamp_ntz", @@ -614,7 +818,9 @@ func TestArrowToValue(t *testing.T) { meta := tc.rowType meta.Type = tc.logical - if err := arrowToValue(dest, meta, arr, localTime.Location(), true); err != nil { + withHigherPrecision := tc.higherPrecision + + if err := arrowToValue(dest, meta, arr, localTime.Location(), withHigherPrecision); err != nil { t.Fatalf("error: %s", err) } @@ -1034,3 +1240,39 @@ func TestLargeTimestampBinding(t *testing.T) { } } } + +func TestTimeTypeValueToString(t *testing.T) { + timeValue, err := time.Parse("2006-01-02 15:04:05", "2020-01-02 10:11:12") + if err != nil { + t.Fatal(err) + } + offsetTimeValue, err := time.ParseInLocation("2006-01-02 15:04:05", "2020-01-02 10:11:12", Location(6*60)) + if err != nil { + t.Fatal(err) + } + + testcases := []struct { + in time.Time + dataType SnowflakeDataType + out string + }{ + {timeValue, DataTypeDate, "1577959872000"}, + {timeValue, DataTypeTime, "36672000000000"}, + {timeValue, DataTypeTimestampNtz, "1577959872000000000"}, + {timeValue, DataTypeTimestampLtz, "1577959872000000000"}, + {timeValue, DataTypeTimestampTz, "1577959872000000000 1440"}, + {offsetTimeValue, DataTypeTimestampTz, "1577938272000000000 1800"}, + } + + for _, tc := range testcases { + t.Run(tc.out, func(t *testing.T) { + output, err := timeTypeValueToString(tc.in, tc.dataType) + if err != nil { + t.Error(err) + } + if strings.Compare(tc.out, *output) != 0 { + t.Errorf("failed to convert time %v of type %v. expected: %v, received: %v", tc.in, tc.dataType, tc.out, *output) + } + }) + } +} diff --git a/datatype.go b/datatype.go index b419e9531..07ab9bc73 100644 --- a/datatype.go +++ b/datatype.go @@ -86,17 +86,6 @@ func (dt SnowflakeDataType) Equals(o SnowflakeDataType) bool { return bytes.Equal(([]byte)(dt), ([]byte)(o)) } -// SnowflakeDataType is the type used by clients to explicitly indicate the type -// of an argument to ExecContext and friends. We use a separate public-facing -// type rather than a Go primitive type so that we can always differentiate -// between args that indicate type and args that are values. -type SnowflakeDataType []byte - -// Equals checks if dt and o represent the same type indicator -func (dt SnowflakeDataType) Equals(o SnowflakeDataType) bool { - return bytes.Equal(([]byte)(dt), ([]byte)(o)) -} - var ( // DataTypeFixed is a FIXED datatype. DataTypeFixed = SnowflakeDataType{fixedType.Byte()} diff --git a/datatype_test.go b/datatype_test.go index aa7763779..65f59fd9e 100644 --- a/datatype_test.go +++ b/datatype_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. +// Copyright (c) 2017-2023 Snowflake Computing Inc. All rights reserved. package gosnowflake @@ -33,18 +33,34 @@ func TestClientTypeToInternal(t *testing.T) { err: fmt.Errorf(errMsgInvalidByteArray, nil)}, } for _, ts := range testcases { - tmode, err := clientTypeToInternal(ts.tp) - if ts.err == nil { - if err != nil { - t.Errorf("failed to get datatype mode: %v", err) - } - if tmode != ts.tmode { - t.Errorf("wrong data type: %v", tmode) - } - } else { - if err == nil { - t.Errorf("should raise an error: %v", ts.err) + t.Run(fmt.Sprintf("%v_%v", ts.tp, ts.tmode), func(t *testing.T) { + tmode, err := clientTypeToInternal(ts.tp) + if ts.err == nil { + if err != nil { + t.Errorf("failed to get datatype mode: %v", err) + } + if tmode != ts.tmode { + t.Errorf("wrong data type: %v", tmode) + } + } else { + if err == nil { + t.Errorf("should raise an error: %v", ts.err) + } } + }) + } +} + +func TestPopulateSnowflakeParameter(t *testing.T) { + columns := []string{"key", "value", "default", "level", "description", "set_by_user", "set_in_job", "set_on", "set_by_thread_id", "set_by_thread_name", "set_by_class", "parameter_comment", "type", "is_expired", "expires_at", "set_by_controlling_parameter", "activate_version", "partial_rollout"} + p := SnowflakeParameter{} + cols := make([]interface{}, len(columns)) + for i := 0; i < len(columns); i++ { + cols[i] = populateSnowflakeParameter(columns[i], &p) + } + for i := 0; i < len(cols); i++ { + if cols[i] == nil { + t.Fatal("failed to populate parameter") } } } diff --git a/driver_test.go b/driver_test.go index c45479f9e..88ddb1a05 100644 --- a/driver_test.go +++ b/driver_test.go @@ -163,7 +163,7 @@ func TestMain(m *testing.M) { type DBTest struct { *testing.T - db *sql.DB + conn *sql.Conn } func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *RowsExtended) { @@ -187,7 +187,7 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *RowsExten close(c) }() - rs, err := dbt.db.QueryContext(ctx, query, args...) + rs, err := dbt.conn.QueryContext(ctx, query, args...) if err != nil { dbt.fail("query", query, err) } @@ -218,7 +218,7 @@ func (dbt *DBTest) mustQueryContext(ctx context.Context, query string, args ...i close(c) }() - rs, err := dbt.db.QueryContext(ctx, query, args...) + rs, err := dbt.conn.QueryContext(ctx, query, args...) if err != nil { dbt.fail("query", query, err) } @@ -228,8 +228,13 @@ func (dbt *DBTest) mustQueryContext(ctx context.Context, query string, args ...i } } +func (dbt *DBTest) query(query string, args ...any) (*sql.Rows, error) { + return dbt.conn.QueryContext(context.Background(), query, args...) +} + func (dbt *DBTest) mustQueryAssertCount(query string, expected int, args ...interface{}) { rows := dbt.mustQuery(query, args...) + defer rows.Close() cnt := 0 for rows.Next() { cnt++ @@ -239,6 +244,10 @@ func (dbt *DBTest) mustQueryAssertCount(query string, expected int, args ...inte } } +func (dbt *DBTest) prepare(query string) (*sql.Stmt, error) { + return dbt.conn.PrepareContext(context.Background(), query) +} + func (dbt *DBTest) fail(method, query string, err error) { if len(query) > 300 { query = "[query too large to print]" @@ -247,21 +256,21 @@ func (dbt *DBTest) fail(method, query string, err error) { } func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) { - res, err := dbt.db.Exec(query, args...) - if err != nil { - dbt.fail("exec", query, err) - } - return res + return dbt.mustExecContext(context.Background(), query, args...) } func (dbt *DBTest) mustExecContext(ctx context.Context, query string, args ...interface{}) (res sql.Result) { - res, err := dbt.db.ExecContext(ctx, query, args...) + res, err := dbt.conn.ExecContext(ctx, query, args...) if err != nil { dbt.fail("exec context", query, err) } return res } +func (dbt *DBTest) exec(query string, args ...any) (sql.Result, error) { + return dbt.conn.ExecContext(context.Background(), query, args...) +} + func (dbt *DBTest) mustDecimalSize(ct *sql.ColumnType) (pr int64, sc int64) { var ok bool pr, sc, ok = ct.DecimalSize() @@ -304,29 +313,36 @@ func (dbt *DBTest) mustNullable(ct *sql.ColumnType) (canNull bool) { } func (dbt *DBTest) mustPrepare(query string) (stmt *sql.Stmt) { - stmt, err := dbt.db.Prepare(query) + stmt, err := dbt.conn.PrepareContext(context.Background(), query) if err != nil { dbt.fail("prepare", query, err) } return stmt } -func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { - db, err := sql.Open("sigmacomputing+gosnowflake", dsn) +func runDBTest(t *testing.T, test func(dbt *DBTest)) { + conn := openConn(t) + defer conn.Close() + dbt := &DBTest{t, conn} + + test(dbt) +} + +func runSnowflakeConnTest(t *testing.T, test func(sc *snowflakeConn)) { + config, err := ParseDSN(dsn) if err != nil { - t.Fatalf("error connecting: %s", err.Error()) + t.Error(err) } - defer db.Close() - - if _, err = db.Exec("DROP TABLE IF EXISTS test"); err != nil { - t.Fatalf("failed to drop table: %v", err) + sc, err := buildSnowflakeConn(context.Background(), *config) + if err != nil { + t.Fatal(err) } - - dbt := &DBTest{t, db} - for _, test := range tests { - test(dbt) - dbt.db.Exec("DROP TABLE IF EXISTS test") + defer sc.Close() + if err = authenticateWithConfig(sc); err != nil { + t.Fatal(err) } + + test(sc) } func runningOnAWS() bool { @@ -412,10 +428,10 @@ func invalidHostErrorTests(invalidDNS string, mstr []string, t *testing.T) { } func TestCommentOnlyQuery(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { query := "--" // just a comment, no query - rows, err := dbt.db.Query(query) + rows, err := dbt.query(query) if err == nil { rows.Close() dbt.fail("query", query, err) @@ -429,15 +445,15 @@ func TestCommentOnlyQuery(t *testing.T) { } func TestEmptyQuery(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { query := "select 1 from dual where 1=0" // just a comment, no query - rows := dbt.db.QueryRow(query) - var v1 interface{} + rows := dbt.conn.QueryRowContext(context.Background(), query) + var v1 any if err := rows.Scan(&v1); err != sql.ErrNoRows { dbt.Errorf("should fail. err: %v", err) } - rows = dbt.db.QueryRowContext(context.Background(), query) + rows = dbt.conn.QueryRowContext(context.Background(), query) if err := rows.Scan(&v1); err != sql.ErrNoRows { dbt.Errorf("should fail. err: %v", err) } @@ -445,10 +461,10 @@ func TestEmptyQuery(t *testing.T) { } func TestEmptyQueryWithRequestID(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { query := "select 1" ctx := WithRequestID(context.Background(), NewUUID()) - rows := dbt.db.QueryRowContext(ctx, query) + rows := dbt.conn.QueryRowContext(ctx, query) var v1 interface{} if err := rows.Scan(&v1); err != nil { dbt.Errorf("should not have failed with valid request id. err: %v", err) @@ -457,9 +473,9 @@ func TestEmptyQueryWithRequestID(t *testing.T) { } func TestCRUD(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { // Create Table - dbt.mustExec("CREATE TABLE test (value BOOLEAN)") + dbt.mustExec("CREATE OR REPLACE TABLE test (value BOOLEAN)") // Test for unexpected Data var out bool @@ -555,7 +571,7 @@ func TestInt(t *testing.T) { } func testInt(t *testing.T, json bool) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { types := []string{"INT", "INTEGER"} in := int64(42) var out int64 @@ -563,24 +579,26 @@ func testInt(t *testing.T, json bool) { // SIGNED for _, v := range types { - if json { - dbt.mustExec(forceJSON) - } - dbt.mustExec("CREATE TABLE test (value " + v + ")") - dbt.mustExec("INSERT INTO test VALUES (?)", in) - rows = dbt.mustQuery("SELECT value FROM test") - defer rows.Close() - if rows.Next() { - rows.Scan(&out) - if in != out { - dbt.Errorf("%s: %d != %d", v, in, out) + t.Run(v, func(t *testing.T) { + if json { + dbt.mustExec(forceJSON) + } + dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (?)", in) + rows = dbt.mustQuery("SELECT value FROM test") + defer rows.Close() + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s: %d != %d", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) } - } else { - dbt.Errorf("%s: no data", v) - } - dbt.mustExec("DROP TABLE IF EXISTS test") + }) } + dbt.mustExec("DROP TABLE IF EXISTS test") }) } @@ -589,32 +607,34 @@ func TestFloat32(t *testing.T) { } func testFloat32(t *testing.T, json bool) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { types := [2]string{"FLOAT", "DOUBLE"} in := float32(42.23) var out float32 var rows *RowsExtended for _, v := range types { - if json { - dbt.mustExec(forceJSON) - } - dbt.mustExec("CREATE TABLE test (value " + v + ")") - dbt.mustExec("INSERT INTO test VALUES (?)", in) - rows = dbt.mustQuery("SELECT value FROM test") - defer rows.Close() - if rows.Next() { - err := rows.Scan(&out) - if err != nil { - dbt.Errorf("failed to scan data: %v", err) + t.Run(v, func(t *testing.T) { + if json { + dbt.mustExec(forceJSON) } - if in != out { - dbt.Errorf("%s: %g != %g", v, in, out) + dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (?)", in) + rows = dbt.mustQuery("SELECT value FROM test") + defer rows.Close() + if rows.Next() { + err := rows.Scan(&out) + if err != nil { + dbt.Errorf("failed to scan data: %v", err) + } + if in != out { + dbt.Errorf("%s: %g != %g", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) } - } else { - dbt.Errorf("%s: no data", v) - } - dbt.mustExec("DROP TABLE IF EXISTS test") + }) } + dbt.mustExec("DROP TABLE IF EXISTS test") }) } @@ -623,29 +643,31 @@ func TestFloat64(t *testing.T) { } func testFloat64(t *testing.T, json bool) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { types := [2]string{"FLOAT", "DOUBLE"} expected := 42.23 var out float64 var rows *RowsExtended for _, v := range types { - if json { - dbt.mustExec(forceJSON) - } - dbt.mustExec("CREATE TABLE test (value " + v + ")") - dbt.mustExec("INSERT INTO test VALUES (42.23)") - rows = dbt.mustQuery("SELECT value FROM test") - defer rows.Close() - if rows.Next() { - rows.Scan(&out) - if expected != out { - dbt.Errorf("%s: %g != %g", v, expected, out) + t.Run(v, func(t *testing.T) { + if json { + dbt.mustExec(forceJSON) } - } else { - dbt.Errorf("%s: no data", v) - } - dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (42.23)") + rows = dbt.mustQuery("SELECT value FROM test") + defer rows.Close() + if rows.Next() { + rows.Scan(&out) + if expected != out { + dbt.Errorf("%s: %g != %g", v, expected, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + }) } + dbt.mustExec("DROP TABLE IF EXISTS test") }) } @@ -654,7 +676,7 @@ func TestString(t *testing.T) { } func testString(t *testing.T, json bool) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } @@ -664,24 +686,26 @@ func testString(t *testing.T, json bool) { var rows *RowsExtended for _, v := range types { - dbt.mustExec("CREATE TABLE test (value " + v + ")") - dbt.mustExec("INSERT INTO test VALUES (?)", in) - - rows = dbt.mustQuery("SELECT value FROM test") - defer rows.Close() - if rows.Next() { - rows.Scan(&out) - if in != out { - dbt.Errorf("%s: %s != %s", v, in, out) + t.Run(v, func(t *testing.T) { + dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM test") + defer rows.Close() + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s: %s != %s", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) } - } else { - dbt.Errorf("%s: no data", v) - } - dbt.mustExec("DROP TABLE IF EXISTS test") + }) } + dbt.mustExec("DROP TABLE IF EXISTS test") // BLOB (Snowflake doesn't support BLOB type but STRING covers large text data) - dbt.mustExec("CREATE TABLE test (id int, value STRING)") + dbt.mustExec("CREATE OR REPLACE TABLE test (id int, value STRING)") id := 2 in = `Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam @@ -695,7 +719,7 @@ func testString(t *testing.T, json bool) { gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet.` dbt.mustExec("INSERT INTO test VALUES (?, ?)", id, in) - if err := dbt.db.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&out); err != nil { + if err := dbt.conn.QueryRowContext(context.Background(), "SELECT value FROM test WHERE id = ?", id).Scan(&out); err != nil { dbt.Fatalf("Error on BLOB-Query: %s", err.Error()) } else if out != in { dbt.Errorf("BLOB: %s != %s", in, out) @@ -786,7 +810,7 @@ func testSimpleDateTimeTimestampFetch(t *testing.T, json bool) { scan(rows, &cd, &ct, &cts) }, } - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } @@ -852,18 +876,20 @@ func testDateTime(t *testing.T, json bool) { {t: time.Date(2011, 11, 20, 21, 27, 37, 123456789, time.UTC)}, }}, } - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } for _, setups := range testcases { - for _, setup := range setups.tests { - if setup.s == "" { - // fill time string wherever Go can reliable produce it - setup.s = setup.t.Format(setups.tlayout) + t.Run(setups.dbtype, func(t *testing.T) { + for _, setup := range setups.tests { + if setup.s == "" { + // fill time string wherever Go can reliable produce it + setup.s = setup.t.Format(setups.tlayout) + } + setup.run(t, dbt, setups.dbtype, setups.tlayout) } - setup.run(t, dbt, setups.dbtype, setups.tlayout) - } + }) } }) } @@ -921,18 +947,20 @@ func testTimestampLTZ(t *testing.T, json bool) { }, }, } - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } for _, setups := range testcases { - for _, setup := range setups.tests { - if setup.s == "" { - // fill time string wherever Go can reliable produce it - setup.s = setup.t.Format(setups.tlayout) + t.Run(setups.dbtype, func(t *testing.T) { + for _, setup := range setups.tests { + if setup.s == "" { + // fill time string wherever Go can reliable produce it + setup.s = setup.t.Format(setups.tlayout) + } + setup.run(t, dbt, setups.dbtype, setups.tlayout) } - setup.run(t, dbt, setups.dbtype, setups.tlayout) - } + }) } }) // Revert timezone to UTC, which is default for the test suit @@ -969,18 +997,20 @@ func testTimestampTZ(t *testing.T, json bool) { }, }, } - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } for _, setups := range testcases { - for _, setup := range setups.tests { - if setup.s == "" { - // fill time string wherever Go can reliable produce it - setup.s = setup.t.Format(setups.tlayout) + t.Run(setups.dbtype, func(t *testing.T) { + for _, setup := range setups.tests { + if setup.s == "" { + // fill time string wherever Go can reliable produce it + setup.s = setup.t.Format(setups.tlayout) + } + setup.run(t, dbt, setups.dbtype, setups.tlayout) } - setup.run(t, dbt, setups.dbtype, setups.tlayout) - } + }) } }) } @@ -990,17 +1020,17 @@ func TestNULL(t *testing.T) { } func testNULL(t *testing.T, json bool) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } - nullStmt, err := dbt.db.Prepare("SELECT NULL") + nullStmt, err := dbt.conn.PrepareContext(context.Background(), "SELECT NULL") if err != nil { dbt.Fatal(err) } defer nullStmt.Close() - nonNullStmt, err := dbt.db.Prepare("SELECT 1") + nonNullStmt, err := dbt.conn.PrepareContext(context.Background(), "SELECT 1") if err != nil { dbt.Fatal(err) } @@ -1101,7 +1131,7 @@ func testNULL(t *testing.T, json bool) { // Insert nil b = nil success := false - if err = dbt.db.QueryRow("SELECT ? IS NULL", b).Scan(&success); err != nil { + if err = dbt.conn.QueryRowContext(context.Background(), "SELECT ? IS NULL", b).Scan(&success); err != nil { dbt.Fatal(err) } if !success { @@ -1110,7 +1140,7 @@ func testNULL(t *testing.T, json bool) { } // Check input==output with input==nil b = nil - if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil { + if err = dbt.conn.QueryRowContext(context.Background(), "SELECT ?", b).Scan(&b); err != nil { dbt.Fatal(err) } if b != nil { @@ -1118,7 +1148,7 @@ func testNULL(t *testing.T, json bool) { } // Check input==output with input!=nil b = []byte("") - if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil { + if err = dbt.conn.QueryRowContext(context.Background(), "SELECT ?", b).Scan(&b); err != nil { dbt.Fatal(err) } if b == nil { @@ -1126,7 +1156,7 @@ func testNULL(t *testing.T, json bool) { } // Insert NULL - dbt.mustExec("CREATE TABLE test (dummmy1 int, value int, dummy2 int)") + dbt.mustExec("CREATE OR REPLACE TABLE test (dummmy1 int, value int, dummy2 int)") dbt.mustExec("INSERT INTO test VALUES (?, ?, ?)", 1, nil, 2) var out interface{} @@ -1148,7 +1178,7 @@ func TestVariant(t *testing.T) { } func testVariant(t *testing.T, json bool) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } @@ -1170,7 +1200,7 @@ func TestArray(t *testing.T) { } func testArray(t *testing.T, json bool) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } @@ -1193,7 +1223,7 @@ func TestLargeSetResult(t *testing.T) { } func testLargeSetResult(t *testing.T, numrows int, json bool) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } @@ -1217,7 +1247,7 @@ func testLargeSetResult(t *testing.T, numrows int, json bool) { } func TestPingpongQuery(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { numrows := 1 rows := dbt.mustQuery("SELECT DISTINCT 1 FROM TABLE(GENERATOR(TIMELIMIT=> 60))") defer rows.Close() @@ -1232,7 +1262,7 @@ func TestPingpongQuery(t *testing.T) { } func TestDML(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { dbt.mustExec("CREATE OR REPLACE TABLE test(c1 int, c2 string)") if err := insertData(dbt, false); err != nil { dbt.Fatalf("failed to insert data: %v", err) @@ -1258,7 +1288,7 @@ func TestDML(t *testing.T) { } func insertData(dbt *DBTest, commit bool) error { - tx, err := dbt.db.Begin() + tx, err := dbt.conn.BeginTx(context.Background(), nil) if err != nil { dbt.Fatalf("failed to begin transaction: %v", err) } @@ -1314,7 +1344,7 @@ func queryTestTx(tx *sql.Tx) (*map[int]string, error) { func queryTest(dbt *DBTest) (*map[int]string, error) { var c1 int var c2 string - rows, err := dbt.db.Query("SELECT c1, c2 FROM test") + rows, err := dbt.query("SELECT c1, c2 FROM test") if err != nil { return nil, err } @@ -1330,11 +1360,11 @@ func queryTest(dbt *DBTest) (*map[int]string, error) { } func TestCancelQuery(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - _, err := dbt.db.QueryContext(ctx, "SELECT DISTINCT 1 FROM TABLE(GENERATOR(TIMELIMIT=> 100))") + _, err := dbt.conn.QueryContext(ctx, "SELECT DISTINCT 1 FROM TABLE(GENERATOR(TIMELIMIT=> 100))") if err == nil { dbt.Fatal("No timeout error returned") } @@ -1345,8 +1375,8 @@ func TestCancelQuery(t *testing.T) { } func TestPing(t *testing.T) { - db := openDB(t) - if err := db.Ping(); err != nil { + db := openConn(t) + if err := db.PingContext(context.Background()); err != nil { t.Fatalf("failed to ping. err: %v", err) } if err := db.PingContext(context.Background()); err != nil { @@ -1355,7 +1385,7 @@ func TestPing(t *testing.T) { if err := db.Close(); err != nil { t.Fatalf("failed to close db. err: %v", err) } - if err := db.Ping(); err == nil { + if err := db.PingContext(context.Background()); err == nil { t.Fatal("should have failed to ping") } if err := db.PingContext(context.Background()); err == nil { @@ -1365,7 +1395,7 @@ func TestPing(t *testing.T) { func TestDoubleDollar(t *testing.T) { // no escape is required for dollar signs - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { sql := `create or replace function dateErr(I double) returns date language javascript strict as $$ @@ -1387,10 +1417,10 @@ $$ func TestTimezoneSessionParameter(t *testing.T) { createDSN(PSTLocation) - db := openDB(t) - defer db.Close() + conn := openConn(t) + defer conn.Close() - rows, err := db.Query("SHOW PARAMETERS LIKE 'TIMEZONE'") + rows, err := conn.QueryContext(context.Background(), "SHOW PARAMETERS LIKE 'TIMEZONE'") if err != nil { t.Errorf("failed to run show parameters. err: %v", err) } @@ -1410,13 +1440,13 @@ func TestTimezoneSessionParameter(t *testing.T) { } func TestLargeSetResultCancel(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { c := make(chan error) ctx, cancel := context.WithCancel(context.Background()) go func() { // attempt to run a 100 seconds query, but it should be canceled in 1 second timelimit := 100 - rows, err := dbt.db.QueryContext( + rows, err := dbt.conn.QueryContext( ctx, fmt.Sprintf("SELECT COUNT(*) FROM TABLE(GENERATOR(timelimit=>%v))", timelimit)) if err != nil { @@ -1468,32 +1498,34 @@ func TestValidateDatabaseParameter(t *testing.T) { }, } for idx, tc := range testcases { - newDSN := tc.dsn - parameters := url.Values{} - if protocol != "" { - parameters.Add("protocol", protocol) - } - if account != "" { - parameters.Add("account", account) - } - for k, v := range tc.params { - parameters.Add(k, v) - } - newDSN += "?" + parameters.Encode() - db, err := sql.Open("sigmacomputing+gosnowflake", newDSN) - // actual connection won't happen until run a query - if err != nil { - t.Fatalf("error creating a connection object: %s", err.Error()) - } - defer db.Close() - if _, err = db.Exec("SELECT 1"); err == nil { - t.Fatal("should cause an error.") - } - if driverErr, ok := err.(*SnowflakeError); ok { - if driverErr.Number != tc.errorCode { // not exist error - t.Errorf("got unexpected error: %v in %v", err, idx) + t.Run(dsn, func(t *testing.T) { + newDSN := tc.dsn + parameters := url.Values{} + if protocol != "" { + parameters.Add("protocol", protocol) } - } + if account != "" { + parameters.Add("account", account) + } + for k, v := range tc.params { + parameters.Add(k, v) + } + newDSN += "?" + parameters.Encode() + db, err := sql.Open("sigmacomputing+gosnowflake", newDSN) + // actual connection won't happen until run a query + if err != nil { + t.Fatalf("error creating a connection object: %s", err.Error()) + } + defer db.Close() + if _, err = db.Exec("SELECT 1"); err == nil { + t.Fatal("should cause an error.") + } + if driverErr, ok := err.(*SnowflakeError); ok { + if driverErr.Number != tc.errorCode { // not exist error + t.Errorf("got unexpected error: %v in %v", err, idx) + } + } + }) } } @@ -1520,7 +1552,7 @@ func TestSpecifyWarehouseDatabase(t *testing.T) { } func TestFetchNil(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQuery("SELECT * FROM values(3,4),(null, 5) order by 2") defer rows.Close() var c1 sql.NullInt64 @@ -1579,6 +1611,18 @@ func TestOpenWithConfig(t *testing.T) { db.Close() } +func TestOpenWithInvalidConfig(t *testing.T) { + config, err := ParseDSN("u:p@h?tmpDirPath=%2Fnon-existing") + if err != nil { + t.Fatalf("failed to parse dsn. err: %v", err) + } + driver := SnowflakeDriver{} + _, err = driver.OpenWithConfig(context.Background(), *config) + if err == nil || !strings.Contains(err.Error(), "/non-existing") { + t.Fatalf("should fail on missing directory") + } +} + type CountingTransport struct { requests int } @@ -1653,8 +1697,9 @@ func TestClientSessionKeepAliveParameter(t *testing.T) { // This test doesn't really validate the CLIENT_SESSION_KEEP_ALIVE functionality but simply checks // the session parameter. createDSNWithClientSessionKeepAlive() - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQuery("SHOW PARAMETERS LIKE 'CLIENT_SESSION_KEEP_ALIVE'") + defer rows.Close() if !rows.Next() { t.Fatal("failed to get timezone.") } @@ -1667,15 +1712,16 @@ func TestClientSessionKeepAliveParameter(t *testing.T) { t.Fatalf("failed to get an expected client_session_keep_alive. got: %v", p.Value) } - rows = dbt.mustQuery("select count(*) from table(generator(timelimit=>30))") - defer rows.Close() + rows2 := dbt.mustQuery("select count(*) from table(generator(timelimit=>30))") + defer rows2.Close() }) } func TestTimePrecision(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { dbt.mustExec("create or replace table z3 (t1 time(5))") rows := dbt.mustQuery("select * from z3") + defer rows.Close() cols, err := rows.ColumnTypes() if err != nil { t.Error(err) diff --git a/dsn.go b/dsn.go index 1689a4b3b..bd3364e86 100644 --- a/dsn.go +++ b/dsn.go @@ -4,11 +4,15 @@ package gosnowflake import ( "crypto/rsa" + "crypto/x509" "encoding/base64" + "encoding/pem" + "errors" "fmt" "net" "net/http" "net/url" + "os" "strconv" "strings" "time" @@ -17,11 +21,13 @@ import ( ) const ( - defaultClientTimeout = 900 * time.Second // Timeout for network round trip + read out http response - defaultLoginTimeout = 60 * time.Second // Timeout for retry for login EXCLUDING clientTimeout - defaultRequestTimeout = 0 * time.Second // Timeout for retry for request EXCLUDING clientTimeout - defaultJWTTimeout = 60 * time.Second - defaultDomain = ".snowflakecomputing.com" + defaultClientTimeout = 900 * time.Second // Timeout for network round trip + read out http response + defaultJWTClientTimeout = 10 * time.Second // Timeout for network round trip + read out http response but used for JWT auth + defaultLoginTimeout = 60 * time.Second // Timeout for retry for login EXCLUDING clientTimeout + defaultRequestTimeout = 0 * time.Second // Timeout for retry for request EXCLUDING clientTimeout + defaultJWTTimeout = 60 * time.Second + defaultExternalBrowserTimeout = 120 * time.Second // Timeout for external browser login + defaultDomain = ".snowflakecomputing.com" // default monitoring fetcher config values defaultMonitoringFetcherQueryMonitoringThreshold = 45 * time.Second @@ -69,10 +75,12 @@ type Config struct { OktaURL *url.URL - LoginTimeout time.Duration // Login retry timeout EXCLUDING network roundtrip and read out http response - RequestTimeout time.Duration // request retry timeout EXCLUDING network roundtrip and read out http response - JWTExpireTimeout time.Duration // JWT expire after timeout - ClientTimeout time.Duration // Timeout for network round trip + read out http response + LoginTimeout time.Duration // Login retry timeout EXCLUDING network roundtrip and read out http response + RequestTimeout time.Duration // request retry timeout EXCLUDING network roundtrip and read out http response + JWTExpireTimeout time.Duration // JWT expire after timeout + ClientTimeout time.Duration // Timeout for network round trip + read out http response + JWTClientTimeout time.Duration // Timeout for network round trip + read out http response used when JWT token auth is taking place + ExternalBrowserTimeout time.Duration // Timeout for external browser login Application string // application name. InsecureMode bool // driver doesn't check certificate revocation status @@ -90,6 +98,8 @@ type Config struct { Tracing string // sets logging level + TmpDirPath string // sets temporary directory used by a driver for operations like encrypting, compressing etc + MfaToken string // Internally used to cache the MFA token IDToken string // Internally used to cache the Id Token for external browser ClientRequestMfaToken ConfigBool // When true the MFA token is cached in the credential manager. True by default in Windows/OSX. False for Linux. @@ -118,6 +128,17 @@ type MonitoringFetcherConfig struct { RetrySleepDuration time.Duration } +// Validate enables testing if config is correct. +// A driver client may call it manually, but it is also called during opening first connection. +func (c *Config) Validate() error { + if c.TmpDirPath != "" { + if _, err := os.Stat(c.TmpDirPath); err != nil { + return err + } + } + return nil +} + // ocspMode returns the OCSP mode in string INSECURE, FAIL_OPEN, FAIL_CLOSED func (c *Config) ocspMode() string { if c.InsecureMode { @@ -147,7 +168,7 @@ func DSN(cfg *Config) (dsn string, err error) { posDot := strings.Index(cfg.Account, ".") if posDot > 0 { if cfg.Region != "" { - return "", ErrInvalidRegion + return "", errInvalidRegion() } cfg.Region = cfg.Account[posDot+1:] cfg.Account = cfg.Account[:posDot] @@ -192,6 +213,9 @@ func DSN(cfg *Config) (dsn string, err error) { if cfg.ClientTimeout != defaultClientTimeout { params.Add("clientTimeout", strconv.FormatInt(int64(cfg.ClientTimeout/time.Second), 10)) } + if cfg.JWTClientTimeout != defaultJWTClientTimeout { + params.Add("jwtClientTimeout", strconv.FormatInt(int64(cfg.JWTClientTimeout/time.Second), 10)) + } if cfg.LoginTimeout != defaultLoginTimeout { params.Add("loginTimeout", strconv.FormatInt(int64(cfg.LoginTimeout/time.Second), 10)) } @@ -201,6 +225,9 @@ func DSN(cfg *Config) (dsn string, err error) { if cfg.JWTExpireTimeout != defaultJWTTimeout { params.Add("jwtTimeout", strconv.FormatInt(int64(cfg.JWTExpireTimeout/time.Second), 10)) } + if cfg.ExternalBrowserTimeout != defaultExternalBrowserTimeout { + params.Add("externalBrowserTimeout", strconv.FormatInt(int64(cfg.ExternalBrowserTimeout/time.Second), 10)) + } if cfg.Application != clientType { params.Add("application", cfg.Application) } @@ -229,6 +256,9 @@ func DSN(cfg *Config) (dsn string, err error) { if cfg.Tracing != "" { params.Add("tracing", cfg.Tracing) } + if cfg.TmpDirPath != "" { + params.Add("tmpDirPath", cfg.TmpDirPath) + } params.Add("ocspFailOpen", strconv.FormatBool(cfg.OCSPFailOpen != OCSPFailOpenFalse)) @@ -421,14 +451,23 @@ func fillMissingConfigParameters(cfg *Config) error { } } if strings.Trim(cfg.Account, " ") == "" { - return ErrEmptyAccount + return errEmptyAccount() + } + + if authRequiresUser(cfg) && strings.TrimSpace(cfg.User) == "" { + return errEmptyUsername() + } + + if authRequiresPassword(cfg) && strings.TrimSpace(cfg.Password) == "" { + return errEmptyPassword() } if cfg.Authenticator != AuthTypeOAuth && cfg.Authenticator != AuthTypeTokenAccessor && + cfg.Authenticator != AuthTypeExternalBrowser && strings.Trim(cfg.User, " ") == "" { // oauth and token accessor do not require a username - return ErrEmptyUsername + return errEmptyUsername() } if cfg.Authenticator != AuthTypeExternalBrowser && @@ -437,7 +476,7 @@ func fillMissingConfigParameters(cfg *Config) error { cfg.Authenticator != AuthTypeTokenAccessor && strings.Trim(cfg.Password, " ") == "" { // no password parameter is required for EXTERNALBROWSER, OAUTH JWT, or TOKENACCESSOR. - return ErrEmptyPassword + return errEmptyPassword() } if strings.Trim(cfg.Protocol, " ") == "" { cfg.Protocol = "https" @@ -476,6 +515,12 @@ func fillMissingConfigParameters(cfg *Config) error { if cfg.ClientTimeout == 0 { cfg.ClientTimeout = defaultClientTimeout } + if cfg.JWTClientTimeout == 0 { + cfg.JWTClientTimeout = defaultJWTClientTimeout + } + if cfg.ExternalBrowserTimeout == 0 { + cfg.ExternalBrowserTimeout = defaultExternalBrowserTimeout + } if strings.Trim(cfg.Application, " ") == "" { cfg.Application = clientType } @@ -512,6 +557,19 @@ func fillMissingConfigParameters(cfg *Config) error { return nil } +func authRequiresUser(cfg *Config) bool { + return cfg.Authenticator != AuthTypeOAuth && + cfg.Authenticator != AuthTypeTokenAccessor && + cfg.Authenticator != AuthTypeExternalBrowser +} + +func authRequiresPassword(cfg *Config) bool { + return cfg.Authenticator != AuthTypeOAuth && + cfg.Authenticator != AuthTypeTokenAccessor && + cfg.Authenticator != AuthTypeExternalBrowser && + cfg.Authenticator != AuthTypeJwt +} + // transformAccountToHost transforms host to account name func transformAccountToHost(cfg *Config) (err error) { if cfg.Port == 0 && !strings.HasSuffix(cfg.Host, defaultDomain) && cfg.Host != "" { @@ -619,6 +677,11 @@ func parseDSNParams(cfg *Config, params string) (err error) { if err != nil { return } + case "jwtClientTimeout": + cfg.JWTClientTimeout, err = parseTimeout(value) + if err != nil { + return + } case "loginTimeout": cfg.LoginTimeout, err = parseTimeout(value) if err != nil { @@ -634,6 +697,11 @@ func parseDSNParams(cfg *Config, params string) (err error) { if err != nil { return err } + case "externalBrowserTimeout": + cfg.ExternalBrowserTimeout, err = parseTimeout(value) + if err != nil { + return err + } case "application": cfg.Application = value case "authenticator": @@ -726,6 +794,8 @@ func parseDSNParams(cfg *Config, params string) (err error) { if err != nil { return err } + case "tmpDirPath": + cfg.TmpDirPath = value default: if cfg.Params == nil { cfg.Params = make(map[string]*string) @@ -758,3 +828,101 @@ func parseTimeout(value string) (time.Duration, error) { } return time.Duration(vv * int64(time.Second)), nil } + +// ConfigParam is used to bind the name of the Config field with the environment variable and set the requirement for it +type ConfigParam struct { + Name string + EnvName string + FailOnMissing bool +} + +// GetConfigFromEnv is used to parse the environment variable values to specific fields of the Config +func GetConfigFromEnv(properties []*ConfigParam) (*Config, error) { + var account, user, password, role, host, portStr, protocol, warehouse, database, schema, region, passcode, application string + var privateKey *rsa.PrivateKey + var err error + if len(properties) == 0 || properties == nil { + return nil, errors.New("missing configuration parameters for the connection") + } + for _, prop := range properties { + value, err := GetFromEnv(prop.EnvName, prop.FailOnMissing) + if err != nil { + return nil, err + } + switch prop.Name { + case "Account": + account = value + case "User": + user = value + case "Password": + password = value + case "Role": + role = value + case "Host": + host = value + case "Port": + portStr = value + case "Protocol": + protocol = value + case "Warehouse": + warehouse = value + case "Database": + database = value + case "Region": + region = value + case "Passcode": + passcode = value + case "Schema": + schema = value + case "Application": + application = value + case "PrivateKey": + privateKey, err = parsePrivateKeyFromFile(value) + if err != nil { + return nil, err + } + } + } + + port := 443 // snowflake default port + if len(portStr) > 0 { + port, err = strconv.Atoi(portStr) + if err != nil { + return nil, err + } + } + + cfg := &Config{ + Account: account, + User: user, + Password: password, + Role: role, + Host: host, + Port: port, + Protocol: protocol, + Warehouse: warehouse, + Database: database, + Schema: schema, + PrivateKey: privateKey, + Region: region, + Passcode: passcode, + Application: application, + } + return cfg, nil +} + +func parsePrivateKeyFromFile(path string) (*rsa.PrivateKey, error) { + bytes, err := os.ReadFile(path) + if err != nil { + return nil, err + } + block, _ := pem.Decode(bytes) + if block == nil { + return nil, errors.New("failed to parse PEM block containing the private key") + } + privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + return privateKey.(*rsa.PrivateKey), nil +} diff --git a/dsn_test.go b/dsn_test.go index 52a0c9d08..4d7fe76b5 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -3,9 +3,16 @@ package gosnowflake import ( + cr "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" "fmt" "net/url" + "os" "reflect" + "strconv" + "strings" "testing" "time" ) @@ -18,7 +25,6 @@ type tcParseDSN struct { } func TestParseDSN(t *testing.T) { - privKeyPKCS8 := generatePKCS8StringSupress(testPrivKey) privKeyPKCS1 := generatePKCS1String(testPrivKey) testcases := []tcParseDSN{ @@ -31,6 +37,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -44,6 +52,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -55,6 +65,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -67,6 +79,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -80,6 +94,8 @@ func TestParseDSN(t *testing.T) { ValidateDefaultParameters: ConfigBoolTrue, OCSPFailOpen: OCSPFailOpenTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -93,6 +109,8 @@ func TestParseDSN(t *testing.T) { ValidateDefaultParameters: ConfigBoolTrue, OCSPFailOpen: OCSPFailOpenTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -105,6 +123,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -117,6 +137,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -129,6 +151,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -142,6 +166,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -155,6 +181,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -168,9 +196,11 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, - err: ErrEmptyPassword, + err: errEmptyPassword(), }, { dsn: "@host:123/db/schema?account=ac&protocol=http", @@ -181,9 +211,11 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, - err: ErrEmptyUsername, + err: errEmptyUsername(), }, { dsn: "user:p@host:123/db/schema?protocol=http", @@ -194,9 +226,11 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, - err: ErrEmptyAccount, + err: errEmptyAccount(), }, { dsn: "u:p@a.snowflakecomputing.com/db/pa?account=a&protocol=https&role=r&timezone=UTC&warehouse=w", @@ -207,6 +241,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -219,6 +255,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -231,6 +269,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -243,6 +283,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -258,6 +300,22 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + }, + ocspMode: ocspModeFailOpen, + }, + { + dsn: "u:p@a?database=d&externalBrowserTimeout=20", + config: &Config{ + Account: "a", User: "u", Password: "p", + Protocol: "https", Host: "a.snowflakecomputing.com", Port: 443, + Database: "d", Schema: "", + ExternalBrowserTimeout: 20 * time.Second, + OCSPFailOpen: OCSPFailOpenTrue, + ValidateDefaultParameters: ConfigBoolTrue, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, }, ocspMode: ocspModeFailOpen, }, @@ -271,6 +329,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, }, @@ -282,6 +342,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: &SnowflakeError{ @@ -300,6 +362,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeInsecure, err: nil, @@ -314,6 +378,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -348,12 +414,12 @@ func TestParseDSN(t *testing.T) { { dsn: "u:u@/+/+?account=+&=0", config: &Config{}, - err: ErrEmptyAccount, + err: errEmptyAccount(), }, { dsn: "u:u@/+/+?account=+&=+&=+", config: &Config{}, - err: ErrEmptyAccount, + err: errEmptyAccount(), }, { dsn: "user%40%2F1:p%3A%40s@/db%2F?account=ac", @@ -363,6 +429,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -376,6 +444,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -394,6 +464,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -407,6 +479,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: &SnowflakeError{Number: ErrCodePrivateKeyParseError}, @@ -419,6 +493,8 @@ func TestParseDSN(t *testing.T) { Database: "db", Schema: "s", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -431,6 +507,8 @@ func TestParseDSN(t *testing.T) { Database: "db", Schema: "s", OCSPFailOpen: OCSPFailOpenFalse, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailClosed, err: nil, @@ -443,6 +521,8 @@ func TestParseDSN(t *testing.T) { Database: "db", Schema: "s", OCSPFailOpen: OCSPFailOpenFalse, InsecureMode: true, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeInsecure, err: nil, @@ -453,7 +533,9 @@ func TestParseDSN(t *testing.T) { Account: "account", User: "user", Password: "pass", Protocol: "https", Host: "account.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: ConfigBoolTrue, OCSPFailOpen: OCSPFailOpenTrue, - ClientTimeout: defaultClientTimeout, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -464,7 +546,9 @@ func TestParseDSN(t *testing.T) { Account: "account", User: "user", Password: "pass", Protocol: "https", Host: "account.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: ConfigBoolFalse, OCSPFailOpen: OCSPFailOpenTrue, - ClientTimeout: defaultClientTimeout, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -475,122 +559,210 @@ func TestParseDSN(t *testing.T) { Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: ConfigBoolFalse, OCSPFailOpen: OCSPFailOpenTrue, - ClientTimeout: defaultClientTimeout, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + }, + ocspMode: ocspModeFailOpen, + err: nil, + }, + { + dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&clientTimeout=300&jwtClientTimeout=45", + config: &Config{ + Account: "a", User: "u", Password: "p", + Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, + Database: "db", Schema: "s", ValidateDefaultParameters: ConfigBoolTrue, OCSPFailOpen: OCSPFailOpenTrue, + ClientTimeout: 300 * time.Second, + JWTClientTimeout: 45 * time.Second, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, }, { - dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&clientTimeout=300", + dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&tmpDirPath=%2Ftmp", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: ConfigBoolTrue, OCSPFailOpen: OCSPFailOpenTrue, - ClientTimeout: 300 * time.Second, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + TmpDirPath: "/tmp", }, ocspMode: ocspModeFailOpen, err: nil, }, } + for _, at := range []AuthType{AuthTypeExternalBrowser, AuthTypeOAuth} { + testcases = append(testcases, tcParseDSN{ + dsn: fmt.Sprintf("@host:777/db/schema?account=ac&protocol=http&authenticator=%v", strings.ToLower(at.String())), + config: &Config{ + Account: "ac", User: "", Password: "", + Protocol: "http", Host: "host", Port: 777, + Database: "db", Schema: "schema", + OCSPFailOpen: OCSPFailOpenTrue, + ValidateDefaultParameters: ConfigBoolTrue, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + Authenticator: at, + }, + ocspMode: ocspModeFailOpen, + err: nil, + }) + } + + for _, at := range []AuthType{AuthTypeSnowflake, AuthTypeUsernamePasswordMFA, AuthTypeJwt} { + testcases = append(testcases, tcParseDSN{ + dsn: fmt.Sprintf("@host:888/db/schema?account=ac&protocol=http&authenticator=%v", strings.ToLower(at.String())), + config: &Config{ + Account: "ac", User: "", Password: "", + Protocol: "http", Host: "host", Port: 888, + Database: "db", Schema: "schema", + OCSPFailOpen: OCSPFailOpenTrue, + ValidateDefaultParameters: ConfigBoolTrue, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + Authenticator: at, + }, + ocspMode: ocspModeFailOpen, + err: errEmptyUsername(), + }) + } + + for _, at := range []AuthType{AuthTypeSnowflake, AuthTypeUsernamePasswordMFA} { + testcases = append(testcases, tcParseDSN{ + dsn: fmt.Sprintf("user@host:888/db/schema?account=ac&protocol=http&authenticator=%v", strings.ToLower(at.String())), + config: &Config{ + Account: "ac", User: "user", Password: "", + Protocol: "http", Host: "host", Port: 888, + Database: "db", Schema: "schema", + OCSPFailOpen: OCSPFailOpenTrue, + ValidateDefaultParameters: ConfigBoolTrue, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + Authenticator: at, + }, + ocspMode: ocspModeFailOpen, + err: errEmptyPassword(), + }) + } + for i, test := range testcases { - // t.Logf("Parsing testcase %d, DSN: %s", i, test.dsn) - cfg, err := ParseDSN(test.dsn) - switch { - case test.err == nil: - if err != nil { - t.Fatalf("%d: Failed to parse the DSN. dsn: %v, err: %v", i, test.dsn, err) - } - if test.config.Host != cfg.Host { - t.Fatalf("%d: Failed to match host. expected: %v, got: %v", - i, test.config.Host, cfg.Host) - } - if test.config.Account != cfg.Account { - t.Fatalf("%d: Failed to match account. expected: %v, got: %v", - i, test.config.Account, cfg.Account) - } - if test.config.User != cfg.User { - t.Fatalf("%d: Failed to match user. expected: %v, got: %v", - i, test.config.User, cfg.User) - } - if test.config.Password != cfg.Password { - t.Fatalf("%d: Failed to match password. expected: %v, got: %v", - i, test.config.Password, cfg.Password) - } - if test.config.Database != cfg.Database { - t.Fatalf("%d: Failed to match database. expected: %v, got: %v", - i, test.config.Database, cfg.Database) - } - if test.config.Schema != cfg.Schema { - t.Fatalf("%d: Failed to match schema. expected: %v, got: %v", - i, test.config.Schema, cfg.Schema) - } - if test.config.Warehouse != cfg.Warehouse { - t.Fatalf("%d: Failed to match warehouse. expected: %v, got: %v", - i, test.config.Warehouse, cfg.Warehouse) - } - if test.config.Role != cfg.Role { - t.Fatalf("%d: Failed to match role. expected: %v, got: %v", - i, test.config.Role, cfg.Role) - } - if test.config.Region != cfg.Region { - t.Fatalf("%d: Failed to match region. expected: %v, got: %v", - i, test.config.Region, cfg.Region) - } - if test.config.Protocol != cfg.Protocol { - t.Fatalf("%d: Failed to match protocol. expected: %v, got: %v", - i, test.config.Protocol, cfg.Protocol) - } - if test.config.Passcode != cfg.Passcode { - t.Fatalf("%d: Failed to match passcode. expected: %v, got: %v", - i, test.config.Passcode, cfg.Passcode) - } - if test.config.PasscodeInPassword != cfg.PasscodeInPassword { - t.Fatalf("%d: Failed to match passcodeInPassword. expected: %v, got: %v", - i, test.config.PasscodeInPassword, cfg.PasscodeInPassword) - } - if test.config.Authenticator != cfg.Authenticator { - t.Fatalf("%d: Failed to match Authenticator. expected: %v, got: %v", - i, test.config.Authenticator.String(), cfg.Authenticator.String()) - } - if test.config.Authenticator == AuthTypeOkta && *test.config.OktaURL != *cfg.OktaURL { - t.Fatalf("%d: Failed to match okta URL. expected: %v, got: %v", - i, test.config.OktaURL, cfg.OktaURL) - } - if test.config.OCSPFailOpen != cfg.OCSPFailOpen { - t.Fatalf("%d: Failed to match OCSPFailOpen. expected: %v, got: %v", - i, test.config.OCSPFailOpen, cfg.OCSPFailOpen) - } - if test.ocspMode != cfg.ocspMode() { - t.Fatalf("%d: Failed to match OCSPMode. expected: %v, got: %v", - i, test.ocspMode, cfg.ocspMode()) - } - if test.config.ValidateDefaultParameters != cfg.ValidateDefaultParameters { - t.Fatalf("%d: Failed to match ValidateDefaultParameters. expected: %v, got: %v", - i, test.config.ValidateDefaultParameters, cfg.ValidateDefaultParameters) - } - if test.config.ClientTimeout != cfg.ClientTimeout { - t.Fatalf("%d: Failed to match ClientTimeout. expected: %v, got: %v", - i, test.config.ClientTimeout, cfg.ClientTimeout) - } - case test.err != nil: - driverErrE, okE := test.err.(*SnowflakeError) - driverErrG, okG := err.(*SnowflakeError) - if okE && !okG || !okE && okG { - t.Fatalf("%d: Wrong error. expected: %v, got: %v", i, test.err, err) - } - if okE && okG { - if driverErrE.Number != driverErrG.Number { - t.Fatalf("%d: Wrong error number. expected: %v, got: %v", i, driverErrE.Number, driverErrG.Number) + t.Run(test.dsn, func(t *testing.T) { + cfg, err := ParseDSN(test.dsn) + switch { + case test.err == nil: + if err != nil { + t.Fatalf("%d: Failed to parse the DSN. dsn: %v, err: %v", i, test.dsn, err) + } + if test.config.Host != cfg.Host { + t.Fatalf("%d: Failed to match host. expected: %v, got: %v", + i, test.config.Host, cfg.Host) + } + if test.config.Account != cfg.Account { + t.Fatalf("%d: Failed to match account. expected: %v, got: %v", + i, test.config.Account, cfg.Account) + } + if test.config.User != cfg.User { + t.Fatalf("%d: Failed to match user. expected: %v, got: %v", + i, test.config.User, cfg.User) + } + if test.config.Password != cfg.Password { + t.Fatalf("%d: Failed to match password. expected: %v, got: %v", + i, test.config.Password, cfg.Password) + } + if test.config.Database != cfg.Database { + t.Fatalf("%d: Failed to match database. expected: %v, got: %v", + i, test.config.Database, cfg.Database) + } + if test.config.Schema != cfg.Schema { + t.Fatalf("%d: Failed to match schema. expected: %v, got: %v", + i, test.config.Schema, cfg.Schema) + } + if test.config.Warehouse != cfg.Warehouse { + t.Fatalf("%d: Failed to match warehouse. expected: %v, got: %v", + i, test.config.Warehouse, cfg.Warehouse) + } + if test.config.Role != cfg.Role { + t.Fatalf("%d: Failed to match role. expected: %v, got: %v", + i, test.config.Role, cfg.Role) } - } else { - t1 := reflect.TypeOf(err) - t2 := reflect.TypeOf(test.err) - if t1 != t2 { - t.Fatalf("%d: Wrong error. expected: %T:%v, got: %T:%v", i, test.err, test.err, err, err) + if test.config.Region != cfg.Region { + t.Fatalf("%d: Failed to match region. expected: %v, got: %v", + i, test.config.Region, cfg.Region) + } + if test.config.Protocol != cfg.Protocol { + t.Fatalf("%d: Failed to match protocol. expected: %v, got: %v", + i, test.config.Protocol, cfg.Protocol) + } + if test.config.Passcode != cfg.Passcode { + t.Fatalf("%d: Failed to match passcode. expected: %v, got: %v", + i, test.config.Passcode, cfg.Passcode) + } + if test.config.PasscodeInPassword != cfg.PasscodeInPassword { + t.Fatalf("%d: Failed to match passcodeInPassword. expected: %v, got: %v", + i, test.config.PasscodeInPassword, cfg.PasscodeInPassword) + } + if test.config.Authenticator != cfg.Authenticator { + t.Fatalf("%d: Failed to match Authenticator. expected: %v, got: %v", + i, test.config.Authenticator.String(), cfg.Authenticator.String()) + } + if test.config.Authenticator == AuthTypeOkta && *test.config.OktaURL != *cfg.OktaURL { + t.Fatalf("%d: Failed to match okta URL. expected: %v, got: %v", + i, test.config.OktaURL, cfg.OktaURL) + } + if test.config.OCSPFailOpen != cfg.OCSPFailOpen { + t.Fatalf("%d: Failed to match OCSPFailOpen. expected: %v, got: %v", + i, test.config.OCSPFailOpen, cfg.OCSPFailOpen) + } + if test.ocspMode != cfg.ocspMode() { + t.Fatalf("%d: Failed to match OCSPMode. expected: %v, got: %v", + i, test.ocspMode, cfg.ocspMode()) + } + if test.config.ValidateDefaultParameters != cfg.ValidateDefaultParameters { + t.Fatalf("%d: Failed to match ValidateDefaultParameters. expected: %v, got: %v", + i, test.config.ValidateDefaultParameters, cfg.ValidateDefaultParameters) + } + if test.config.ClientTimeout != cfg.ClientTimeout { + t.Fatalf("%d: Failed to match ClientTimeout. expected: %v, got: %v", + i, test.config.ClientTimeout, cfg.ClientTimeout) + } + if test.config.JWTClientTimeout != cfg.JWTClientTimeout { + t.Fatalf("%d: Failed to match JWTClientTimeout. expected: %v, got: %v", + i, test.config.JWTClientTimeout, cfg.JWTClientTimeout) + } + if test.config.ExternalBrowserTimeout != cfg.ExternalBrowserTimeout { + t.Fatalf("%d: Failed to match ExternalBrowserTimeout. expected: %v, got: %v", + i, test.config.ExternalBrowserTimeout, cfg.ExternalBrowserTimeout) + } + if test.config.TmpDirPath != cfg.TmpDirPath { + t.Fatalf("%v: Failed to match TmpDirPatch. expected: %v, got: %v", i, test.config.TmpDirPath, cfg.TmpDirPath) + } + case test.err != nil: + driverErrE, okE := test.err.(*SnowflakeError) + driverErrG, okG := err.(*SnowflakeError) + if okE && !okG || !okE && okG { + t.Fatalf("%d: Wrong error. expected: %v, got: %v", i, test.err, err) + } + if okE && okG { + if driverErrE.Number != driverErrG.Number { + t.Fatalf("%d: Wrong error number. expected: %v, got: %v", i, driverErrE.Number, driverErrG.Number) + } + } else { + t1 := reflect.TypeOf(err) + t2 := reflect.TypeOf(test.err) + if t1 != t2 { + t.Fatalf("%d: Wrong error. expected: %T:%v, got: %T:%v", i, test.err, test.err, err, err) + } } } - } + + }) } } @@ -603,7 +775,6 @@ type tcDSN struct { func TestDSN(t *testing.T) { tmfmt := "MM-DD-YYYY" testConnectionID := "abcd-0123-4567-1234" - testcases := []tcDSN{ { cfg: &Config{ @@ -641,7 +812,7 @@ func TestDSN(t *testing.T) { Region: "r", ConnectionID: testConnectionID, }, - err: ErrInvalidRegion, + err: errInvalidRegion(), }, { cfg: &Config{ @@ -672,6 +843,17 @@ func TestDSN(t *testing.T) { }, dsn: "u:p@a.r.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=r&validateDefaultParameters=true", }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a", + Region: "r", + ExternalBrowserTimeout: 20 * time.Second, + ConnectionID: testConnectionID, + }, + dsn: "u:p@a.r.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&externalBrowserTimeout=20&ocspFailOpen=true®ion=r&validateDefaultParameters=true", + }, { cfg: &Config{ User: "", @@ -679,7 +861,7 @@ func TestDSN(t *testing.T) { Account: "a", ConnectionID: testConnectionID, }, - err: ErrEmptyUsername, + err: errEmptyUsername(), }, { cfg: &Config{ @@ -688,16 +870,16 @@ func TestDSN(t *testing.T) { Account: "a", ConnectionID: testConnectionID, }, - err: ErrEmptyPassword, + err: errEmptyPassword(), }, { cfg: &Config{ User: "u", Password: "p", - Account: "", + Account: "a.e", ConnectionID: testConnectionID, }, - err: ErrEmptyAccount, + dsn: "u:p@a.e.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=e&validateDefaultParameters=true", }, { cfg: &Config{ @@ -726,7 +908,7 @@ func TestDSN(t *testing.T) { Region: "r", ConnectionID: testConnectionID, }, - err: ErrInvalidRegion, + err: errInvalidRegion(), }, { cfg: &Config{ @@ -749,13 +931,14 @@ func TestDSN(t *testing.T) { }, { cfg: &Config{ - User: "u", - Password: "p", - Account: "a", - Authenticator: AuthTypeExternalBrowser, - ConnectionID: testConnectionID, + User: "u", + Password: "p", + Account: "a", + Authenticator: AuthTypeExternalBrowser, + ClientStoreTemporaryCredential: ConfigBoolTrue, + ConnectionID: testConnectionID, }, - dsn: "u:p@a.snowflakecomputing.com:443?authenticator=externalbrowser&connectionId=abcd-0123-4567-1234&ocspFailOpen=true&validateDefaultParameters=true", + dsn: "u:p@a.snowflakecomputing.com:443?authenticator=externalbrowser&clientStoreTemporaryCredential=true&connectionId=abcd-0123-4567-1234&ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ @@ -872,43 +1055,101 @@ func TestDSN(t *testing.T) { Region: "r", ConnectionID: testConnectionID, }, - err: ErrInvalidRegion, + err: errInvalidRegion(), }, { cfg: &Config{ - User: "u", - Password: "p", - Account: "a.b.c", - ClientTimeout: 300 * time.Second, - ConnectionID: testConnectionID, + User: "u", + Password: "p", + Account: "a.b.c", + ClientTimeout: 300 * time.Second, + JWTClientTimeout: 60 * time.Second, + ConnectionID: testConnectionID, }, - dsn: "u:p@a.b.c.snowflakecomputing.com:443?clientTimeout=300&connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", + dsn: "u:p@a.b.c.snowflakecomputing.com:443?clientTimeout=300&connectionId=abcd-0123-4567-1234&jwtClientTimeout=60&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ - User: "u", - Password: "p", - Account: "a.e", - MonitoringFetcher: MonitoringFetcherConfig{ - QueryRuntimeThreshold: time.Second * 56, - }, + User: "u", + Password: "p", + Account: "a.b.c", + ClientTimeout: 300 * time.Second, + JWTExpireTimeout: 30 * time.Second, + ConnectionID: testConnectionID, + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?clientTimeout=300&connectionId=abcd-0123-4567-1234&jwtTimeout=30&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", + }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + Protocol: "http", ConnectionID: testConnectionID, }, - dsn: "u:p@a.e.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&monitoringFetcher_queryRuntimeThresholdMs=56000&ocspFailOpen=true®ion=e&validateDefaultParameters=true", + dsn: "u:p@a.b.c.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true&protocol=http®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ - User: "u", - Password: "p", - Account: "a.e", - MonitoringFetcher: MonitoringFetcherConfig{ - QueryRuntimeThreshold: time.Second * 56, - MaxDuration: time.Second * 14, - RetrySleepDuration: time.Millisecond * 45, - }, + User: "u", + Password: "p", + Account: "a.b.c", + Tracing: "debug", ConnectionID: testConnectionID, }, - dsn: "u:p@a.e.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&monitoringFetcher_maxDurationMs=14000&monitoringFetcher_queryRuntimeThresholdMs=56000&monitoringFetcher_retrySleepDurationMs=45&ocspFailOpen=true®ion=e&validateDefaultParameters=true", + dsn: "u:p@a.b.c.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=b.c&tracing=debug&validateDefaultParameters=true", + }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + Authenticator: AuthTypeUsernamePasswordMFA, + ClientRequestMfaToken: ConfigBoolTrue, + ConnectionID: testConnectionID, + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?authenticator=username_password_mfa&clientRequestMfaToken=true&connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", + }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + Authenticator: AuthTypeUsernamePasswordMFA, + ClientRequestMfaToken: ConfigBoolFalse, + ConnectionID: testConnectionID, + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?authenticator=username_password_mfa&clientRequestMfaToken=false&connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", + }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + Warehouse: "wh", + ConnectionID: testConnectionID, + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true&warehouse=wh", + }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + Token: "t", + ConnectionID: testConnectionID, + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=b.c&token=t&validateDefaultParameters=true", + }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + Authenticator: AuthTypeTokenAccessor, + ConnectionID: testConnectionID, + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?authenticator=tokenaccessor&connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ @@ -916,34 +1157,162 @@ func TestDSN(t *testing.T) { Password: "p", Account: "a.e", MonitoringFetcher: MonitoringFetcherConfig{ - QueryRuntimeThreshold: defaultMonitoringFetcherQueryMonitoringThreshold, - MaxDuration: defaultMonitoringFetcherMaxDuration, - RetrySleepDuration: defaultMonitoringFetcherRetrySleepDuration, + QueryRuntimeThreshold: time.Second * 20, }, ConnectionID: testConnectionID, }, - dsn: "u:p@a.e.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=e&validateDefaultParameters=true", + dsn: "u:p@a.e.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&monitoringFetcher_queryRuntimeThresholdMs=20000&ocspFailOpen=true®ion=e&validateDefaultParameters=true", }, } for _, test := range testcases { - dsn, err := DSN(test.cfg) - if test.err == nil && err == nil { - if dsn != test.dsn { - t.Errorf("failed to get DSN. expected: %v, got:\n %v", test.dsn, dsn) + t.Run(test.dsn, func(t *testing.T) { + dsn, err := DSN(test.cfg) + if test.err == nil && err == nil { + if dsn != test.dsn { + t.Errorf("failed to get DSN. expected: %v, got:\n %v", test.dsn, dsn) + } + _, err := ParseDSN(dsn) + if err != nil { + t.Errorf("failed to parse DSN. dsn: %v, err: %v", dsn, err) + } } - _, err := ParseDSN(dsn) - if err != nil { - t.Errorf("failed to parse DSN. dsn: %v, err: %v", dsn, err) + if test.err != nil && err == nil { + t.Errorf("expected error. dsn: %v, err: %v", test.dsn, test.err) } - continue - } - if test.err != nil && err == nil { - t.Errorf("expected error. dsn: %v, err: %v", test.dsn, test.err) - continue + if err != nil && test.err == nil { + t.Errorf("failed to match. err: %v", err) + } + }) + } +} + +func TestParsePrivateKeyFromFileMissingFile(t *testing.T) { + _, err := parsePrivateKeyFromFile("nonexistent") + + if err == nil { + t.Error("should report error for nonexistent file") + } +} + +func TestParsePrivateKeyFromFileIncorrectData(t *testing.T) { + pemFile := createTmpFile("exampleKey.pem", []byte("gibberish")) + _, err := parsePrivateKeyFromFile(pemFile) + + if err == nil { + t.Error("should report error for wrong data in file") + } +} + +func TestParsePrivateKeyFromFile(t *testing.T) { + generatedKey, _ := rsa.GenerateKey(cr.Reader, 1024) + pemKey, _ := x509.MarshalPKCS8PrivateKey(generatedKey) + pemData := pem.EncodeToMemory( + &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: pemKey, + }, + ) + keyFile := createTmpFile("exampleKey.pem", pemData) + defer os.Remove(keyFile) + + parsedKey, err := parsePrivateKeyFromFile(keyFile) + if err != nil { + t.Errorf("unable to parse pam file from path: %v, err: %v", keyFile, err) + } else if !parsedKey.Equal(generatedKey) { + t.Errorf("generated key does not equal to parsed key from file\ngeneratedKey=%v\nparsedKey=%v", + generatedKey, parsedKey) + } +} + +func createTmpFile(fileName string, content []byte) string { + tempFile, _ := os.CreateTemp("", fileName) + tempFile.Write(content) + absolutePath := tempFile.Name() + return absolutePath +} + +type configParamToValue struct { + configParam string + value string +} + +func TestGetConfigFromEnv(t *testing.T) { + envMap := map[string]configParamToValue{ + "SF_TEST_ACCOUNT": {"Account", "account"}, + "SF_TEST_USER": {"User", "user"}, + "SF_TEST_PASSWORD": {"Password", "password"}, + "SF_TEST_ROLE": {"Role", "role"}, + "SF_TEST_HOST": {"Host", "host"}, + "SF_TEST_PORT": {"Port", "8080"}, + "SF_TEST_PROTOCOL": {"Protocol", "http"}, + "SF_TEST_WAREHOUSE": {"Warehouse", "warehouse"}, + "SF_TEST_DATABASE": {"Database", "database"}, + "SF_TEST_REGION": {"Region", "region"}, + "SF_TEST_PASSCODE": {"Passcode", "passcode"}, + "SF_TEST_SCHEMA": {"Schema", "schema"}, + "SF_TEST_APPLICATION": {"Application", "application"}, + } + var properties = make([]*ConfigParam, len(envMap)) + i := 0 + for key, ctv := range envMap { + os.Setenv(key, ctv.value) + cfgParam := ConfigParam{ctv.configParam, key, true} + properties[i] = &cfgParam + i++ + } + defer func() { + for key := range envMap { + os.Unsetenv(key) } - if err != nil && test.err == nil { - t.Errorf("failed to match. err: %v", err) - continue + }() + + cfg, err := GetConfigFromEnv(properties) + if err != nil { + t.Errorf("unable to parse env variables to Config, err: %v", err) + } + + err = checkConfig(*cfg, envMap) + if err != nil { + t.Error(err) + } +} + +func checkConfig(cfg Config, envMap map[string]configParamToValue) error { + appendError := func(errArray []string, envName string, expected string, received string) []string { + errArray = append(errArray, fmt.Sprintf("field %v expected value: %v, received value: %v", envName, expected, received)) + return errArray + } + + value := reflect.ValueOf(cfg) + typeOfCfg := value.Type() + cfgValues := make(map[string]interface{}, value.NumField()) + for i := 0; i < value.NumField(); i++ { + cfgValues[typeOfCfg.Field(i).Name] = value.Field(i).Interface() + } + + var errArray []string + for key, ctv := range envMap { + if ctv.configParam == "Port" { + if portStr := strconv.Itoa(cfgValues[ctv.configParam].(int)); portStr != ctv.value { + errArray = appendError(errArray, key, ctv.value, cfgValues[ctv.configParam].(string)) + } + } else if cfgValues[ctv.configParam] != ctv.value { + errArray = appendError(errArray, key, ctv.value, cfgValues[ctv.configParam].(string)) } } + + if errArray != nil { + return fmt.Errorf(strings.Join(errArray, "\n")) + } + + return nil +} + +func TestConfigValidateTmpDirPath(t *testing.T) { + cfg := &Config{ + TmpDirPath: "/not/existing", + } + if err := cfg.Validate(); err == nil { + t.Fatalf("Should fail on not existing TmpDirPath") + } } diff --git a/errors.go b/errors.go index bdd1bdd67..366e5dcb5 100644 --- a/errors.go +++ b/errors.go @@ -66,7 +66,7 @@ func (se *SnowflakeError) generateTelemetryExceptionData() *telemetryData { } func (se *SnowflakeError) sendExceptionTelemetry(sc *snowflakeConn, data *telemetryData) error { - if sc != nil { + if sc != nil && sc.telemetry != nil { return sc.telemetry.addLog(data) } return nil // TODO oob telemetry @@ -82,7 +82,7 @@ func (se *SnowflakeError) exceptionTelemetry(sc *snowflakeConn) *SnowflakeError // return populated error fields replacing the default response func populateErrorFields(code int, data *execResponse) *SnowflakeError { - err := ErrUnknownError + err := errUnknownError() if code != -1 { err.Number = code } @@ -299,32 +299,44 @@ const ( errMsgAsyncWithNoResults = "async with no results" ) -var ( - // ErrEmptyAccount is returned if a DNS doesn't include account parameter. - ErrEmptyAccount = &SnowflakeError{ +// Returned if a DNS doesn't include account parameter. +func errEmptyAccount() *SnowflakeError { + return &SnowflakeError{ Number: ErrCodeEmptyAccountCode, Message: "account is empty", } - // ErrEmptyUsername is returned if a DNS doesn't include user parameter. - ErrEmptyUsername = &SnowflakeError{ +} + +// Returned if a DNS doesn't include user parameter. +func errEmptyUsername() *SnowflakeError { + return &SnowflakeError{ Number: ErrCodeEmptyUsernameCode, Message: "user is empty", } - // ErrEmptyPassword is returned if a DNS doesn't include password parameter. - ErrEmptyPassword = &SnowflakeError{ +} + +// Returned if a DNS doesn't include password parameter. +func errEmptyPassword() *SnowflakeError { + return &SnowflakeError{ Number: ErrCodeEmptyPasswordCode, - Message: "password is empty"} + Message: "password is empty", + } +} - // ErrInvalidRegion is returned if a DSN's implicit region from account parameter and explicit region parameter conflict. - ErrInvalidRegion = &SnowflakeError{ +// Returned if a DSN's implicit region from account parameter and explicit region parameter conflict. +func errInvalidRegion() *SnowflakeError { + return &SnowflakeError{ Number: ErrCodeRegionOverlap, - Message: "two regions specified"} + Message: "two regions specified", + } +} - // ErrUnknownError is returned if the server side returns an error without meaningful message. - ErrUnknownError = &SnowflakeError{ +// Returned if the server side returns an error without meaningful message. +func errUnknownError() *SnowflakeError { + return &SnowflakeError{ Number: -1, SQLState: "-1", Message: "an unknown server side error occurred", QueryID: "-1", } -) +} diff --git a/htap_test.go b/htap_test.go index ad58b8b78..39fae90f0 100644 --- a/htap_test.go +++ b/htap_test.go @@ -101,37 +101,36 @@ func trimWhitespaces(s string) string { ) } -// Not released yet, we should add them after they are released -// func TestAddingQueryContextCacheEntry(t *testing.T) { -// runSnowflakeConnTest(t, func(sc *snowflakeConn) { -// t.Run("First query (may be on empty cache)", func(t *testing.T) { -// entriesBefore := sc.queryContextCache.entries -// if _, err := sc.Query("SELECT 1", nil); err != nil { -// t.Fatalf("cannot query. %v", err) -// } -// entriesAfter := sc.queryContextCache.entries +func TestAddingQueryContextCacheEntry(t *testing.T) { + runSnowflakeConnTest(t, func(sc *snowflakeConn) { + t.Run("First query (may be on empty cache)", func(t *testing.T) { + entriesBefore := sc.queryContextCache.entries + if _, err := sc.Query("SELECT 1", nil); err != nil { + t.Fatalf("cannot query. %v", err) + } + entriesAfter := sc.queryContextCache.entries -// if !containsNewEntries(entriesAfter, entriesBefore) { -// t.Error("no new entries added to the query context cache") -// } -// }) + if !containsNewEntries(entriesAfter, entriesBefore) { + t.Error("no new entries added to the query context cache") + } + }) -// t.Run("Second query (cache should not be empty)", func(t *testing.T) { -// entriesBefore := sc.queryContextCache.entries -// if len(entriesBefore) == 0 { -// t.Fatalf("cache should not be empty after first query") -// } -// if _, err := sc.Query("SELECT 1", nil); err != nil { -// t.Fatalf("cannot query. %v", err) -// } -// entriesAfter := sc.queryContextCache.entries + t.Run("Second query (cache should not be empty)", func(t *testing.T) { + entriesBefore := sc.queryContextCache.entries + if len(entriesBefore) == 0 { + t.Fatalf("cache should not be empty after first query") + } + if _, err := sc.Query("SELECT 1", nil); err != nil { + t.Fatalf("cannot query. %v", err) + } + entriesAfter := sc.queryContextCache.entries -// if !containsNewEntries(entriesAfter, entriesBefore) { -// t.Error("no new entries added to the query context cache") -// } -// }) -// }) -// } + if !containsNewEntries(entriesAfter, entriesBefore) { + t.Error("no new entries added to the query context cache") + } + }) + }) +} func containsNewEntries(entriesAfter []queryContextEntry, entriesBefore []queryContextEntry) bool { if len(entriesAfter) > len(entriesBefore) { diff --git a/query.go b/query.go index 882e1b62a..50473e16f 100644 --- a/query.go +++ b/query.go @@ -125,6 +125,11 @@ type execResponseData struct { Command string `json:"command,omitempty"` Kind string `json:"kind,omitempty"` Operation string `json:"operation,omitempty"` + + // HTAP + QueryContext struct { + Entries []queryContextEntry `json:"entries,omitempty"` + } `json:"queryContext,omitempty"` } type execResponse struct { diff --git a/rows.go b/rows.go index 71e87caa9..314981a94 100644 --- a/rows.go +++ b/rows.go @@ -45,9 +45,17 @@ type snowflakeRows struct { err error errChannel chan error monitoring *monitoringResult + location *time.Location asyncRequestID UUID } +func (rows *snowflakeRows) getLocation() *time.Location { + if rows.location == nil && rows.sc != nil && rows.sc.cfg != nil { + rows.location = getCurrentLocation(rows.sc.cfg.Params) + } + return rows.location +} + type snowflakeValue interface{} type chunkRowType struct { @@ -233,11 +241,7 @@ func (rows *snowflakeRows) Next(dest []driver.Value) (err error) { for i, n := 0, len(row.RowSet); i < n; i++ { // could move to chunk downloader so that each go routine // can convert data - var loc *time.Location - if rows.sc != nil { - loc = getCurrentLocation(rows.sc.cfg.Params) - } - err = stringToValue(&dest[i], rows.ChunkDownloader.getRowType()[i], row.RowSet[i], loc) + err = stringToValue(&dest[i], rows.ChunkDownloader.getRowType()[i], row.RowSet[i], rows.getLocation()) if err != nil { return err } diff --git a/statement_test.go b/statement_test.go index 5d1554202..340e1e1d6 100644 --- a/statement_test.go +++ b/statement_test.go @@ -1,4 +1,5 @@ // Copyright (c) 2020-2022 Snowflake Computing Inc. All rights reserved. +//lint:file-ignore SA1019 Ignore deprecated methods. We should leave them as-is to keep backward compatibility. package gosnowflake @@ -7,7 +8,10 @@ import ( "database/sql" "database/sql/driver" "fmt" + "net/http" + "net/url" "testing" + "time" ) func openDB(t *testing.T) *sql.DB { @@ -17,20 +21,30 @@ func openDB(t *testing.T) *sql.DB { if db, err = sql.Open("sigmacomputing+gosnowflake", dsn); err != nil { t.Fatalf("failed to open db. %v, err: %v", dsn, err) } + return db } -func TestGetQueryID(t *testing.T) { - db := openDB(t) - defer db.Close() +func openConn(t *testing.T) *sql.Conn { + var db *sql.DB + var conn *sql.Conn + var err error - ctx := context.TODO() - conn, err := db.Conn(ctx) - if err != nil { - t.Error(err) + if db, err = sql.Open("sigmacomputing+gosnowflake", dsn); err != nil { + t.Fatalf("failed to open db. %v, err: %v", dsn, err) + } + if conn, err = db.Conn(context.Background()); err != nil { + t.Fatalf("failed to open connection: %v", err) } + return conn +} - if err = conn.Raw(func(x interface{}) error { +func TestGetQueryID(t *testing.T) { + ctx := context.Background() + conn := openConn(t) + defer conn.Close() + + if err := conn.Raw(func(x interface{}) error { rows, err := x.(driver.QueryerContext).QueryContext(ctx, "select 1", nil) if err != nil { return err @@ -71,7 +85,7 @@ func TestEmitQueryID(t *testing.T) { cnt := 0 var idx int var v string - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQueryContext(ctx, fmt.Sprintf(selectRandomGenerator, numrows)) defer rows.Close() @@ -144,9 +158,10 @@ func TestE2EFetchResultByID(t *testing.T) { } func TestWithDescribeOnly(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { ctx := WithDescribeOnly(context.Background()) rows := dbt.mustQueryContext(ctx, selectVariousTypes) + defer rows.Close() cols, err := rows.Columns() if err != nil { t.Error(err) @@ -166,10 +181,243 @@ func TestWithDescribeOnly(t *testing.T) { }) } +func TestCallStatement(t *testing.T) { + runDBTest(t, func(dbt *DBTest) { + in1 := float64(1) + in2 := string("[2,3]") + expected := "1 \"[2,3]\" [2,3]" + var out string + dbt.exec("ALTER SESSION SET USE_STATEMENT_TYPE_CALL_FOR_STORED_PROC_CALLS = true") + dbt.mustExec("create or replace procedure " + + "TEST_SP_CALL_STMT_ENABLED(in1 float, in2 variant) " + + "returns string language javascript as $$ " + + "let res = snowflake.execute({sqlText: 'select ? c1, ? c2', binds:[IN1, JSON.stringify(IN2)]}); " + + "res.next(); " + + "return res.getColumnValueAsString(1) + ' ' + res.getColumnValueAsString(2) + ' ' + IN2; " + + "$$;") + stmt, err := dbt.conn.PrepareContext(context.Background(), "call TEST_SP_CALL_STMT_ENABLED(?, to_variant(?))") + if err != nil { + dbt.Errorf("failed to prepare query: %v", err) + } + defer stmt.Close() + err = stmt.QueryRow(in1, in2).Scan(&out) + if err != nil { + dbt.Errorf("failed to scan: %v", err) + } + if expected != out { + dbt.Errorf("expected: %s, got: %s", expected, out) + } + dbt.mustExec("drop procedure if exists TEST_SP_CALL_STMT_ENABLED(float, variant)") + }) +} + +func TestStmtExec(t *testing.T) { + ctx := context.Background() + conn := openConn(t) + defer conn.Close() + + if _, err := conn.ExecContext(ctx, `create or replace table test_table(col1 int, col2 int)`); err != nil { + t.Fatalf("failed to create table: %v", err) + } + + if err := conn.Raw(func(x interface{}) error { + stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "insert into test_table values (1, 2)") + if err != nil { + t.Error(err) + } + _, err = stmt.(*snowflakeStmt).Exec(nil) + if err != nil { + t.Error(err) + } + _, err = stmt.(*snowflakeStmt).Query(nil) + if err != nil { + t.Error(err) + } + return nil + }); err != nil { + t.Fatalf("failed to drop table: %v", err) + } + + if _, err := conn.ExecContext(ctx, "drop table if exists test_table"); err != nil { + t.Fatalf("failed to drop table: %v", err) + } +} + +func getStatusSuccessButInvalidJSONfunc(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ time.Duration) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, + }, nil +} + +func TestUnitCheckQueryStatus(t *testing.T) { + sc := getDefaultSnowflakeConn() + ctx := context.Background() + qid := NewUUID() + + sr := &snowflakeRestful{ + FuncGet: getStatusSuccessButInvalidJSONfunc, + TokenAccessor: getSimpleTokenAccessor(), + } + sc.rest = sr + _, err := sc.checkQueryStatus(ctx, qid.String()) + if err == nil { + t.Fatal("invalid json. should have failed") + } + sc.rest.FuncGet = funcGetQueryRespFail + _, err = sc.checkQueryStatus(ctx, qid.String()) + if err == nil { + t.Fatal("should have failed") + } + + sc.rest.FuncGet = funcGetQueryRespError + _, err = sc.checkQueryStatus(ctx, qid.String()) + if err == nil { + t.Fatal("should have failed") + } + driverErr, ok := err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + if driverErr.Number != ErrQueryStatus { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrQueryStatus, driverErr.Number) + } +} + +func TestStatementQueryIdForQueries(t *testing.T) { + ctx := context.Background() + conn := openConn(t) + defer conn.Close() + + testcases := []struct { + name string + f func(stmt driver.Stmt) (driver.Rows, error) + }{ + { + "query", + func(stmt driver.Stmt) (driver.Rows, error) { + return stmt.Query(nil) + }, + }, + { + "queryContext", + func(stmt driver.Stmt) (driver.Rows, error) { + return stmt.(driver.StmtQueryContext).QueryContext(ctx, nil) + }, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + err := conn.Raw(func(x any) error { + stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "SELECT 1") + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() != "" { + t.Error("queryId should be empty before executing any query") + } + firstQuery, err := tc.f(stmt) + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() == "" { + t.Error("queryId should not be empty after executing query") + } + if stmt.(SnowflakeStmt).GetQueryID() != firstQuery.(SnowflakeRows).GetQueryID() { + t.Error("queryId should be equal among query result and prepared statement") + } + secondQuery, err := tc.f(stmt) + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() == "" { + t.Error("queryId should not be empty after executing query") + } + if stmt.(SnowflakeStmt).GetQueryID() != secondQuery.(SnowflakeRows).GetQueryID() { + t.Error("queryId should be equal among query result and prepared statement") + } + return nil + }) + if err != nil { + t.Fatal(err) + } + }) + } +} + +func TestStatementQueryIdForExecs(t *testing.T) { + ctx := context.Background() + runDBTest(t, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE TestStatementQueryIdForExecs (v INTEGER)") + defer dbt.mustExec("DROP TABLE IF EXISTS TestStatementQueryIdForExecs") + + testcases := []struct { + name string + f func(stmt driver.Stmt) (driver.Result, error) + }{ + { + "exec", + func(stmt driver.Stmt) (driver.Result, error) { + return stmt.Exec(nil) + }, + }, + { + "execContext", + func(stmt driver.Stmt) (driver.Result, error) { + return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) + }, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + err := dbt.conn.Raw(func(x any) error { + stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "INSERT INTO TestStatementQueryIdForExecs VALUES (1)") + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() != "" { + t.Error("queryId should be empty before executing any query") + } + firstExec, err := tc.f(stmt) + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() == "" { + t.Error("queryId should not be empty after executing query") + } + if stmt.(SnowflakeStmt).GetQueryID() != firstExec.(SnowflakeResult).GetQueryID() { + t.Error("queryId should be equal among query result and prepared statement") + } + secondExec, err := tc.f(stmt) + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() == "" { + t.Error("queryId should not be empty after executing query") + } + if stmt.(SnowflakeStmt).GetQueryID() != secondExec.(SnowflakeResult).GetQueryID() { + t.Error("queryId should be equal among query result and prepared statement") + } + return nil + }) + if err != nil { + t.Fatal(err) + } + }) + } + }) +} + func TestWithQueryTag(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { testQueryTag := "TEST QUERY TAG" ctx := WithQueryTag(context.Background(), testQueryTag) + ctx, fn := context.WithCancel(ctx) + // For whatever reason we have to cancel the context explicitly + // To prevent sql conn.Close() from hanging. + defer fn() // This query itself will be part of the history and will have the query tag rows := dbt.mustQueryContext( diff --git a/submit_sync_test.go b/submit_sync_test.go index 0baad913f..1f19a2017 100644 --- a/submit_sync_test.go +++ b/submit_sync_test.go @@ -50,8 +50,9 @@ func TestSubmitQuerySync(t *testing.T) { // Set a long threshold to prevent the monitoring fetch from kicking in. MonitoringFetcher: MonitoringFetcherConfig{QueryRuntimeThreshold: 1 * time.Hour}, }, - rest: sr, - telemetry: testTelemetry, + rest: sr, + telemetry: testTelemetry, + queryContextCache: (&queryContextCache{}).init(), } res, err := sc.SubmitQuerySync(context.TODO(), "") @@ -129,8 +130,9 @@ func TestSubmitQuerySyncQueryComplete(t *testing.T) { // Set a long threshold to prevent the monitoring fetch from kicking in. MonitoringFetcher: MonitoringFetcherConfig{QueryRuntimeThreshold: 1 * time.Hour}, }, - rest: sr, - telemetry: testTelemetry, + rest: sr, + telemetry: testTelemetry, + queryContextCache: (&queryContextCache{}).init(), } res, err := sc.SubmitQuerySync(context.TODO(), "") diff --git a/telemetry_test.go b/telemetry_test.go index 49826512f..54ab12c85 100644 --- a/telemetry_test.go +++ b/telemetry_test.go @@ -476,3 +476,209 @@ func TestAddLogError(t *testing.T) { t.Fatal("should have failed") } } + +func funcPostTelemetryRespFail(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool) (*http.Response, error) { + return nil, errors.New("failed to upload metrics to telemetry") +} + +func TestTelemetryError(t *testing.T) { + config, err := ParseDSN(dsn) + if err != nil { + t.Error(err) + } + sc, err := buildSnowflakeConn(context.Background(), *config) + if err != nil { + t.Fatal(err) + } + if err = authenticateWithConfig(sc); err != nil { + t.Fatal(err) + } + sr := &snowflakeRestful{ + FuncPost: funcPostTelemetryRespFail, + TokenAccessor: getSimpleTokenAccessor(), + } + st := &snowflakeTelemetry{ + sr: sr, + mutex: &sync.Mutex{}, + enabled: true, + flushSize: defaultFlushSize, + } + + if err = st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err != nil { + t.Fatal(err) + } + + err = st.sendBatch() + if err == nil { + t.Fatal("should have failed") + } +} + +func TestTelemetryDisabledOnBadResponse(t *testing.T) { + config, err := ParseDSN(dsn) + if err != nil { + t.Error(err) + } + sc, err := buildSnowflakeConn(context.Background(), *config) + if err != nil { + t.Fatal(err) + } + if err = authenticateWithConfig(sc); err != nil { + t.Fatal(err) + } + sr := &snowflakeRestful{ + FuncPost: postTestAppBadGatewayError, + TokenAccessor: getSimpleTokenAccessor(), + } + st := &snowflakeTelemetry{ + sr: sr, + mutex: &sync.Mutex{}, + enabled: true, + flushSize: defaultFlushSize, + } + + if err = st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err != nil { + t.Fatal(err) + } + err = st.sendBatch() + if err == nil { + t.Fatal("should have failed") + } + if st.enabled == true { + t.Fatal("telemetry should be disabled") + } + + st.enabled = true + st.sr.FuncPost = postTestQueryNotExecuting + if err = st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err != nil { + t.Fatal(err) + } + err = st.sendBatch() + if err == nil { + t.Fatal("should have failed") + } + if st.enabled == true { + t.Fatal("telemetry should be disabled") + } + + st.enabled = true + st.sr.FuncPost = postTestSuccessButInvalidJSON + if err = st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err != nil { + t.Fatal(err) + } + err = st.sendBatch() + if err == nil { + t.Fatal("should have failed") + } + if st.enabled == true { + t.Fatal("telemetry should be disabled") + } +} + +func TestTelemetryDisabled(t *testing.T) { + config, err := ParseDSN(dsn) + if err != nil { + t.Error(err) + } + sc, err := buildSnowflakeConn(context.Background(), *config) + if err != nil { + t.Fatal(err) + } + if err = authenticateWithConfig(sc); err != nil { + t.Fatal(err) + } + sr := &snowflakeRestful{ + FuncPost: postTestAppBadGatewayError, + TokenAccessor: getSimpleTokenAccessor(), + } + st := &snowflakeTelemetry{ + sr: sr, + mutex: &sync.Mutex{}, + enabled: false, // disable + flushSize: defaultFlushSize, + } + if err = st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err == nil { + t.Fatal("should have failed") + } + st.enabled = true + if err = st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err != nil { + t.Fatal(err) + } + st.enabled = false + err = st.sendBatch() + if err == nil { + t.Fatal("should have failed") + } +} + +func TestAddLogError(t *testing.T) { + config, err := ParseDSN(dsn) + if err != nil { + t.Error(err) + } + sc, err := buildSnowflakeConn(context.Background(), *config) + if err != nil { + t.Fatal(err) + } + if err = authenticateWithConfig(sc); err != nil { + t.Fatal(err) + } + + sr := &snowflakeRestful{ + FuncPost: funcPostTelemetryRespFail, + TokenAccessor: getSimpleTokenAccessor(), + } + + st := &snowflakeTelemetry{ + sr: sr, + mutex: &sync.Mutex{}, + enabled: true, + flushSize: 1, + } + + if err = st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err == nil { + t.Fatal("should have failed") + } +} diff --git a/util.go b/util.go index ff37f9dc5..5c2c676c4 100644 --- a/util.go +++ b/util.go @@ -5,7 +5,9 @@ package gosnowflake import ( "context" "database/sql/driver" + "fmt" "io" + "os" "strings" "sync" "time" @@ -272,3 +274,14 @@ func escapeForCSV(value string) string { } return value } + +// GetFromEnv is used to get the value of an environment variable from the system +func GetFromEnv(name string, failOnMissing bool) (string, error) { + if value := os.Getenv(name); value != "" { + return value, nil + } + if failOnMissing { + return "", fmt.Errorf("%v environment variable is not set", name) + } + return "", nil +}