Skip to content

Commit

Permalink
[Feature] Allow clients to specify any/null DataType explicitly
Browse files Browse the repository at this point in the history
- Allow clients to specify DataTypeNull explicitly (#50)

* first cut of explict DataTypeNull

* make lint

* more tests

- Allow client to specify any data type explicitly (#28)

* dataTypeMode should work for all types

* fix test

* allow explicit binaryType declaration even if we're already using binaryType

* silly attempt with *SnowflakeDataType

* switch from *SnowflakeDataType --> SnowflakeDataType

* bindings_test.go passes

* add test for corner case

* fix lint

* fix null-handling tests

* only need connection.CheckNamedValue

* use checked cast instead of switch on type
  • Loading branch information
GregOwen authored and madisonchamberlain committed Jul 29, 2022
1 parent 55b6498 commit 8473cd8
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 98 deletions.
18 changes: 9 additions & 9 deletions bind_uploader.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,24 +211,24 @@ func (sc *snowflakeConn) processBindings(
}

func getBindValues(bindings []driver.NamedValue) (map[string]execBindParameter, error) {
tsmode := timestampNtzType
idx := 1
var err error
bindValues := make(map[string]execBindParameter, len(bindings))
var dataType SnowflakeDataType
for _, binding := range bindings {
t := goTypeToSnowflake(binding.Value, tsmode)
if t == changeType {
tsmode, err = dataTypeMode(binding.Value)
if err != nil {
return nil, err
}
} else {
switch binding.Value.(type) {
case SnowflakeDataType:
// This binding is just specifying the type for subsequent bindings
dataType = binding.Value.(SnowflakeDataType)
default:
// This binding is an actual parameter for the query
t := goTypeToSnowflake(binding.Value, dataType)
var val interface{}
if t == sliceType {
// retrieve array binding data
t, val = snowflakeArrayToString(&binding, false)
} else {
val, err = valueToString(binding.Value, tsmode)
val, err = valueToString(binding.Value, dataType)
if err != nil {
return nil, err
}
Expand Down
26 changes: 22 additions & 4 deletions bindings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@ const (
selectAllSQLBulkArrayDateTimeTimestamp = "select * from test_bulk_array_DateTimeTimestamp ORDER BY 1"
)

func TestBindingNull(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE test (id int, c1 STRING, c2 BOOLEAN)")
_, err := dbt.db.Exec("INSERT INTO test VALUES (1, ?, ?)",
DataTypeText, "hello",
DataTypeNull, nil,
)
if err != nil {
dbt.Fatal(err)
}
dbt.mustExec("DROP TABLE IF EXISTS test")
})
}

func TestBindingFloat64(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
types := [2]string{"FLOAT", "DOUBLE"}
Expand Down Expand Up @@ -150,19 +164,23 @@ func TestBindingDateTimeTimestamp(t *testing.T) {

func TestBindingBinary(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
dbt.mustExec("CREATE OR REPLACE TABLE bintest (id int, b binary)")
dbt.mustExec("CREATE OR REPLACE TABLE bintest (id int, b binary, c binary)")
var b = []byte{0x01, 0x02, 0x03}
dbt.mustExec("INSERT INTO bintest(id,b) VALUES(1, ?)", DataTypeBinary, b)
rows := dbt.mustQuery("SELECT b FROM bintest WHERE id=?", 1)
dbt.mustExec("INSERT INTO bintest(id,b,c) VALUES(1, ?, ?)", DataTypeBinary, b, DataTypeBinary, b)
rows := dbt.mustQuery("SELECT b, c FROM bintest WHERE id=?", 1)
defer rows.Close()
if rows.Next() {
var rb []byte
if err := rows.Scan(&rb); err != nil {
var rc []byte
if err := rows.Scan(&rb, &rc); err != nil {
dbt.Errorf("failed to scan data. err: %v", err)
}
if !bytes.Equal(b, rb) {
dbt.Errorf("failed to match data. expected: %v, got: %v", b, rb)
}
if !bytes.Equal(b, rc) {
dbt.Errorf("failed to match data. expected: %v, got: %v", b, rc)
}
} else {
dbt.Errorf("no data")
}
Expand Down
3 changes: 2 additions & 1 deletion ci/scripts/test_component.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ if [[ -n "$GITHUB_WORKFLOW" ]]; then
fi
env | grep SNOWFLAKE | grep -v PASS | sort
cd $TOPDIR
go test -timeout 30m -race $COVFLAGS -v .
go test -timeout 30m -race $COVFLAGS -v .

5 changes: 5 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,11 @@ func (sc *snowflakeConn) Ping(ctx context.Context) error {
// CheckNamedValue determines which types are handled by this driver aside from
// the instances captured by driver.Value
func (sc *snowflakeConn) CheckNamedValue(nv *driver.NamedValue) error {
if _, ok := nv.Value.(SnowflakeDataType); ok {
// Pass SnowflakeDataType args through without modification so that we can
// distinguish them from arguments of type []byte
return nil
}
if supported := supportedArrayBind(nv); !supported {
return driver.ErrSkip
}
Expand Down
64 changes: 34 additions & 30 deletions converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,32 +38,36 @@ const (
)

// goTypeToSnowflake translates Go data type to Snowflake data type.
func goTypeToSnowflake(v driver.Value, tsmode snowflakeType) snowflakeType {
switch t := v.(type) {
case int64:
return fixedType
case float64:
return realType
case bool:
return booleanType
case string:
return textType
case []byte:
if tsmode == binaryType {
return binaryType // may be redundant but ensures BINARY type
}
if t == nil {
return nullType // invalid byte array. won't take as BINARY
}
if len(t) != 1 {
func goTypeToSnowflake(v driver.Value, dataType SnowflakeDataType) snowflakeType {
if dataType == nil {
switch t := v.(type) {
case SnowflakeDataType:
return changeType
case int64:
return fixedType
case float64:
return realType
case bool:
return booleanType
case string:
return textType
case []byte:
if t == nil {
return nullType // invalid byte array. won't take as BINARY
}
// If we don't have an explicit data type, binary blobs are unsupported
return unSupportedType
case time.Time:
// Default timestamp type
return timestampNtzType
}
if _, err := dataTypeMode(t); err != nil {
} else {
// If we have an explicit type, use it
ty, err := clientTypeToInternal(dataType)
if err != nil {
return unSupportedType
}
return changeType
case time.Time:
return tsmode
return ty
}
if supportedArrayBind(&driver.NamedValue{Value: v}) {
return sliceType
Expand Down Expand Up @@ -96,7 +100,7 @@ func snowflakeTypeToGo(dbtype snowflakeType, scale int64) reflect.Type {

// valueToString converts arbitrary golang type to a string. This is mainly used in binding data with placeholders
// in queries.
func valueToString(v driver.Value, tsmode snowflakeType) (*string, error) {
func valueToString(v driver.Value, dataType SnowflakeDataType) (*string, error) {
logger.Debugf("TYPE: %v, %v", reflect.TypeOf(v), reflect.ValueOf(v))
if v == nil {
return nil, nil
Expand All @@ -120,7 +124,7 @@ func valueToString(v driver.Value, tsmode snowflakeType) (*string, error) {
return nil, nil
}
if bd, ok := v.([]byte); ok {
if tsmode == binaryType {
if dataType != nil && dataType.Equals(DataTypeBinary) {
s := hex.EncodeToString(bd)
return &s, nil
}
Expand All @@ -129,21 +133,21 @@ func valueToString(v driver.Value, tsmode snowflakeType) (*string, error) {
s := v1.String()
return &s, nil
case reflect.Struct:
if tm, ok := v.(time.Time); ok {
switch tsmode {
case dateType:
if tm, ok := v.(time.Time); ok && dataType != nil {
switch {
case dataType.Equals(DataTypeDate):
_, offset := tm.Zone()
tm = tm.Add(time.Second * time.Duration(offset))
s := fmt.Sprintf("%d", tm.Unix()*1000)
return &s, nil
case timeType:
case dataType.Equals(DataTypeTime):
s := fmt.Sprintf("%d",
(tm.Hour()*3600+tm.Minute()*60+tm.Second())*1e9+tm.Nanosecond())
return &s, nil
case timestampNtzType, timestampLtzType:
case dataType.Equals(DataTypeTimestampNtz) || dataType.Equals(DataTypeTimestampLtz):
s := fmt.Sprintf("%d", tm.UnixNano())
return &s, nil
case timestampTzType:
case dataType.Equals(DataTypeTimestampTz):
_, offset := tm.Zone()
s := fmt.Sprintf("%v %v", tm.UnixNano(), offset/60+1440)
return &s, nil
Expand Down
8 changes: 4 additions & 4 deletions converter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func stringFloatToDecimal(src string, scale int64) (decimal128.Num, bool) {

type tcGoTypeToSnowflake struct {
in interface{}
tmode snowflakeType
tmode SnowflakeDataType
out snowflakeType
}

Expand Down Expand Up @@ -128,7 +128,7 @@ func TestSnowflakeTypeToGo(t *testing.T) {

func TestValueToString(t *testing.T) {
v := cmplx.Sqrt(-5 + 12i) // should never happen as Go sql package must have already validated.
_, err := valueToString(v, nullType)
_, err := valueToString(v, nil)
if err == nil {
t.Errorf("should raise error: %v", v)
}
Expand All @@ -138,15 +138,15 @@ func TestValueToString(t *testing.T) {
utcTime := time.Date(2019, 2, 6, 22, 17, 31, 123456789, time.UTC)
expectedUnixTime := "1549491451123456789" // time.Unix(1549491451, 123456789).Format(time.RFC3339) == "2019-02-06T14:17:31-08:00"

if s, err := valueToString(localTime, timestampLtzType); err != nil {
if s, err := valueToString(localTime, DataTypeTimestampLtz); err != nil {
t.Error("unexpected error")
} else if s == nil {
t.Errorf("expected '%v', got %v", expectedUnixTime, s)
} else if *s != expectedUnixTime {
t.Errorf("expected '%v', got '%v'", expectedUnixTime, *s)
}

if s, err := valueToString(utcTime, timestampLtzType); err != nil {
if s, err := valueToString(utcTime, DataTypeTimestampLtz); err != nil {
t.Error("unexpected error")
} else if s == nil {
t.Errorf("expected '%v', got %v", expectedUnixTime, s)
Expand Down
93 changes: 60 additions & 33 deletions datatype.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package gosnowflake
import (
"bytes"
"database/sql"
"database/sql/driver"
"fmt"
)

Expand All @@ -25,8 +24,8 @@ const (
binaryType
timeType
booleanType
// the following are not snowflake types per se but internal types
nullType
// the following are not snowflake types per se but internal types
sliceType
changeType
unSupportedType
Expand Down Expand Up @@ -55,58 +54,86 @@ func getSnowflakeType(typ string) snowflakeType {
return nullType
}

// SnowflakeDataType is the type used by clients to explicitly indicate the type
// of an argument to ExecContext and friends. We use a separate public-facing
// type rather than a Go primitive type so that we can always differentiate
// between args that indicate type and args that are values.
type SnowflakeDataType []byte

// Equals checks if dt and o represent the same type indicator
func (dt SnowflakeDataType) Equals(o SnowflakeDataType) bool {
return bytes.Equal(([]byte)(dt), ([]byte)(o))
}

var (
// DataTypeFixed is a FIXED datatype.
DataTypeFixed = []byte{fixedType.Byte()}
DataTypeFixed = SnowflakeDataType{fixedType.Byte()}
// DataTypeReal is a REAL datatype.
DataTypeReal = []byte{realType.Byte()}
DataTypeReal = SnowflakeDataType{realType.Byte()}
// DataTypeText is a TEXT datatype.
DataTypeText = []byte{textType.Byte()}
DataTypeText = SnowflakeDataType{textType.Byte()}
// DataTypeDate is a Date datatype.
DataTypeDate = []byte{dateType.Byte()}
DataTypeDate = SnowflakeDataType{dateType.Byte()}
// DataTypeVariant is a TEXT datatype.
DataTypeVariant = []byte{variantType.Byte()}
DataTypeVariant = SnowflakeDataType{variantType.Byte()}
// DataTypeTimestampLtz is a TIMESTAMP_LTZ datatype.
DataTypeTimestampLtz = []byte{timestampLtzType.Byte()}
DataTypeTimestampLtz = SnowflakeDataType{timestampLtzType.Byte()}
// DataTypeTimestampNtz is a TIMESTAMP_NTZ datatype.
DataTypeTimestampNtz = []byte{timestampNtzType.Byte()}
DataTypeTimestampNtz = SnowflakeDataType{timestampNtzType.Byte()}
// DataTypeTimestampTz is a TIMESTAMP_TZ datatype.
DataTypeTimestampTz = []byte{timestampTzType.Byte()}
DataTypeTimestampTz = SnowflakeDataType{timestampTzType.Byte()}
// DataTypeObject is a OBJECT datatype.
DataTypeObject = []byte{objectType.Byte()}
DataTypeObject = SnowflakeDataType{objectType.Byte()}
// DataTypeArray is a ARRAY datatype.
DataTypeArray = []byte{arrayType.Byte()}
DataTypeArray = SnowflakeDataType{arrayType.Byte()}
// DataTypeBinary is a BINARY datatype.
DataTypeBinary = []byte{binaryType.Byte()}
DataTypeBinary = SnowflakeDataType{binaryType.Byte()}
// DataTypeTime is a TIME datatype.
DataTypeTime = []byte{timeType.Byte()}
DataTypeTime = SnowflakeDataType{timeType.Byte()}
// DataTypeBoolean is a BOOLEAN datatype.
DataTypeBoolean = []byte{booleanType.Byte()}
DataTypeBoolean = SnowflakeDataType{booleanType.Byte()}
// DataTypeNull is a NULL datatype.
DataTypeNull = SnowflakeDataType{nullType.Byte()}
)

// dataTypeMode returns the subsequent data type in a string representation.
func dataTypeMode(v driver.Value) (tsmode snowflakeType, err error) {
if bd, ok := v.([]byte); ok {
func clientTypeToInternal(cType SnowflakeDataType) (iType snowflakeType, err error) {
if cType != nil {
switch {
case bytes.Equal(bd, DataTypeDate):
tsmode = dateType
case bytes.Equal(bd, DataTypeTime):
tsmode = timeType
case bytes.Equal(bd, DataTypeTimestampLtz):
tsmode = timestampLtzType
case bytes.Equal(bd, DataTypeTimestampNtz):
tsmode = timestampNtzType
case bytes.Equal(bd, DataTypeTimestampTz):
tsmode = timestampTzType
case bytes.Equal(bd, DataTypeBinary):
tsmode = binaryType
case cType.Equals(DataTypeFixed):
iType = fixedType
case cType.Equals(DataTypeReal):
iType = realType
case cType.Equals(DataTypeText):
iType = textType
case cType.Equals(DataTypeDate):
iType = dateType
case cType.Equals(DataTypeVariant):
iType = variantType
case cType.Equals(DataTypeTimestampLtz):
iType = timestampLtzType
case cType.Equals(DataTypeTimestampNtz):
iType = timestampNtzType
case cType.Equals(DataTypeTimestampTz):
iType = timestampTzType
case cType.Equals(DataTypeObject):
iType = objectType
case cType.Equals(DataTypeArray):
iType = arrayType
case cType.Equals(DataTypeBinary):
iType = binaryType
case cType.Equals(DataTypeTime):
iType = timeType
case cType.Equals(DataTypeBoolean):
iType = booleanType
case cType.Equals(DataTypeNull):
iType = nullType
default:
return nullType, fmt.Errorf(errMsgInvalidByteArray, v)
return nullType, fmt.Errorf(errMsgInvalidByteArray, ([]byte)(cType))
}
} else {
return nullType, fmt.Errorf(errMsgInvalidByteArray, v)
return nullType, fmt.Errorf(errMsgInvalidByteArray, nil)
}
return tsmode, nil
return iType, nil
}

// SnowflakeParameter includes the columns output from SHOW PARAMETER command.
Expand Down
Loading

0 comments on commit 8473cd8

Please sign in to comment.