Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

Commit

Permalink
sql: correctly handle nulls in SQL type conversion
Browse files Browse the repository at this point in the history
Signed-off-by: Miguel Molina <miguel@erizocosmi.co>
  • Loading branch information
erizocosmico committed Jun 19, 2019
1 parent 5f48ea3 commit 084fa59
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 21 deletions.
36 changes: 16 additions & 20 deletions sql/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ func (t numberT) Type() query.Type {

// SQL implements Type interface.
func (t numberT) SQL(v interface{}) (sqltypes.Value, error) {
if _, ok := v.(nullT); ok {
return sqltypes.NULL, nil
if v == nil {
return sqltypes.MakeTrusted(t.t, nil), nil
}

switch t.t {
Expand Down Expand Up @@ -428,8 +428,8 @@ var TimestampLayouts = []string{

// SQL implements Type interface.
func (t timestampT) SQL(v interface{}) (sqltypes.Value, error) {
if _, ok := v.(nullT); ok {
return sqltypes.NULL, nil
if v == nil {
return sqltypes.MakeTrusted(sqltypes.Timestamp, nil), nil
}

v, err := t.Convert(v)
Expand Down Expand Up @@ -504,8 +504,8 @@ func (t dateT) Type() query.Type {
}

func (t dateT) SQL(v interface{}) (sqltypes.Value, error) {
if _, ok := v.(nullT); ok {
return sqltypes.NULL, nil
if v == nil {
return sqltypes.MakeTrusted(sqltypes.Timestamp, nil), nil
}

v, err := t.Convert(v)
Expand Down Expand Up @@ -561,8 +561,8 @@ func (t textT) Type() query.Type {

// SQL implements Type interface.
func (t textT) SQL(v interface{}) (sqltypes.Value, error) {
if _, ok := v.(nullT); ok {
return sqltypes.NULL, nil
if v == nil {
return sqltypes.MakeTrusted(sqltypes.Text, nil), nil
}

v, err := t.Convert(v)
Expand Down Expand Up @@ -598,8 +598,8 @@ func (t booleanT) Type() query.Type {

// SQL implements Type interface.
func (t booleanT) SQL(v interface{}) (sqltypes.Value, error) {
if _, ok := v.(nullT); ok {
return sqltypes.NULL, nil
if v == nil {
return sqltypes.MakeTrusted(sqltypes.Bit, nil), nil
}

b := []byte{'0'}
Expand Down Expand Up @@ -670,8 +670,8 @@ func (t blobT) Type() query.Type {

// SQL implements Type interface.
func (t blobT) SQL(v interface{}) (sqltypes.Value, error) {
if _, ok := v.(nullT); ok {
return sqltypes.NULL, nil
if v == nil {
return sqltypes.MakeTrusted(sqltypes.Blob, nil), nil
}

v, err := t.Convert(v)
Expand Down Expand Up @@ -714,8 +714,8 @@ func (t jsonT) Type() query.Type {

// SQL implements Type interface.
func (t jsonT) SQL(v interface{}) (sqltypes.Value, error) {
if _, ok := v.(nullT); ok {
return sqltypes.NULL, nil
if v == nil {
return sqltypes.MakeTrusted(sqltypes.TypeJSON, nil), nil
}

v, err := t.Convert(v)
Expand Down Expand Up @@ -760,10 +760,6 @@ func (t tupleT) Type() query.Type {
}

func (t tupleT) SQL(v interface{}) (sqltypes.Value, error) {
if _, ok := v.(nullT); ok {
return sqltypes.NULL, nil
}

return sqltypes.Value{}, fmt.Errorf("unable to convert tuple type to SQL")
}

Expand Down Expand Up @@ -825,8 +821,8 @@ func (t arrayT) Type() query.Type {
}

func (t arrayT) SQL(v interface{}) (sqltypes.Value, error) {
if _, ok := v.(nullT); ok {
return sqltypes.NULL, nil
if v == nil {
return sqltypes.MakeTrusted(sqltypes.TypeJSON, nil), nil
}

v, err := t.Convert(v)
Expand Down
2 changes: 1 addition & 1 deletion sql/type_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func TestIsNull(t *testing.T) {
require.True(t, IsNull(nil))

n := numberT{sqltypes.Uint64}
require.Equal(t, sqltypes.NULL, mustSQL(n.SQL(Null)))
require.Equal(t, sqltypes.MakeTrusted(sqltypes.Uint64, nil), mustSQL(n.SQL(nil)))
require.Equal(t, sqltypes.NewUint64(0), mustSQL(n.SQL(uint64(0))))
}

Expand Down

0 comments on commit 084fa59

Please sign in to comment.