From 228e9d2a3bd22477344c229039b4ac2608e83342 Mon Sep 17 00:00:00 2001 From: rajsigma <124622569+rajsigma@users.noreply.github.com> Date: Wed, 30 Aug 2023 16:33:44 -0700 Subject: [PATCH] Sf rebase#2 (#163) * SNOW-824297: Bumped up GoLang connector PATCH version from 1.6.20 to 1.6.21 (#800) * use the context object supplied by BeginTx instead of from the parent connection (#804) * Make port/portStr variables in cmd more consistent (#803) * Add tests to increase code coverage (#784) add tests to increase code coverage to baseline 80% * SNOW-828186 Remove code duplication in transaction.go (#807) * SNOW-829489: Fix CLA Assistant Github Action (#811) * SNOW-829187 Add multistatement query example (#810) * Add missing makefiles for example apps (#809) * Add multistatement exec example (#812) * SNOW-829454: Async API Example (#813) * SNOW-830268: Add PUT/GET feature example (#814) * SNOW-838450: Bumped up GoLang connector PATCH version from 1.6.21 to 1.6.22 (#820) @noreview - This is an automated process. No review is required Co-authored-by: Jenkins User * SNOW-832816: Unify and remove duplicates of getDsn() function (#818) * SNOW-832885 Add arrow_batches example (#819) * Add null checks before accessing connection config during chunk downloading (#821) * SNOW-832816: Unify and remove duplicates of getDsn() function (#823) * Fix multistatement demo omitting first result set (#827) * SNOW-840123 Upgrade arrow 12.0.0 -> 12.0.1 (#829) * retry async request if still in progress (#824) update comment * SNOW-723750: Panic with readonly file system (#828) * add jenkins user for cla assistant check (#830) * add mutex lock to prevent race condition in exec() function (#834) defer mutex unlock * SNOW-847338: Add timeout for authentication in external browser (#835) * SNOW-852381: Add documentation and DSN parameter for external browser timeout (#837) * SNOW-853914 Add distributed fetch example (#840) * SNOW-832825: Fix async documentation (#839) * SNOW-856850: Add missing parameter check (#842) * SNOW-833537 Separate JWT http client to specify timeout (#841) * SNOW-847417: Added support for sql.Null types for query bind mapping (#844) * SNOW-833537 Each time retry keypair auth with new token (#845) * SNOW-857829 Fix username and password requiredness (#846) * fix(arrow): handle non-arrow result sets (#851) * SNOW-645253 Handle binding named parameters (#850) * SNOW-870818: Add Snyk to CLA Assistant allowlist (#854) * SNOW-870818: Add snyk-bot to CLA Assistant allowlist (#855) * SNOW-871839: Fix Snyk permissions (#857) * SNOW-859548 Reuse connection in tests (#856) * SNOW-859548 Replace test table in each test (#858) * SNOW-875425: Bumped up GoLang connector PATCH version from 1.6.22 to 1.6.23 (#861) * Enable procedure calling test (#863) * cover critical areas test code coverage for logger, dsn and connection (#866) * SNOW-833537 Add JWT retry docs (#868) * SNOW-880396: Do not cancel other matrix builds when one fails (#870) * SNOW-880396: Do not cancel other matrix builds when one fails * SNOW-880396: Cancel builds when new code in branch appears * increase code coverage on coverter.go (#867) * Remove executable bits in file open/creation Co-authored-by: Rami <72725910+ramikg@users.noreply.github.com> * SNOW-859548 Refactor runTests to runDBTest (#869) * SNOW-880442: Add pass context to arrow fetch function (#871) * refactor download files (#875) * SNOW-540086: Missing critical areas for code coverage (#873) * SNOW-870356: Integrate code coverage with CodeCov (#876) * SNOW-845282: Allow configuring tmpdir in DSN (#874) * SNOW-857660: Init rows location once (#882) * SNOW-845282: Add docs about tmpDirPath (#884) * SNOW-859547: Refactor tests to create a new test for each test case (#877) * SNOW-894815: Disable TestConcurrentReadOnParams (#886) * [SNOW-892549] Remove references to sfcdev1 (#885) * SNOW-897024: Bumped up GoLang connector PATCH version from 1.6.23 to 1.6.24 (#891) @noreview - This is an automated process. No review is required Co-authored-by: Jenkins User * SNOW-895534: Add HTAP query context struct (#888) * SNOW-726742 Remove queryID from snowflakeConn (#892) * SNOW-898353 Fix snyk permissions (#893) * SNOW-889572 Refactored snowflake type implementation to map (#890) * SNOW-895534: Add HTAP query context entries to cache (#889) * SNOW-848019 Change global error objects to instantiated by function (#897) * SNOW-726742: Implement GetQueryId for statements (#899) * [Maintenance] Create Fork [Maintenance] enable CI tests in our repo (#20) add README note describing how CI secrets are set up change yaml to reference environment remove tests for platforms other than Ubuntu + AWS disable staticcheck add encrypted parameters file add private key disable tests that fail [Maintenance] Ignore vendored libs * [Fix] arrowToValue handling of timestamp_ltz at non-nanosecond scales (#24) * [Fix] SIG-18794: Fix handling of results response for canceled queries (#58) * [Feature] Add fetchers for result, monitoring, query-graph. * add hook to fetch "monitoring" data and query-graph, in goroutines * add support to fetching monitoring info * add monitoring threshold * move to Duration Lint/Build fixes nit: Goimports fixes - [Feature] [monitoring] fetch the query graph for slow queries (#26) * fetch query graph * fix * more fixes update interface move monitoring to a goroutine; client decides how long it'll wait (#27) add MonitoringResultFetcher to fetch monitoring (#48) - not use qid from SFconnection when fetch monitoring data (#52) * [Feature] allow clients to set QUERY_TAG parameter via context * [Fix] SIG-17456: Trap chunk-downloader errors with Sentry at Multiplex. (#44) * nit: Move "send error on channel" within the `recover` condition. * Improvement: don't panic _indiscriminately_, have a separate error-type that signals the internal panics we want to re-propagate. Co-authored-by: Agam Brahma * [Feature] Allow clients to specify any/null DataType explicitly - Allow clients to specify DataTypeNull explicitly (#50) * first cut of explict DataTypeNull * make lint * more tests - Allow client to specify any data type explicitly (#28) * dataTypeMode should work for all types * fix test * allow explicit binaryType declaration even if we're already using binaryType * silly attempt with *SnowflakeDataType * switch from *SnowflakeDataType --> SnowflakeDataType * bindings_test.go passes * add test for corner case * fix lint * fix null-handling tests * only need connection.CheckNamedValue * use checked cast instead of switch on type * [Feature] add a cache to the FetchResult codepath to avoid API calls (#45) lint: fix `ConnectionId` Fix CI for current master (#47) * lint: fix `ConnectionId` * Test fixes for previous change (`41c90a09`) * nit: tweaks to dsn-test Co-authored-by: Agam Brahma fix var inits (#46) * [Fix] SIG-18794: Fix getAsync() to not panic on context exceeded + test. (#56) * [Fix] shutdown race over access of the connection's "restful" structure. * [Feature] Improvements to Async handling SIG-18794: Refactor the monitoring fetcher & auto-cancel async (#57) * SIG-18794: Refactor the monitoring fetcher & auto-cancel async * better configuration * wait until the monitoring data refers to completed queries SIG-16907: Error out early when invalid state for multi-statement requests. (#43) Remove the debug-tracing added earlier. Co-authored-by: Agam Brahma SIG-18794: Make getAsync wait until the query completes or times out. (#59) * SIG-18794: Make getAsync wait until the query completes or times out. * make the getStatus for an async made query block until the ctx timeout SIG-18794: Fix a bug when looping to complete an async query (#60) * discovered a new error code the needs looping when async, enhanced the test * [Maintenance] Rebase: 6/6/22 reconcile with origin/master nit: fix build and gitignore dead code Lint fix nit: fix erroneous deletion nit: s/uuid/UUID nit: refactore use of monitoring fetcher for multi-stmt Refactor out older version of `getMonitoringResult` that didn't have the `endpoint` parameter * [Feature] Enable fetching raw Arrow records directly * [Feature][SIG-24903] Prototype for Snowflake SubmitSync (#81) Adds a SubmitSync method to the Snowflake driver that directly calls Snowflake's POST /queries/v1/query-request API endpoint synchronously, waiting for up to 45 seconds (the current, fixed timeout) but does not enter the "ping pong" phase of fetching query results for long-running queries. Instead, the caller is responsible for using the query ID to wait for and fetch query results. This prototype mainly exists right now so that we can test the relevant functionality and compare performance with the fully asynchronous mode of execution. * [Feature] Snowflake driver to support `WaitForQueryCompletion()`; wait for results to finish without returning result rows (#84) * [Feature] Enable extracting tokens from snowflake conn and constructing TokenAccessor (#86) * Add TokenGetter interface * Add public constructor for TokenAccessor * Expose FillMissingConfigParameters * Add comment * Add comments * Fix comment * [Maintenance] Rebase with upstream 2022-09-19 * If FuncGet() returns no code dont convert to str (#89) * test with context deadline (#92) * add panic Log whether or not we use the cache here (#95) * Log whether or not we use the cache here * Update monitoring.go * experiment * Update connection.go * make status error thing more clear an explicit * more logging * Don't assume error type * add more logging if request fails * pass qid * also tag code * Retry one more time & log response * log one more thing * skipCache * add the proper logging * retry on bad response * retry at most 2 times * lint * Dont check status before block * Don't cache non success responses (#109) Don't cache non success responses and debug the body response * Fix atoi error. (#110) * the result can return just Data field, so we parse a code as an error only if it is set as an error. * Revert "Don't cache non success responses" (#111) Revert "Don't cache non success responses (#109)" This reverts commit ff7d6b960dceadfd459d8f08c70f154e51c21d7b. * fix incorrect error sending (#114) * fix res.Success in async.go (#115) * Add StateDuration to QueryMonitoringData struct (#116) * add logging for sf cache bug (#117) * add logging for sf cache bug * add log without data to see if prod is seeing this issue * Protect against nil and raise error instead (#118) * [Feature] Add support for client log context hooks (#119) * [Feature] Add support for client log context hooks * Pin honnef.co/go/stools/cmd/staticcheck dep version * Add comment * save changes * fix up * bump down go * put back build test to what it was * just test on linux like we used to * fix status * update test back to what it was * Add a nil check for snowflakeRows.Close (#123) * Only cache successful API responses (#124) * Only cache success responses * Add another if * [Maint] squash 8 commits from my last rebase (#125) * rebase master with head * Add a nil check for snowflakeRows.Close (#123) * Only cache successful API responses (#124) * Only cache success responses * Add another if --------- Co-authored-by: Eric Bannatyne * fmt * use unsupported type * use TimestampNtz as default when the client does not specify data type explicitly * [SIG-35025] Add a nil check in snowflake driver heartbeat function (#133) * Add a nil check in snowflake driver heartbeat function * Update heartbeat.go * [SIG-36502]add more visibility into chunk downloader (#132) * Add status on snowflakes status result so we know if the query is queued in the warehouse (#139) * proof of concept: materialization more info * add status to status response * Don't require username & password if using token accessor (#144) * dont require user & pass for token accessor type * undo rows commit * undo connection commit * Adding Changes of sf_stable_master to master branch (#147) * Add more visibility in chunk downloader * Added changes that are on stable branch but not on master * Added comments to pass lint * remove dependency on arrow v0 (#148) * [SIG-40301] fix integer value not in range error (#149) * use bigFloat when convert int64 to float (#150) * [SIG-40455] fix arrow batch int64 to ntz/ltz (#151) * add arrowToRecord tz tests (#152) * [SIG-40738] convert binary array to string array (#153) * [SIG-40820]don't panic if type mismatch is decode arrow (#154) * Revert "[SIG-40738] convert binary array to string array" (#155) * [SIG-41137] remove decode fixed types from arrowToRecord (#156) * fix recordToSchema and not do conversion for fixedType (#157) * [Security] Remove logging sql text from the snowflake driver (part 2 of 2) (#158) Remove logging sql text from the snowflake driver * Update connection.go * reconcile merge changes * Update bindings_test.go * Increase test timeout to 60 mins to see if it helps with timeout * Fix a hanging test. Change timeout to 30s again. * Fix DSN Test * Fix DSN TESTS * Update submit_sync_test.go * fix TestFunctionParameters and TestValueToString * Revert "SNOW-645253 Handle binding named parameters (#850)" This reverts commit 67ec6cf62b6160c1fa575e89e8bed21364ef69cf. * Revert "Revert "SNOW-645253 Handle binding named parameters (#850)"" This reverts commit a6a8a01c74fc0db4804e6bc964f5c4c8aa310e35. * Fix DSN for external Browser * dont break things * Final Fix --------- Co-authored-by: Kiran Dama <69480841+sfc-gh-kdama@users.noreply.github.com> Co-authored-by: Lorna <115649563+sfc-gh-ext-simba-lb@users.noreply.github.com> Co-authored-by: Piotr Fus Co-authored-by: Dawid Heyman Co-authored-by: Piotr Bulawa Co-authored-by: Angel Antonio Avalos Cisneros Co-authored-by: Matt Topol Co-authored-by: Dominik Przybysz <132913826+sfc-gh-dprzybysz@users.noreply.github.com> Co-authored-by: Rami <72725910+ramikg@users.noreply.github.com> Co-authored-by: Srikanth Reddy Kumbham Co-authored-by: Agam Brahma Co-authored-by: Luke Paulsen Co-authored-by: Mihai Claudiu Toader Co-authored-by: Yifeng-Sigma Co-authored-by: Jack Qian Co-authored-by: Agam Brahma Co-authored-by: Greg Owen Co-authored-by: Max Seiden Co-authored-by: fengqingthu Co-authored-by: Eric Bannatyne Co-authored-by: Madison Chamberlain <46542378+madisonchamberlain@users.noreply.github.com> Co-authored-by: Eric Bannatyne Co-authored-by: mansap22 <110418152+mansap22@users.noreply.github.com> Co-authored-by: Ayman Elkfrawy <120422207+ayman-sigma@users.noreply.github.com> Co-authored-by: GregOwen Co-authored-by: sureshmula2 <126016714+sureshmula2@users.noreply.github.com> Co-authored-by: Ryan Kwong Co-authored-by: Madison Chamberlain --- async_test.go | 103 +++++- bind_uploader.go | 23 +- bindings_test.go | 311 ++++++++++++---- ci/scripts/test_component.sh | 3 +- cmd/showparam/showparam.go | 59 +-- connection.go | 87 +++-- connection_util.go | 14 +- converter.go | 119 ++++-- converter_test.go | 332 ++++++++++++++--- datatype.go | 11 - datatype_test.go | 40 +- driver_test.go | 396 +++++++++++--------- dsn.go | 194 +++++++++- dsn_test.go | 689 +++++++++++++++++++++++++++-------- errors.go | 44 ++- htap_test.go | 55 ++- query.go | 5 + rows.go | 14 +- statement_test.go | 270 +++++++++++++- submit_sync_test.go | 10 +- telemetry_test.go | 206 +++++++++++ util.go | 13 + 22 files changed, 2311 insertions(+), 687 deletions(-) diff --git a/async_test.go b/async_test.go index 447337cc6..e0190ecdb 100644 --- a/async_test.go +++ b/async_test.go @@ -1,9 +1,10 @@ -// Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved. +// Copyright (c) 2021-2023 Snowflake Computing Inc. All rights reserved. package gosnowflake import ( "context" + "database/sql" "fmt" "testing" "time" @@ -16,7 +17,7 @@ func TestAsyncMode(t *testing.T) { var idx int var v string - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQueryContext(ctx, fmt.Sprintf(selectRandomGenerator, numrows)) defer rows.Close() @@ -45,12 +46,39 @@ func TestAsyncMode(t *testing.T) { }) } +func TestAsyncModeMultiStatement(t *testing.T) { + withMultiStmtCtx, _ := WithMultiStatement(context.Background(), 6) + ctx := WithAsyncMode(withMultiStmtCtx) + multiStmtQuery := "begin;\n" + + "delete from test_multi_statement_async;\n" + + "insert into test_multi_statement_async values (1, 'a'), (2, 'b');\n" + + "select 1;\n" + + "select 2;\n" + + "rollback;" + + runDBTest(t, func(dbt *DBTest) { + dbt.mustExec("drop table if exists test_multi_statement_async") + dbt.mustExec(`create or replace table test_multi_statement_async( + c1 number, c2 string) as select 10, 'z'`) + defer dbt.mustExec("drop table if exists test_multi_statement_async") + + res := dbt.mustExecContext(ctx, multiStmtQuery) + count, err := res.RowsAffected() + if err != nil { + t.Fatalf("res.RowsAffected() returned error: %v", err) + } + if count != 3 { + t.Fatalf("expected 3 affected rows, got %d", count) + } + }) +} + func TestAsyncModeCancel(t *testing.T) { withCancelCtx, cancel := context.WithCancel(context.Background()) ctx := WithAsyncMode(withCancelCtx) numrows := 100000 - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { dbt.mustQueryContext(ctx, fmt.Sprintf(selectRandomGenerator, numrows)) cancel() }) @@ -67,7 +95,7 @@ func TestAsyncModeNoFetch(t *testing.T) { // completes, so we make the test take longer than 45s secondsToRun := 50 - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { start := time.Now() rows := dbt.mustQueryContext(ctx, fmt.Sprintf(selectTimelineGenerator, secondsToRun)) defer rows.Close() @@ -104,7 +132,7 @@ func TestAsyncModeNoFetch(t *testing.T) { func TestAsyncQueryFail(t *testing.T) { ctx := WithAsyncMode(context.Background()) - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQueryContext(ctx, "selectt 1") defer rows.Close() @@ -125,10 +153,18 @@ func TestMultipleAsyncQueries(t *testing.T) { ch1 := make(chan string) ch2 := make(chan string) - runTests(t, dsn, func(dbt *DBTest) { - rows1 := dbt.mustQueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s1, 30)) + db := openDB(t) + + runDBTest(t, func(dbt *DBTest) { + rows1, err := db.QueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s1, 30)) + if err != nil { + t.Fatalf("can't read rows1: %v", err) + } defer rows1.Close() - rows2 := dbt.mustQueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s2, 10)) + rows2, err := db.QueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s2, 10)) + if err != nil { + t.Fatalf("can't read rows2: %v", err) + } defer rows2.Close() go retrieveRows(rows1, ch1) @@ -155,11 +191,19 @@ func TestMultipleAsyncSuccessAndFailedQueries(t *testing.T) { ch1 := make(chan string) ch2 := make(chan string) - runTests(t, dsn, func(dbt *DBTest) { - rows1 := dbt.mustQueryContext(ctx, fmt.Sprintf("select distinct '%s' from table (generator(timelimit=>3))", s1)) + db := openDB(t) + + runDBTest(t, func(dbt *DBTest) { + rows1, err := db.QueryContext(ctx, fmt.Sprintf("select distinct '%s' from table (generator(timelimit=>3))", s1)) + if err != nil { + t.Fatalf("can't read rows1: %v", err) + } defer rows1.Close() - rows2 := dbt.mustQueryContext(ctx, fmt.Sprintf("select distinct '%s' from table (generator(timelimit=>7))", s2)) + rows2, err := db.QueryContext(ctx, fmt.Sprintf("select distinct '%s' from table (generator(timelimit=>7))", s2)) + if err != nil { + t.Fatalf("can't read rows2: %v", err) + } defer rows2.Close() go retrieveRows(rows1, ch1) @@ -179,7 +223,7 @@ func TestMultipleAsyncSuccessAndFailedQueries(t *testing.T) { }) } -func retrieveRows(rows *RowsExtended, ch chan string) { +func retrieveRows(rows *sql.Rows, ch chan string) { var s string for rows.Next() { if err := rows.Scan(&s); err != nil { @@ -191,3 +235,38 @@ func retrieveRows(rows *RowsExtended, ch chan string) { ch <- s close(ch) } + +func TestLongRunningAsyncQuery(t *testing.T) { + conn := openConn(t) + defer conn.Close() + + ctx, _ := WithMultiStatement(context.Background(), 0) + query := "CALL SYSTEM$WAIT(50, 'SECONDS');use snowflake_sample_data" + + rows, err := conn.QueryContext(WithAsyncMode(ctx), query) + if err != nil { + t.Fatalf("failed to run a query. %v, err: %v", query, err) + } + defer rows.Close() + var v string + i := 0 + for { + for rows.Next() { + err := rows.Scan(&v) + if err != nil { + t.Fatalf("failed to get result. err: %v", err) + } + if v == "" { + t.Fatal("should have returned a result") + } + results := []string{"waited 50 seconds", "Statement executed successfully."} + if v != results[i] { + t.Fatalf("unexpected result returned. expected: %v, but got: %v", results[i], v) + } + i++ + } + if !rows.NextResultSet() { + break + } + } +} diff --git a/bind_uploader.go b/bind_uploader.go index da817f159..5c10ff5e3 100644 --- a/bind_uploader.go +++ b/bind_uploader.go @@ -5,6 +5,7 @@ package gosnowflake import ( "bytes" "context" + "database/sql" "database/sql/driver" "fmt" "reflect" @@ -221,6 +222,10 @@ func getBindValues(bindings []driver.NamedValue) (map[string]execBindParameter, dataType = binding.Value.(SnowflakeDataType) default: // This binding is an actual parameter for the query + if tnt, ok := binding.Value.(TypedNullTime); ok { + dataType = convertTzTypeToSnowflakeType(tnt.TzType) + binding.Value = tnt.Time + } t := goTypeToSnowflake(binding.Value, dataType) var val interface{} if t == sliceType { @@ -235,7 +240,7 @@ func getBindValues(bindings []driver.NamedValue) (map[string]execBindParameter, if t == nullType || t == unSupportedType { t = textType // if null or not supported, pass to GS as text } - bindValues[strconv.Itoa(idx)] = execBindParameter{ + bindValues[bindingName(binding, idx)] = execBindParameter{ Type: t.String(), Value: val, } @@ -245,6 +250,13 @@ func getBindValues(bindings []driver.NamedValue) (map[string]execBindParameter, return bindValues, nil } +func bindingName(nv driver.NamedValue, idx int) string { + if nv.Name != "" { + return nv.Name + } + return strconv.Itoa(idx) +} + func arrayBindValueCount(bindValues []driver.NamedValue) int { if !isArrayBind(bindValues) { return 0 @@ -298,3 +310,12 @@ func supportedArrayBind(nv *driver.NamedValue) bool { return false } } + +func supportedNullBind(nv *driver.NamedValue) bool { + switch reflect.TypeOf(nv.Value) { + case reflect.TypeOf(sql.NullString{}), reflect.TypeOf(sql.NullInt64{}), + reflect.TypeOf(sql.NullBool{}), reflect.TypeOf(sql.NullFloat64{}), reflect.TypeOf(TypedNullTime{}): + return true + } + return false +} diff --git a/bindings_test.go b/bindings_test.go index 490729d14..44561545e 100644 --- a/bindings_test.go +++ b/bindings_test.go @@ -9,6 +9,7 @@ import ( "fmt" "math/big" "math/rand" + "reflect" "strconv" "testing" "time" @@ -23,9 +24,9 @@ const ( selectAllSQL = "select * from TEST_PREP_STATEMENT ORDER BY 1" createTableSQLBulkArray = `create or replace table test_bulk_array(c1 INTEGER, - c2 FLOAT, c3 BOOLEAN, c4 STRING, C5 BINARY)` + c2 FLOAT, c3 BOOLEAN, c4 STRING, C5 BINARY, C6 INTEGER)` deleteTableSQLBulkArray = "drop table if exists test_bulk_array" - insertSQLBulkArray = "insert into test_bulk_array values(?, ?, ?, ?, ?)" + insertSQLBulkArray = "insert into test_bulk_array values(?, ?, ?, ?, ?, ?)" selectAllSQLBulkArray = "select * from test_bulk_array ORDER BY 1" createTableSQLBulkArrayDateTimeTimestamp = `create or replace table test_bulk_array_DateTimeTimestamp( @@ -36,9 +37,9 @@ const ( ) func TestBindingNull(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (id int, c1 STRING, c2 BOOLEAN)") - _, err := dbt.db.Exec("INSERT INTO test VALUES (1, ?, ?)", + _, err := dbt.exec("INSERT INTO test VALUES (1, ?, ?)", DataTypeText, "hello", DataTypeNull, nil, ) @@ -50,54 +51,53 @@ func TestBindingNull(t *testing.T) { } func TestBindingFloat64(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { types := [2]string{"FLOAT", "DOUBLE"} expected := 42.23 var out float64 var rows *RowsExtended for _, v := range types { - dbt.mustExec(fmt.Sprintf("CREATE TABLE test (id int, value %v)", v)) - dbt.mustExec("INSERT INTO test VALUES (1, ?)", expected) - rows = dbt.mustQuery("SELECT value FROM test WHERE id = ?", 1) - defer rows.Close() - if rows.Next() { - rows.Scan(&out) - if expected != out { - dbt.Errorf("%s: %g != %g", v, expected, out) + t.Run(v, func(t *testing.T) { + dbt.mustExec(fmt.Sprintf("CREATE OR REPLACE TABLE test (id int, value %v)", v)) + dbt.mustExec("INSERT INTO test VALUES (1, ?)", expected) + rows = dbt.mustQuery("SELECT value FROM test WHERE id = ?", 1) + defer rows.Close() + if rows.Next() { + rows.Scan(&out) + if expected != out { + dbt.Errorf("%s: %g != %g", v, expected, out) + } + } else { + dbt.Errorf("%s: no data", v) } - } else { - dbt.Errorf("%s: no data", v) - } - dbt.mustExec("DROP TABLE IF EXISTS test") + }) } + dbt.mustExec("DROP TABLE IF EXISTS test") }) } // TestBindingUint64 tests uint64 binding. Should fail as unit64 is not a // supported binding value by Go's sql package. func TestBindingUint64(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - types := []string{"INTEGER"} + runDBTest(t, func(dbt *DBTest) { expected := uint64(18446744073709551615) - for _, v := range types { - dbt.mustExec(fmt.Sprintf("CREATE TABLE test (id int, value %v)", v)) - if _, err := dbt.db.Exec("INSERT INTO test VALUES (1, ?)", expected); err == nil { - dbt.Fatal("should fail as uint64 values with high bit set are not supported.") - } else { - logger.Infof("expected err: %v", err) - } - dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.mustExec("CREATE OR REPLACE TABLE test (id int, value INTEGER)") + if _, err := dbt.exec("INSERT INTO test VALUES (1, ?)", expected); err == nil { + dbt.Fatal("should fail as uint64 values with high bit set are not supported.") + } else { + logger.Infof("expected err: %v", err) } + dbt.mustExec("DROP TABLE IF EXISTS test") }) } func TestBindingDateTimeTimestamp(t *testing.T) { createDSN(PSTLocation) - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { expected := time.Now() dbt.mustExec( "CREATE OR REPLACE TABLE tztest (id int, ntz timestamp_ntz, ltz timestamp_ltz, dt date, tm time)") - stmt, err := dbt.db.Prepare("INSERT INTO tztest(id,ntz,ltz,dt,tm) VALUES(1,?,?,?,?)") + stmt, err := dbt.prepare("INSERT INTO tztest(id,ntz,ltz,dt,tm) VALUES(1,?,?,?,?)") if err != nil { dbt.Fatal(err.Error()) } @@ -163,36 +163,33 @@ func TestBindingDateTimeTimestamp(t *testing.T) { } func TestBindingBinary(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE OR REPLACE TABLE bintest (id int, b binary, c binary)") + runDBTest(t, func(dbt *DBTest) { + dbt.mustExec("CREATE OR REPLACE TABLE bintest (id int, b binary)") var b = []byte{0x01, 0x02, 0x03} - dbt.mustExec("INSERT INTO bintest(id,b,c) VALUES(1, ?, ?)", DataTypeBinary, b, DataTypeBinary, b) - rows := dbt.mustQuery("SELECT b, c FROM bintest WHERE id=?", 1) + dbt.mustExec("INSERT INTO bintest(id,b) VALUES(1, ?)", DataTypeBinary, b) + rows := dbt.mustQuery("SELECT b FROM bintest WHERE id=?", 1) defer rows.Close() if rows.Next() { var rb []byte - var rc []byte - if err := rows.Scan(&rb, &rc); err != nil { + if err := rows.Scan(&rb); err != nil { dbt.Errorf("failed to scan data. err: %v", err) } if !bytes.Equal(b, rb) { dbt.Errorf("failed to match data. expected: %v, got: %v", b, rb) } - if !bytes.Equal(b, rc) { - dbt.Errorf("failed to match data. expected: %v, got: %v", b, rc) - } } else { dbt.Errorf("no data") } dbt.mustExec("DROP TABLE bintest") }) + } func TestBindingTimestampTZ(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { expected := time.Now() dbt.mustExec("CREATE OR REPLACE TABLE tztest (id int, tz timestamp_tz)") - stmt, err := dbt.db.Prepare("INSERT INTO tztest(id,tz) VALUES(1, ?)") + stmt, err := dbt.prepare("INSERT INTO tztest(id,tz) VALUES(1, ?)") if err != nil { dbt.Fatal(err.Error()) } @@ -218,7 +215,7 @@ func TestBindingTimestampTZ(t *testing.T) { // SNOW-755844: Test the use of a pointer *time.Time type in user-defined structures to perform updates/inserts func TestBindingTimePtrInStruct(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { type timePtrStruct struct { id *int timeVal *time.Time @@ -231,7 +228,7 @@ func TestBindingTimePtrInStruct(t *testing.T) { runInsertQuery := false for i := 0; i < 2; i++ { if !runInsertQuery { - _, err := dbt.db.Exec("INSERT INTO timeStructTest(id,tz) VALUES(?, ?)", testStruct.id, testStruct.timeVal) + _, err := dbt.exec("INSERT INTO timeStructTest(id,tz) VALUES(?, ?)", testStruct.id, testStruct.timeVal) if err != nil { dbt.Fatal(err.Error()) } @@ -240,7 +237,7 @@ func TestBindingTimePtrInStruct(t *testing.T) { // Update row with a new time value expectedTime = time.Now().Add(1) testStruct.timeVal = &expectedTime - _, err := dbt.db.Exec("UPDATE timeStructTest SET tz = ? where id = ?", testStruct.timeVal, testStruct.id) + _, err := dbt.exec("UPDATE timeStructTest SET tz = ? where id = ?", testStruct.timeVal, testStruct.id) if err != nil { dbt.Fatal(err.Error()) } @@ -265,7 +262,7 @@ func TestBindingTimePtrInStruct(t *testing.T) { // SNOW-755844: Test the use of a time.Time type in user-defined structures to perform updates/inserts func TestBindingTimeInStruct(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { type timeStruct struct { id int timeVal time.Time @@ -278,7 +275,7 @@ func TestBindingTimeInStruct(t *testing.T) { runInsertQuery := false for i := 0; i < 2; i++ { if !runInsertQuery { - _, err := dbt.db.Exec("INSERT INTO timeStructTest(id,tz) VALUES(?, ?)", testStruct.id, testStruct.timeVal) + _, err := dbt.exec("INSERT INTO timeStructTest(id,tz) VALUES(?, ?)", testStruct.id, testStruct.timeVal) if err != nil { dbt.Fatal(err.Error()) } @@ -287,7 +284,7 @@ func TestBindingTimeInStruct(t *testing.T) { // Update row with a new time value expectedTime = time.Now().Add(1) testStruct.timeVal = expectedTime - _, err := dbt.db.Exec("UPDATE timeStructTest SET tz = ? where id = ?", testStruct.timeVal, testStruct.id) + _, err := dbt.exec("UPDATE timeStructTest SET tz = ? where id = ?", testStruct.timeVal, testStruct.id) if err != nil { dbt.Fatal(err.Error()) } @@ -311,14 +308,14 @@ func TestBindingTimeInStruct(t *testing.T) { } func TestBindingInterface(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQueryContext( WithHigherPrecision(context.Background()), selectVariousTypes) defer rows.Close() if !rows.Next() { dbt.Error("failed to query") } - var v1, v2, v3, v4, v5, v6 interface{} + var v1, v2, v3, v4, v5, v6 any if err := rows.Scan(&v1, &v2, &v3, &v4, &v5, &v6); err != nil { dbt.Errorf("failed to scan: %#v", err) } @@ -338,13 +335,13 @@ func TestBindingInterface(t *testing.T) { } func TestBindingInterfaceString(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQuery(selectVariousTypes) defer rows.Close() if !rows.Next() { dbt.Error("failed to query") } - var v1, v2, v3, v4, v5, v6 interface{} + var v1, v2, v3, v4, v5, v6 any if err := rows.Scan(&v1, &v2, &v3, &v4, &v5, &v6); err != nil { dbt.Errorf("failed to scan: %#v", err) } @@ -365,9 +362,9 @@ func TestBindingInterfaceString(t *testing.T) { } func TestBulkArrayBindingInterfaceNil(t *testing.T) { - nilArray := make([]interface{}, 1) + nilArray := make([]any, 1) - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { dbt.mustExec(createTableSQL) defer dbt.mustExec(deleteTableSQL) @@ -430,31 +427,35 @@ func TestBulkArrayBindingInterfaceNil(t *testing.T) { } func TestBulkArrayBindingInterface(t *testing.T) { - intArray := make([]interface{}, 3) + intArray := make([]any, 3) intArray[0] = int32(100) intArray[1] = int32(200) - fltArray := make([]interface{}, 3) + fltArray := make([]any, 3) fltArray[0] = float64(0.1) fltArray[2] = float64(5.678) - boolArray := make([]interface{}, 3) + boolArray := make([]any, 3) boolArray[1] = false boolArray[2] = true - strArray := make([]interface{}, 3) + strArray := make([]any, 3) strArray[2] = "test3" - byteArray := make([]interface{}, 3) + byteArray := make([]any, 3) byteArray[0] = []byte{0x01, 0x02, 0x03} byteArray[2] = []byte{0x07, 0x08, 0x09} - runTests(t, dsn, func(dbt *DBTest) { + int64Array := make([]any, 3) + int64Array[0] = int64(100) + int64Array[1] = int64(200) + + runDBTest(t, func(dbt *DBTest) { dbt.mustExec(createTableSQLBulkArray) defer dbt.mustExec(deleteTableSQLBulkArray) dbt.mustExec(insertSQLBulkArray, Array(&intArray), Array(&fltArray), - Array(&boolArray), Array(&strArray), Array(&byteArray)) + Array(&boolArray), Array(&strArray), Array(&byteArray), Array(&int64Array)) rows := dbt.mustQuery(selectAllSQLBulkArray) defer rows.Close() @@ -463,10 +464,11 @@ func TestBulkArrayBindingInterface(t *testing.T) { var v2 sql.NullBool var v3 sql.NullString var v4 []byte + var v5 sql.NullInt64 cnt := 0 for i := 0; rows.Next(); i++ { - if err := rows.Scan(&v0, &v1, &v2, &v3, &v4); err != nil { + if err := rows.Scan(&v0, &v1, &v2, &v3, &v4, &v5); err != nil { t.Fatal(err) } if v0.Valid { @@ -504,6 +506,13 @@ func TestBulkArrayBindingInterface(t *testing.T) { } else if v4 != nil { t.Fatalf("failed to fetch the []byte column v4. expected %v, got: %v", byteArray[i], v4) } + if v5.Valid { + if v5.Int64 != int64Array[i] { + t.Fatalf("failed to fetch the sql.NullInt64 column v5. expected %v, got: %v", int64Array[i], v5.Int64) + } + } else if int64Array[i] != nil { + t.Fatalf("failed to fetch the sql.NullInt64 column v5. expected %v, got: %v", int64Array[i], v5) + } cnt++ } if cnt != len(intArray) { @@ -521,27 +530,27 @@ func TestBulkArrayBindingInterfaceDateTimeTimestamp(t *testing.T) { if err != nil { t.Error(err) } - ntzArray := make([]interface{}, 3) + ntzArray := make([]any, 3) ntzArray[0] = now ntzArray[1] = now.Add(1) - ltzArray := make([]interface{}, 3) + ltzArray := make([]any, 3) ltzArray[1] = now.Add(2).In(loc) ltzArray[2] = now.Add(3).In(loc) - tzArray := make([]interface{}, 3) + tzArray := make([]any, 3) tzArray[0] = tz.Add(4).In(loc) tzArray[2] = tz.Add(5).In(loc) - dtArray := make([]interface{}, 3) + dtArray := make([]any, 3) dtArray[0] = tz.Add(6).In(loc) dtArray[1] = now.Add(7).In(loc) - tmArray := make([]interface{}, 3) + tmArray := make([]any, 3) tmArray[1] = now.Add(8).In(loc) tmArray[2] = now.Add(9).In(loc) - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { dbt.mustExec(createTableSQLBulkArrayDateTimeTimestamp) defer dbt.mustExec(deleteTableSQLBulkArrayDateTimeTimestamp) @@ -643,11 +652,11 @@ func testBindingArray(t *testing.T, bulk bool) { dtArray := []time.Time{now.Add(9), now.Add(10), now.Add(11)} tmArray := []time.Time{now.Add(12), now.Add(13), now.Add(14)} - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { dbt.mustExec(createTableSQL) defer dbt.mustExec(deleteTableSQL) if bulk { - if _, err := dbt.db.Exec("ALTER SESSION SET CLIENT_STAGE_ARRAY_BINDING_THRESHOLD = 1"); err != nil { + if _, err := dbt.exec("ALTER SESSION SET CLIENT_STAGE_ARRAY_BINDING_THRESHOLD = 1"); err != nil { t.Error(err) } } @@ -711,7 +720,7 @@ func testBindingArray(t *testing.T, bulk bool) { } func TestBulkArrayBinding(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { dbt.mustExec(fmt.Sprintf("create or replace table %v (c1 integer, c2 string)", dbname)) numRows := 100000 intArr := make([]int, numRows) @@ -722,6 +731,7 @@ func TestBulkArrayBinding(t *testing.T) { } dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?)", dbname), Array(&intArr), Array(&strArr)) rows := dbt.mustQuery("select * from " + dbname) + defer rows.Close() cnt := 0 var i int var s string @@ -755,7 +765,7 @@ func TestBulkArrayMultiPartBinding(t *testing.T) { tempTableName := fmt.Sprintf("test_table_%v", randomString(5)) ctx := context.Background() - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { dbt.mustExec(fmt.Sprintf("CREATE TABLE %s (C VARCHAR(64) NOT NULL)", tempTableName)) defer dbt.mustExec("drop table " + tempTableName) @@ -764,6 +774,7 @@ func TestBulkArrayMultiPartBinding(t *testing.T) { fmt.Sprintf("INSERT INTO %s VALUES (?)", tempTableName), Array(&randomStrings)) rows := dbt.mustQuery("select count(*) from " + tempTableName) + defer rows.Close() if rows.Next() { var count int if err := rows.Scan(&count); err != nil { @@ -773,6 +784,7 @@ func TestBulkArrayMultiPartBinding(t *testing.T) { } rows := dbt.mustQuery("select count(*) from " + tempTableName) + defer rows.Close() if rows.Next() { var count int if err := rows.Scan(&count); err != nil { @@ -786,7 +798,7 @@ func TestBulkArrayMultiPartBinding(t *testing.T) { } func TestBulkArrayMultiPartBindingInt(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { dbt.mustExec("create or replace table binding_test (c1 integer)") startNum := 1000000 endNum := 3000000 @@ -795,12 +807,13 @@ func TestBulkArrayMultiPartBindingInt(t *testing.T) { for i := startNum; i < endNum; i++ { intArr[i-startNum] = i } - _, err := dbt.db.Exec("insert into binding_test values (?)", Array(&intArr)) + _, err := dbt.exec("insert into binding_test values (?)", Array(&intArr)) if err != nil { t.Errorf("Should have succeeded to insert. err: %v", err) } rows := dbt.mustQuery("select * from binding_test order by c1") + defer rows.Close() cnt := startNum var i int for rows.Next() { @@ -820,15 +833,15 @@ func TestBulkArrayMultiPartBindingInt(t *testing.T) { } func TestBulkArrayMultiPartBindingWithNull(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { dbt.mustExec("create or replace table binding_test (c1 integer, c2 string)") startNum := 1000000 endNum := 2000000 numRows := endNum - startNum // Define the integer and string arrays - intArr := make([]interface{}, numRows) - stringArr := make([]interface{}, numRows) + intArr := make([]any, numRows) + stringArr := make([]any, numRows) for i := startNum; i < endNum; i++ { intArr[i-startNum] = i stringArr[i-startNum] = fmt.Sprint(i) @@ -842,12 +855,13 @@ func TestBulkArrayMultiPartBindingWithNull(t *testing.T) { stringArr[2] = nil stringArr[3] = nil - _, err := dbt.db.Exec("insert into binding_test values (?, ?)", Array(&intArr), Array(&stringArr)) + _, err := dbt.exec("insert into binding_test values (?, ?)", Array(&intArr), Array(&stringArr)) if err != nil { t.Errorf("Should have succeeded to insert. err: %v", err) } rows := dbt.mustQuery("select * from binding_test order by c1,c2") + defer rows.Close() cnt := startNum var i sql.NullInt32 var s sql.NullString @@ -879,3 +893,148 @@ func TestBulkArrayMultiPartBindingWithNull(t *testing.T) { dbt.mustExec("DROP TABLE binding_test") }) } + +func TestFunctionParameters(t *testing.T) { + testcases := []struct { + testDesc string + paramType string + input any + nullResult bool + }{ + {"textAndNullStringResultInNull", "text", sql.NullString{}, true}, + {"numberAndNullInt64ResultInNull", "number", sql.NullInt64{}, true}, + {"floatAndNullFloat64ResultInNull", "float", sql.NullFloat64{}, true}, + {"booleanAndAndNullBoolResultInNull", "boolean", sql.NullBool{}, true}, + {"dateAndTypedNullTimeResultInNull", "date", TypedNullTime{sql.NullTime{}, DateType}, true}, + {"datetimeAndTypedNullTimeResultInNull", "datetime", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true}, + {"timeAndTypedNullTimeResultInNull", "time", TypedNullTime{sql.NullTime{}, TimeType}, true}, + {"timestampAndTypedNullTimeResultInNull", "timestamp", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true}, + {"timestamp_ntzAndTypedNullTimeResultInNull", "timestamp_ntz", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true}, + {"timestamp_ltzAndTypedNullTimeResultInNull", "timestamp_ltz", TypedNullTime{sql.NullTime{}, TimestampLTZType}, true}, + {"timestamp_tzAndTypedNullTimeResultInNull", "timestamp_tz", TypedNullTime{sql.NullTime{}, TimestampTZType}, true}, + {"textAndStringResultInNotNull", "text", "string", false}, + {"numberAndIntegerResultInNotNull", "number", 123, false}, + {"floatAndFloatResultInNotNull", "float", 123.01, false}, + {"booleanAndBooleanResultInNotNull", "boolean", true, false}, + {"dateAndTimeResultInNotNull", "date", time.Now(), false}, + {"datetimeAndTimeResultInNotNull", "datetime", time.Now(), false}, + {"timeAndTimeResultInNotNull", "time", time.Now(), false}, + {"timestampAndTimeResultInNotNull", "timestamp", time.Now(), false}, + {"timestamp_ntzAndTimeResultInNotNull", "timestamp_ntz", time.Now(), false}, + {"timestamp_ltzAndTimeResultInNotNull", "timestamp_ltz", time.Now(), false}, + {"timestamp_tzAndTimeResultInNotNull", "timestamp_tz", time.Now(), false}, + } + + runDBTest(t, func(dbt *DBTest) { + for _, tc := range testcases { + t.Run(tc.testDesc, func(t *testing.T) { + query := fmt.Sprintf(` + CREATE OR REPLACE FUNCTION NULLPARAMFUNCTION("param1" %v) + RETURNS TABLE("r1" %v) + LANGUAGE SQL + AS 'select param1';`, tc.paramType, tc.paramType) + dbt.mustExec(query) + rows, err := dbt.query("select * from table(NULLPARAMFUNCTION(?))", tc.input) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + if rows.Err() != nil { + t.Fatal(err) + } + if !rows.Next() { + t.Fatal("no rows fetched") + } + var r1 any + err = rows.Scan(&r1) + if err != nil { + t.Fatal(err) + } + if tc.nullResult && r1 != nil { + t.Fatalf("the result for %v is of type %v but should be null", tc.paramType, reflect.TypeOf(r1)) + } + if !tc.nullResult && r1 == nil { + t.Fatalf("the result for %v should not be null", tc.paramType) + } + }) + } + }) +} + +func TestVariousBindingModes(t *testing.T) { + testcases := []struct { + testDesc string + paramType string + input any + isNil bool + }{ + {"textAndString", "text", "string", false}, + {"numberAndInteger", "number", 123, false}, + {"floatAndFloat", "float", 123.01, false}, + {"booleanAndBoolean", "boolean", true, false}, + {"dateAndTime", "date", time.Now().Truncate(24 * time.Hour), false}, + {"datetimeAndTime", "datetime", time.Now(), false}, + {"timeAndTime", "time", "12:34:56", false}, + {"timestampAndTime", "timestamp", time.Now(), false}, + {"timestamp_ntzAndTime", "timestamp_ntz", time.Now(), false}, + {"timestamp_ltzAndTime", "timestamp_ltz", time.Now(), false}, + {"timestamp_tzAndTime", "timestamp_tz", time.Now(), false}, + {"textAndNullString", "text", sql.NullString{}, true}, + {"numberAndNullInt64", "number", sql.NullInt64{}, true}, + {"floatAndNullFloat64", "float", sql.NullFloat64{}, true}, + {"booleanAndAndNullBool", "boolean", sql.NullBool{}, true}, + {"dateAndTypedNullTime", "date", TypedNullTime{sql.NullTime{}, DateType}, true}, + {"datetimeAndTypedNullTime", "datetime", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true}, + {"timeAndTypedNullTime", "time", TypedNullTime{sql.NullTime{}, TimeType}, true}, + {"timestampAndTypedNullTime", "timestamp", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true}, + {"timestamp_ntzAndTypedNullTime", "timestamp_ntz", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true}, + {"timestamp_ltzAndTypedNullTime", "timestamp_ltz", TypedNullTime{sql.NullTime{}, TimestampLTZType}, true}, + {"timestamp_tzAndTypedNullTime", "timestamp_tz", TypedNullTime{sql.NullTime{}, TimestampTZType}, true}, + } + + bindingModes := []struct { + param string + query string + transform func(any) any + }{ + { + param: "?", + transform: func(v any) any { return v }, + }, + { + param: ":1", + transform: func(v any) any { return v }, + }, + { + param: ":param", + transform: func(v any) any { return sql.Named("param", v) }, + }, + } + + runDBTest(t, func(dbt *DBTest) { + for _, tc := range testcases { + for _, bindingMode := range bindingModes { + t.Run(tc.testDesc+" "+bindingMode.param, func(t *testing.T) { + query := fmt.Sprintf(`CREATE OR REPLACE TABLE BINDING_MODES(param1 %v)`, tc.paramType) + dbt.mustExec(query) + if _, err := dbt.exec(fmt.Sprintf("INSERT INTO BINDING_MODES VALUES (%v)", bindingMode.param), bindingMode.transform(tc.input)); err != nil { + t.Fatal(err) + } + if tc.isNil { + query = "SELECT * FROM BINDING_MODES WHERE param1 IS NULL" + } else { + query = fmt.Sprintf("SELECT * FROM BINDING_MODES WHERE param1 = %v", bindingMode.param) + } + rows, err := dbt.query(query, bindingMode.transform(tc.input)) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + if !rows.Next() { + t.Fatal("Expected to return a row") + } + }) + } + } + }) +} diff --git a/ci/scripts/test_component.sh b/ci/scripts/test_component.sh index ec3700569..4829c40f3 100755 --- a/ci/scripts/test_component.sh +++ b/ci/scripts/test_component.sh @@ -12,5 +12,4 @@ if [[ -n "$GITHUB_WORKFLOW" ]]; then fi env | grep SNOWFLAKE | grep -v PASS | sort cd $TOPDIR -go test -timeout 30m -race $COVFLAGS -v . - +go test -timeout 30m -race -coverprofile=coverage.txt -covermode=atomic -v . diff --git a/cmd/showparam/showparam.go b/cmd/showparam/showparam.go index ce77261cd..e29713a08 100644 --- a/cmd/showparam/showparam.go +++ b/cmd/showparam/showparam.go @@ -6,62 +6,31 @@ import ( "flag" "fmt" "log" - "os" - "strconv" sf "github.com/snowflakedb/gosnowflake" ) -// getDSN constructs a DSN based on the test connection parameters -func getDSN(params map[string]*string) (string, *sf.Config, error) { - env := func(k string, failOnMissing bool) string { - if value := os.Getenv(k); value != "" { - return value - } - if failOnMissing { - log.Fatalf("%v environment variable is not set.", k) - } - return "" - } - - account := env("SNOWFLAKE_TEST_ACCOUNT", true) - user := env("SNOWFLAKE_TEST_USER", true) - password := env("SNOWFLAKE_TEST_PASSWORD", true) - host := env("SNOWFLAKE_TEST_HOST", false) - port := env("SNOWFLAKE_TEST_PORT", false) - protocol := env("SNOWFLAKE_TEST_PROTOCOL", false) - - portStr, err := strconv.Atoi(port) - if err != nil { - return "", nil, err - } - cfg := &sf.Config{ - Account: account, - User: user, - Password: password, - Host: host, - Port: portStr, - Protocol: protocol, - Params: params, - } - - dsn, err := sf.DSN(cfg) - return dsn, cfg, err -} - func main() { if !flag.Parsed() { flag.Parse() } - tmfmt := "MM-DD-YYYY" - dsn, cfg, err := getDSN( - map[string]*string{ - "TIMESTAMP_OUTPUT_FORMAT": &tmfmt, // session parameter - }) + cfg, err := sf.GetConfigFromEnv([]*sf.ConfigParam{ + {Name: "Account", EnvName: "SNOWFLAKE_TEST_ACCOUNT", FailOnMissing: true}, + {Name: "User", EnvName: "SNOWFLAKE_TEST_USER", FailOnMissing: true}, + {Name: "Password", EnvName: "SNOWFLAKE_TEST_PASSWORD", FailOnMissing: true}, + {Name: "Host", EnvName: "SNOWFLAKE_TEST_HOST", FailOnMissing: false}, + {Name: "Port", EnvName: "SNOWFLAKE_TEST_PORT", FailOnMissing: false}, + {Name: "Protocol", EnvName: "SNOWFLAKE_TEST_PROTOCOL", FailOnMissing: false}, + }) if err != nil { - log.Fatalf("failed to create DSN from Config: %v, err: %v", cfg, err) + log.Fatalf("failed to create Config, err: %v", err) + } + tmfmt := "MM-DD-YYYY" + cfg.Params = map[string]*string{ + "TIMESTAMP_OUTPUT_FORMAT": &tmfmt, // session parameter } + dsn, err := sf.DSN(cfg) if err != nil { log.Fatalf("failed to create DSN from Config: %v, err: %v", cfg, err) } diff --git a/connection.go b/connection.go index 7a172f812..a81747d9b 100644 --- a/connection.go +++ b/connection.go @@ -60,16 +60,17 @@ const ( const privateLinkSuffix = "privatelink.snowflakecomputing.com" type snowflakeConn struct { - ctx context.Context - cfg *Config - rest *snowflakeRestful - restMu sync.RWMutex // guard shutdown race - SequenceCounter uint64 - QueryID string - SQLState string - telemetry *snowflakeTelemetry - internal InternalClient - execRespCache *execRespCache + ctx context.Context + cfg *Config + rest *snowflakeRestful + restMu sync.RWMutex + SequenceCounter uint64 + QueryID string + SQLState string + telemetry *snowflakeTelemetry + internal InternalClient + execRespCache *execRespCache + queryContextCache *queryContextCache } var ( @@ -143,9 +144,12 @@ func (sc *snowflakeConn) exec( } logger.WithContext(ctx).Infof("Success: %v, Code: %v", data.Success, code) if !data.Success { - return nil, (populateErrorFields(code, data)).exceptionTelemetry(sc) + err = (populateErrorFields(code, data)).exceptionTelemetry(sc) + return nil, err } + sc.queryContextCache.add(data.Data.QueryContext.Entries...) + // handle PUT/GET commands if isFileTransfer(query) { data, err = sc.processFileTransfer(ctx, data, query, isInternal) @@ -159,8 +163,6 @@ func (sc *snowflakeConn) exec( sc.cfg.Schema = data.Data.FinalSchemaName sc.cfg.Role = data.Data.FinalRoleName sc.cfg.Warehouse = data.Data.FinalWarehouseName - sc.QueryID = data.Data.QueryID - sc.SQLState = data.Data.SQLState sc.populateSessionParameters(data.Data.Parameters) return data, err } @@ -196,7 +198,7 @@ func (sc *snowflakeConn) BeginTx( false /* isInternal */, isDesc, nil); err != nil { return nil, err } - return &snowflakeTx{sc}, nil + return &snowflakeTx{sc, ctx}, nil } func (sc *snowflakeConn) cleanup() { @@ -219,7 +221,7 @@ func (sc *snowflakeConn) Close() (err error) { sc.stopHeartBeat() defer sc.cleanup() - if !sc.cfg.KeepSessionAlive { + if sc.cfg != nil && !sc.cfg.KeepSessionAlive { if err = sc.rest.FuncCloseSession(sc.ctx, sc.rest, sc.rest.RequestTimeout); err != nil { logger.Error(err) } @@ -259,9 +261,9 @@ func (sc *snowflakeConn) ExecContext( if err != nil { logger.WithContext(ctx).Infof("error: %v", err) if data != nil { - code, err := strconv.Atoi(data.Code) - if err != nil { - return nil, err + code, e := strconv.Atoi(data.Code) + if e != nil { + return nil, e } return nil, (&SnowflakeError{ Number: code, @@ -288,11 +290,10 @@ func (sc *snowflakeConn) ExecContext( rows := &snowflakeResult{ affectedRows: updatedRows, insertID: -1, - queryID: sc.QueryID, + queryID: data.Data.QueryID, } // last insert id is not supported by Snowflake - rows.monitoring = mkMonitoringFetcher(sc, sc.QueryID, time.Since(qStart)) - + rows.monitoring = mkMonitoringFetcher(sc, data.Data.QueryID, time.Since(qStart)) return rows, nil } else if isMultiStmt(&data.Data) { rows, err := sc.handleMultiExec(ctx, data.Data) @@ -348,9 +349,9 @@ func (sc *snowflakeConn) queryContextInternal( if err != nil { logger.WithContext(ctx).Errorf("error: %v", err) if data != nil { - code, err := strconv.Atoi(data.Code) - if err != nil { - return nil, err + code, e := strconv.Atoi(data.Code) + if e != nil { + return nil, e } return nil, (&SnowflakeError{ Number: code, @@ -369,8 +370,8 @@ func (sc *snowflakeConn) queryContextInternal( rows := new(snowflakeRows) rows.sc = sc - rows.queryID = sc.QueryID - rows.monitoring = mkMonitoringFetcher(sc, sc.QueryID, time.Since(qStart)) + rows.queryID = data.Data.QueryID + rows.monitoring = mkMonitoringFetcher(sc, data.Data.QueryID, time.Since(qStart)) if isSubmitSync(ctx) && data.Code == queryInProgressCode { rows.status = QueryStatusInProgress @@ -442,10 +443,10 @@ func (sc *snowflakeConn) CheckNamedValue(nv *driver.NamedValue) error { // distinguish them from arguments of type []byte return nil } - if supported := supportedArrayBind(nv); !supported { - return driver.ErrSkip + if supportedNullBind(nv) || supportedArrayBind(nv) { + return nil } - return nil + return driver.ErrSkip } func (sc *snowflakeConn) GetQueryStatus( @@ -479,9 +480,9 @@ func (sc *snowflakeConn) QueryArrowStream(ctx context.Context, query string, bin if err != nil { logger.WithContext(ctx).Errorf("error: %v", err) if data != nil { - code, err := strconv.Atoi(data.Code) - if err != nil { - return nil, err + code, e := strconv.Atoi(data.Code) + if e != nil { + return nil, e } return nil, (&SnowflakeError{ Number: code, @@ -619,11 +620,18 @@ func (asb *ArrowStreamBatch) GetStream(ctx context.Context) (io.ReadCloser, erro // ArrowStreamLoader is a convenience interface for downloading // Snowflake results via multiple Arrow Record Batch streams. +// +// Some queries from Snowflake do not return Arrow data regardless +// of the settings, such as "SHOW WAREHOUSES". In these cases, +// you'll find TotalRows() > 0 but GetBatches returns no batches +// and no errors. In this case, the data is accessible via JSONData +// with the actual types matching up to the metadata in RowTypes. type ArrowStreamLoader interface { GetBatches() ([]ArrowStreamBatch, error) TotalRows() int64 RowTypes() []execResponseRowType Location() *time.Location + JSONData() [][]*string } type snowflakeArrowStreamChunkDownloader struct { @@ -646,6 +654,9 @@ func (scd *snowflakeArrowStreamChunkDownloader) TotalRows() int64 { return scd.T func (scd *snowflakeArrowStreamChunkDownloader) RowTypes() []execResponseRowType { return scd.RowSet.RowType } +func (scd *snowflakeArrowStreamChunkDownloader) JSONData() [][]*string { + return scd.RowSet.JSON +} // the server might have had an empty first batch, check if we can decode // that first batch, if not we skip it. @@ -714,9 +725,10 @@ func (scd *snowflakeArrowStreamChunkDownloader) GetBatches() (out []ArrowStreamB func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, error) { sc := &snowflakeConn{ - SequenceCounter: 0, - ctx: ctx, - cfg: &config, + SequenceCounter: 0, + ctx: ctx, + cfg: &config, + queryContextCache: (&queryContextCache{}).init(), } var st http.RoundTripper = SnowflakeTransport if sc.cfg.Transporter == nil { @@ -765,11 +777,16 @@ func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, err Timeout: sc.cfg.ClientTimeout, Transport: st, }, + JWTClient: &http.Client{ + Timeout: sc.cfg.JWTClientTimeout, + Transport: st, + }, TokenAccessor: tokenAccessor, LoginTimeout: sc.cfg.LoginTimeout, RequestTimeout: sc.cfg.RequestTimeout, FuncPost: postRestful, FuncGet: getRestful, + FuncAuthPost: postAuthRestful, FuncPostQuery: postRestfulQuery, FuncPostQueryHelper: postRestfulQueryHelper, FuncRenewSession: renewRestfulSession, diff --git a/connection_util.go b/connection_util.go index a85ffea22..f9dad37e2 100644 --- a/connection_util.go +++ b/connection_util.go @@ -24,20 +24,22 @@ func (sc *snowflakeConn) isClientSessionKeepAliveEnabled() bool { } func (sc *snowflakeConn) startHeartBeat() { - if !sc.isClientSessionKeepAliveEnabled() { + if sc.cfg != nil && !sc.isClientSessionKeepAliveEnabled() { return } - sc.rest.HeartBeat = &heartbeat{ - restful: sc.rest, + if sc.rest != nil { + sc.rest.HeartBeat = &heartbeat{ + restful: sc.rest, + } + sc.rest.HeartBeat.start() } - sc.rest.HeartBeat.start() } func (sc *snowflakeConn) stopHeartBeat() { - if !sc.isClientSessionKeepAliveEnabled() { + if sc.cfg != nil && !sc.isClientSessionKeepAliveEnabled() { return } - if sc.rest.HeartBeat != nil { + if sc.rest != nil && sc.rest.HeartBeat != nil { sc.rest.HeartBeat.stop() } } diff --git a/converter.go b/converter.go index 19ac55499..4070d5fc5 100644 --- a/converter.go +++ b/converter.go @@ -4,6 +4,7 @@ package gosnowflake import ( "context" + "database/sql" "database/sql/driver" "encoding/hex" "fmt" @@ -57,17 +58,18 @@ func isInterfaceArrayBinding(t interface{}) bool { // goTypeToSnowflake translates Go data type to Snowflake data type. func goTypeToSnowflake(v driver.Value, dataType SnowflakeDataType) snowflakeType { + // (raj) This will fail build. Reconcile after merge if dataType == nil { switch t := v.(type) { case SnowflakeDataType: return changeType - case int64: + case int64, sql.NullInt64: return fixedType - case float64: + case float64, sql.NullFloat64: return realType - case bool: + case bool, sql.NullBool: return booleanType - case string: + case string, sql.NullString: return textType case []byte: if t == nil { @@ -75,7 +77,7 @@ func goTypeToSnowflake(v driver.Value, dataType SnowflakeDataType) snowflakeType } // If we don't have an explicit data type, binary blobs are unsupported return unSupportedType - case time.Time: + case time.Time, sql.NullTime: // Default timestamp type return timestampNtzType } @@ -151,37 +153,71 @@ func valueToString(v driver.Value, dataType SnowflakeDataType) (*string, error) s := v1.String() return &s, nil case reflect.Struct: - if tm, ok := v.(time.Time); ok { - switch { - case dataType.Equals(DataTypeDate): - _, offset := tm.Zone() - tm = tm.Add(time.Second * time.Duration(offset)) - s := strconv.FormatInt(tm.Unix()*1000, 10) - return &s, nil - case dataType.Equals(DataTypeTime): - s := fmt.Sprintf("%d", - (tm.Hour()*3600+tm.Minute()*60+tm.Second())*1e9+tm.Nanosecond()) - return &s, nil - case dataType.Equals(DataTypeTimestampNtz) || dataType.Equals(DataTypeTimestampLtz) || dataType == nil: - // NOTE(greg): when the client has not given us an explicit dataType - // (dataType == nil), we assume DataTypeTimestampNtz for compatibility - // with the upstream driver - unixTime, _ := new(big.Int).SetString(fmt.Sprintf("%d", tm.Unix()), 10) - m, _ := new(big.Int).SetString(strconv.FormatInt(1e9, 10), 10) - unixTime.Mul(unixTime, m) - tmNanos, _ := new(big.Int).SetString(fmt.Sprintf("%d", tm.Nanosecond()), 10) - s := unixTime.Add(unixTime, tmNanos).String() - return &s, nil - case dataType.Equals(DataTypeTimestampTz): - _, offset := tm.Zone() - s := fmt.Sprintf("%v %v", tm.UnixNano(), offset/60+1440) - return &s, nil + switch typedVal := v.(type) { + case time.Time: + return timeTypeValueToString(typedVal, dataType) + case sql.NullTime: + if !typedVal.Valid { + return nil, nil } + return timeTypeValueToString(typedVal.Time, dataType) + case sql.NullBool: + if !typedVal.Valid { + return nil, nil + } + s := strconv.FormatBool(typedVal.Bool) + return &s, nil + case sql.NullInt64: + if !typedVal.Valid { + return nil, nil + } + s := strconv.FormatInt(typedVal.Int64, 10) + return &s, nil + case sql.NullFloat64: + if !typedVal.Valid { + return nil, nil + } + s := strconv.FormatFloat(typedVal.Float64, 'g', -1, 32) + return &s, nil + case sql.NullString: + if !typedVal.Valid { + return nil, nil + } + return &typedVal.String, nil } } return nil, fmt.Errorf("unsupported type: %v", v1.Kind()) } +func timeTypeValueToString(tm time.Time, dataType SnowflakeDataType) (*string, error) { + switch { + case dataType.Equals(DataTypeDate): + _, offset := tm.Zone() + tm = tm.Add(time.Second * time.Duration(offset)) + s := strconv.FormatInt(tm.Unix()*1000, 10) + return &s, nil + case dataType.Equals(DataTypeTime): + s := fmt.Sprintf("%d", + (tm.Hour()*3600+tm.Minute()*60+tm.Second())*1e9+tm.Nanosecond()) + return &s, nil + case dataType.Equals(DataTypeTimestampNtz) || dataType.Equals(DataTypeTimestampLtz) || dataType == nil: + // NOTE(greg): when the client has not given us an explicit dataType + // (dataType == nil), we assume DataTypeTimestampNtz for compatibility + // with the upstream driver + unixTime, _ := new(big.Int).SetString(fmt.Sprintf("%d", tm.Unix()), 10) + m, _ := new(big.Int).SetString(strconv.FormatInt(1e9, 10), 10) + unixTime.Mul(unixTime, m) + tmNanos, _ := new(big.Int).SetString(fmt.Sprintf("%d", tm.Nanosecond()), 10) + s := unixTime.Add(unixTime, tmNanos).String() + return &s, nil + case dataType.Equals(DataTypeTimestampTz): + _, offset := tm.Zone() + s := fmt.Sprintf("%v %v", tm.UnixNano(), offset/60+1440) + return &s, nil + } + return nil, fmt.Errorf("unsupported time type: %v", dataType) +} + // extractTimestamp extracts the internal timestamp data to epoch time in seconds and milliseconds func extractTimestamp(srcValue *string) (sec int64, nsec int64, err error) { logger.Debugf("SRC: %v", srcValue) @@ -1207,3 +1243,26 @@ func recordToSchema(sc *arrow.Schema, rowType []execResponseRowType, loc *time.L meta := sc.Metadata() return arrow.NewSchema(fields, &meta), nil } + +// TypedNullTime is required to properly bind the null value with the snowflakeType as the Snowflake functions +// require the type of the field to be provided explicitly for the null values +type TypedNullTime struct { + Time sql.NullTime + TzType timezoneType +} + +func convertTzTypeToSnowflakeType(tzType timezoneType) SnowflakeDataType { + switch tzType { + case TimestampNTZType: + return DataTypeTimestampNtz + case TimestampLTZType: + return DataTypeTimestampLtz + case TimestampTZType: + return DataTypeTimestampTz + case DateType: + return DataTypeDate + case TimeType: + return DataTypeTime + } + return nil +} diff --git a/converter_test.go b/converter_test.go index 9d3b2c3ad..21167c818 100644 --- a/converter_test.go +++ b/converter_test.go @@ -4,12 +4,15 @@ package gosnowflake import ( "context" + "database/sql" "database/sql/driver" "fmt" "io" + "math" "math/big" "math/cmplx" "reflect" + "strings" "testing" "time" @@ -45,6 +48,21 @@ func stringFloatToDecimal(src string, scale int64) (decimal128.Num, bool) { return decimal128.New(high.Int64(), low.Uint64()), ok } +func stringFloatToInt(src string, scale int64) (int64, bool) { + b, ok := new(big.Float).SetString(src) + if !ok { + return 0, ok + } + s := new(big.Float).SetInt(new(big.Int).Exp(big.NewInt(10), big.NewInt(scale), nil)) + n := new(big.Float).Mul(b, s) + var z big.Int + n.Int(&z) + if !z.IsInt64() { + return 0, false + } + return z.Int64(), true +} + type tcGoTypeToSnowflake struct { in interface{} tmode SnowflakeDataType @@ -90,10 +108,12 @@ func TestGoTypeToSnowflake(t *testing.T) { {in: []int{1}, tmode: nil, out: unSupportedType}, } for _, test := range testcases { - a := goTypeToSnowflake(test.in, test.tmode) - if a != test.out { - t.Errorf("failed. in: %v, tmode: %v, expected: %v, got: %v", test.in, test.tmode, test.out, a) - } + t.Run(fmt.Sprintf("%v_%v_%v", test.in, test.out, test.tmode), func(t *testing.T) { + a := goTypeToSnowflake(test.in, test.tmode) + if a != test.out { + t.Errorf("failed. in: %v, tmode: %v, expected: %v, got: %v", test.in, test.tmode, test.out, a) + } + }) } } @@ -119,13 +139,16 @@ func TestSnowflakeTypeToGo(t *testing.T) { {in: arrayType, scale: 0, out: reflect.TypeOf("")}, {in: binaryType, scale: 0, out: reflect.TypeOf([]byte{})}, {in: booleanType, scale: 0, out: reflect.TypeOf(true)}, + {in: sliceType, scale: 0, out: reflect.TypeOf("")}, } for _, test := range testcases { - a := snowflakeTypeToGo(test.in, test.scale) - if a != test.out { - t.Errorf("failed. in: %v, scale: %v, expected: %v, got: %v", - test.in, test.scale, test.out, a) - } + t.Run(fmt.Sprintf("%v_%v", test.in, test.out), func(t *testing.T) { + a := snowflakeTypeToGo(test.in, test.scale) + if a != test.out { + t.Errorf("failed. in: %v, scale: %v, expected: %v, got: %v", + test.in, test.scale, test.out, a) + } + }) } } @@ -140,6 +163,10 @@ func TestValueToString(t *testing.T) { localTime := time.Date(2019, 2, 6, 14, 17, 31, 123456789, time.FixedZone("-08:00", -8*3600)) utcTime := time.Date(2019, 2, 6, 22, 17, 31, 123456789, time.UTC) expectedUnixTime := "1549491451123456789" // time.Unix(1549491451, 123456789).Format(time.RFC3339) == "2019-02-06T14:17:31-08:00" + expectedBool := "true" + expectedInt64 := "1" + expectedFloat64 := "1.1" + expectedString := "teststring" if s, err := valueToString(localTime, DataTypeTimestampLtz); err != nil { t.Error("unexpected error") @@ -156,10 +183,42 @@ func TestValueToString(t *testing.T) { } else if *s != expectedUnixTime { t.Errorf("expected '%v', got '%v'", expectedUnixTime, *s) } + + if s, err := valueToString(sql.NullBool{Bool: true, Valid: true}, DataTypeTimestampLtz); err != nil { + t.Error("unexpected error") + } else if s == nil { + t.Errorf("expected '%v', got %v", expectedBool, s) + } else if *s != expectedBool { + t.Errorf("expected '%v', got '%v'", expectedBool, *s) + } + + if s, err := valueToString(sql.NullInt64{Int64: 1, Valid: true}, DataTypeTimestampLtz); err != nil { + t.Error("unexpected error") + } else if s == nil { + t.Errorf("expected '%v', got %v", expectedInt64, s) + } else if *s != expectedInt64 { + t.Errorf("expected '%v', got '%v'", expectedInt64, *s) + } + + if s, err := valueToString(sql.NullFloat64{Float64: 1.1, Valid: true}, DataTypeTimestampLtz); err != nil { + t.Error("unexpected error") + } else if s == nil { + t.Errorf("expected '%v', got %v", expectedFloat64, s) + } else if *s != expectedFloat64 { + t.Errorf("expected '%v', got '%v'", expectedFloat64, *s) + } + + if s, err := valueToString(sql.NullString{String: "teststring", Valid: true}, DataTypeTimestampLtz); err != nil { + t.Error("unexpected error") + } else if s == nil { + t.Errorf("expected '%v', got %v", expectedString, s) + } else if *s != expectedString { + t.Errorf("expected '%v', got '%v'", expectedString, *s) + } } func TestExtractTimestamp(t *testing.T) { - s := "1234abcdef" + s := "1234abcdef" // pragma: allowlist secret _, _, err := extractTimestamp(&s) if err == nil { t.Errorf("should raise error: %v", s) @@ -188,12 +247,14 @@ func TestStringToValue(t *testing.T) { } for _, tt := range types { - rowType = &execResponseRowType{ - Type: tt, - } - if err = stringToValue(&dest, *rowType, &source, nil); err == nil { - t.Errorf("should raise error. type: %v, value:%v", tt, source) - } + t.Run(tt, func(t *testing.T) { + rowType = &execResponseRowType{ + Type: tt, + } + if err = stringToValue(&dest, *rowType, &source, nil); err == nil { + t.Errorf("should raise error. type: %v, value:%v", tt, source) + } + }) } sources := []string{ @@ -207,12 +268,14 @@ func TestStringToValue(t *testing.T) { for _, ss := range sources { for _, tt := range types { - rowType = &execResponseRowType{ - Type: tt, - } - if err = stringToValue(&dest, *rowType, &ss, nil); err == nil { - t.Errorf("should raise error. type: %v, value:%v", tt, source) - } + t.Run(ss+tt, func(t *testing.T) { + rowType = &execResponseRowType{ + Type: tt, + } + if err = stringToValue(&dest, *rowType, &ss, nil); err == nil { + t.Errorf("should raise error. type: %v, value:%v", tt, source) + } + }) } } @@ -235,21 +298,25 @@ type tcArrayToString struct { func TestArrayToString(t *testing.T) { testcases := []tcArrayToString{ {in: driver.NamedValue{Value: &intArray{1, 2}}, typ: fixedType, out: []string{"1", "2"}}, + {in: driver.NamedValue{Value: &int32Array{1, 2}}, typ: fixedType, out: []string{"1", "2"}}, {in: driver.NamedValue{Value: &int64Array{3, 4, 5}}, typ: fixedType, out: []string{"3", "4", "5"}}, {in: driver.NamedValue{Value: &float64Array{6.7}}, typ: realType, out: []string{"6.7"}}, + {in: driver.NamedValue{Value: &float32Array{1.5}}, typ: realType, out: []string{"1.5"}}, {in: driver.NamedValue{Value: &boolArray{true, false}}, typ: booleanType, out: []string{"true", "false"}}, {in: driver.NamedValue{Value: &stringArray{"foo", "bar", "baz"}}, typ: textType, out: []string{"foo", "bar", "baz"}}, } for _, test := range testcases { - s, a := snowflakeArrayToString(&test.in, false) - if s != test.typ { - t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.typ, s) - } - for i, v := range a { - if *v != test.out[i] { - t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.out[i], a) + t.Run(strings.Join(test.out, "_"), func(t *testing.T) { + s, a := snowflakeArrayToString(&test.in, false) + if s != test.typ { + t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.typ, s) } - } + for i, v := range a { + if *v != test.out[i] { + t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.out[i], a) + } + } + }) } } @@ -272,20 +339,86 @@ func TestArrowToValue(t *testing.T) { } for _, tc := range []struct { - logical string - physical string - rowType execResponseRowType - values interface{} - builder array.Builder - append func(b array.Builder, vs interface{}) - compare func(src interface{}, dst []snowflakeValue) int + logical string + physical string + rowType execResponseRowType + values interface{} + builder array.Builder + append func(b array.Builder, vs interface{}) + compare func(src interface{}, dst []snowflakeValue) int + higherPrecision bool }{ + { + logical: "fixed", + physical: "number", // default: number(38, 0) + values: []int64{1, 2}, + builder: array.NewInt64Builder(pool), + append: func(b array.Builder, vs interface{}) { b.(*array.Int64Builder).AppendValues(vs.([]int64), valids) }, + higherPrecision: true, + }, + { + logical: "fixed", + physical: "number(38,5)", + rowType: execResponseRowType{Scale: 5}, + values: []string{"1.05430", "2.08983"}, + builder: array.NewInt64Builder(pool), + append: func(b array.Builder, vs interface{}) { + for _, s := range vs.([]string) { + num, ok := stringFloatToInt(s, 5) + if !ok { + t.Fatalf("failed to convert to int") + } + b.(*array.Int64Builder).Append(num) + } + }, + compare: func(src interface{}, dst []snowflakeValue) int { + srcvs := src.([]string) + for i := range srcvs { + num, ok := stringFloatToInt(srcvs[i], 5) + if !ok { + return i + } + srcDec := intToBigFloat(num, 5) + dstDec := dst[i].(*big.Float) + if srcDec.Cmp(dstDec) != 0 { + return i + } + } + return -1 + }, + higherPrecision: true, + }, { logical: "fixed", - physical: "number", // default: number(38, 0) - values: []int64{1, 2}, + physical: "number(38,5)", + rowType: execResponseRowType{Scale: 5}, + values: []string{"1.05430", "2.08983"}, builder: array.NewInt64Builder(pool), - append: func(b array.Builder, vs interface{}) { b.(*array.Int64Builder).AppendValues(vs.([]int64), valids) }, + append: func(b array.Builder, vs interface{}) { + for _, s := range vs.([]string) { + num, ok := stringFloatToInt(s, 5) + if !ok { + t.Fatalf("failed to convert to int") + } + b.(*array.Int64Builder).Append(num) + } + }, + compare: func(src interface{}, dst []snowflakeValue) int { + srcvs := src.([]string) + for i := range srcvs { + num, ok := stringFloatToInt(srcvs[i], 5) + if !ok { + return i + } + srcDec := fmt.Sprintf("%.*f", 5, float64(num)/math.Pow10(int(5))) + dstDec := dst[i] + if srcDec != dstDec { + return i + } + } + return -1 + }, + higherPrecision: false, }, { logical: "fixed", @@ -316,6 +449,7 @@ func TestArrowToValue(t *testing.T) { } return -1 }, + higherPrecision: true, }, { logical: "fixed", @@ -347,6 +481,7 @@ func TestArrowToValue(t *testing.T) { } return -1 }, + higherPrecision: true, }, { logical: "fixed", @@ -363,6 +498,7 @@ func TestArrowToValue(t *testing.T) { } return -1 }, + higherPrecision: true, }, { logical: "fixed", @@ -379,6 +515,7 @@ func TestArrowToValue(t *testing.T) { } return -1 }, + higherPrecision: true, }, { logical: "fixed", @@ -395,13 +532,79 @@ func TestArrowToValue(t *testing.T) { } return -1 }, + higherPrecision: true, }, { logical: "fixed", - physical: "int64", - values: []int64{1, 2}, - builder: array.NewInt64Builder(pool), - append: func(b array.Builder, vs interface{}) { b.(*array.Int64Builder).AppendValues(vs.([]int64), valids) }, + physical: "int32", + values: []string{"1.23456", "2.34567"}, + rowType: execResponseRowType{Scale: 5}, + builder: array.NewInt32Builder(pool), + append: func(b array.Builder, vs interface{}) { + for _, s := range vs.([]string) { + num, ok := stringFloatToInt(s, 5) + if !ok { + t.Fatalf("failed to convert to int") + } + b.(*array.Int32Builder).Append(int32(num)) + } + }, + compare: func(src interface{}, dst []snowflakeValue) int { + srcvs := src.([]string) + for i := range srcvs { + num, ok := stringFloatToInt(srcvs[i], 5) + if !ok { + return i + } + srcDec := intToBigFloat(num, 5) + dstDec := dst[i].(*big.Float) + if srcDec.Cmp(dstDec) != 0 { + return i + } + } + return -1 + }, + higherPrecision: true, + }, + { + logical: "fixed", + physical: "int32", + values: []string{"1.23456", "2.34567"}, + rowType: execResponseRowType{Scale: 5}, + builder: array.NewInt32Builder(pool), + append: func(b array.Builder, vs interface{}) { + for _, s := range vs.([]string) { + num, ok := stringFloatToInt(s, 5) + if !ok { + t.Fatalf("failed to convert to int") + } + b.(*array.Int32Builder).Append(int32(num)) + } + }, + compare: func(src interface{}, dst []snowflakeValue) int { + srcvs := src.([]string) + for i := range srcvs { + num, ok := stringFloatToInt(srcvs[i], 5) + if !ok { + return i + } + srcDec := fmt.Sprintf("%.*f", 5, float64(num)/math.Pow10(int(5))) + dstDec := dst[i] + if srcDec != dstDec { + return i + } + } + return -1 + }, + higherPrecision: false, + }, + { + logical: "fixed", + physical: "int64", + values: []int64{1, 2}, + builder: array.NewInt64Builder(pool), + append: func(b array.Builder, vs interface{}) { b.(*array.Int64Builder).AppendValues(vs.([]int64), valids) }, + higherPrecision: true, }, { logical: "boolean", @@ -458,6 +661,7 @@ func TestArrowToValue(t *testing.T) { } return -1 }, + higherPrecision: true, }, { logical: "timestamp_ntz", @@ -614,7 +818,9 @@ func TestArrowToValue(t *testing.T) { meta := tc.rowType meta.Type = tc.logical - if err := arrowToValue(dest, meta, arr, localTime.Location(), true); err != nil { + withHigherPrecision := tc.higherPrecision + + if err := arrowToValue(dest, meta, arr, localTime.Location(), withHigherPrecision); err != nil { t.Fatalf("error: %s", err) } @@ -1034,3 +1240,39 @@ func TestLargeTimestampBinding(t *testing.T) { } } } + +func TestTimeTypeValueToString(t *testing.T) { + timeValue, err := time.Parse("2006-01-02 15:04:05", "2020-01-02 10:11:12") + if err != nil { + t.Fatal(err) + } + offsetTimeValue, err := time.ParseInLocation("2006-01-02 15:04:05", "2020-01-02 10:11:12", Location(6*60)) + if err != nil { + t.Fatal(err) + } + + testcases := []struct { + in time.Time + dataType SnowflakeDataType + out string + }{ + {timeValue, DataTypeDate, "1577959872000"}, + {timeValue, DataTypeTime, "36672000000000"}, + {timeValue, DataTypeTimestampNtz, "1577959872000000000"}, + {timeValue, DataTypeTimestampLtz, "1577959872000000000"}, + {timeValue, DataTypeTimestampTz, "1577959872000000000 1440"}, + {offsetTimeValue, DataTypeTimestampTz, "1577938272000000000 1800"}, + } + + for _, tc := range testcases { + t.Run(tc.out, func(t *testing.T) { + output, err := timeTypeValueToString(tc.in, tc.dataType) + if err != nil { + t.Error(err) + } + if strings.Compare(tc.out, *output) != 0 { + t.Errorf("failed to convert time %v of type %v. expected: %v, received: %v", tc.in, tc.dataType, tc.out, *output) + } + }) + } +} diff --git a/datatype.go b/datatype.go index b419e9531..07ab9bc73 100644 --- a/datatype.go +++ b/datatype.go @@ -86,17 +86,6 @@ func (dt SnowflakeDataType) Equals(o SnowflakeDataType) bool { return bytes.Equal(([]byte)(dt), ([]byte)(o)) } -// SnowflakeDataType is the type used by clients to explicitly indicate the type -// of an argument to ExecContext and friends. We use a separate public-facing -// type rather than a Go primitive type so that we can always differentiate -// between args that indicate type and args that are values. -type SnowflakeDataType []byte - -// Equals checks if dt and o represent the same type indicator -func (dt SnowflakeDataType) Equals(o SnowflakeDataType) bool { - return bytes.Equal(([]byte)(dt), ([]byte)(o)) -} - var ( // DataTypeFixed is a FIXED datatype. DataTypeFixed = SnowflakeDataType{fixedType.Byte()} diff --git a/datatype_test.go b/datatype_test.go index aa7763779..65f59fd9e 100644 --- a/datatype_test.go +++ b/datatype_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. +// Copyright (c) 2017-2023 Snowflake Computing Inc. All rights reserved. package gosnowflake @@ -33,18 +33,34 @@ func TestClientTypeToInternal(t *testing.T) { err: fmt.Errorf(errMsgInvalidByteArray, nil)}, } for _, ts := range testcases { - tmode, err := clientTypeToInternal(ts.tp) - if ts.err == nil { - if err != nil { - t.Errorf("failed to get datatype mode: %v", err) - } - if tmode != ts.tmode { - t.Errorf("wrong data type: %v", tmode) - } - } else { - if err == nil { - t.Errorf("should raise an error: %v", ts.err) + t.Run(fmt.Sprintf("%v_%v", ts.tp, ts.tmode), func(t *testing.T) { + tmode, err := clientTypeToInternal(ts.tp) + if ts.err == nil { + if err != nil { + t.Errorf("failed to get datatype mode: %v", err) + } + if tmode != ts.tmode { + t.Errorf("wrong data type: %v", tmode) + } + } else { + if err == nil { + t.Errorf("should raise an error: %v", ts.err) + } } + }) + } +} + +func TestPopulateSnowflakeParameter(t *testing.T) { + columns := []string{"key", "value", "default", "level", "description", "set_by_user", "set_in_job", "set_on", "set_by_thread_id", "set_by_thread_name", "set_by_class", "parameter_comment", "type", "is_expired", "expires_at", "set_by_controlling_parameter", "activate_version", "partial_rollout"} + p := SnowflakeParameter{} + cols := make([]interface{}, len(columns)) + for i := 0; i < len(columns); i++ { + cols[i] = populateSnowflakeParameter(columns[i], &p) + } + for i := 0; i < len(cols); i++ { + if cols[i] == nil { + t.Fatal("failed to populate parameter") } } } diff --git a/driver_test.go b/driver_test.go index c45479f9e..88ddb1a05 100644 --- a/driver_test.go +++ b/driver_test.go @@ -163,7 +163,7 @@ func TestMain(m *testing.M) { type DBTest struct { *testing.T - db *sql.DB + conn *sql.Conn } func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *RowsExtended) { @@ -187,7 +187,7 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *RowsExten close(c) }() - rs, err := dbt.db.QueryContext(ctx, query, args...) + rs, err := dbt.conn.QueryContext(ctx, query, args...) if err != nil { dbt.fail("query", query, err) } @@ -218,7 +218,7 @@ func (dbt *DBTest) mustQueryContext(ctx context.Context, query string, args ...i close(c) }() - rs, err := dbt.db.QueryContext(ctx, query, args...) + rs, err := dbt.conn.QueryContext(ctx, query, args...) if err != nil { dbt.fail("query", query, err) } @@ -228,8 +228,13 @@ func (dbt *DBTest) mustQueryContext(ctx context.Context, query string, args ...i } } +func (dbt *DBTest) query(query string, args ...any) (*sql.Rows, error) { + return dbt.conn.QueryContext(context.Background(), query, args...) +} + func (dbt *DBTest) mustQueryAssertCount(query string, expected int, args ...interface{}) { rows := dbt.mustQuery(query, args...) + defer rows.Close() cnt := 0 for rows.Next() { cnt++ @@ -239,6 +244,10 @@ func (dbt *DBTest) mustQueryAssertCount(query string, expected int, args ...inte } } +func (dbt *DBTest) prepare(query string) (*sql.Stmt, error) { + return dbt.conn.PrepareContext(context.Background(), query) +} + func (dbt *DBTest) fail(method, query string, err error) { if len(query) > 300 { query = "[query too large to print]" @@ -247,21 +256,21 @@ func (dbt *DBTest) fail(method, query string, err error) { } func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) { - res, err := dbt.db.Exec(query, args...) - if err != nil { - dbt.fail("exec", query, err) - } - return res + return dbt.mustExecContext(context.Background(), query, args...) } func (dbt *DBTest) mustExecContext(ctx context.Context, query string, args ...interface{}) (res sql.Result) { - res, err := dbt.db.ExecContext(ctx, query, args...) + res, err := dbt.conn.ExecContext(ctx, query, args...) if err != nil { dbt.fail("exec context", query, err) } return res } +func (dbt *DBTest) exec(query string, args ...any) (sql.Result, error) { + return dbt.conn.ExecContext(context.Background(), query, args...) +} + func (dbt *DBTest) mustDecimalSize(ct *sql.ColumnType) (pr int64, sc int64) { var ok bool pr, sc, ok = ct.DecimalSize() @@ -304,29 +313,36 @@ func (dbt *DBTest) mustNullable(ct *sql.ColumnType) (canNull bool) { } func (dbt *DBTest) mustPrepare(query string) (stmt *sql.Stmt) { - stmt, err := dbt.db.Prepare(query) + stmt, err := dbt.conn.PrepareContext(context.Background(), query) if err != nil { dbt.fail("prepare", query, err) } return stmt } -func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { - db, err := sql.Open("sigmacomputing+gosnowflake", dsn) +func runDBTest(t *testing.T, test func(dbt *DBTest)) { + conn := openConn(t) + defer conn.Close() + dbt := &DBTest{t, conn} + + test(dbt) +} + +func runSnowflakeConnTest(t *testing.T, test func(sc *snowflakeConn)) { + config, err := ParseDSN(dsn) if err != nil { - t.Fatalf("error connecting: %s", err.Error()) + t.Error(err) } - defer db.Close() - - if _, err = db.Exec("DROP TABLE IF EXISTS test"); err != nil { - t.Fatalf("failed to drop table: %v", err) + sc, err := buildSnowflakeConn(context.Background(), *config) + if err != nil { + t.Fatal(err) } - - dbt := &DBTest{t, db} - for _, test := range tests { - test(dbt) - dbt.db.Exec("DROP TABLE IF EXISTS test") + defer sc.Close() + if err = authenticateWithConfig(sc); err != nil { + t.Fatal(err) } + + test(sc) } func runningOnAWS() bool { @@ -412,10 +428,10 @@ func invalidHostErrorTests(invalidDNS string, mstr []string, t *testing.T) { } func TestCommentOnlyQuery(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { query := "--" // just a comment, no query - rows, err := dbt.db.Query(query) + rows, err := dbt.query(query) if err == nil { rows.Close() dbt.fail("query", query, err) @@ -429,15 +445,15 @@ func TestCommentOnlyQuery(t *testing.T) { } func TestEmptyQuery(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { query := "select 1 from dual where 1=0" // just a comment, no query - rows := dbt.db.QueryRow(query) - var v1 interface{} + rows := dbt.conn.QueryRowContext(context.Background(), query) + var v1 any if err := rows.Scan(&v1); err != sql.ErrNoRows { dbt.Errorf("should fail. err: %v", err) } - rows = dbt.db.QueryRowContext(context.Background(), query) + rows = dbt.conn.QueryRowContext(context.Background(), query) if err := rows.Scan(&v1); err != sql.ErrNoRows { dbt.Errorf("should fail. err: %v", err) } @@ -445,10 +461,10 @@ func TestEmptyQuery(t *testing.T) { } func TestEmptyQueryWithRequestID(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { query := "select 1" ctx := WithRequestID(context.Background(), NewUUID()) - rows := dbt.db.QueryRowContext(ctx, query) + rows := dbt.conn.QueryRowContext(ctx, query) var v1 interface{} if err := rows.Scan(&v1); err != nil { dbt.Errorf("should not have failed with valid request id. err: %v", err) @@ -457,9 +473,9 @@ func TestEmptyQueryWithRequestID(t *testing.T) { } func TestCRUD(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { // Create Table - dbt.mustExec("CREATE TABLE test (value BOOLEAN)") + dbt.mustExec("CREATE OR REPLACE TABLE test (value BOOLEAN)") // Test for unexpected Data var out bool @@ -555,7 +571,7 @@ func TestInt(t *testing.T) { } func testInt(t *testing.T, json bool) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { types := []string{"INT", "INTEGER"} in := int64(42) var out int64 @@ -563,24 +579,26 @@ func testInt(t *testing.T, json bool) { // SIGNED for _, v := range types { - if json { - dbt.mustExec(forceJSON) - } - dbt.mustExec("CREATE TABLE test (value " + v + ")") - dbt.mustExec("INSERT INTO test VALUES (?)", in) - rows = dbt.mustQuery("SELECT value FROM test") - defer rows.Close() - if rows.Next() { - rows.Scan(&out) - if in != out { - dbt.Errorf("%s: %d != %d", v, in, out) + t.Run(v, func(t *testing.T) { + if json { + dbt.mustExec(forceJSON) + } + dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (?)", in) + rows = dbt.mustQuery("SELECT value FROM test") + defer rows.Close() + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s: %d != %d", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) } - } else { - dbt.Errorf("%s: no data", v) - } - dbt.mustExec("DROP TABLE IF EXISTS test") + }) } + dbt.mustExec("DROP TABLE IF EXISTS test") }) } @@ -589,32 +607,34 @@ func TestFloat32(t *testing.T) { } func testFloat32(t *testing.T, json bool) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { types := [2]string{"FLOAT", "DOUBLE"} in := float32(42.23) var out float32 var rows *RowsExtended for _, v := range types { - if json { - dbt.mustExec(forceJSON) - } - dbt.mustExec("CREATE TABLE test (value " + v + ")") - dbt.mustExec("INSERT INTO test VALUES (?)", in) - rows = dbt.mustQuery("SELECT value FROM test") - defer rows.Close() - if rows.Next() { - err := rows.Scan(&out) - if err != nil { - dbt.Errorf("failed to scan data: %v", err) + t.Run(v, func(t *testing.T) { + if json { + dbt.mustExec(forceJSON) } - if in != out { - dbt.Errorf("%s: %g != %g", v, in, out) + dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (?)", in) + rows = dbt.mustQuery("SELECT value FROM test") + defer rows.Close() + if rows.Next() { + err := rows.Scan(&out) + if err != nil { + dbt.Errorf("failed to scan data: %v", err) + } + if in != out { + dbt.Errorf("%s: %g != %g", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) } - } else { - dbt.Errorf("%s: no data", v) - } - dbt.mustExec("DROP TABLE IF EXISTS test") + }) } + dbt.mustExec("DROP TABLE IF EXISTS test") }) } @@ -623,29 +643,31 @@ func TestFloat64(t *testing.T) { } func testFloat64(t *testing.T, json bool) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { types := [2]string{"FLOAT", "DOUBLE"} expected := 42.23 var out float64 var rows *RowsExtended for _, v := range types { - if json { - dbt.mustExec(forceJSON) - } - dbt.mustExec("CREATE TABLE test (value " + v + ")") - dbt.mustExec("INSERT INTO test VALUES (42.23)") - rows = dbt.mustQuery("SELECT value FROM test") - defer rows.Close() - if rows.Next() { - rows.Scan(&out) - if expected != out { - dbt.Errorf("%s: %g != %g", v, expected, out) + t.Run(v, func(t *testing.T) { + if json { + dbt.mustExec(forceJSON) } - } else { - dbt.Errorf("%s: no data", v) - } - dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (42.23)") + rows = dbt.mustQuery("SELECT value FROM test") + defer rows.Close() + if rows.Next() { + rows.Scan(&out) + if expected != out { + dbt.Errorf("%s: %g != %g", v, expected, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + }) } + dbt.mustExec("DROP TABLE IF EXISTS test") }) } @@ -654,7 +676,7 @@ func TestString(t *testing.T) { } func testString(t *testing.T, json bool) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } @@ -664,24 +686,26 @@ func testString(t *testing.T, json bool) { var rows *RowsExtended for _, v := range types { - dbt.mustExec("CREATE TABLE test (value " + v + ")") - dbt.mustExec("INSERT INTO test VALUES (?)", in) - - rows = dbt.mustQuery("SELECT value FROM test") - defer rows.Close() - if rows.Next() { - rows.Scan(&out) - if in != out { - dbt.Errorf("%s: %s != %s", v, in, out) + t.Run(v, func(t *testing.T) { + dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM test") + defer rows.Close() + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s: %s != %s", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) } - } else { - dbt.Errorf("%s: no data", v) - } - dbt.mustExec("DROP TABLE IF EXISTS test") + }) } + dbt.mustExec("DROP TABLE IF EXISTS test") // BLOB (Snowflake doesn't support BLOB type but STRING covers large text data) - dbt.mustExec("CREATE TABLE test (id int, value STRING)") + dbt.mustExec("CREATE OR REPLACE TABLE test (id int, value STRING)") id := 2 in = `Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam @@ -695,7 +719,7 @@ func testString(t *testing.T, json bool) { gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet.` dbt.mustExec("INSERT INTO test VALUES (?, ?)", id, in) - if err := dbt.db.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&out); err != nil { + if err := dbt.conn.QueryRowContext(context.Background(), "SELECT value FROM test WHERE id = ?", id).Scan(&out); err != nil { dbt.Fatalf("Error on BLOB-Query: %s", err.Error()) } else if out != in { dbt.Errorf("BLOB: %s != %s", in, out) @@ -786,7 +810,7 @@ func testSimpleDateTimeTimestampFetch(t *testing.T, json bool) { scan(rows, &cd, &ct, &cts) }, } - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } @@ -852,18 +876,20 @@ func testDateTime(t *testing.T, json bool) { {t: time.Date(2011, 11, 20, 21, 27, 37, 123456789, time.UTC)}, }}, } - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } for _, setups := range testcases { - for _, setup := range setups.tests { - if setup.s == "" { - // fill time string wherever Go can reliable produce it - setup.s = setup.t.Format(setups.tlayout) + t.Run(setups.dbtype, func(t *testing.T) { + for _, setup := range setups.tests { + if setup.s == "" { + // fill time string wherever Go can reliable produce it + setup.s = setup.t.Format(setups.tlayout) + } + setup.run(t, dbt, setups.dbtype, setups.tlayout) } - setup.run(t, dbt, setups.dbtype, setups.tlayout) - } + }) } }) } @@ -921,18 +947,20 @@ func testTimestampLTZ(t *testing.T, json bool) { }, }, } - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } for _, setups := range testcases { - for _, setup := range setups.tests { - if setup.s == "" { - // fill time string wherever Go can reliable produce it - setup.s = setup.t.Format(setups.tlayout) + t.Run(setups.dbtype, func(t *testing.T) { + for _, setup := range setups.tests { + if setup.s == "" { + // fill time string wherever Go can reliable produce it + setup.s = setup.t.Format(setups.tlayout) + } + setup.run(t, dbt, setups.dbtype, setups.tlayout) } - setup.run(t, dbt, setups.dbtype, setups.tlayout) - } + }) } }) // Revert timezone to UTC, which is default for the test suit @@ -969,18 +997,20 @@ func testTimestampTZ(t *testing.T, json bool) { }, }, } - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } for _, setups := range testcases { - for _, setup := range setups.tests { - if setup.s == "" { - // fill time string wherever Go can reliable produce it - setup.s = setup.t.Format(setups.tlayout) + t.Run(setups.dbtype, func(t *testing.T) { + for _, setup := range setups.tests { + if setup.s == "" { + // fill time string wherever Go can reliable produce it + setup.s = setup.t.Format(setups.tlayout) + } + setup.run(t, dbt, setups.dbtype, setups.tlayout) } - setup.run(t, dbt, setups.dbtype, setups.tlayout) - } + }) } }) } @@ -990,17 +1020,17 @@ func TestNULL(t *testing.T) { } func testNULL(t *testing.T, json bool) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } - nullStmt, err := dbt.db.Prepare("SELECT NULL") + nullStmt, err := dbt.conn.PrepareContext(context.Background(), "SELECT NULL") if err != nil { dbt.Fatal(err) } defer nullStmt.Close() - nonNullStmt, err := dbt.db.Prepare("SELECT 1") + nonNullStmt, err := dbt.conn.PrepareContext(context.Background(), "SELECT 1") if err != nil { dbt.Fatal(err) } @@ -1101,7 +1131,7 @@ func testNULL(t *testing.T, json bool) { // Insert nil b = nil success := false - if err = dbt.db.QueryRow("SELECT ? IS NULL", b).Scan(&success); err != nil { + if err = dbt.conn.QueryRowContext(context.Background(), "SELECT ? IS NULL", b).Scan(&success); err != nil { dbt.Fatal(err) } if !success { @@ -1110,7 +1140,7 @@ func testNULL(t *testing.T, json bool) { } // Check input==output with input==nil b = nil - if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil { + if err = dbt.conn.QueryRowContext(context.Background(), "SELECT ?", b).Scan(&b); err != nil { dbt.Fatal(err) } if b != nil { @@ -1118,7 +1148,7 @@ func testNULL(t *testing.T, json bool) { } // Check input==output with input!=nil b = []byte("") - if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil { + if err = dbt.conn.QueryRowContext(context.Background(), "SELECT ?", b).Scan(&b); err != nil { dbt.Fatal(err) } if b == nil { @@ -1126,7 +1156,7 @@ func testNULL(t *testing.T, json bool) { } // Insert NULL - dbt.mustExec("CREATE TABLE test (dummmy1 int, value int, dummy2 int)") + dbt.mustExec("CREATE OR REPLACE TABLE test (dummmy1 int, value int, dummy2 int)") dbt.mustExec("INSERT INTO test VALUES (?, ?, ?)", 1, nil, 2) var out interface{} @@ -1148,7 +1178,7 @@ func TestVariant(t *testing.T) { } func testVariant(t *testing.T, json bool) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } @@ -1170,7 +1200,7 @@ func TestArray(t *testing.T) { } func testArray(t *testing.T, json bool) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } @@ -1193,7 +1223,7 @@ func TestLargeSetResult(t *testing.T) { } func testLargeSetResult(t *testing.T, numrows int, json bool) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { if json { dbt.mustExec(forceJSON) } @@ -1217,7 +1247,7 @@ func testLargeSetResult(t *testing.T, numrows int, json bool) { } func TestPingpongQuery(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { numrows := 1 rows := dbt.mustQuery("SELECT DISTINCT 1 FROM TABLE(GENERATOR(TIMELIMIT=> 60))") defer rows.Close() @@ -1232,7 +1262,7 @@ func TestPingpongQuery(t *testing.T) { } func TestDML(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { dbt.mustExec("CREATE OR REPLACE TABLE test(c1 int, c2 string)") if err := insertData(dbt, false); err != nil { dbt.Fatalf("failed to insert data: %v", err) @@ -1258,7 +1288,7 @@ func TestDML(t *testing.T) { } func insertData(dbt *DBTest, commit bool) error { - tx, err := dbt.db.Begin() + tx, err := dbt.conn.BeginTx(context.Background(), nil) if err != nil { dbt.Fatalf("failed to begin transaction: %v", err) } @@ -1314,7 +1344,7 @@ func queryTestTx(tx *sql.Tx) (*map[int]string, error) { func queryTest(dbt *DBTest) (*map[int]string, error) { var c1 int var c2 string - rows, err := dbt.db.Query("SELECT c1, c2 FROM test") + rows, err := dbt.query("SELECT c1, c2 FROM test") if err != nil { return nil, err } @@ -1330,11 +1360,11 @@ func queryTest(dbt *DBTest) (*map[int]string, error) { } func TestCancelQuery(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - _, err := dbt.db.QueryContext(ctx, "SELECT DISTINCT 1 FROM TABLE(GENERATOR(TIMELIMIT=> 100))") + _, err := dbt.conn.QueryContext(ctx, "SELECT DISTINCT 1 FROM TABLE(GENERATOR(TIMELIMIT=> 100))") if err == nil { dbt.Fatal("No timeout error returned") } @@ -1345,8 +1375,8 @@ func TestCancelQuery(t *testing.T) { } func TestPing(t *testing.T) { - db := openDB(t) - if err := db.Ping(); err != nil { + db := openConn(t) + if err := db.PingContext(context.Background()); err != nil { t.Fatalf("failed to ping. err: %v", err) } if err := db.PingContext(context.Background()); err != nil { @@ -1355,7 +1385,7 @@ func TestPing(t *testing.T) { if err := db.Close(); err != nil { t.Fatalf("failed to close db. err: %v", err) } - if err := db.Ping(); err == nil { + if err := db.PingContext(context.Background()); err == nil { t.Fatal("should have failed to ping") } if err := db.PingContext(context.Background()); err == nil { @@ -1365,7 +1395,7 @@ func TestPing(t *testing.T) { func TestDoubleDollar(t *testing.T) { // no escape is required for dollar signs - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { sql := `create or replace function dateErr(I double) returns date language javascript strict as $$ @@ -1387,10 +1417,10 @@ $$ func TestTimezoneSessionParameter(t *testing.T) { createDSN(PSTLocation) - db := openDB(t) - defer db.Close() + conn := openConn(t) + defer conn.Close() - rows, err := db.Query("SHOW PARAMETERS LIKE 'TIMEZONE'") + rows, err := conn.QueryContext(context.Background(), "SHOW PARAMETERS LIKE 'TIMEZONE'") if err != nil { t.Errorf("failed to run show parameters. err: %v", err) } @@ -1410,13 +1440,13 @@ func TestTimezoneSessionParameter(t *testing.T) { } func TestLargeSetResultCancel(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { c := make(chan error) ctx, cancel := context.WithCancel(context.Background()) go func() { // attempt to run a 100 seconds query, but it should be canceled in 1 second timelimit := 100 - rows, err := dbt.db.QueryContext( + rows, err := dbt.conn.QueryContext( ctx, fmt.Sprintf("SELECT COUNT(*) FROM TABLE(GENERATOR(timelimit=>%v))", timelimit)) if err != nil { @@ -1468,32 +1498,34 @@ func TestValidateDatabaseParameter(t *testing.T) { }, } for idx, tc := range testcases { - newDSN := tc.dsn - parameters := url.Values{} - if protocol != "" { - parameters.Add("protocol", protocol) - } - if account != "" { - parameters.Add("account", account) - } - for k, v := range tc.params { - parameters.Add(k, v) - } - newDSN += "?" + parameters.Encode() - db, err := sql.Open("sigmacomputing+gosnowflake", newDSN) - // actual connection won't happen until run a query - if err != nil { - t.Fatalf("error creating a connection object: %s", err.Error()) - } - defer db.Close() - if _, err = db.Exec("SELECT 1"); err == nil { - t.Fatal("should cause an error.") - } - if driverErr, ok := err.(*SnowflakeError); ok { - if driverErr.Number != tc.errorCode { // not exist error - t.Errorf("got unexpected error: %v in %v", err, idx) + t.Run(dsn, func(t *testing.T) { + newDSN := tc.dsn + parameters := url.Values{} + if protocol != "" { + parameters.Add("protocol", protocol) } - } + if account != "" { + parameters.Add("account", account) + } + for k, v := range tc.params { + parameters.Add(k, v) + } + newDSN += "?" + parameters.Encode() + db, err := sql.Open("sigmacomputing+gosnowflake", newDSN) + // actual connection won't happen until run a query + if err != nil { + t.Fatalf("error creating a connection object: %s", err.Error()) + } + defer db.Close() + if _, err = db.Exec("SELECT 1"); err == nil { + t.Fatal("should cause an error.") + } + if driverErr, ok := err.(*SnowflakeError); ok { + if driverErr.Number != tc.errorCode { // not exist error + t.Errorf("got unexpected error: %v in %v", err, idx) + } + } + }) } } @@ -1520,7 +1552,7 @@ func TestSpecifyWarehouseDatabase(t *testing.T) { } func TestFetchNil(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQuery("SELECT * FROM values(3,4),(null, 5) order by 2") defer rows.Close() var c1 sql.NullInt64 @@ -1579,6 +1611,18 @@ func TestOpenWithConfig(t *testing.T) { db.Close() } +func TestOpenWithInvalidConfig(t *testing.T) { + config, err := ParseDSN("u:p@h?tmpDirPath=%2Fnon-existing") + if err != nil { + t.Fatalf("failed to parse dsn. err: %v", err) + } + driver := SnowflakeDriver{} + _, err = driver.OpenWithConfig(context.Background(), *config) + if err == nil || !strings.Contains(err.Error(), "/non-existing") { + t.Fatalf("should fail on missing directory") + } +} + type CountingTransport struct { requests int } @@ -1653,8 +1697,9 @@ func TestClientSessionKeepAliveParameter(t *testing.T) { // This test doesn't really validate the CLIENT_SESSION_KEEP_ALIVE functionality but simply checks // the session parameter. createDSNWithClientSessionKeepAlive() - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQuery("SHOW PARAMETERS LIKE 'CLIENT_SESSION_KEEP_ALIVE'") + defer rows.Close() if !rows.Next() { t.Fatal("failed to get timezone.") } @@ -1667,15 +1712,16 @@ func TestClientSessionKeepAliveParameter(t *testing.T) { t.Fatalf("failed to get an expected client_session_keep_alive. got: %v", p.Value) } - rows = dbt.mustQuery("select count(*) from table(generator(timelimit=>30))") - defer rows.Close() + rows2 := dbt.mustQuery("select count(*) from table(generator(timelimit=>30))") + defer rows2.Close() }) } func TestTimePrecision(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { dbt.mustExec("create or replace table z3 (t1 time(5))") rows := dbt.mustQuery("select * from z3") + defer rows.Close() cols, err := rows.ColumnTypes() if err != nil { t.Error(err) diff --git a/dsn.go b/dsn.go index 1689a4b3b..bd3364e86 100644 --- a/dsn.go +++ b/dsn.go @@ -4,11 +4,15 @@ package gosnowflake import ( "crypto/rsa" + "crypto/x509" "encoding/base64" + "encoding/pem" + "errors" "fmt" "net" "net/http" "net/url" + "os" "strconv" "strings" "time" @@ -17,11 +21,13 @@ import ( ) const ( - defaultClientTimeout = 900 * time.Second // Timeout for network round trip + read out http response - defaultLoginTimeout = 60 * time.Second // Timeout for retry for login EXCLUDING clientTimeout - defaultRequestTimeout = 0 * time.Second // Timeout for retry for request EXCLUDING clientTimeout - defaultJWTTimeout = 60 * time.Second - defaultDomain = ".snowflakecomputing.com" + defaultClientTimeout = 900 * time.Second // Timeout for network round trip + read out http response + defaultJWTClientTimeout = 10 * time.Second // Timeout for network round trip + read out http response but used for JWT auth + defaultLoginTimeout = 60 * time.Second // Timeout for retry for login EXCLUDING clientTimeout + defaultRequestTimeout = 0 * time.Second // Timeout for retry for request EXCLUDING clientTimeout + defaultJWTTimeout = 60 * time.Second + defaultExternalBrowserTimeout = 120 * time.Second // Timeout for external browser login + defaultDomain = ".snowflakecomputing.com" // default monitoring fetcher config values defaultMonitoringFetcherQueryMonitoringThreshold = 45 * time.Second @@ -69,10 +75,12 @@ type Config struct { OktaURL *url.URL - LoginTimeout time.Duration // Login retry timeout EXCLUDING network roundtrip and read out http response - RequestTimeout time.Duration // request retry timeout EXCLUDING network roundtrip and read out http response - JWTExpireTimeout time.Duration // JWT expire after timeout - ClientTimeout time.Duration // Timeout for network round trip + read out http response + LoginTimeout time.Duration // Login retry timeout EXCLUDING network roundtrip and read out http response + RequestTimeout time.Duration // request retry timeout EXCLUDING network roundtrip and read out http response + JWTExpireTimeout time.Duration // JWT expire after timeout + ClientTimeout time.Duration // Timeout for network round trip + read out http response + JWTClientTimeout time.Duration // Timeout for network round trip + read out http response used when JWT token auth is taking place + ExternalBrowserTimeout time.Duration // Timeout for external browser login Application string // application name. InsecureMode bool // driver doesn't check certificate revocation status @@ -90,6 +98,8 @@ type Config struct { Tracing string // sets logging level + TmpDirPath string // sets temporary directory used by a driver for operations like encrypting, compressing etc + MfaToken string // Internally used to cache the MFA token IDToken string // Internally used to cache the Id Token for external browser ClientRequestMfaToken ConfigBool // When true the MFA token is cached in the credential manager. True by default in Windows/OSX. False for Linux. @@ -118,6 +128,17 @@ type MonitoringFetcherConfig struct { RetrySleepDuration time.Duration } +// Validate enables testing if config is correct. +// A driver client may call it manually, but it is also called during opening first connection. +func (c *Config) Validate() error { + if c.TmpDirPath != "" { + if _, err := os.Stat(c.TmpDirPath); err != nil { + return err + } + } + return nil +} + // ocspMode returns the OCSP mode in string INSECURE, FAIL_OPEN, FAIL_CLOSED func (c *Config) ocspMode() string { if c.InsecureMode { @@ -147,7 +168,7 @@ func DSN(cfg *Config) (dsn string, err error) { posDot := strings.Index(cfg.Account, ".") if posDot > 0 { if cfg.Region != "" { - return "", ErrInvalidRegion + return "", errInvalidRegion() } cfg.Region = cfg.Account[posDot+1:] cfg.Account = cfg.Account[:posDot] @@ -192,6 +213,9 @@ func DSN(cfg *Config) (dsn string, err error) { if cfg.ClientTimeout != defaultClientTimeout { params.Add("clientTimeout", strconv.FormatInt(int64(cfg.ClientTimeout/time.Second), 10)) } + if cfg.JWTClientTimeout != defaultJWTClientTimeout { + params.Add("jwtClientTimeout", strconv.FormatInt(int64(cfg.JWTClientTimeout/time.Second), 10)) + } if cfg.LoginTimeout != defaultLoginTimeout { params.Add("loginTimeout", strconv.FormatInt(int64(cfg.LoginTimeout/time.Second), 10)) } @@ -201,6 +225,9 @@ func DSN(cfg *Config) (dsn string, err error) { if cfg.JWTExpireTimeout != defaultJWTTimeout { params.Add("jwtTimeout", strconv.FormatInt(int64(cfg.JWTExpireTimeout/time.Second), 10)) } + if cfg.ExternalBrowserTimeout != defaultExternalBrowserTimeout { + params.Add("externalBrowserTimeout", strconv.FormatInt(int64(cfg.ExternalBrowserTimeout/time.Second), 10)) + } if cfg.Application != clientType { params.Add("application", cfg.Application) } @@ -229,6 +256,9 @@ func DSN(cfg *Config) (dsn string, err error) { if cfg.Tracing != "" { params.Add("tracing", cfg.Tracing) } + if cfg.TmpDirPath != "" { + params.Add("tmpDirPath", cfg.TmpDirPath) + } params.Add("ocspFailOpen", strconv.FormatBool(cfg.OCSPFailOpen != OCSPFailOpenFalse)) @@ -421,14 +451,23 @@ func fillMissingConfigParameters(cfg *Config) error { } } if strings.Trim(cfg.Account, " ") == "" { - return ErrEmptyAccount + return errEmptyAccount() + } + + if authRequiresUser(cfg) && strings.TrimSpace(cfg.User) == "" { + return errEmptyUsername() + } + + if authRequiresPassword(cfg) && strings.TrimSpace(cfg.Password) == "" { + return errEmptyPassword() } if cfg.Authenticator != AuthTypeOAuth && cfg.Authenticator != AuthTypeTokenAccessor && + cfg.Authenticator != AuthTypeExternalBrowser && strings.Trim(cfg.User, " ") == "" { // oauth and token accessor do not require a username - return ErrEmptyUsername + return errEmptyUsername() } if cfg.Authenticator != AuthTypeExternalBrowser && @@ -437,7 +476,7 @@ func fillMissingConfigParameters(cfg *Config) error { cfg.Authenticator != AuthTypeTokenAccessor && strings.Trim(cfg.Password, " ") == "" { // no password parameter is required for EXTERNALBROWSER, OAUTH JWT, or TOKENACCESSOR. - return ErrEmptyPassword + return errEmptyPassword() } if strings.Trim(cfg.Protocol, " ") == "" { cfg.Protocol = "https" @@ -476,6 +515,12 @@ func fillMissingConfigParameters(cfg *Config) error { if cfg.ClientTimeout == 0 { cfg.ClientTimeout = defaultClientTimeout } + if cfg.JWTClientTimeout == 0 { + cfg.JWTClientTimeout = defaultJWTClientTimeout + } + if cfg.ExternalBrowserTimeout == 0 { + cfg.ExternalBrowserTimeout = defaultExternalBrowserTimeout + } if strings.Trim(cfg.Application, " ") == "" { cfg.Application = clientType } @@ -512,6 +557,19 @@ func fillMissingConfigParameters(cfg *Config) error { return nil } +func authRequiresUser(cfg *Config) bool { + return cfg.Authenticator != AuthTypeOAuth && + cfg.Authenticator != AuthTypeTokenAccessor && + cfg.Authenticator != AuthTypeExternalBrowser +} + +func authRequiresPassword(cfg *Config) bool { + return cfg.Authenticator != AuthTypeOAuth && + cfg.Authenticator != AuthTypeTokenAccessor && + cfg.Authenticator != AuthTypeExternalBrowser && + cfg.Authenticator != AuthTypeJwt +} + // transformAccountToHost transforms host to account name func transformAccountToHost(cfg *Config) (err error) { if cfg.Port == 0 && !strings.HasSuffix(cfg.Host, defaultDomain) && cfg.Host != "" { @@ -619,6 +677,11 @@ func parseDSNParams(cfg *Config, params string) (err error) { if err != nil { return } + case "jwtClientTimeout": + cfg.JWTClientTimeout, err = parseTimeout(value) + if err != nil { + return + } case "loginTimeout": cfg.LoginTimeout, err = parseTimeout(value) if err != nil { @@ -634,6 +697,11 @@ func parseDSNParams(cfg *Config, params string) (err error) { if err != nil { return err } + case "externalBrowserTimeout": + cfg.ExternalBrowserTimeout, err = parseTimeout(value) + if err != nil { + return err + } case "application": cfg.Application = value case "authenticator": @@ -726,6 +794,8 @@ func parseDSNParams(cfg *Config, params string) (err error) { if err != nil { return err } + case "tmpDirPath": + cfg.TmpDirPath = value default: if cfg.Params == nil { cfg.Params = make(map[string]*string) @@ -758,3 +828,101 @@ func parseTimeout(value string) (time.Duration, error) { } return time.Duration(vv * int64(time.Second)), nil } + +// ConfigParam is used to bind the name of the Config field with the environment variable and set the requirement for it +type ConfigParam struct { + Name string + EnvName string + FailOnMissing bool +} + +// GetConfigFromEnv is used to parse the environment variable values to specific fields of the Config +func GetConfigFromEnv(properties []*ConfigParam) (*Config, error) { + var account, user, password, role, host, portStr, protocol, warehouse, database, schema, region, passcode, application string + var privateKey *rsa.PrivateKey + var err error + if len(properties) == 0 || properties == nil { + return nil, errors.New("missing configuration parameters for the connection") + } + for _, prop := range properties { + value, err := GetFromEnv(prop.EnvName, prop.FailOnMissing) + if err != nil { + return nil, err + } + switch prop.Name { + case "Account": + account = value + case "User": + user = value + case "Password": + password = value + case "Role": + role = value + case "Host": + host = value + case "Port": + portStr = value + case "Protocol": + protocol = value + case "Warehouse": + warehouse = value + case "Database": + database = value + case "Region": + region = value + case "Passcode": + passcode = value + case "Schema": + schema = value + case "Application": + application = value + case "PrivateKey": + privateKey, err = parsePrivateKeyFromFile(value) + if err != nil { + return nil, err + } + } + } + + port := 443 // snowflake default port + if len(portStr) > 0 { + port, err = strconv.Atoi(portStr) + if err != nil { + return nil, err + } + } + + cfg := &Config{ + Account: account, + User: user, + Password: password, + Role: role, + Host: host, + Port: port, + Protocol: protocol, + Warehouse: warehouse, + Database: database, + Schema: schema, + PrivateKey: privateKey, + Region: region, + Passcode: passcode, + Application: application, + } + return cfg, nil +} + +func parsePrivateKeyFromFile(path string) (*rsa.PrivateKey, error) { + bytes, err := os.ReadFile(path) + if err != nil { + return nil, err + } + block, _ := pem.Decode(bytes) + if block == nil { + return nil, errors.New("failed to parse PEM block containing the private key") + } + privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + return privateKey.(*rsa.PrivateKey), nil +} diff --git a/dsn_test.go b/dsn_test.go index 52a0c9d08..4d7fe76b5 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -3,9 +3,16 @@ package gosnowflake import ( + cr "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" "fmt" "net/url" + "os" "reflect" + "strconv" + "strings" "testing" "time" ) @@ -18,7 +25,6 @@ type tcParseDSN struct { } func TestParseDSN(t *testing.T) { - privKeyPKCS8 := generatePKCS8StringSupress(testPrivKey) privKeyPKCS1 := generatePKCS1String(testPrivKey) testcases := []tcParseDSN{ @@ -31,6 +37,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -44,6 +52,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -55,6 +65,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -67,6 +79,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -80,6 +94,8 @@ func TestParseDSN(t *testing.T) { ValidateDefaultParameters: ConfigBoolTrue, OCSPFailOpen: OCSPFailOpenTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -93,6 +109,8 @@ func TestParseDSN(t *testing.T) { ValidateDefaultParameters: ConfigBoolTrue, OCSPFailOpen: OCSPFailOpenTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -105,6 +123,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -117,6 +137,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -129,6 +151,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -142,6 +166,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -155,6 +181,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -168,9 +196,11 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, - err: ErrEmptyPassword, + err: errEmptyPassword(), }, { dsn: "@host:123/db/schema?account=ac&protocol=http", @@ -181,9 +211,11 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, - err: ErrEmptyUsername, + err: errEmptyUsername(), }, { dsn: "user:p@host:123/db/schema?protocol=http", @@ -194,9 +226,11 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, - err: ErrEmptyAccount, + err: errEmptyAccount(), }, { dsn: "u:p@a.snowflakecomputing.com/db/pa?account=a&protocol=https&role=r&timezone=UTC&warehouse=w", @@ -207,6 +241,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -219,6 +255,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -231,6 +269,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -243,6 +283,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -258,6 +300,22 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + }, + ocspMode: ocspModeFailOpen, + }, + { + dsn: "u:p@a?database=d&externalBrowserTimeout=20", + config: &Config{ + Account: "a", User: "u", Password: "p", + Protocol: "https", Host: "a.snowflakecomputing.com", Port: 443, + Database: "d", Schema: "", + ExternalBrowserTimeout: 20 * time.Second, + OCSPFailOpen: OCSPFailOpenTrue, + ValidateDefaultParameters: ConfigBoolTrue, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, }, ocspMode: ocspModeFailOpen, }, @@ -271,6 +329,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, }, @@ -282,6 +342,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: &SnowflakeError{ @@ -300,6 +362,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeInsecure, err: nil, @@ -314,6 +378,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -348,12 +414,12 @@ func TestParseDSN(t *testing.T) { { dsn: "u:u@/+/+?account=+&=0", config: &Config{}, - err: ErrEmptyAccount, + err: errEmptyAccount(), }, { dsn: "u:u@/+/+?account=+&=+&=+", config: &Config{}, - err: ErrEmptyAccount, + err: errEmptyAccount(), }, { dsn: "user%40%2F1:p%3A%40s@/db%2F?account=ac", @@ -363,6 +429,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -376,6 +444,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -394,6 +464,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -407,6 +479,8 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: &SnowflakeError{Number: ErrCodePrivateKeyParseError}, @@ -419,6 +493,8 @@ func TestParseDSN(t *testing.T) { Database: "db", Schema: "s", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -431,6 +507,8 @@ func TestParseDSN(t *testing.T) { Database: "db", Schema: "s", OCSPFailOpen: OCSPFailOpenFalse, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailClosed, err: nil, @@ -443,6 +521,8 @@ func TestParseDSN(t *testing.T) { Database: "db", Schema: "s", OCSPFailOpen: OCSPFailOpenFalse, InsecureMode: true, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeInsecure, err: nil, @@ -453,7 +533,9 @@ func TestParseDSN(t *testing.T) { Account: "account", User: "user", Password: "pass", Protocol: "https", Host: "account.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: ConfigBoolTrue, OCSPFailOpen: OCSPFailOpenTrue, - ClientTimeout: defaultClientTimeout, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -464,7 +546,9 @@ func TestParseDSN(t *testing.T) { Account: "account", User: "user", Password: "pass", Protocol: "https", Host: "account.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: ConfigBoolFalse, OCSPFailOpen: OCSPFailOpenTrue, - ClientTimeout: defaultClientTimeout, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, @@ -475,122 +559,210 @@ func TestParseDSN(t *testing.T) { Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: ConfigBoolFalse, OCSPFailOpen: OCSPFailOpenTrue, - ClientTimeout: defaultClientTimeout, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + }, + ocspMode: ocspModeFailOpen, + err: nil, + }, + { + dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&clientTimeout=300&jwtClientTimeout=45", + config: &Config{ + Account: "a", User: "u", Password: "p", + Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, + Database: "db", Schema: "s", ValidateDefaultParameters: ConfigBoolTrue, OCSPFailOpen: OCSPFailOpenTrue, + ClientTimeout: 300 * time.Second, + JWTClientTimeout: 45 * time.Second, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, err: nil, }, { - dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&clientTimeout=300", + dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&tmpDirPath=%2Ftmp", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: ConfigBoolTrue, OCSPFailOpen: OCSPFailOpenTrue, - ClientTimeout: 300 * time.Second, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + TmpDirPath: "/tmp", }, ocspMode: ocspModeFailOpen, err: nil, }, } + for _, at := range []AuthType{AuthTypeExternalBrowser, AuthTypeOAuth} { + testcases = append(testcases, tcParseDSN{ + dsn: fmt.Sprintf("@host:777/db/schema?account=ac&protocol=http&authenticator=%v", strings.ToLower(at.String())), + config: &Config{ + Account: "ac", User: "", Password: "", + Protocol: "http", Host: "host", Port: 777, + Database: "db", Schema: "schema", + OCSPFailOpen: OCSPFailOpenTrue, + ValidateDefaultParameters: ConfigBoolTrue, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + Authenticator: at, + }, + ocspMode: ocspModeFailOpen, + err: nil, + }) + } + + for _, at := range []AuthType{AuthTypeSnowflake, AuthTypeUsernamePasswordMFA, AuthTypeJwt} { + testcases = append(testcases, tcParseDSN{ + dsn: fmt.Sprintf("@host:888/db/schema?account=ac&protocol=http&authenticator=%v", strings.ToLower(at.String())), + config: &Config{ + Account: "ac", User: "", Password: "", + Protocol: "http", Host: "host", Port: 888, + Database: "db", Schema: "schema", + OCSPFailOpen: OCSPFailOpenTrue, + ValidateDefaultParameters: ConfigBoolTrue, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + Authenticator: at, + }, + ocspMode: ocspModeFailOpen, + err: errEmptyUsername(), + }) + } + + for _, at := range []AuthType{AuthTypeSnowflake, AuthTypeUsernamePasswordMFA} { + testcases = append(testcases, tcParseDSN{ + dsn: fmt.Sprintf("user@host:888/db/schema?account=ac&protocol=http&authenticator=%v", strings.ToLower(at.String())), + config: &Config{ + Account: "ac", User: "user", Password: "", + Protocol: "http", Host: "host", Port: 888, + Database: "db", Schema: "schema", + OCSPFailOpen: OCSPFailOpenTrue, + ValidateDefaultParameters: ConfigBoolTrue, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + Authenticator: at, + }, + ocspMode: ocspModeFailOpen, + err: errEmptyPassword(), + }) + } + for i, test := range testcases { - // t.Logf("Parsing testcase %d, DSN: %s", i, test.dsn) - cfg, err := ParseDSN(test.dsn) - switch { - case test.err == nil: - if err != nil { - t.Fatalf("%d: Failed to parse the DSN. dsn: %v, err: %v", i, test.dsn, err) - } - if test.config.Host != cfg.Host { - t.Fatalf("%d: Failed to match host. expected: %v, got: %v", - i, test.config.Host, cfg.Host) - } - if test.config.Account != cfg.Account { - t.Fatalf("%d: Failed to match account. expected: %v, got: %v", - i, test.config.Account, cfg.Account) - } - if test.config.User != cfg.User { - t.Fatalf("%d: Failed to match user. expected: %v, got: %v", - i, test.config.User, cfg.User) - } - if test.config.Password != cfg.Password { - t.Fatalf("%d: Failed to match password. expected: %v, got: %v", - i, test.config.Password, cfg.Password) - } - if test.config.Database != cfg.Database { - t.Fatalf("%d: Failed to match database. expected: %v, got: %v", - i, test.config.Database, cfg.Database) - } - if test.config.Schema != cfg.Schema { - t.Fatalf("%d: Failed to match schema. expected: %v, got: %v", - i, test.config.Schema, cfg.Schema) - } - if test.config.Warehouse != cfg.Warehouse { - t.Fatalf("%d: Failed to match warehouse. expected: %v, got: %v", - i, test.config.Warehouse, cfg.Warehouse) - } - if test.config.Role != cfg.Role { - t.Fatalf("%d: Failed to match role. expected: %v, got: %v", - i, test.config.Role, cfg.Role) - } - if test.config.Region != cfg.Region { - t.Fatalf("%d: Failed to match region. expected: %v, got: %v", - i, test.config.Region, cfg.Region) - } - if test.config.Protocol != cfg.Protocol { - t.Fatalf("%d: Failed to match protocol. expected: %v, got: %v", - i, test.config.Protocol, cfg.Protocol) - } - if test.config.Passcode != cfg.Passcode { - t.Fatalf("%d: Failed to match passcode. expected: %v, got: %v", - i, test.config.Passcode, cfg.Passcode) - } - if test.config.PasscodeInPassword != cfg.PasscodeInPassword { - t.Fatalf("%d: Failed to match passcodeInPassword. expected: %v, got: %v", - i, test.config.PasscodeInPassword, cfg.PasscodeInPassword) - } - if test.config.Authenticator != cfg.Authenticator { - t.Fatalf("%d: Failed to match Authenticator. expected: %v, got: %v", - i, test.config.Authenticator.String(), cfg.Authenticator.String()) - } - if test.config.Authenticator == AuthTypeOkta && *test.config.OktaURL != *cfg.OktaURL { - t.Fatalf("%d: Failed to match okta URL. expected: %v, got: %v", - i, test.config.OktaURL, cfg.OktaURL) - } - if test.config.OCSPFailOpen != cfg.OCSPFailOpen { - t.Fatalf("%d: Failed to match OCSPFailOpen. expected: %v, got: %v", - i, test.config.OCSPFailOpen, cfg.OCSPFailOpen) - } - if test.ocspMode != cfg.ocspMode() { - t.Fatalf("%d: Failed to match OCSPMode. expected: %v, got: %v", - i, test.ocspMode, cfg.ocspMode()) - } - if test.config.ValidateDefaultParameters != cfg.ValidateDefaultParameters { - t.Fatalf("%d: Failed to match ValidateDefaultParameters. expected: %v, got: %v", - i, test.config.ValidateDefaultParameters, cfg.ValidateDefaultParameters) - } - if test.config.ClientTimeout != cfg.ClientTimeout { - t.Fatalf("%d: Failed to match ClientTimeout. expected: %v, got: %v", - i, test.config.ClientTimeout, cfg.ClientTimeout) - } - case test.err != nil: - driverErrE, okE := test.err.(*SnowflakeError) - driverErrG, okG := err.(*SnowflakeError) - if okE && !okG || !okE && okG { - t.Fatalf("%d: Wrong error. expected: %v, got: %v", i, test.err, err) - } - if okE && okG { - if driverErrE.Number != driverErrG.Number { - t.Fatalf("%d: Wrong error number. expected: %v, got: %v", i, driverErrE.Number, driverErrG.Number) + t.Run(test.dsn, func(t *testing.T) { + cfg, err := ParseDSN(test.dsn) + switch { + case test.err == nil: + if err != nil { + t.Fatalf("%d: Failed to parse the DSN. dsn: %v, err: %v", i, test.dsn, err) + } + if test.config.Host != cfg.Host { + t.Fatalf("%d: Failed to match host. expected: %v, got: %v", + i, test.config.Host, cfg.Host) + } + if test.config.Account != cfg.Account { + t.Fatalf("%d: Failed to match account. expected: %v, got: %v", + i, test.config.Account, cfg.Account) + } + if test.config.User != cfg.User { + t.Fatalf("%d: Failed to match user. expected: %v, got: %v", + i, test.config.User, cfg.User) + } + if test.config.Password != cfg.Password { + t.Fatalf("%d: Failed to match password. expected: %v, got: %v", + i, test.config.Password, cfg.Password) + } + if test.config.Database != cfg.Database { + t.Fatalf("%d: Failed to match database. expected: %v, got: %v", + i, test.config.Database, cfg.Database) + } + if test.config.Schema != cfg.Schema { + t.Fatalf("%d: Failed to match schema. expected: %v, got: %v", + i, test.config.Schema, cfg.Schema) + } + if test.config.Warehouse != cfg.Warehouse { + t.Fatalf("%d: Failed to match warehouse. expected: %v, got: %v", + i, test.config.Warehouse, cfg.Warehouse) + } + if test.config.Role != cfg.Role { + t.Fatalf("%d: Failed to match role. expected: %v, got: %v", + i, test.config.Role, cfg.Role) } - } else { - t1 := reflect.TypeOf(err) - t2 := reflect.TypeOf(test.err) - if t1 != t2 { - t.Fatalf("%d: Wrong error. expected: %T:%v, got: %T:%v", i, test.err, test.err, err, err) + if test.config.Region != cfg.Region { + t.Fatalf("%d: Failed to match region. expected: %v, got: %v", + i, test.config.Region, cfg.Region) + } + if test.config.Protocol != cfg.Protocol { + t.Fatalf("%d: Failed to match protocol. expected: %v, got: %v", + i, test.config.Protocol, cfg.Protocol) + } + if test.config.Passcode != cfg.Passcode { + t.Fatalf("%d: Failed to match passcode. expected: %v, got: %v", + i, test.config.Passcode, cfg.Passcode) + } + if test.config.PasscodeInPassword != cfg.PasscodeInPassword { + t.Fatalf("%d: Failed to match passcodeInPassword. expected: %v, got: %v", + i, test.config.PasscodeInPassword, cfg.PasscodeInPassword) + } + if test.config.Authenticator != cfg.Authenticator { + t.Fatalf("%d: Failed to match Authenticator. expected: %v, got: %v", + i, test.config.Authenticator.String(), cfg.Authenticator.String()) + } + if test.config.Authenticator == AuthTypeOkta && *test.config.OktaURL != *cfg.OktaURL { + t.Fatalf("%d: Failed to match okta URL. expected: %v, got: %v", + i, test.config.OktaURL, cfg.OktaURL) + } + if test.config.OCSPFailOpen != cfg.OCSPFailOpen { + t.Fatalf("%d: Failed to match OCSPFailOpen. expected: %v, got: %v", + i, test.config.OCSPFailOpen, cfg.OCSPFailOpen) + } + if test.ocspMode != cfg.ocspMode() { + t.Fatalf("%d: Failed to match OCSPMode. expected: %v, got: %v", + i, test.ocspMode, cfg.ocspMode()) + } + if test.config.ValidateDefaultParameters != cfg.ValidateDefaultParameters { + t.Fatalf("%d: Failed to match ValidateDefaultParameters. expected: %v, got: %v", + i, test.config.ValidateDefaultParameters, cfg.ValidateDefaultParameters) + } + if test.config.ClientTimeout != cfg.ClientTimeout { + t.Fatalf("%d: Failed to match ClientTimeout. expected: %v, got: %v", + i, test.config.ClientTimeout, cfg.ClientTimeout) + } + if test.config.JWTClientTimeout != cfg.JWTClientTimeout { + t.Fatalf("%d: Failed to match JWTClientTimeout. expected: %v, got: %v", + i, test.config.JWTClientTimeout, cfg.JWTClientTimeout) + } + if test.config.ExternalBrowserTimeout != cfg.ExternalBrowserTimeout { + t.Fatalf("%d: Failed to match ExternalBrowserTimeout. expected: %v, got: %v", + i, test.config.ExternalBrowserTimeout, cfg.ExternalBrowserTimeout) + } + if test.config.TmpDirPath != cfg.TmpDirPath { + t.Fatalf("%v: Failed to match TmpDirPatch. expected: %v, got: %v", i, test.config.TmpDirPath, cfg.TmpDirPath) + } + case test.err != nil: + driverErrE, okE := test.err.(*SnowflakeError) + driverErrG, okG := err.(*SnowflakeError) + if okE && !okG || !okE && okG { + t.Fatalf("%d: Wrong error. expected: %v, got: %v", i, test.err, err) + } + if okE && okG { + if driverErrE.Number != driverErrG.Number { + t.Fatalf("%d: Wrong error number. expected: %v, got: %v", i, driverErrE.Number, driverErrG.Number) + } + } else { + t1 := reflect.TypeOf(err) + t2 := reflect.TypeOf(test.err) + if t1 != t2 { + t.Fatalf("%d: Wrong error. expected: %T:%v, got: %T:%v", i, test.err, test.err, err, err) + } } } - } + + }) } } @@ -603,7 +775,6 @@ type tcDSN struct { func TestDSN(t *testing.T) { tmfmt := "MM-DD-YYYY" testConnectionID := "abcd-0123-4567-1234" - testcases := []tcDSN{ { cfg: &Config{ @@ -641,7 +812,7 @@ func TestDSN(t *testing.T) { Region: "r", ConnectionID: testConnectionID, }, - err: ErrInvalidRegion, + err: errInvalidRegion(), }, { cfg: &Config{ @@ -672,6 +843,17 @@ func TestDSN(t *testing.T) { }, dsn: "u:p@a.r.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=r&validateDefaultParameters=true", }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a", + Region: "r", + ExternalBrowserTimeout: 20 * time.Second, + ConnectionID: testConnectionID, + }, + dsn: "u:p@a.r.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&externalBrowserTimeout=20&ocspFailOpen=true®ion=r&validateDefaultParameters=true", + }, { cfg: &Config{ User: "", @@ -679,7 +861,7 @@ func TestDSN(t *testing.T) { Account: "a", ConnectionID: testConnectionID, }, - err: ErrEmptyUsername, + err: errEmptyUsername(), }, { cfg: &Config{ @@ -688,16 +870,16 @@ func TestDSN(t *testing.T) { Account: "a", ConnectionID: testConnectionID, }, - err: ErrEmptyPassword, + err: errEmptyPassword(), }, { cfg: &Config{ User: "u", Password: "p", - Account: "", + Account: "a.e", ConnectionID: testConnectionID, }, - err: ErrEmptyAccount, + dsn: "u:p@a.e.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=e&validateDefaultParameters=true", }, { cfg: &Config{ @@ -726,7 +908,7 @@ func TestDSN(t *testing.T) { Region: "r", ConnectionID: testConnectionID, }, - err: ErrInvalidRegion, + err: errInvalidRegion(), }, { cfg: &Config{ @@ -749,13 +931,14 @@ func TestDSN(t *testing.T) { }, { cfg: &Config{ - User: "u", - Password: "p", - Account: "a", - Authenticator: AuthTypeExternalBrowser, - ConnectionID: testConnectionID, + User: "u", + Password: "p", + Account: "a", + Authenticator: AuthTypeExternalBrowser, + ClientStoreTemporaryCredential: ConfigBoolTrue, + ConnectionID: testConnectionID, }, - dsn: "u:p@a.snowflakecomputing.com:443?authenticator=externalbrowser&connectionId=abcd-0123-4567-1234&ocspFailOpen=true&validateDefaultParameters=true", + dsn: "u:p@a.snowflakecomputing.com:443?authenticator=externalbrowser&clientStoreTemporaryCredential=true&connectionId=abcd-0123-4567-1234&ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ @@ -872,43 +1055,101 @@ func TestDSN(t *testing.T) { Region: "r", ConnectionID: testConnectionID, }, - err: ErrInvalidRegion, + err: errInvalidRegion(), }, { cfg: &Config{ - User: "u", - Password: "p", - Account: "a.b.c", - ClientTimeout: 300 * time.Second, - ConnectionID: testConnectionID, + User: "u", + Password: "p", + Account: "a.b.c", + ClientTimeout: 300 * time.Second, + JWTClientTimeout: 60 * time.Second, + ConnectionID: testConnectionID, }, - dsn: "u:p@a.b.c.snowflakecomputing.com:443?clientTimeout=300&connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", + dsn: "u:p@a.b.c.snowflakecomputing.com:443?clientTimeout=300&connectionId=abcd-0123-4567-1234&jwtClientTimeout=60&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ - User: "u", - Password: "p", - Account: "a.e", - MonitoringFetcher: MonitoringFetcherConfig{ - QueryRuntimeThreshold: time.Second * 56, - }, + User: "u", + Password: "p", + Account: "a.b.c", + ClientTimeout: 300 * time.Second, + JWTExpireTimeout: 30 * time.Second, + ConnectionID: testConnectionID, + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?clientTimeout=300&connectionId=abcd-0123-4567-1234&jwtTimeout=30&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", + }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + Protocol: "http", ConnectionID: testConnectionID, }, - dsn: "u:p@a.e.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&monitoringFetcher_queryRuntimeThresholdMs=56000&ocspFailOpen=true®ion=e&validateDefaultParameters=true", + dsn: "u:p@a.b.c.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true&protocol=http®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ - User: "u", - Password: "p", - Account: "a.e", - MonitoringFetcher: MonitoringFetcherConfig{ - QueryRuntimeThreshold: time.Second * 56, - MaxDuration: time.Second * 14, - RetrySleepDuration: time.Millisecond * 45, - }, + User: "u", + Password: "p", + Account: "a.b.c", + Tracing: "debug", ConnectionID: testConnectionID, }, - dsn: "u:p@a.e.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&monitoringFetcher_maxDurationMs=14000&monitoringFetcher_queryRuntimeThresholdMs=56000&monitoringFetcher_retrySleepDurationMs=45&ocspFailOpen=true®ion=e&validateDefaultParameters=true", + dsn: "u:p@a.b.c.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=b.c&tracing=debug&validateDefaultParameters=true", + }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + Authenticator: AuthTypeUsernamePasswordMFA, + ClientRequestMfaToken: ConfigBoolTrue, + ConnectionID: testConnectionID, + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?authenticator=username_password_mfa&clientRequestMfaToken=true&connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", + }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + Authenticator: AuthTypeUsernamePasswordMFA, + ClientRequestMfaToken: ConfigBoolFalse, + ConnectionID: testConnectionID, + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?authenticator=username_password_mfa&clientRequestMfaToken=false&connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", + }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + Warehouse: "wh", + ConnectionID: testConnectionID, + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true&warehouse=wh", + }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + Token: "t", + ConnectionID: testConnectionID, + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=b.c&token=t&validateDefaultParameters=true", + }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + Authenticator: AuthTypeTokenAccessor, + ConnectionID: testConnectionID, + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?authenticator=tokenaccessor&connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ @@ -916,34 +1157,162 @@ func TestDSN(t *testing.T) { Password: "p", Account: "a.e", MonitoringFetcher: MonitoringFetcherConfig{ - QueryRuntimeThreshold: defaultMonitoringFetcherQueryMonitoringThreshold, - MaxDuration: defaultMonitoringFetcherMaxDuration, - RetrySleepDuration: defaultMonitoringFetcherRetrySleepDuration, + QueryRuntimeThreshold: time.Second * 20, }, ConnectionID: testConnectionID, }, - dsn: "u:p@a.e.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&ocspFailOpen=true®ion=e&validateDefaultParameters=true", + dsn: "u:p@a.e.snowflakecomputing.com:443?connectionId=abcd-0123-4567-1234&monitoringFetcher_queryRuntimeThresholdMs=20000&ocspFailOpen=true®ion=e&validateDefaultParameters=true", }, } for _, test := range testcases { - dsn, err := DSN(test.cfg) - if test.err == nil && err == nil { - if dsn != test.dsn { - t.Errorf("failed to get DSN. expected: %v, got:\n %v", test.dsn, dsn) + t.Run(test.dsn, func(t *testing.T) { + dsn, err := DSN(test.cfg) + if test.err == nil && err == nil { + if dsn != test.dsn { + t.Errorf("failed to get DSN. expected: %v, got:\n %v", test.dsn, dsn) + } + _, err := ParseDSN(dsn) + if err != nil { + t.Errorf("failed to parse DSN. dsn: %v, err: %v", dsn, err) + } } - _, err := ParseDSN(dsn) - if err != nil { - t.Errorf("failed to parse DSN. dsn: %v, err: %v", dsn, err) + if test.err != nil && err == nil { + t.Errorf("expected error. dsn: %v, err: %v", test.dsn, test.err) } - continue - } - if test.err != nil && err == nil { - t.Errorf("expected error. dsn: %v, err: %v", test.dsn, test.err) - continue + if err != nil && test.err == nil { + t.Errorf("failed to match. err: %v", err) + } + }) + } +} + +func TestParsePrivateKeyFromFileMissingFile(t *testing.T) { + _, err := parsePrivateKeyFromFile("nonexistent") + + if err == nil { + t.Error("should report error for nonexistent file") + } +} + +func TestParsePrivateKeyFromFileIncorrectData(t *testing.T) { + pemFile := createTmpFile("exampleKey.pem", []byte("gibberish")) + _, err := parsePrivateKeyFromFile(pemFile) + + if err == nil { + t.Error("should report error for wrong data in file") + } +} + +func TestParsePrivateKeyFromFile(t *testing.T) { + generatedKey, _ := rsa.GenerateKey(cr.Reader, 1024) + pemKey, _ := x509.MarshalPKCS8PrivateKey(generatedKey) + pemData := pem.EncodeToMemory( + &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: pemKey, + }, + ) + keyFile := createTmpFile("exampleKey.pem", pemData) + defer os.Remove(keyFile) + + parsedKey, err := parsePrivateKeyFromFile(keyFile) + if err != nil { + t.Errorf("unable to parse pam file from path: %v, err: %v", keyFile, err) + } else if !parsedKey.Equal(generatedKey) { + t.Errorf("generated key does not equal to parsed key from file\ngeneratedKey=%v\nparsedKey=%v", + generatedKey, parsedKey) + } +} + +func createTmpFile(fileName string, content []byte) string { + tempFile, _ := os.CreateTemp("", fileName) + tempFile.Write(content) + absolutePath := tempFile.Name() + return absolutePath +} + +type configParamToValue struct { + configParam string + value string +} + +func TestGetConfigFromEnv(t *testing.T) { + envMap := map[string]configParamToValue{ + "SF_TEST_ACCOUNT": {"Account", "account"}, + "SF_TEST_USER": {"User", "user"}, + "SF_TEST_PASSWORD": {"Password", "password"}, + "SF_TEST_ROLE": {"Role", "role"}, + "SF_TEST_HOST": {"Host", "host"}, + "SF_TEST_PORT": {"Port", "8080"}, + "SF_TEST_PROTOCOL": {"Protocol", "http"}, + "SF_TEST_WAREHOUSE": {"Warehouse", "warehouse"}, + "SF_TEST_DATABASE": {"Database", "database"}, + "SF_TEST_REGION": {"Region", "region"}, + "SF_TEST_PASSCODE": {"Passcode", "passcode"}, + "SF_TEST_SCHEMA": {"Schema", "schema"}, + "SF_TEST_APPLICATION": {"Application", "application"}, + } + var properties = make([]*ConfigParam, len(envMap)) + i := 0 + for key, ctv := range envMap { + os.Setenv(key, ctv.value) + cfgParam := ConfigParam{ctv.configParam, key, true} + properties[i] = &cfgParam + i++ + } + defer func() { + for key := range envMap { + os.Unsetenv(key) } - if err != nil && test.err == nil { - t.Errorf("failed to match. err: %v", err) - continue + }() + + cfg, err := GetConfigFromEnv(properties) + if err != nil { + t.Errorf("unable to parse env variables to Config, err: %v", err) + } + + err = checkConfig(*cfg, envMap) + if err != nil { + t.Error(err) + } +} + +func checkConfig(cfg Config, envMap map[string]configParamToValue) error { + appendError := func(errArray []string, envName string, expected string, received string) []string { + errArray = append(errArray, fmt.Sprintf("field %v expected value: %v, received value: %v", envName, expected, received)) + return errArray + } + + value := reflect.ValueOf(cfg) + typeOfCfg := value.Type() + cfgValues := make(map[string]interface{}, value.NumField()) + for i := 0; i < value.NumField(); i++ { + cfgValues[typeOfCfg.Field(i).Name] = value.Field(i).Interface() + } + + var errArray []string + for key, ctv := range envMap { + if ctv.configParam == "Port" { + if portStr := strconv.Itoa(cfgValues[ctv.configParam].(int)); portStr != ctv.value { + errArray = appendError(errArray, key, ctv.value, cfgValues[ctv.configParam].(string)) + } + } else if cfgValues[ctv.configParam] != ctv.value { + errArray = appendError(errArray, key, ctv.value, cfgValues[ctv.configParam].(string)) } } + + if errArray != nil { + return fmt.Errorf(strings.Join(errArray, "\n")) + } + + return nil +} + +func TestConfigValidateTmpDirPath(t *testing.T) { + cfg := &Config{ + TmpDirPath: "/not/existing", + } + if err := cfg.Validate(); err == nil { + t.Fatalf("Should fail on not existing TmpDirPath") + } } diff --git a/errors.go b/errors.go index bdd1bdd67..366e5dcb5 100644 --- a/errors.go +++ b/errors.go @@ -66,7 +66,7 @@ func (se *SnowflakeError) generateTelemetryExceptionData() *telemetryData { } func (se *SnowflakeError) sendExceptionTelemetry(sc *snowflakeConn, data *telemetryData) error { - if sc != nil { + if sc != nil && sc.telemetry != nil { return sc.telemetry.addLog(data) } return nil // TODO oob telemetry @@ -82,7 +82,7 @@ func (se *SnowflakeError) exceptionTelemetry(sc *snowflakeConn) *SnowflakeError // return populated error fields replacing the default response func populateErrorFields(code int, data *execResponse) *SnowflakeError { - err := ErrUnknownError + err := errUnknownError() if code != -1 { err.Number = code } @@ -299,32 +299,44 @@ const ( errMsgAsyncWithNoResults = "async with no results" ) -var ( - // ErrEmptyAccount is returned if a DNS doesn't include account parameter. - ErrEmptyAccount = &SnowflakeError{ +// Returned if a DNS doesn't include account parameter. +func errEmptyAccount() *SnowflakeError { + return &SnowflakeError{ Number: ErrCodeEmptyAccountCode, Message: "account is empty", } - // ErrEmptyUsername is returned if a DNS doesn't include user parameter. - ErrEmptyUsername = &SnowflakeError{ +} + +// Returned if a DNS doesn't include user parameter. +func errEmptyUsername() *SnowflakeError { + return &SnowflakeError{ Number: ErrCodeEmptyUsernameCode, Message: "user is empty", } - // ErrEmptyPassword is returned if a DNS doesn't include password parameter. - ErrEmptyPassword = &SnowflakeError{ +} + +// Returned if a DNS doesn't include password parameter. +func errEmptyPassword() *SnowflakeError { + return &SnowflakeError{ Number: ErrCodeEmptyPasswordCode, - Message: "password is empty"} + Message: "password is empty", + } +} - // ErrInvalidRegion is returned if a DSN's implicit region from account parameter and explicit region parameter conflict. - ErrInvalidRegion = &SnowflakeError{ +// Returned if a DSN's implicit region from account parameter and explicit region parameter conflict. +func errInvalidRegion() *SnowflakeError { + return &SnowflakeError{ Number: ErrCodeRegionOverlap, - Message: "two regions specified"} + Message: "two regions specified", + } +} - // ErrUnknownError is returned if the server side returns an error without meaningful message. - ErrUnknownError = &SnowflakeError{ +// Returned if the server side returns an error without meaningful message. +func errUnknownError() *SnowflakeError { + return &SnowflakeError{ Number: -1, SQLState: "-1", Message: "an unknown server side error occurred", QueryID: "-1", } -) +} diff --git a/htap_test.go b/htap_test.go index ad58b8b78..39fae90f0 100644 --- a/htap_test.go +++ b/htap_test.go @@ -101,37 +101,36 @@ func trimWhitespaces(s string) string { ) } -// Not released yet, we should add them after they are released -// func TestAddingQueryContextCacheEntry(t *testing.T) { -// runSnowflakeConnTest(t, func(sc *snowflakeConn) { -// t.Run("First query (may be on empty cache)", func(t *testing.T) { -// entriesBefore := sc.queryContextCache.entries -// if _, err := sc.Query("SELECT 1", nil); err != nil { -// t.Fatalf("cannot query. %v", err) -// } -// entriesAfter := sc.queryContextCache.entries +func TestAddingQueryContextCacheEntry(t *testing.T) { + runSnowflakeConnTest(t, func(sc *snowflakeConn) { + t.Run("First query (may be on empty cache)", func(t *testing.T) { + entriesBefore := sc.queryContextCache.entries + if _, err := sc.Query("SELECT 1", nil); err != nil { + t.Fatalf("cannot query. %v", err) + } + entriesAfter := sc.queryContextCache.entries -// if !containsNewEntries(entriesAfter, entriesBefore) { -// t.Error("no new entries added to the query context cache") -// } -// }) + if !containsNewEntries(entriesAfter, entriesBefore) { + t.Error("no new entries added to the query context cache") + } + }) -// t.Run("Second query (cache should not be empty)", func(t *testing.T) { -// entriesBefore := sc.queryContextCache.entries -// if len(entriesBefore) == 0 { -// t.Fatalf("cache should not be empty after first query") -// } -// if _, err := sc.Query("SELECT 1", nil); err != nil { -// t.Fatalf("cannot query. %v", err) -// } -// entriesAfter := sc.queryContextCache.entries + t.Run("Second query (cache should not be empty)", func(t *testing.T) { + entriesBefore := sc.queryContextCache.entries + if len(entriesBefore) == 0 { + t.Fatalf("cache should not be empty after first query") + } + if _, err := sc.Query("SELECT 1", nil); err != nil { + t.Fatalf("cannot query. %v", err) + } + entriesAfter := sc.queryContextCache.entries -// if !containsNewEntries(entriesAfter, entriesBefore) { -// t.Error("no new entries added to the query context cache") -// } -// }) -// }) -// } + if !containsNewEntries(entriesAfter, entriesBefore) { + t.Error("no new entries added to the query context cache") + } + }) + }) +} func containsNewEntries(entriesAfter []queryContextEntry, entriesBefore []queryContextEntry) bool { if len(entriesAfter) > len(entriesBefore) { diff --git a/query.go b/query.go index 882e1b62a..50473e16f 100644 --- a/query.go +++ b/query.go @@ -125,6 +125,11 @@ type execResponseData struct { Command string `json:"command,omitempty"` Kind string `json:"kind,omitempty"` Operation string `json:"operation,omitempty"` + + // HTAP + QueryContext struct { + Entries []queryContextEntry `json:"entries,omitempty"` + } `json:"queryContext,omitempty"` } type execResponse struct { diff --git a/rows.go b/rows.go index 71e87caa9..314981a94 100644 --- a/rows.go +++ b/rows.go @@ -45,9 +45,17 @@ type snowflakeRows struct { err error errChannel chan error monitoring *monitoringResult + location *time.Location asyncRequestID UUID } +func (rows *snowflakeRows) getLocation() *time.Location { + if rows.location == nil && rows.sc != nil && rows.sc.cfg != nil { + rows.location = getCurrentLocation(rows.sc.cfg.Params) + } + return rows.location +} + type snowflakeValue interface{} type chunkRowType struct { @@ -233,11 +241,7 @@ func (rows *snowflakeRows) Next(dest []driver.Value) (err error) { for i, n := 0, len(row.RowSet); i < n; i++ { // could move to chunk downloader so that each go routine // can convert data - var loc *time.Location - if rows.sc != nil { - loc = getCurrentLocation(rows.sc.cfg.Params) - } - err = stringToValue(&dest[i], rows.ChunkDownloader.getRowType()[i], row.RowSet[i], loc) + err = stringToValue(&dest[i], rows.ChunkDownloader.getRowType()[i], row.RowSet[i], rows.getLocation()) if err != nil { return err } diff --git a/statement_test.go b/statement_test.go index 5d1554202..340e1e1d6 100644 --- a/statement_test.go +++ b/statement_test.go @@ -1,4 +1,5 @@ // Copyright (c) 2020-2022 Snowflake Computing Inc. All rights reserved. +//lint:file-ignore SA1019 Ignore deprecated methods. We should leave them as-is to keep backward compatibility. package gosnowflake @@ -7,7 +8,10 @@ import ( "database/sql" "database/sql/driver" "fmt" + "net/http" + "net/url" "testing" + "time" ) func openDB(t *testing.T) *sql.DB { @@ -17,20 +21,30 @@ func openDB(t *testing.T) *sql.DB { if db, err = sql.Open("sigmacomputing+gosnowflake", dsn); err != nil { t.Fatalf("failed to open db. %v, err: %v", dsn, err) } + return db } -func TestGetQueryID(t *testing.T) { - db := openDB(t) - defer db.Close() +func openConn(t *testing.T) *sql.Conn { + var db *sql.DB + var conn *sql.Conn + var err error - ctx := context.TODO() - conn, err := db.Conn(ctx) - if err != nil { - t.Error(err) + if db, err = sql.Open("sigmacomputing+gosnowflake", dsn); err != nil { + t.Fatalf("failed to open db. %v, err: %v", dsn, err) + } + if conn, err = db.Conn(context.Background()); err != nil { + t.Fatalf("failed to open connection: %v", err) } + return conn +} - if err = conn.Raw(func(x interface{}) error { +func TestGetQueryID(t *testing.T) { + ctx := context.Background() + conn := openConn(t) + defer conn.Close() + + if err := conn.Raw(func(x interface{}) error { rows, err := x.(driver.QueryerContext).QueryContext(ctx, "select 1", nil) if err != nil { return err @@ -71,7 +85,7 @@ func TestEmitQueryID(t *testing.T) { cnt := 0 var idx int var v string - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQueryContext(ctx, fmt.Sprintf(selectRandomGenerator, numrows)) defer rows.Close() @@ -144,9 +158,10 @@ func TestE2EFetchResultByID(t *testing.T) { } func TestWithDescribeOnly(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { ctx := WithDescribeOnly(context.Background()) rows := dbt.mustQueryContext(ctx, selectVariousTypes) + defer rows.Close() cols, err := rows.Columns() if err != nil { t.Error(err) @@ -166,10 +181,243 @@ func TestWithDescribeOnly(t *testing.T) { }) } +func TestCallStatement(t *testing.T) { + runDBTest(t, func(dbt *DBTest) { + in1 := float64(1) + in2 := string("[2,3]") + expected := "1 \"[2,3]\" [2,3]" + var out string + dbt.exec("ALTER SESSION SET USE_STATEMENT_TYPE_CALL_FOR_STORED_PROC_CALLS = true") + dbt.mustExec("create or replace procedure " + + "TEST_SP_CALL_STMT_ENABLED(in1 float, in2 variant) " + + "returns string language javascript as $$ " + + "let res = snowflake.execute({sqlText: 'select ? c1, ? c2', binds:[IN1, JSON.stringify(IN2)]}); " + + "res.next(); " + + "return res.getColumnValueAsString(1) + ' ' + res.getColumnValueAsString(2) + ' ' + IN2; " + + "$$;") + stmt, err := dbt.conn.PrepareContext(context.Background(), "call TEST_SP_CALL_STMT_ENABLED(?, to_variant(?))") + if err != nil { + dbt.Errorf("failed to prepare query: %v", err) + } + defer stmt.Close() + err = stmt.QueryRow(in1, in2).Scan(&out) + if err != nil { + dbt.Errorf("failed to scan: %v", err) + } + if expected != out { + dbt.Errorf("expected: %s, got: %s", expected, out) + } + dbt.mustExec("drop procedure if exists TEST_SP_CALL_STMT_ENABLED(float, variant)") + }) +} + +func TestStmtExec(t *testing.T) { + ctx := context.Background() + conn := openConn(t) + defer conn.Close() + + if _, err := conn.ExecContext(ctx, `create or replace table test_table(col1 int, col2 int)`); err != nil { + t.Fatalf("failed to create table: %v", err) + } + + if err := conn.Raw(func(x interface{}) error { + stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "insert into test_table values (1, 2)") + if err != nil { + t.Error(err) + } + _, err = stmt.(*snowflakeStmt).Exec(nil) + if err != nil { + t.Error(err) + } + _, err = stmt.(*snowflakeStmt).Query(nil) + if err != nil { + t.Error(err) + } + return nil + }); err != nil { + t.Fatalf("failed to drop table: %v", err) + } + + if _, err := conn.ExecContext(ctx, "drop table if exists test_table"); err != nil { + t.Fatalf("failed to drop table: %v", err) + } +} + +func getStatusSuccessButInvalidJSONfunc(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ time.Duration) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, + }, nil +} + +func TestUnitCheckQueryStatus(t *testing.T) { + sc := getDefaultSnowflakeConn() + ctx := context.Background() + qid := NewUUID() + + sr := &snowflakeRestful{ + FuncGet: getStatusSuccessButInvalidJSONfunc, + TokenAccessor: getSimpleTokenAccessor(), + } + sc.rest = sr + _, err := sc.checkQueryStatus(ctx, qid.String()) + if err == nil { + t.Fatal("invalid json. should have failed") + } + sc.rest.FuncGet = funcGetQueryRespFail + _, err = sc.checkQueryStatus(ctx, qid.String()) + if err == nil { + t.Fatal("should have failed") + } + + sc.rest.FuncGet = funcGetQueryRespError + _, err = sc.checkQueryStatus(ctx, qid.String()) + if err == nil { + t.Fatal("should have failed") + } + driverErr, ok := err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + if driverErr.Number != ErrQueryStatus { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrQueryStatus, driverErr.Number) + } +} + +func TestStatementQueryIdForQueries(t *testing.T) { + ctx := context.Background() + conn := openConn(t) + defer conn.Close() + + testcases := []struct { + name string + f func(stmt driver.Stmt) (driver.Rows, error) + }{ + { + "query", + func(stmt driver.Stmt) (driver.Rows, error) { + return stmt.Query(nil) + }, + }, + { + "queryContext", + func(stmt driver.Stmt) (driver.Rows, error) { + return stmt.(driver.StmtQueryContext).QueryContext(ctx, nil) + }, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + err := conn.Raw(func(x any) error { + stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "SELECT 1") + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() != "" { + t.Error("queryId should be empty before executing any query") + } + firstQuery, err := tc.f(stmt) + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() == "" { + t.Error("queryId should not be empty after executing query") + } + if stmt.(SnowflakeStmt).GetQueryID() != firstQuery.(SnowflakeRows).GetQueryID() { + t.Error("queryId should be equal among query result and prepared statement") + } + secondQuery, err := tc.f(stmt) + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() == "" { + t.Error("queryId should not be empty after executing query") + } + if stmt.(SnowflakeStmt).GetQueryID() != secondQuery.(SnowflakeRows).GetQueryID() { + t.Error("queryId should be equal among query result and prepared statement") + } + return nil + }) + if err != nil { + t.Fatal(err) + } + }) + } +} + +func TestStatementQueryIdForExecs(t *testing.T) { + ctx := context.Background() + runDBTest(t, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE TestStatementQueryIdForExecs (v INTEGER)") + defer dbt.mustExec("DROP TABLE IF EXISTS TestStatementQueryIdForExecs") + + testcases := []struct { + name string + f func(stmt driver.Stmt) (driver.Result, error) + }{ + { + "exec", + func(stmt driver.Stmt) (driver.Result, error) { + return stmt.Exec(nil) + }, + }, + { + "execContext", + func(stmt driver.Stmt) (driver.Result, error) { + return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) + }, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + err := dbt.conn.Raw(func(x any) error { + stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "INSERT INTO TestStatementQueryIdForExecs VALUES (1)") + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() != "" { + t.Error("queryId should be empty before executing any query") + } + firstExec, err := tc.f(stmt) + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() == "" { + t.Error("queryId should not be empty after executing query") + } + if stmt.(SnowflakeStmt).GetQueryID() != firstExec.(SnowflakeResult).GetQueryID() { + t.Error("queryId should be equal among query result and prepared statement") + } + secondExec, err := tc.f(stmt) + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() == "" { + t.Error("queryId should not be empty after executing query") + } + if stmt.(SnowflakeStmt).GetQueryID() != secondExec.(SnowflakeResult).GetQueryID() { + t.Error("queryId should be equal among query result and prepared statement") + } + return nil + }) + if err != nil { + t.Fatal(err) + } + }) + } + }) +} + func TestWithQueryTag(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runDBTest(t, func(dbt *DBTest) { testQueryTag := "TEST QUERY TAG" ctx := WithQueryTag(context.Background(), testQueryTag) + ctx, fn := context.WithCancel(ctx) + // For whatever reason we have to cancel the context explicitly + // To prevent sql conn.Close() from hanging. + defer fn() // This query itself will be part of the history and will have the query tag rows := dbt.mustQueryContext( diff --git a/submit_sync_test.go b/submit_sync_test.go index 0baad913f..1f19a2017 100644 --- a/submit_sync_test.go +++ b/submit_sync_test.go @@ -50,8 +50,9 @@ func TestSubmitQuerySync(t *testing.T) { // Set a long threshold to prevent the monitoring fetch from kicking in. MonitoringFetcher: MonitoringFetcherConfig{QueryRuntimeThreshold: 1 * time.Hour}, }, - rest: sr, - telemetry: testTelemetry, + rest: sr, + telemetry: testTelemetry, + queryContextCache: (&queryContextCache{}).init(), } res, err := sc.SubmitQuerySync(context.TODO(), "") @@ -129,8 +130,9 @@ func TestSubmitQuerySyncQueryComplete(t *testing.T) { // Set a long threshold to prevent the monitoring fetch from kicking in. MonitoringFetcher: MonitoringFetcherConfig{QueryRuntimeThreshold: 1 * time.Hour}, }, - rest: sr, - telemetry: testTelemetry, + rest: sr, + telemetry: testTelemetry, + queryContextCache: (&queryContextCache{}).init(), } res, err := sc.SubmitQuerySync(context.TODO(), "") diff --git a/telemetry_test.go b/telemetry_test.go index 49826512f..54ab12c85 100644 --- a/telemetry_test.go +++ b/telemetry_test.go @@ -476,3 +476,209 @@ func TestAddLogError(t *testing.T) { t.Fatal("should have failed") } } + +func funcPostTelemetryRespFail(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool) (*http.Response, error) { + return nil, errors.New("failed to upload metrics to telemetry") +} + +func TestTelemetryError(t *testing.T) { + config, err := ParseDSN(dsn) + if err != nil { + t.Error(err) + } + sc, err := buildSnowflakeConn(context.Background(), *config) + if err != nil { + t.Fatal(err) + } + if err = authenticateWithConfig(sc); err != nil { + t.Fatal(err) + } + sr := &snowflakeRestful{ + FuncPost: funcPostTelemetryRespFail, + TokenAccessor: getSimpleTokenAccessor(), + } + st := &snowflakeTelemetry{ + sr: sr, + mutex: &sync.Mutex{}, + enabled: true, + flushSize: defaultFlushSize, + } + + if err = st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err != nil { + t.Fatal(err) + } + + err = st.sendBatch() + if err == nil { + t.Fatal("should have failed") + } +} + +func TestTelemetryDisabledOnBadResponse(t *testing.T) { + config, err := ParseDSN(dsn) + if err != nil { + t.Error(err) + } + sc, err := buildSnowflakeConn(context.Background(), *config) + if err != nil { + t.Fatal(err) + } + if err = authenticateWithConfig(sc); err != nil { + t.Fatal(err) + } + sr := &snowflakeRestful{ + FuncPost: postTestAppBadGatewayError, + TokenAccessor: getSimpleTokenAccessor(), + } + st := &snowflakeTelemetry{ + sr: sr, + mutex: &sync.Mutex{}, + enabled: true, + flushSize: defaultFlushSize, + } + + if err = st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err != nil { + t.Fatal(err) + } + err = st.sendBatch() + if err == nil { + t.Fatal("should have failed") + } + if st.enabled == true { + t.Fatal("telemetry should be disabled") + } + + st.enabled = true + st.sr.FuncPost = postTestQueryNotExecuting + if err = st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err != nil { + t.Fatal(err) + } + err = st.sendBatch() + if err == nil { + t.Fatal("should have failed") + } + if st.enabled == true { + t.Fatal("telemetry should be disabled") + } + + st.enabled = true + st.sr.FuncPost = postTestSuccessButInvalidJSON + if err = st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err != nil { + t.Fatal(err) + } + err = st.sendBatch() + if err == nil { + t.Fatal("should have failed") + } + if st.enabled == true { + t.Fatal("telemetry should be disabled") + } +} + +func TestTelemetryDisabled(t *testing.T) { + config, err := ParseDSN(dsn) + if err != nil { + t.Error(err) + } + sc, err := buildSnowflakeConn(context.Background(), *config) + if err != nil { + t.Fatal(err) + } + if err = authenticateWithConfig(sc); err != nil { + t.Fatal(err) + } + sr := &snowflakeRestful{ + FuncPost: postTestAppBadGatewayError, + TokenAccessor: getSimpleTokenAccessor(), + } + st := &snowflakeTelemetry{ + sr: sr, + mutex: &sync.Mutex{}, + enabled: false, // disable + flushSize: defaultFlushSize, + } + if err = st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err == nil { + t.Fatal("should have failed") + } + st.enabled = true + if err = st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err != nil { + t.Fatal(err) + } + st.enabled = false + err = st.sendBatch() + if err == nil { + t.Fatal("should have failed") + } +} + +func TestAddLogError(t *testing.T) { + config, err := ParseDSN(dsn) + if err != nil { + t.Error(err) + } + sc, err := buildSnowflakeConn(context.Background(), *config) + if err != nil { + t.Fatal(err) + } + if err = authenticateWithConfig(sc); err != nil { + t.Fatal(err) + } + + sr := &snowflakeRestful{ + FuncPost: funcPostTelemetryRespFail, + TokenAccessor: getSimpleTokenAccessor(), + } + + st := &snowflakeTelemetry{ + sr: sr, + mutex: &sync.Mutex{}, + enabled: true, + flushSize: 1, + } + + if err = st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err == nil { + t.Fatal("should have failed") + } +} diff --git a/util.go b/util.go index ff37f9dc5..5c2c676c4 100644 --- a/util.go +++ b/util.go @@ -5,7 +5,9 @@ package gosnowflake import ( "context" "database/sql/driver" + "fmt" "io" + "os" "strings" "sync" "time" @@ -272,3 +274,14 @@ func escapeForCSV(value string) string { } return value } + +// GetFromEnv is used to get the value of an environment variable from the system +func GetFromEnv(name string, failOnMissing bool) (string, error) { + if value := os.Getenv(name); value != "" { + return value, nil + } + if failOnMissing { + return "", fmt.Errorf("%v environment variable is not set", name) + } + return "", nil +}