Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions generator/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ func (p *Processor) processPackage() (*Package, error) {
}

pkg.SetModels(models)
if err := pkg.addMissingRelationships(); err != nil {
return nil, err
}
for _, ctor := range ctors {
p.tryMatchConstructor(pkg, ctor)
}
Expand Down
11 changes: 11 additions & 0 deletions generator/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ func (td *TemplateData) genFieldsTimeTruncations(buf *bytes.Buffer, fields []*Fi
func (td *TemplateData) GenColumnAddresses(model *Model) string {
var buf bytes.Buffer
td.genFieldsColumnAddresses(&buf, model.Fields)
for _, fk := range model.ImplicitFKs {
buf.WriteString(fmt.Sprintf("case \"%s\":\n", fk.Name))
buf.WriteString(fmt.Sprintf("return types.Nullable(kallax.VirtualColumn(\"%s\", r, new(%s))), nil\n", fk.Name, fk.Type))
}
return buf.String()
}

Expand Down Expand Up @@ -129,6 +133,10 @@ func (td *TemplateData) IdentifierType(f *Field) string {
func (td *TemplateData) GenColumnValues(model *Model) string {
var buf bytes.Buffer
td.genFieldsValues(&buf, model.Fields)
for _, fk := range model.ImplicitFKs {
buf.WriteString(fmt.Sprintf("case \"%s\":\n", fk.Name))
buf.WriteString(fmt.Sprintf("return r.Model.VirtualColumn(col), nil\n"))
}
return buf.String()
}

Expand Down Expand Up @@ -159,6 +167,9 @@ func (td *TemplateData) genFieldsValues(buf *bytes.Buffer, fields []*Field) {
func (td *TemplateData) GenModelColumns(model *Model) string {
var buf bytes.Buffer
td.genFieldsColumns(&buf, model.Fields)
for _, fk := range model.ImplicitFKs {
buf.WriteString(fmt.Sprintf("kallax.NewSchemaField(\"%s\"),\n", fk.Name))
}
return buf.String()
}

Expand Down
62 changes: 62 additions & 0 deletions generator/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,59 @@ func (p *Package) FindModel(name string) *Model {
return p.indexedModels[name]
}

func (p *Package) addMissingRelationships() error {
for _, m := range p.Models {
for _, f := range m.Fields {
if f.Kind == Relationship && !f.IsInverse() {
if err := p.trySetFK(f.TypeSchemaName(), f); err != nil {
return err
}
}
}
}

return nil
}

func (p *Package) trySetFK(model string, fk *Field) error {
m := p.FindModel(model)
if m == nil {
return fmt.Errorf("kallax: cannot assign implicit foreign key to non-existent model %s", model)
}

var found bool
for _, f := range m.Fields {
if f.Kind == Relationship {
if f.ForeignKey() == fk.ForeignKey() {
found = true
break
}
} else {
if f.ColumnName() == fk.ForeignKey() {
found = true
break
}
}
}

if !found {
for _, ifk := range m.ImplicitFKs {
if ifk.Name == fk.ForeignKey() {
found = true
break
}
}
}

if !found {
m.ImplicitFKs = append(m.ImplicitFKs, ImplicitFK{
Name: fk.ForeignKey(),
Type: identifierType(fk.Model.ID),
})
}
return nil
}

const (
// StoreNamePattern is the pattern used to name stores.
StoreNamePattern = "%sStore"
Expand Down Expand Up @@ -182,6 +235,10 @@ type Model struct {
Type string
// Fields contains the list of fields in the model.
Fields []*Field
// ImplicitFKs contains the list of fks that are implicit based on
// other models' definitions, such as foreign keys with no explicit inverse
// on the related model.
ImplicitFKs []ImplicitFK
// ID contains the identifier field of the model.
ID *Field
// Events contains the list of events implemented by the model.
Expand Down Expand Up @@ -499,6 +556,11 @@ func relationshipsOnFields(fields []*Field) []*Field {
return result
}

type ImplicitFK struct {
Name string
Type string
}

// Field is the representation of a model field.
type Field struct {
// Name is the field name.
Expand Down
5 changes: 5 additions & 0 deletions tests/kallax.go
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,8 @@ func (r *Child) ColumnAddress(col string) (interface{}, error) {
return (*kallax.NumericID)(&r.ID), nil
case "name":
return &r.Name, nil
case "parent_id":
return types.Nullable(kallax.VirtualColumn("parent_id", r, new(kallax.NumericID))), nil

default:
return nil, fmt.Errorf("kallax: invalid column in Child: %s", col)
Expand All @@ -624,6 +626,8 @@ func (r *Child) Value(col string) (interface{}, error) {
return r.ID, nil
case "name":
return r.Name, nil
case "parent_id":
return r.Model.VirtualColumn(col), nil

default:
return nil, fmt.Errorf("kallax: invalid column in Child: %s", col)
Expand Down Expand Up @@ -10773,6 +10777,7 @@ var Schema = &schema{
true,
kallax.NewSchemaField("id"),
kallax.NewSchemaField("name"),
kallax.NewSchemaField("parent_id"),
),
ID: kallax.NewSchemaField("id"),
Name: kallax.NewSchemaField("name"),
Expand Down
16 changes: 10 additions & 6 deletions tests/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,12 @@ func (s *StoreSuite) TestInsert_RelWithNoInverse() {
s.NoError(store.Insert(p))
s.NotEqual(0, p.ID)

var count int
err := s.db.QueryRow("SELECT COUNT(*) FROM children WHERE parent_id = $1", p.ID).Scan(&count)
p, err := store.FindOne(NewParentQuery().WithChildren(nil))
s.NoError(err)
s.Equal(3, count)
s.Len(p.Children, 3)
for _, c := range p.Children {
s.NotEqual(int64(0), c.ID)
}
}

func (s *StoreSuite) TestInsert_RelWithNoInverseNoPtr() {
Expand All @@ -298,8 +300,10 @@ func (s *StoreSuite) TestInsert_RelWithNoInverseNoPtr() {
s.NoError(store.Insert(p))
s.NotEqual(0, p.ID)

var count int
err := s.db.QueryRow("SELECT COUNT(*) FROM children WHERE parent_id = $1", p.ID).Scan(&count)
p, err := store.FindOne(NewParentNoPtrQuery().WithChildren(nil))
s.NoError(err)
s.Equal(3, count)
s.Len(p.Children, 3)
for _, c := range p.Children {
s.NotEqual(int64(0), c.ID)
}
}