From 82ca87c7c49797d507b31fdaacf8343716d4feff Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Sat, 18 Sep 2021 10:49:52 +0300 Subject: [PATCH] fix: make allowzero work with auto-detected primary keys --- internal/dbtest/db_test.go | 32 +++++++++++-------- internal/dbtest/query_test.go | 6 ++++ .../testdata/snapshots/TestQuery-mysql5-85 | 1 + .../testdata/snapshots/TestQuery-mysql8-85 | 1 + .../dbtest/testdata/snapshots/TestQuery-pg-85 | 1 + .../testdata/snapshots/TestQuery-pgx-85 | 1 + .../testdata/snapshots/TestQuery-sqlite-85 | 1 + schema/field.go | 4 ++- schema/table.go | 18 +++++------ 9 files changed, 41 insertions(+), 24 deletions(-) create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-mysql5-85 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-mysql8-85 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-pg-85 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-pgx-85 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-sqlite-85 diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index e981b2f2..fc98aca5 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -853,27 +853,31 @@ func testScanBytes(t *testing.T, db *bun.DB) { func testPointers(t *testing.T, db *bun.DB) { type Model struct { - ID *int64 `bun:",allowzero,default:0"` + ID *int64 `bun:",default:0"` Str *string } ctx := context.Background() - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + for _, id := range []int64{-1, 0, 1} { + err := db.ResetModel(ctx, (*Model)(nil)) + require.NoError(t, err) - id := int64(1) - str := "hello" - models := []Model{ - {}, - {ID: &id, Str: &str}, - } - _, err = db.NewInsert().Model(&models).Exec(ctx) - require.NoError(t, err) + var model Model + if id >= 0 { + str := "hello" + model.ID = &id + model.Str = &str - var models2 []Model - err = db.NewSelect().Model(&models2).Order("id ASC").Scan(ctx) - require.NoError(t, err) + } + + _, err = db.NewInsert().Model(&model).Exec(ctx) + require.NoError(t, err) + + var model2 Model + err = db.NewSelect().Model(&model2).Order("id ASC").Scan(ctx) + require.NoError(t, err) + } } func testExists(t *testing.T, db *bun.DB) { diff --git a/internal/dbtest/query_test.go b/internal/dbtest/query_test.go index fa4e8da6..d57d2e94 100644 --- a/internal/dbtest/query_test.go +++ b/internal/dbtest/query_test.go @@ -545,6 +545,12 @@ func TestQuery(t *testing.T) { } return db.NewInsert().Model(new(Model)) }, + func(db *bun.DB) schema.QueryAppender { + type Model struct { + ID int `bun:",allowzero"` + } + return db.NewInsert().Model(new(Model)) + }, } timeRE := regexp.MustCompile(`'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d+(\+\d{2}:\d{2})?'`) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mysql5-85 b/internal/dbtest/testdata/snapshots/TestQuery-mysql5-85 new file mode 100644 index 00000000..977e19ae --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mysql5-85 @@ -0,0 +1 @@ +INSERT INTO `models` (`id`) VALUES (0) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mysql8-85 b/internal/dbtest/testdata/snapshots/TestQuery-mysql8-85 new file mode 100644 index 00000000..977e19ae --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mysql8-85 @@ -0,0 +1 @@ +INSERT INTO `models` (`id`) VALUES (0) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-pg-85 b/internal/dbtest/testdata/snapshots/TestQuery-pg-85 new file mode 100644 index 00000000..23686901 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-pg-85 @@ -0,0 +1 @@ +INSERT INTO "models" ("id") VALUES (0) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-pgx-85 b/internal/dbtest/testdata/snapshots/TestQuery-pgx-85 new file mode 100644 index 00000000..23686901 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-pgx-85 @@ -0,0 +1 @@ +INSERT INTO "models" ("id") VALUES (0) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-sqlite-85 b/internal/dbtest/testdata/snapshots/TestQuery-sqlite-85 new file mode 100644 index 00000000..23686901 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-sqlite-85 @@ -0,0 +1 @@ +INSERT INTO "models" ("id") VALUES (0) diff --git a/schema/field.go b/schema/field.go index 1e069b82..59990b92 100644 --- a/schema/field.go +++ b/schema/field.go @@ -101,7 +101,9 @@ func (f *Field) ScanValue(strct reflect.Value, src interface{}) error { func (f *Field) markAsPK() { f.IsPK = true f.NotNull = true - f.NullZero = true + if !f.Tag.HasOption("allowzero") { + f.NullZero = true + } } func indexEqual(ind1, ind2 []int) bool { diff --git a/schema/table.go b/schema/table.go index 8bed5ed3..213f821a 100644 --- a/schema/table.go +++ b/schema/table.go @@ -181,17 +181,17 @@ func (t *Table) initFields() { t.FieldMap = make(map[string]*Field, t.Type.NumField()) t.addFields(t.Type, nil) - if len(t.PKs) > 0 { - return - } - for _, name := range []string{"id", "uuid", "pk_" + t.ModelName} { - if field, ok := t.FieldMap[name]; ok { - field.markAsPK() - t.PKs = []*Field{field} - t.DataFields = removeField(t.DataFields, field) - break + if len(t.PKs) == 0 { + for _, name := range []string{"id", "uuid", "pk_" + t.ModelName} { + if field, ok := t.FieldMap[name]; ok { + field.markAsPK() + t.PKs = []*Field{field} + t.DataFields = removeField(t.DataFields, field) + break + } } } + if len(t.PKs) == 1 { pk := t.PKs[0] if pk.SQLDefault != "" {