From 25d907fac9d777897d7de242a1041665c5f07672 Mon Sep 17 00:00:00 2001 From: Miguel Molina Date: Tue, 13 Jun 2017 16:40:28 +0200 Subject: [PATCH] add implicit inverse foreign keys when necessary --- generator/processor.go | 3 ++ generator/template.go | 11 ++++++++ generator/types.go | 62 ++++++++++++++++++++++++++++++++++++++++++ tests/kallax.go | 5 ++++ tests/store_test.go | 16 +++++++---- 5 files changed, 91 insertions(+), 6 deletions(-) diff --git a/generator/processor.go b/generator/processor.go index 4d078ff..d3f3b67 100644 --- a/generator/processor.go +++ b/generator/processor.go @@ -158,6 +158,9 @@ func (p *Processor) processPackage() (*Package, error) { } pkg.SetModels(models) + if err := pkg.addMissingRelationships(); err != nil { + return nil, err + } for _, ctor := range ctors { p.tryMatchConstructor(pkg, ctor) } diff --git a/generator/template.go b/generator/template.go index 0f49162..bc4f3a1 100644 --- a/generator/template.go +++ b/generator/template.go @@ -80,6 +80,10 @@ func (td *TemplateData) genFieldsTimeTruncations(buf *bytes.Buffer, fields []*Fi func (td *TemplateData) GenColumnAddresses(model *Model) string { var buf bytes.Buffer td.genFieldsColumnAddresses(&buf, model.Fields) + for _, fk := range model.ImplicitFKs { + buf.WriteString(fmt.Sprintf("case \"%s\":\n", fk.Name)) + buf.WriteString(fmt.Sprintf("return types.Nullable(kallax.VirtualColumn(\"%s\", r, new(%s))), nil\n", fk.Name, fk.Type)) + } return buf.String() } @@ -129,6 +133,10 @@ func (td *TemplateData) IdentifierType(f *Field) string { func (td *TemplateData) GenColumnValues(model *Model) string { var buf bytes.Buffer td.genFieldsValues(&buf, model.Fields) + for _, fk := range model.ImplicitFKs { + buf.WriteString(fmt.Sprintf("case \"%s\":\n", fk.Name)) + buf.WriteString(fmt.Sprintf("return r.Model.VirtualColumn(col), nil\n")) + } return buf.String() } @@ -159,6 +167,9 @@ func (td *TemplateData) genFieldsValues(buf *bytes.Buffer, fields []*Field) { func (td *TemplateData) GenModelColumns(model *Model) string { var buf bytes.Buffer td.genFieldsColumns(&buf, model.Fields) + for _, fk := range model.ImplicitFKs { + buf.WriteString(fmt.Sprintf("kallax.NewSchemaField(\"%s\"),\n", fk.Name)) + } return buf.String() } diff --git a/generator/types.go b/generator/types.go index fdec423..2270785 100644 --- a/generator/types.go +++ b/generator/types.go @@ -153,6 +153,59 @@ func (p *Package) FindModel(name string) *Model { return p.indexedModels[name] } +func (p *Package) addMissingRelationships() error { + for _, m := range p.Models { + for _, f := range m.Fields { + if f.Kind == Relationship && !f.IsInverse() { + if err := p.trySetFK(f.TypeSchemaName(), f); err != nil { + return err + } + } + } + } + + return nil +} + +func (p *Package) trySetFK(model string, fk *Field) error { + m := p.FindModel(model) + if m == nil { + return fmt.Errorf("kallax: cannot assign implicit foreign key to non-existent model %s", model) + } + + var found bool + for _, f := range m.Fields { + if f.Kind == Relationship { + if f.ForeignKey() == fk.ForeignKey() { + found = true + break + } + } else { + if f.ColumnName() == fk.ForeignKey() { + found = true + break + } + } + } + + if !found { + for _, ifk := range m.ImplicitFKs { + if ifk.Name == fk.ForeignKey() { + found = true + break + } + } + } + + if !found { + m.ImplicitFKs = append(m.ImplicitFKs, ImplicitFK{ + Name: fk.ForeignKey(), + Type: identifierType(fk.Model.ID), + }) + } + return nil +} + const ( // StoreNamePattern is the pattern used to name stores. StoreNamePattern = "%sStore" @@ -182,6 +235,10 @@ type Model struct { Type string // Fields contains the list of fields in the model. Fields []*Field + // ImplicitFKs contains the list of fks that are implicit based on + // other models' definitions, such as foreign keys with no explicit inverse + // on the related model. + ImplicitFKs []ImplicitFK // ID contains the identifier field of the model. ID *Field // Events contains the list of events implemented by the model. @@ -499,6 +556,11 @@ func relationshipsOnFields(fields []*Field) []*Field { return result } +type ImplicitFK struct { + Name string + Type string +} + // Field is the representation of a model field. type Field struct { // Name is the field name. diff --git a/tests/kallax.go b/tests/kallax.go index 930a4d7..3982146 100644 --- a/tests/kallax.go +++ b/tests/kallax.go @@ -611,6 +611,8 @@ func (r *Child) ColumnAddress(col string) (interface{}, error) { return (*kallax.NumericID)(&r.ID), nil case "name": return &r.Name, nil + case "parent_id": + return types.Nullable(kallax.VirtualColumn("parent_id", r, new(kallax.NumericID))), nil default: return nil, fmt.Errorf("kallax: invalid column in Child: %s", col) @@ -624,6 +626,8 @@ func (r *Child) Value(col string) (interface{}, error) { return r.ID, nil case "name": return r.Name, nil + case "parent_id": + return r.Model.VirtualColumn(col), nil default: return nil, fmt.Errorf("kallax: invalid column in Child: %s", col) @@ -10773,6 +10777,7 @@ var Schema = &schema{ true, kallax.NewSchemaField("id"), kallax.NewSchemaField("name"), + kallax.NewSchemaField("parent_id"), ), ID: kallax.NewSchemaField("id"), Name: kallax.NewSchemaField("name"), diff --git a/tests/store_test.go b/tests/store_test.go index f28904f..9aa09a4 100644 --- a/tests/store_test.go +++ b/tests/store_test.go @@ -278,10 +278,12 @@ func (s *StoreSuite) TestInsert_RelWithNoInverse() { s.NoError(store.Insert(p)) s.NotEqual(0, p.ID) - var count int - err := s.db.QueryRow("SELECT COUNT(*) FROM children WHERE parent_id = $1", p.ID).Scan(&count) + p, err := store.FindOne(NewParentQuery().WithChildren(nil)) s.NoError(err) - s.Equal(3, count) + s.Len(p.Children, 3) + for _, c := range p.Children { + s.NotEqual(int64(0), c.ID) + } } func (s *StoreSuite) TestInsert_RelWithNoInverseNoPtr() { @@ -298,8 +300,10 @@ func (s *StoreSuite) TestInsert_RelWithNoInverseNoPtr() { s.NoError(store.Insert(p)) s.NotEqual(0, p.ID) - var count int - err := s.db.QueryRow("SELECT COUNT(*) FROM children WHERE parent_id = $1", p.ID).Scan(&count) + p, err := store.FindOne(NewParentNoPtrQuery().WithChildren(nil)) s.NoError(err) - s.Equal(3, count) + s.Len(p.Children, 3) + for _, c := range p.Children { + s.NotEqual(int64(0), c.ID) + } }