From a17454ac6b95b2a2e927d0c4e4aee96494108389 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Mon, 27 Dec 2021 15:17:04 +0200 Subject: [PATCH] fix: properly handle driver.Valuer and type:json --- internal/dbtest/db_test.go | 49 ++++++++++++++++++++++++++++++++++++-- schema/append_value.go | 14 ++++++++++- 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index 902845a2..a42f9998 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -3,6 +3,7 @@ package dbtest_test import ( "context" "database/sql" + "database/sql/driver" "encoding/json" "errors" "os" @@ -222,7 +223,8 @@ func TestDB(t *testing.T) { {testScanSingleRowByRow}, {testScanRows}, {testRunInTx}, - {testInsertIface}, + {testJSONInterface}, + {testJSONValuer}, {testSelectBool}, {testFKViolation}, {testInterfaceAny}, @@ -712,7 +714,7 @@ func testJSONSpecialChars(t *testing.T, db *bun.DB) { } } -func testInsertIface(t *testing.T, db *bun.DB) { +func testJSONInterface(t *testing.T, db *bun.DB) { type Model struct { ID int Value interface{} `bun:"type:json"` @@ -734,6 +736,49 @@ func testInsertIface(t *testing.T, db *bun.DB) { require.NoError(t, err) } +type JSONValue struct { + str string +} + +var _ driver.Valuer = (*JSONValue)(nil) + +func (v *JSONValue) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + v.str = string(src) + case string: + v.str = src + default: + panic("not reached") + } + return nil +} + +func (v *JSONValue) Value() (driver.Value, error) { + return `"driver.Value"`, nil +} + +func testJSONValuer(t *testing.T, db *bun.DB) { + type Model struct { + ID int + Value JSONValue `bun:"type:json"` + } + + ctx := context.Background() + + err := db.ResetModel(ctx, (*Model)(nil)) + require.NoError(t, err) + + model := new(Model) + _, err = db.NewInsert().Model(model).Exec(ctx) + require.NoError(t, err) + + model2 := new(Model) + err = db.NewSelect().Model(model2).Scan(ctx) + require.NoError(t, err) + require.Equal(t, `"driver.Value"`, model2.Value.str) +} + func testSelectBool(t *testing.T, db *bun.DB) { var flag bool err := db.NewSelect().ColumnExpr("1").Scan(ctx, &flag) diff --git a/schema/append_value.go b/schema/append_value.go index 5697e35e..04d1636b 100644 --- a/schema/append_value.go +++ b/schema/append_value.go @@ -58,12 +58,24 @@ func FieldAppender(dialect Dialect, field *Field) AppenderFunc { return appendMsgpack } + fieldType := field.StructField.Type + switch strings.ToUpper(field.UserSQLType) { case sqltype.JSON, sqltype.JSONB: + if fieldType.Implements(driverValuerType) { + return appendDriverValue + } + + if fieldType.Kind() != reflect.Ptr { + if reflect.PtrTo(fieldType).Implements(driverValuerType) { + return addrAppender(appendDriverValue) + } + } + return AppendJSONValue } - return Appender(dialect, field.StructField.Type) + return Appender(dialect, fieldType) } func Appender(dialect Dialect, typ reflect.Type) AppenderFunc {