Skip to content

Commit

Permalink
SNOW-645253 Handle binding named parameters (snowflakedb#850)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus committed Jul 18, 2023
1 parent fec38ba commit 67ec6cf
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 18 deletions.
9 changes: 8 additions & 1 deletion bind_uploader.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ func getBindValues(bindings []driver.NamedValue) (map[string]execBindParameter,
if t == nullType || t == unSupportedType {
t = textType // if null or not supported, pass to GS as text
}
bindValues[strconv.Itoa(idx)] = execBindParameter{
bindValues[bindingName(binding, idx)] = execBindParameter{
Type: t.String(),
Value: val,
}
Expand All @@ -250,6 +250,13 @@ func getBindValues(bindings []driver.NamedValue) (map[string]execBindParameter,
return bindValues, nil
}

func bindingName(nv driver.NamedValue, idx int) string {
if nv.Name != "" {
return nv.Name
}
return strconv.Itoa(idx)
}

func arrayBindValueCount(bindValues []driver.NamedValue) int {
if !isArrayBind(bindValues) {
return 0
Expand Down
111 changes: 94 additions & 17 deletions bindings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ func TestBindingInterface(t *testing.T) {
if !rows.Next() {
dbt.Error("failed to query")
}
var v1, v2, v3, v4, v5, v6 interface{}
var v1, v2, v3, v4, v5, v6 any
if err := rows.Scan(&v1, &v2, &v3, &v4, &v5, &v6); err != nil {
dbt.Errorf("failed to scan: %#v", err)
}
Expand All @@ -327,7 +327,7 @@ func TestBindingInterfaceString(t *testing.T) {
if !rows.Next() {
dbt.Error("failed to query")
}
var v1, v2, v3, v4, v5, v6 interface{}
var v1, v2, v3, v4, v5, v6 any
if err := rows.Scan(&v1, &v2, &v3, &v4, &v5, &v6); err != nil {
dbt.Errorf("failed to scan: %#v", err)
}
Expand All @@ -348,7 +348,7 @@ func TestBindingInterfaceString(t *testing.T) {
}

func TestBulkArrayBindingInterfaceNil(t *testing.T) {
nilArray := make([]interface{}, 1)
nilArray := make([]any, 1)

runTests(t, dsn, func(dbt *DBTest) {
dbt.mustExec(createTableSQL)
Expand Down Expand Up @@ -413,22 +413,22 @@ func TestBulkArrayBindingInterfaceNil(t *testing.T) {
}

func TestBulkArrayBindingInterface(t *testing.T) {
intArray := make([]interface{}, 3)
intArray := make([]any, 3)
intArray[0] = int32(100)
intArray[1] = int32(200)

fltArray := make([]interface{}, 3)
fltArray := make([]any, 3)
fltArray[0] = float64(0.1)
fltArray[2] = float64(5.678)

boolArray := make([]interface{}, 3)
boolArray := make([]any, 3)
boolArray[1] = false
boolArray[2] = true

strArray := make([]interface{}, 3)
strArray := make([]any, 3)
strArray[2] = "test3"

byteArray := make([]interface{}, 3)
byteArray := make([]any, 3)
byteArray[0] = []byte{0x01, 0x02, 0x03}
byteArray[2] = []byte{0x07, 0x08, 0x09}

Expand Down Expand Up @@ -504,23 +504,23 @@ func TestBulkArrayBindingInterfaceDateTimeTimestamp(t *testing.T) {
if err != nil {
t.Error(err)
}
ntzArray := make([]interface{}, 3)
ntzArray := make([]any, 3)
ntzArray[0] = now
ntzArray[1] = now.Add(1)

ltzArray := make([]interface{}, 3)
ltzArray := make([]any, 3)
ltzArray[1] = now.Add(2).In(loc)
ltzArray[2] = now.Add(3).In(loc)

tzArray := make([]interface{}, 3)
tzArray := make([]any, 3)
tzArray[0] = tz.Add(4).In(loc)
tzArray[2] = tz.Add(5).In(loc)

dtArray := make([]interface{}, 3)
dtArray := make([]any, 3)
dtArray[0] = tz.Add(6).In(loc)
dtArray[1] = now.Add(7).In(loc)

tmArray := make([]interface{}, 3)
tmArray := make([]any, 3)
tmArray[1] = now.Add(8).In(loc)
tmArray[2] = now.Add(9).In(loc)

Expand Down Expand Up @@ -810,8 +810,8 @@ func TestBulkArrayMultiPartBindingWithNull(t *testing.T) {
numRows := endNum - startNum

// Define the integer and string arrays
intArr := make([]interface{}, numRows)
stringArr := make([]interface{}, numRows)
intArr := make([]any, numRows)
stringArr := make([]any, numRows)
for i := startNum; i < endNum; i++ {
intArr[i-startNum] = i
stringArr[i-startNum] = fmt.Sprint(i)
Expand Down Expand Up @@ -867,7 +867,7 @@ func TestFunctionParameters(t *testing.T) {
testcases := []struct {
testDesc string
paramType string
input interface{}
input any
nullResult bool
}{
{"textAndNullStringResultInNull", "text", sql.NullString{}, true},
Expand Down Expand Up @@ -912,7 +912,7 @@ func TestFunctionParameters(t *testing.T) {
if !rows.Next() {
t.Fatal()
} else {
var r1 interface{}
var r1 any
err = rows.Scan(&r1)
if err != nil {
t.Fatal(err)
Expand All @@ -930,3 +930,80 @@ func TestFunctionParameters(t *testing.T) {
}
})
}

func TestVariousBindingModes(t *testing.T) {
testcases := []struct {
testDesc string
paramType string
input any
isNil bool
}{
{"textAndString", "text", "string", false},
{"numberAndInteger", "number", 123, false},
{"floatAndFloat", "float", 123.01, false},
{"booleanAndBoolean", "boolean", true, false},
{"dateAndTime", "date", time.Now().Truncate(24 * time.Hour), false},
{"datetimeAndTime", "datetime", time.Now(), false},
{"timeAndTime", "time", "12:34:56", false},
{"timestampAndTime", "timestamp", time.Now(), false},
{"timestamp_ntzAndTime", "timestamp_ntz", time.Now(), false},
{"timestamp_ltzAndTime", "timestamp_ltz", time.Now(), false},
{"timestamp_tzAndTime", "timestamp_tz", time.Now(), false},
{"textAndNullString", "text", sql.NullString{}, true},
{"numberAndNullInt64", "number", sql.NullInt64{}, true},
{"floatAndNullFloat64", "float", sql.NullFloat64{}, true},
{"booleanAndAndNullBool", "boolean", sql.NullBool{}, true},
{"dateAndTypedNullTime", "date", TypedNullTime{sql.NullTime{}, DateType}, true},
{"datetimeAndTypedNullTime", "datetime", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true},
{"timeAndTypedNullTime", "time", TypedNullTime{sql.NullTime{}, TimeType}, true},
{"timestampAndTypedNullTime", "timestamp", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true},
{"timestamp_ntzAndTypedNullTime", "timestamp_ntz", TypedNullTime{sql.NullTime{}, TimestampNTZType}, true},
{"timestamp_ltzAndTypedNullTime", "timestamp_ltz", TypedNullTime{sql.NullTime{}, TimestampLTZType}, true},
{"timestamp_tzAndTypedNullTime", "timestamp_tz", TypedNullTime{sql.NullTime{}, TimestampTZType}, true},
}

bindingModes := []struct {
param string
query string
transform func(any) any
}{
{
param: "?",
transform: func(v any) any { return v },
},
{
param: ":1",
transform: func(v any) any { return v },
},
{
param: ":param",
transform: func(v any) any { return sql.Named("param", v) },
},
}

runTests(t, dsn, func(dbt *DBTest) {
for _, tc := range testcases {
for _, bindingMode := range bindingModes {
t.Run(tc.testDesc+" "+bindingMode.param, func(t *testing.T) {
query := fmt.Sprintf(`CREATE OR REPLACE TABLE BINDING_MODES(param1 %v)`, tc.paramType)
dbt.mustExec(query)
if _, err := dbt.db.Exec(fmt.Sprintf("INSERT INTO BINDING_MODES VALUES (%v)", bindingMode.param), bindingMode.transform(tc.input)); err != nil {
t.Fatal(err)
}
if tc.isNil {
query = "SELECT * FROM BINDING_MODES WHERE param1 IS NULL"
} else {
query = fmt.Sprintf("SELECT * FROM BINDING_MODES WHERE param1 = %v", bindingMode.param)
}
rows, err := dbt.db.Query(query, bindingMode.transform(tc.input))
if err != nil {
t.Fatal(err)
}
if !rows.Next() {
t.Fatal("Expected to return a row")
}
})
}
}
})
}

0 comments on commit 67ec6cf

Please sign in to comment.