diff --git a/generator/templates/model.tgo b/generator/templates/model.tgo index 507cd0c..0cbc838 100644 --- a/generator/templates/model.tgo +++ b/generator/templates/model.tgo @@ -96,26 +96,17 @@ func New{{.StoreName}}(db *sql.DB) *{{.StoreName}} { return &{{.StoreName}}{kallax.NewStore(db)} } -{{if .HasRelationships}} +{{if .HasNonInverses}} func (s *{{.StoreName}}) relationshipRecords(record *{{.Name}}) []kallax.RecordWithSchema { - record.ClearVirtualColumns() var records []kallax.RecordWithSchema - {{range .Relationships}} - {{- if .IsInverse -}} - if {{if .IsPtr}}record.{{.Name}} != nil{{else}}!record.{{.Name}}.GetID().IsEmpty(){{end}} { - record.AddVirtualColumn("{{.ForeignKey}}", record.{{.Name}}.GetID()) - records = append(records, kallax.RecordWithSchema{ - Schema.{{.TypeSchemaName}}.BaseSchema, - {{if not .IsPtr}}&{{end}}record.{{.Name}}, - }) - } - {{else if .IsOneToManyRelationship}} + {{range .NonInverses}} + {{if .IsOneToManyRelationship}} for _, rec := range record.{{.Name}} { rec.ClearVirtualColumns() rec.AddVirtualColumn("{{.ForeignKey}}", record.GetID()) records = append(records, kallax.RecordWithSchema{ - Schema.{{.TypeSchemaName}}.BaseSchema, - {{if not ($.IsPtrSlice .)}}&{{end}}rec, + Schema: Schema.{{.TypeSchemaName}}.BaseSchema, + Record: {{if not ($.IsPtrSlice .)}}&{{end}}rec, }) } {{else}} @@ -123,8 +114,8 @@ func (s *{{.StoreName}}) relationshipRecords(record *{{.Name}}) []kallax.RecordW record.{{.Name}}.ClearVirtualColumns() record.{{.Name}}.AddVirtualColumn("{{.ForeignKey}}", record.GetID()) records = append(records, kallax.RecordWithSchema{ - Schema.{{.TypeSchemaName}}.BaseSchema, - {{if not .IsPtr}}&{{end}}record.{{.Name}}, + Schema: Schema.{{.TypeSchemaName}}.BaseSchema, + Record: {{if not .IsPtr}}&{{end}}record.{{.Name}}, }) } {{end}} @@ -133,6 +124,23 @@ func (s *{{.StoreName}}) relationshipRecords(record *{{.Name}}) []kallax.RecordW } {{end}} +{{if .HasInverses}} +func (s *{{.StoreName}}) inverseRecords(record *{{.Name}}) []kallax.RecordWithSchema { + record.ClearVirtualColumns() + var records []kallax.RecordWithSchema + {{range .Inverses}} + if {{if .IsPtr}}record.{{.Name}} != nil{{else}}!record.{{.Name}}.GetID().IsEmpty(){{end}} { + record.AddVirtualColumn("{{.ForeignKey}}", record.{{.Name}}.GetID()) + records = append(records, kallax.RecordWithSchema{ + Schema: Schema.{{.TypeSchemaName}}.BaseSchema, + Record: {{if not .IsPtr}}&{{end}}record.{{.Name}}, + }) + } + {{end}} + return records +} +{{end}} + // Insert inserts a {{.Name}} in the database. A non-persisted object is // required for this operation. func (s *{{.StoreName}}) Insert(record *{{.Name}}) error { @@ -146,13 +154,34 @@ func (s *{{.StoreName}}) Insert(record *{{.Name}}) error { } {{end}} {{if .HasRelationships}} + {{if .HasNonInverses}} records := s.relationshipRecords(record) - if len(records) > 0 { + {{end}} + {{if .HasInverses}} + inverseRecords := s.inverseRecords(record) + {{end}} + if {{if .HasNonInverses}}len(records) > 0{{end}} {{if and (.HasNonInverses) (.HasInverses)}}&&{{end}} {{if .HasInverses}}len(inverseRecords) > 0{{end}} { return s.Store.Transaction(func(s *kallax.Store) error { + {{if .HasInverses}} + for _, r := range inverseRecords { + if err := kallax.ApplyBeforeEvents(r.Record); err != nil { + return err + } + persisted := r.Record.IsPersisted() + + if _, err := s.Save(r.Schema, r.Record); err != nil { + return err + } + + if err := kallax.ApplyAfterEvents(r.Record, persisted); err != nil { + return err + } + } + {{end}} if err := s.Insert(Schema.{{.Name}}.BaseSchema, record); err != nil { return err } - + {{if .HasNonInverses}} for _, r := range records { if err := kallax.ApplyBeforeEvents(r.Record); err != nil { return err @@ -167,6 +196,7 @@ func (s *{{.StoreName}}) Insert(record *{{.Name}}) error { return err } } + {{end}} {{if .Events.Has "AfterInsert"}} if err := record.AfterInsert(); err != nil { @@ -224,14 +254,37 @@ func (s *{{.StoreName}}) Update(record *{{.Name}}, cols ...kallax.SchemaField) ( } {{end}} {{if .HasRelationships}} + {{if .HasNonInverses}} records := s.relationshipRecords(record) - if len(records) > 0 { + {{end}} + {{if .HasInverses}} + inverseRecords := s.inverseRecords(record) + {{end}} + if {{if .HasNonInverses}}len(records) > 0{{end}} {{if and (.HasNonInverses) (.HasInverses)}}&&{{end}} {{if .HasInverses}}len(inverseRecords) > 0{{end}} { err = s.Store.Transaction(func(s *kallax.Store) error { + {{if .HasInverses}} + for _, r := range inverseRecords { + if err := kallax.ApplyBeforeEvents(r.Record); err != nil { + return err + } + persisted := r.Record.IsPersisted() + + if _, err := s.Save(r.Schema, r.Record); err != nil { + return err + } + + if err := kallax.ApplyAfterEvents(r.Record, persisted); err != nil { + return err + } + } + {{end}} + updated, err = s.Update(Schema.{{.Name}}.BaseSchema, record, cols...) if err != nil { return err } + {{if .HasNonInverses}} for _, r := range records { if err := kallax.ApplyBeforeEvents(r.Record); err != nil { return err @@ -246,6 +299,7 @@ func (s *{{.StoreName}}) Update(record *{{.Name}}, cols ...kallax.SchemaField) ( return err } } + {{end}} {{if .Events.Has "AfterUpdate"}} if err := record.AfterUpdate(); err != nil { diff --git a/generator/types.go b/generator/types.go index aeee9f4..395b826 100644 --- a/generator/types.go +++ b/generator/types.go @@ -440,11 +440,43 @@ func (m *Model) Relationships() []*Field { return relationshipsOnFields(m.Fields) } +// Inverses returns the inverse relationships of the model. +func (m *Model) Inverses() []*Field { + var inverses []*Field + for _, f := range relationshipsOnFields(m.Fields) { + if f.IsInverse() { + inverses = append(inverses, f) + } + } + return inverses +} + +// NonInverses returns the relationships of the model that are not inverses. +func (m *Model) NonInverses() []*Field { + var rels []*Field + for _, f := range relationshipsOnFields(m.Fields) { + if !f.IsInverse() { + rels = append(rels, f) + } + } + return rels +} + // HasRelationships returns whether the model has relationships or not. func (m *Model) HasRelationships() bool { return len(m.Relationships()) > 0 } +// HasInverses returns whether the model has inverse relationships or not. +func (m *Model) HasInverses() bool { + return len(m.Inverses()) > 0 +} + +// HasNonInverses returns whether the model has non inverse relationships or not. +func (m *Model) HasNonInverses() bool { + return len(m.NonInverses()) > 0 +} + func relationshipsOnFields(fields []*Field) []*Field { var result []*Field for _, f := range fields { diff --git a/tests/kallax.go b/tests/kallax.go index 42e8b09..8e1c367 100644 --- a/tests/kallax.go +++ b/tests/kallax.go @@ -99,14 +99,15 @@ func NewCarStore(db *sql.DB) *CarStore { return &CarStore{kallax.NewStore(db)} } -func (s *CarStore) relationshipRecords(record *Car) []kallax.RecordWithSchema { +func (s *CarStore) inverseRecords(record *Car) []kallax.RecordWithSchema { record.ClearVirtualColumns() var records []kallax.RecordWithSchema + if record.Owner != nil { record.AddVirtualColumn("owner_id", record.Owner.GetID()) records = append(records, kallax.RecordWithSchema{ - Schema.Person.BaseSchema, - record.Owner, + Schema: Schema.Person.BaseSchema, + Record: record.Owner, }) } @@ -121,14 +122,12 @@ func (s *CarStore) Insert(record *Car) error { return err } - records := s.relationshipRecords(record) - if len(records) > 0 { + inverseRecords := s.inverseRecords(record) + + if len(inverseRecords) > 0 { return s.Store.Transaction(func(s *kallax.Store) error { - if err := s.Insert(Schema.Car.BaseSchema, record); err != nil { - return err - } - for _, r := range records { + for _, r := range inverseRecords { if err := kallax.ApplyBeforeEvents(r.Record); err != nil { return err } @@ -143,6 +142,10 @@ func (s *CarStore) Insert(record *Car) error { } } + if err := s.Insert(Schema.Car.BaseSchema, record); err != nil { + return err + } + if err := record.AfterSave(); err != nil { return err } @@ -177,15 +180,12 @@ func (s *CarStore) Update(record *Car, cols ...kallax.SchemaField) (updated int6 return 0, err } - records := s.relationshipRecords(record) - if len(records) > 0 { + inverseRecords := s.inverseRecords(record) + + if len(inverseRecords) > 0 { err = s.Store.Transaction(func(s *kallax.Store) error { - updated, err = s.Update(Schema.Car.BaseSchema, record, cols...) - if err != nil { - return err - } - for _, r := range records { + for _, r := range inverseRecords { if err := kallax.ApplyBeforeEvents(r.Record); err != nil { return err } @@ -200,6 +200,11 @@ func (s *CarStore) Update(record *Car, cols ...kallax.SchemaField) (updated int6 } } + updated, err = s.Update(Schema.Car.BaseSchema, record, cols...) + if err != nil { + return err + } + if err := record.AfterSave(); err != nil { return err } @@ -3116,15 +3121,14 @@ func NewPersonStore(db *sql.DB) *PersonStore { } func (s *PersonStore) relationshipRecords(record *Person) []kallax.RecordWithSchema { - record.ClearVirtualColumns() var records []kallax.RecordWithSchema for _, rec := range record.Pets { rec.ClearVirtualColumns() rec.AddVirtualColumn("owner_id", record.GetID()) records = append(records, kallax.RecordWithSchema{ - Schema.Pet.BaseSchema, - rec, + Schema: Schema.Pet.BaseSchema, + Record: rec, }) } @@ -3132,8 +3136,8 @@ func (s *PersonStore) relationshipRecords(record *Person) []kallax.RecordWithSch record.Car.ClearVirtualColumns() record.Car.AddVirtualColumn("owner_id", record.GetID()) records = append(records, kallax.RecordWithSchema{ - Schema.Car.BaseSchema, - record.Car, + Schema: Schema.Car.BaseSchema, + Record: record.Car, }) } @@ -3149,8 +3153,10 @@ func (s *PersonStore) Insert(record *Person) error { } records := s.relationshipRecords(record) + if len(records) > 0 { return s.Store.Transaction(func(s *kallax.Store) error { + if err := s.Insert(Schema.Person.BaseSchema, record); err != nil { return err } @@ -3205,8 +3211,10 @@ func (s *PersonStore) Update(record *Person, cols ...kallax.SchemaField) (update } records := s.relationshipRecords(record) + if len(records) > 0 { err = s.Store.Transaction(func(s *kallax.Store) error { + updated, err = s.Update(Schema.Person.BaseSchema, record, cols...) if err != nil { return err @@ -3787,14 +3795,15 @@ func NewPetStore(db *sql.DB) *PetStore { return &PetStore{kallax.NewStore(db)} } -func (s *PetStore) relationshipRecords(record *Pet) []kallax.RecordWithSchema { +func (s *PetStore) inverseRecords(record *Pet) []kallax.RecordWithSchema { record.ClearVirtualColumns() var records []kallax.RecordWithSchema + if record.Owner != nil { record.AddVirtualColumn("owner_id", record.Owner.GetID()) records = append(records, kallax.RecordWithSchema{ - Schema.Person.BaseSchema, - record.Owner, + Schema: Schema.Person.BaseSchema, + Record: record.Owner, }) } @@ -3809,14 +3818,12 @@ func (s *PetStore) Insert(record *Pet) error { return err } - records := s.relationshipRecords(record) - if len(records) > 0 { + inverseRecords := s.inverseRecords(record) + + if len(inverseRecords) > 0 { return s.Store.Transaction(func(s *kallax.Store) error { - if err := s.Insert(Schema.Pet.BaseSchema, record); err != nil { - return err - } - for _, r := range records { + for _, r := range inverseRecords { if err := kallax.ApplyBeforeEvents(r.Record); err != nil { return err } @@ -3831,6 +3838,10 @@ func (s *PetStore) Insert(record *Pet) error { } } + if err := s.Insert(Schema.Pet.BaseSchema, record); err != nil { + return err + } + if err := record.AfterSave(); err != nil { return err } @@ -3865,15 +3876,12 @@ func (s *PetStore) Update(record *Pet, cols ...kallax.SchemaField) (updated int6 return 0, err } - records := s.relationshipRecords(record) - if len(records) > 0 { + inverseRecords := s.inverseRecords(record) + + if len(inverseRecords) > 0 { err = s.Store.Transaction(func(s *kallax.Store) error { - updated, err = s.Update(Schema.Pet.BaseSchema, record, cols...) - if err != nil { - return err - } - for _, r := range records { + for _, r := range inverseRecords { if err := kallax.ApplyBeforeEvents(r.Record); err != nil { return err } @@ -3888,6 +3896,11 @@ func (s *PetStore) Update(record *Pet, cols ...kallax.SchemaField) (updated int6 } } + updated, err = s.Update(Schema.Pet.BaseSchema, record, cols...) + if err != nil { + return err + } + if err := record.AfterSave(); err != nil { return err } @@ -4447,15 +4460,14 @@ func NewQueryFixtureStore(db *sql.DB) *QueryFixtureStore { } func (s *QueryFixtureStore) relationshipRecords(record *QueryFixture) []kallax.RecordWithSchema { - record.ClearVirtualColumns() var records []kallax.RecordWithSchema if record.Relation != nil { record.Relation.ClearVirtualColumns() record.Relation.AddVirtualColumn("owner_id", record.GetID()) records = append(records, kallax.RecordWithSchema{ - Schema.QueryRelationFixture.BaseSchema, - record.Relation, + Schema: Schema.QueryRelationFixture.BaseSchema, + Record: record.Relation, }) } @@ -4463,8 +4475,8 @@ func (s *QueryFixtureStore) relationshipRecords(record *QueryFixture) []kallax.R rec.ClearVirtualColumns() rec.AddVirtualColumn("owner_id", record.GetID()) records = append(records, kallax.RecordWithSchema{ - Schema.QueryRelationFixture.BaseSchema, - rec, + Schema: Schema.QueryRelationFixture.BaseSchema, + Record: rec, }) } @@ -4476,8 +4488,10 @@ func (s *QueryFixtureStore) relationshipRecords(record *QueryFixture) []kallax.R func (s *QueryFixtureStore) Insert(record *QueryFixture) error { records := s.relationshipRecords(record) + if len(records) > 0 { return s.Store.Transaction(func(s *kallax.Store) error { + if err := s.Insert(Schema.QueryFixture.BaseSchema, record); err != nil { return err } @@ -4514,8 +4528,10 @@ func (s *QueryFixtureStore) Insert(record *QueryFixture) error { func (s *QueryFixtureStore) Update(record *QueryFixture, cols ...kallax.SchemaField) (updated int64, err error) { records := s.relationshipRecords(record) + if len(records) > 0 { err = s.Store.Transaction(func(s *kallax.Store) error { + updated, err = s.Update(Schema.QueryFixture.BaseSchema, record, cols...) if err != nil { return err @@ -5230,14 +5246,15 @@ func NewQueryRelationFixtureStore(db *sql.DB) *QueryRelationFixtureStore { return &QueryRelationFixtureStore{kallax.NewStore(db)} } -func (s *QueryRelationFixtureStore) relationshipRecords(record *QueryRelationFixture) []kallax.RecordWithSchema { +func (s *QueryRelationFixtureStore) inverseRecords(record *QueryRelationFixture) []kallax.RecordWithSchema { record.ClearVirtualColumns() var records []kallax.RecordWithSchema + if record.Owner != nil { record.AddVirtualColumn("owner_id", record.Owner.GetID()) records = append(records, kallax.RecordWithSchema{ - Schema.QueryFixture.BaseSchema, - record.Owner, + Schema: Schema.QueryFixture.BaseSchema, + Record: record.Owner, }) } @@ -5248,14 +5265,12 @@ func (s *QueryRelationFixtureStore) relationshipRecords(record *QueryRelationFix // required for this operation. func (s *QueryRelationFixtureStore) Insert(record *QueryRelationFixture) error { - records := s.relationshipRecords(record) - if len(records) > 0 { + inverseRecords := s.inverseRecords(record) + + if len(inverseRecords) > 0 { return s.Store.Transaction(func(s *kallax.Store) error { - if err := s.Insert(Schema.QueryRelationFixture.BaseSchema, record); err != nil { - return err - } - for _, r := range records { + for _, r := range inverseRecords { if err := kallax.ApplyBeforeEvents(r.Record); err != nil { return err } @@ -5270,6 +5285,10 @@ func (s *QueryRelationFixtureStore) Insert(record *QueryRelationFixture) error { } } + if err := s.Insert(Schema.QueryRelationFixture.BaseSchema, record); err != nil { + return err + } + return nil }) } @@ -5286,15 +5305,12 @@ func (s *QueryRelationFixtureStore) Insert(record *QueryRelationFixture) error { // been just inserted or retrieved using a query with no custom select fields. func (s *QueryRelationFixtureStore) Update(record *QueryRelationFixture, cols ...kallax.SchemaField) (updated int64, err error) { - records := s.relationshipRecords(record) - if len(records) > 0 { + inverseRecords := s.inverseRecords(record) + + if len(inverseRecords) > 0 { err = s.Store.Transaction(func(s *kallax.Store) error { - updated, err = s.Update(Schema.QueryRelationFixture.BaseSchema, record, cols...) - if err != nil { - return err - } - for _, r := range records { + for _, r := range inverseRecords { if err := kallax.ApplyBeforeEvents(r.Record); err != nil { return err } @@ -5309,6 +5325,11 @@ func (s *QueryRelationFixtureStore) Update(record *QueryRelationFixture, cols .. } } + updated, err = s.Update(Schema.QueryRelationFixture.BaseSchema, record, cols...) + if err != nil { + return err + } + return nil }) if err != nil { @@ -6106,15 +6127,14 @@ func NewSchemaFixtureStore(db *sql.DB) *SchemaFixtureStore { } func (s *SchemaFixtureStore) relationshipRecords(record *SchemaFixture) []kallax.RecordWithSchema { - record.ClearVirtualColumns() var records []kallax.RecordWithSchema if record.Nested != nil { record.Nested.ClearVirtualColumns() record.Nested.AddVirtualColumn("schema_fixture_id", record.GetID()) records = append(records, kallax.RecordWithSchema{ - Schema.SchemaFixture.BaseSchema, - record.Nested, + Schema: Schema.SchemaFixture.BaseSchema, + Record: record.Nested, }) } @@ -6126,8 +6146,10 @@ func (s *SchemaFixtureStore) relationshipRecords(record *SchemaFixture) []kallax func (s *SchemaFixtureStore) Insert(record *SchemaFixture) error { records := s.relationshipRecords(record) + if len(records) > 0 { return s.Store.Transaction(func(s *kallax.Store) error { + if err := s.Insert(Schema.SchemaFixture.BaseSchema, record); err != nil { return err } @@ -6164,8 +6186,10 @@ func (s *SchemaFixtureStore) Insert(record *SchemaFixture) error { func (s *SchemaFixtureStore) Update(record *SchemaFixture, cols ...kallax.SchemaField) (updated int64, err error) { records := s.relationshipRecords(record) + if len(records) > 0 { err = s.Store.Transaction(func(s *kallax.Store) error { + updated, err = s.Update(Schema.SchemaFixture.BaseSchema, record, cols...) if err != nil { return err diff --git a/tests/relationships_test.go b/tests/relationships_test.go index f564d26..3827f26 100644 --- a/tests/relationships_test.go +++ b/tests/relationships_test.go @@ -116,6 +116,16 @@ func (s *RelationshipsSuite) TestEvents() { s.assertEvents(car.events, "BeforeDelete", "AfterDelete") } +func (s *RelationshipsSuite) TestSaveWithInverse() { + p := NewPerson("Foo") + car := NewCar("Bar", p) + + store := NewCarStore(s.db) + s.NoError(store.Insert(car)) + + s.NotNil(s.getPerson()) +} + func (s *RelationshipsSuite) assertEvents(evs map[string]int, events ...string) { for _, e := range events { s.Equal(1, evs[e])