diff --git a/sql/type.go b/sql/type.go index c3aee00d9..e082abd17 100644 --- a/sql/type.go +++ b/sql/type.go @@ -302,8 +302,8 @@ func (t numberT) Type() query.Type { // SQL implements Type interface. func (t numberT) SQL(v interface{}) (sqltypes.Value, error) { - if _, ok := v.(nullT); ok { - return sqltypes.NULL, nil + if v == nil { + return sqltypes.MakeTrusted(t.t, nil), nil } switch t.t { @@ -428,8 +428,8 @@ var TimestampLayouts = []string{ // SQL implements Type interface. func (t timestampT) SQL(v interface{}) (sqltypes.Value, error) { - if _, ok := v.(nullT); ok { - return sqltypes.NULL, nil + if v == nil { + return sqltypes.MakeTrusted(sqltypes.Timestamp, nil), nil } v, err := t.Convert(v) @@ -504,8 +504,8 @@ func (t dateT) Type() query.Type { } func (t dateT) SQL(v interface{}) (sqltypes.Value, error) { - if _, ok := v.(nullT); ok { - return sqltypes.NULL, nil + if v == nil { + return sqltypes.MakeTrusted(sqltypes.Timestamp, nil), nil } v, err := t.Convert(v) @@ -561,8 +561,8 @@ func (t textT) Type() query.Type { // SQL implements Type interface. func (t textT) SQL(v interface{}) (sqltypes.Value, error) { - if _, ok := v.(nullT); ok { - return sqltypes.NULL, nil + if v == nil { + return sqltypes.MakeTrusted(sqltypes.Text, nil), nil } v, err := t.Convert(v) @@ -598,8 +598,8 @@ func (t booleanT) Type() query.Type { // SQL implements Type interface. func (t booleanT) SQL(v interface{}) (sqltypes.Value, error) { - if _, ok := v.(nullT); ok { - return sqltypes.NULL, nil + if v == nil { + return sqltypes.MakeTrusted(sqltypes.Bit, nil), nil } b := []byte{'0'} @@ -670,8 +670,8 @@ func (t blobT) Type() query.Type { // SQL implements Type interface. func (t blobT) SQL(v interface{}) (sqltypes.Value, error) { - if _, ok := v.(nullT); ok { - return sqltypes.NULL, nil + if v == nil { + return sqltypes.MakeTrusted(sqltypes.Blob, nil), nil } v, err := t.Convert(v) @@ -714,8 +714,8 @@ func (t jsonT) Type() query.Type { // SQL implements Type interface. func (t jsonT) SQL(v interface{}) (sqltypes.Value, error) { - if _, ok := v.(nullT); ok { - return sqltypes.NULL, nil + if v == nil { + return sqltypes.MakeTrusted(sqltypes.TypeJSON, nil), nil } v, err := t.Convert(v) @@ -760,10 +760,6 @@ func (t tupleT) Type() query.Type { } func (t tupleT) SQL(v interface{}) (sqltypes.Value, error) { - if _, ok := v.(nullT); ok { - return sqltypes.NULL, nil - } - return sqltypes.Value{}, fmt.Errorf("unable to convert tuple type to SQL") } @@ -825,8 +821,8 @@ func (t arrayT) Type() query.Type { } func (t arrayT) SQL(v interface{}) (sqltypes.Value, error) { - if _, ok := v.(nullT); ok { - return sqltypes.NULL, nil + if v == nil { + return sqltypes.MakeTrusted(sqltypes.TypeJSON, nil), nil } v, err := t.Convert(v) diff --git a/sql/type_test.go b/sql/type_test.go index 370e4062c..6e0333c14 100644 --- a/sql/type_test.go +++ b/sql/type_test.go @@ -13,7 +13,7 @@ func TestIsNull(t *testing.T) { require.True(t, IsNull(nil)) n := numberT{sqltypes.Uint64} - require.Equal(t, sqltypes.NULL, mustSQL(n.SQL(Null))) + require.Equal(t, sqltypes.MakeTrusted(sqltypes.Uint64, nil), mustSQL(n.SQL(nil))) require.Equal(t, sqltypes.NewUint64(0), mustSQL(n.SQL(uint64(0)))) }