From 3b321b08601c4b8dc6bcaa24adea20875883ac14 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Tue, 18 Jan 2022 09:52:12 +0200 Subject: [PATCH] fix: properly discover json.Marshaler on ptr field --- internal/dbtest/db_test.go | 29 +++++++++++++++++++ internal/dbtest/query_test.go | 6 ++++ .../testdata/snapshots/TestQuery-mariadb-106 | 1 + .../testdata/snapshots/TestQuery-mysql5-106 | 1 + .../testdata/snapshots/TestQuery-mysql8-106 | 1 + .../testdata/snapshots/TestQuery-pg-106 | 1 + .../testdata/snapshots/TestQuery-pgx-106 | 1 + .../testdata/snapshots/TestQuery-sqlite-106 | 1 + schema/append_value.go | 3 ++ schema/reflect.go | 1 + schema/scan.go | 2 ++ 11 files changed, 47 insertions(+) create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-mariadb-106 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-mysql5-106 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-mysql8-106 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-pg-106 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-pgx-106 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-sqlite-106 diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index 3c00d0ba..a3a2dd53 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -240,6 +240,7 @@ func TestDB(t *testing.T) { {testTxScanAndCount}, {testEmbedModelValue}, {testEmbedModelPointer}, + {testJSONMarshaler}, } testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { @@ -1155,3 +1156,31 @@ func testEmbedTypeField(t *testing.T, db *bun.DB) { require.NoError(t, err) require.Equal(t, *m1, m2) } + +type JSONField struct { + Foo string `json:"foo"` +} + +func (f *JSONField) MarshalJSON() ([]byte, error) { + return []byte(`{"foo": "bar"}`), nil +} + +func testJSONMarshaler(t *testing.T, db *bun.DB) { + type Model struct { + Field *JSONField + } + + ctx := context.Background() + + err := db.ResetModel(ctx, (*Model)(nil)) + require.NoError(t, err) + + m1 := &Model{Field: new(JSONField)} + _, err = db.NewInsert().Model(m1).Exec(ctx) + require.NoError(t, err) + + var m2 Model + err = db.NewSelect().Model(&m2).Scan(ctx) + require.NoError(t, err) + require.Equal(t, "bar", m2.Field.Foo) +} diff --git a/internal/dbtest/query_test.go b/internal/dbtest/query_test.go index 02a77754..e7cfe2e2 100644 --- a/internal/dbtest/query_test.go +++ b/internal/dbtest/query_test.go @@ -656,6 +656,12 @@ func TestQuery(t *testing.T) { } return db.NewInsert().Model(&Model{ID: ID("embed")}) }, + func(db *bun.DB) schema.QueryAppender { + type Model struct { + Raw *json.RawMessage `bun:",nullzero"` + } + return db.NewInsert().Model(new(Model)) + }, } timeRE := regexp.MustCompile(`'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d+(\+\d{2}:\d{2})?'`) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mariadb-106 b/internal/dbtest/testdata/snapshots/TestQuery-mariadb-106 new file mode 100644 index 00000000..bc6a3a09 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mariadb-106 @@ -0,0 +1 @@ +INSERT INTO `models` (`raw`) VALUES (DEFAULT) RETURNING `raw` diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mysql5-106 b/internal/dbtest/testdata/snapshots/TestQuery-mysql5-106 new file mode 100644 index 00000000..bef7d892 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mysql5-106 @@ -0,0 +1 @@ +INSERT INTO `models` (`raw`) VALUES (DEFAULT) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mysql8-106 b/internal/dbtest/testdata/snapshots/TestQuery-mysql8-106 new file mode 100644 index 00000000..bef7d892 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mysql8-106 @@ -0,0 +1 @@ +INSERT INTO `models` (`raw`) VALUES (DEFAULT) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-pg-106 b/internal/dbtest/testdata/snapshots/TestQuery-pg-106 new file mode 100644 index 00000000..ef36680d --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-pg-106 @@ -0,0 +1 @@ +INSERT INTO "models" ("raw") VALUES (DEFAULT) RETURNING "raw" diff --git a/internal/dbtest/testdata/snapshots/TestQuery-pgx-106 b/internal/dbtest/testdata/snapshots/TestQuery-pgx-106 new file mode 100644 index 00000000..ef36680d --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-pgx-106 @@ -0,0 +1 @@ +INSERT INTO "models" ("raw") VALUES (DEFAULT) RETURNING "raw" diff --git a/internal/dbtest/testdata/snapshots/TestQuery-sqlite-106 b/internal/dbtest/testdata/snapshots/TestQuery-sqlite-106 new file mode 100644 index 00000000..f2c1073d --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-sqlite-106 @@ -0,0 +1 @@ +INSERT INTO "models" ("raw") VALUES (NULL) RETURNING "raw" diff --git a/schema/append_value.go b/schema/append_value.go index 04d1636b..e6587cd6 100644 --- a/schema/append_value.go +++ b/schema/append_value.go @@ -128,6 +128,9 @@ func appender(dialect Dialect, typ reflect.Type) AppenderFunc { case reflect.Interface: return ifaceAppenderFunc case reflect.Ptr: + if typ.Implements(jsonMarshalerType) { + return AppendJSONValue + } if fn := Appender(dialect, typ.Elem()); fn != nil { return PtrAppender(fn) } diff --git a/schema/reflect.go b/schema/reflect.go index 5b20b196..ff14f3bf 100644 --- a/schema/reflect.go +++ b/schema/reflect.go @@ -17,6 +17,7 @@ var ( driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() queryAppenderType = reflect.TypeOf((*QueryAppender)(nil)).Elem() + jsonMarshalerType = reflect.TypeOf((*json.Marshaler)(nil)).Elem() ) func indirectType(t reflect.Type) reflect.Type { diff --git a/schema/scan.go b/schema/scan.go index 587c386b..541c658f 100644 --- a/schema/scan.go +++ b/schema/scan.go @@ -94,6 +94,8 @@ func scanner(typ reflect.Type) ScannerFunc { } switch typ { + case bytesType: + return scanBytes case timeType: return scanTime case ipType: