Skip to content

Commit

Permalink
fix: properly handle driver.Valuer and type:json
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Dec 27, 2021
1 parent 8a97257 commit a17454a
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 3 deletions.
49 changes: 47 additions & 2 deletions internal/dbtest/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dbtest_test
import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"os"
Expand Down Expand Up @@ -222,7 +223,8 @@ func TestDB(t *testing.T) {
{testScanSingleRowByRow},
{testScanRows},
{testRunInTx},
{testInsertIface},
{testJSONInterface},
{testJSONValuer},
{testSelectBool},
{testFKViolation},
{testInterfaceAny},
Expand Down Expand Up @@ -712,7 +714,7 @@ func testJSONSpecialChars(t *testing.T, db *bun.DB) {
}
}

func testInsertIface(t *testing.T, db *bun.DB) {
func testJSONInterface(t *testing.T, db *bun.DB) {
type Model struct {
ID int
Value interface{} `bun:"type:json"`
Expand All @@ -734,6 +736,49 @@ func testInsertIface(t *testing.T, db *bun.DB) {
require.NoError(t, err)
}

type JSONValue struct {
str string
}

var _ driver.Valuer = (*JSONValue)(nil)

func (v *JSONValue) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
v.str = string(src)
case string:
v.str = src
default:
panic("not reached")
}
return nil
}

func (v *JSONValue) Value() (driver.Value, error) {
return `"driver.Value"`, nil
}

func testJSONValuer(t *testing.T, db *bun.DB) {
type Model struct {
ID int
Value JSONValue `bun:"type:json"`
}

ctx := context.Background()

err := db.ResetModel(ctx, (*Model)(nil))
require.NoError(t, err)

model := new(Model)
_, err = db.NewInsert().Model(model).Exec(ctx)
require.NoError(t, err)

model2 := new(Model)
err = db.NewSelect().Model(model2).Scan(ctx)
require.NoError(t, err)
require.Equal(t, `"driver.Value"`, model2.Value.str)
}

func testSelectBool(t *testing.T, db *bun.DB) {
var flag bool
err := db.NewSelect().ColumnExpr("1").Scan(ctx, &flag)
Expand Down
14 changes: 13 additions & 1 deletion schema/append_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,24 @@ func FieldAppender(dialect Dialect, field *Field) AppenderFunc {
return appendMsgpack
}

fieldType := field.StructField.Type

switch strings.ToUpper(field.UserSQLType) {
case sqltype.JSON, sqltype.JSONB:
if fieldType.Implements(driverValuerType) {
return appendDriverValue
}

if fieldType.Kind() != reflect.Ptr {
if reflect.PtrTo(fieldType).Implements(driverValuerType) {
return addrAppender(appendDriverValue)
}
}

return AppendJSONValue
}

return Appender(dialect, field.StructField.Type)
return Appender(dialect, fieldType)
}

func Appender(dialect Dialect, typ reflect.Type) AppenderFunc {
Expand Down

0 comments on commit a17454a

Please sign in to comment.