Skip to content

Commit

Permalink
feat: add model embedding via embed:prefix_
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Dec 19, 2021
1 parent 0b4c3fc commit 9a2cedc
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 20 deletions.
58 changes: 58 additions & 0 deletions internal/dbtest/db_test.go
Expand Up @@ -236,6 +236,8 @@ func TestDB(t *testing.T) {
{testUpsert},
{testMultiUpdate},
{testTxScanAndCount},
{testEmbedModelValue},
{testEmbedModelPointer},
}

testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
Expand Down Expand Up @@ -1029,3 +1031,59 @@ func testTxScanAndCount(t *testing.T, db *bun.DB) {
require.NoError(t, err)
}
}

func testEmbedModelValue(t *testing.T, db *bun.DB) {
type Embed struct {
Foo string
Bar string
}
type Model struct {
X Embed `bun:"embed:x_"`
Y Embed `bun:"embed:y_"`
}

ctx := context.Background()

err := db.ResetModel(ctx, (*Model)(nil))
require.NoError(t, err)

m1 := &Model{
X: Embed{Foo: "x.foo", Bar: "x.bar"},
Y: Embed{Foo: "y.foo", Bar: "y.bar"},
}
_, err = db.NewInsert().Model(m1).Exec(ctx)
require.NoError(t, err)

var m2 Model
err = db.NewSelect().Model(&m2).Scan(ctx)
require.NoError(t, err)
require.Equal(t, *m1, m2)
}

func testEmbedModelPointer(t *testing.T, db *bun.DB) {
type Embed struct {
Foo string
Bar string
}
type Model struct {
X *Embed `bun:"embed:x_"`
Y *Embed `bun:"embed:y_"`
}

ctx := context.Background()

err := db.ResetModel(ctx, (*Model)(nil))
require.NoError(t, err)

m1 := &Model{
X: &Embed{Foo: "x.foo", Bar: "x.bar"},
Y: &Embed{Foo: "y.foo", Bar: "y.bar"},
}
_, err = db.NewInsert().Model(m1).Exec(ctx)
require.NoError(t, err)

var m2 Model
err = db.NewSelect().Model(&m2).Scan(ctx)
require.NoError(t, err)
require.Equal(t, *m1, m2)
}
46 changes: 26 additions & 20 deletions schema/table.go
Expand Up @@ -203,7 +203,7 @@ func (t *Table) fieldByGoName(name string) *Field {
func (t *Table) initFields() {
t.Fields = make([]*Field, 0, t.Type.NumField())
t.FieldMap = make(map[string]*Field, t.Type.NumField())
t.addFields(t.Type, nil)
t.addFields(t.Type, "", nil)

if len(t.PKs) == 0 {
for _, name := range []string{"id", "uuid", "pk_" + t.ModelName} {
Expand All @@ -230,7 +230,7 @@ func (t *Table) initFields() {
}
}

func (t *Table) addFields(typ reflect.Type, baseIndex []int) {
func (t *Table) addFields(typ reflect.Type, prefix string, index []int) {
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
unexported := f.PkgPath != ""
Expand All @@ -242,10 +242,6 @@ func (t *Table) addFields(typ reflect.Type, baseIndex []int) {
continue
}

// Make a copy so the slice is not shared between fields.
index := make([]int, len(baseIndex))
copy(index, baseIndex)

if f.Anonymous {
if f.Name == "BaseModel" && f.Type == baseModelType {
if len(index) == 0 {
Expand All @@ -258,7 +254,7 @@ func (t *Table) addFields(typ reflect.Type, baseIndex []int) {
if fieldType.Kind() != reflect.Struct {
continue
}
t.addFields(fieldType, append(index, f.Index...))
t.addFields(fieldType, "", withIndex(index, f.Index))

tag := tagparser.Parse(f.Tag.Get("bun"))
if _, inherit := tag.Options["inherit"]; inherit {
Expand All @@ -274,7 +270,7 @@ func (t *Table) addFields(typ reflect.Type, baseIndex []int) {
continue
}

if field := t.newField(f, index); field != nil {
if field := t.newField(f, prefix, index); field != nil {
t.addField(field)
}
}
Expand Down Expand Up @@ -315,10 +311,20 @@ func (t *Table) processBaseModelField(f reflect.StructField) {
}

//nolint
func (t *Table) newField(f reflect.StructField, index []int) *Field {
sqlName := internal.Underscore(f.Name)
func (t *Table) newField(f reflect.StructField, prefix string, index []int) *Field {
tag := tagparser.Parse(f.Tag.Get("bun"))

if prefix, ok := tag.Option("embed"); ok {
fieldType := indirectType(f.Type)
if fieldType.Kind() != reflect.Struct {
panic(fmt.Errorf("bun: embed %s.%s: got %s, wanted reflect.Struct",
t.TypeName, f.Name, fieldType.Kind()))
}
t.addFields(fieldType, prefix, withIndex(index, f.Index))
return nil
}

sqlName := internal.Underscore(f.Name)
if tag.Name != "" && tag.Name != sqlName {
if isKnownFieldOption(tag.Name) {
internal.Warn.Printf(
Expand All @@ -328,18 +334,18 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field {
}
sqlName = tag.Name
}

if s, ok := tag.Option("column"); ok {
sqlName = s
}
sqlName = prefix + sqlName

for name := range tag.Options {
if !isKnownFieldOption(name) {
internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name)
}
}

index = append(index, f.Index...)
index = withIndex(index, f.Index)
if field := t.fieldWithLock(sqlName); field != nil {
if indexEqual(field.Index, index) {
return field
Expand Down Expand Up @@ -795,7 +801,7 @@ func (t *Table) inlineFields(field *Field, seen map[reflect.Type]struct{}) {
f.GoName = field.GoName + "_" + f.GoName
f.Name = field.Name + "__" + f.Name
f.SQLName = t.quoteIdent(f.Name)
f.Index = appendNew(field.Index, f.Index...)
f.Index = withIndex(field.Index, f.Index)

t.fieldsMapMu.Lock()
if _, ok := t.FieldMap[f.Name]; !ok {
Expand Down Expand Up @@ -853,13 +859,6 @@ func (t *Table) quoteIdent(s string) Safe {
return Safe(NewFormatter(t.dialect).AppendIdent(nil, s))
}

func appendNew(dst []int, src ...int) []int {
cp := make([]int, len(dst)+len(src))
copy(cp, dst)
copy(cp[len(dst):], src)
return cp
}

func isKnownTableOption(name string) bool {
switch name {
case "table", "alias", "select":
Expand Down Expand Up @@ -991,3 +990,10 @@ func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value, tm time
return field.ScanWithCheck(fv, tm)
}
}

func withIndex(a, b []int) []int {
dest := make([]int, 0, len(a)+len(b))
dest = append(dest, a...)
dest = append(dest, b...)
return dest
}

0 comments on commit 9a2cedc

Please sign in to comment.