diff --git a/generator/types.go b/generator/types.go index 2270785..9a1ad6e 100644 --- a/generator/types.go +++ b/generator/types.go @@ -562,6 +562,8 @@ type ImplicitFK struct { } // Field is the representation of a model field. +// TODO(erizocosmico): please, refactor all this structure to use precomputed +// data instead of calculating it upon each call. type Field struct { // Name is the field name. Name string @@ -714,8 +716,12 @@ func (f *Field) IsInverse() bool { return false } - for _, part := range strings.Split(f.Tag.Get("fk"), ",") { - if part == "inverse" { + if f.IsManyToManyRelationship() { + return f.isInverseThrough() + } + + for i, part := range strings.Split(f.Tag.Get("fk"), ",") { + if i > 0 && part == "inverse" { return true } } @@ -729,6 +735,58 @@ func (f *Field) IsOneToManyRelationship() bool { return f.Kind == Relationship && strings.HasPrefix(f.Type, "[]") } +// IsManyToManyRelationship reports whether the field is a many to many +// relationship. +func (f *Field) IsManyToManyRelationship() bool { + return f.Kind == Relationship && f.Tag.Get("through") != "" +} + +// ThroughTable returns the name of the intermediate table used to access the +// current field. +func (f *Field) ThroughTable() string { + return f.getThroughTablePart(0) +} + +// LeftForeignKey is the name of the column used to join the current model with +// the intermediate table. +func (f *Field) LeftForeignKey() string { + fk := f.getThroughTablePart(1) + if fk == "" { + fk = foreignKeyForModel(f.Model.Name) + } + return fk +} + +// RightForeignKey is the name of the column used to join the relationship +// model with the intermediate table. +func (f *Field) RightForeignKey() string { + fk := f.getThroughTablePart(2) + if fk == "" { + fk = foreignKeyForModel(f.TypeSchemaName()) + } + return fk +} + +func (f *Field) isInverseThrough() bool { + return f.getThroughPart(1) == "inverse" +} + +func (f *Field) getThroughPart(idx int) string { + parts := strings.Split(f.Tag.Get("through"), ",") + if len(parts) > idx { + return strings.TrimSpace(parts[idx]) + } + return "" +} + +func (f *Field) getThroughTablePart(idx int) string { + parts := strings.Split(f.getThroughPart(0), ":") + if len(parts) > idx { + return strings.TrimSpace(parts[idx]) + } + return "" +} + func foreignKeyForModel(model string) string { return toLowerSnakeCase(model) + "_id" } diff --git a/generator/types_test.go b/generator/types_test.go index 3a0bf03..ffa45b2 100644 --- a/generator/types_test.go +++ b/generator/types_test.go @@ -194,6 +194,79 @@ func (s *FieldSuite) TestValue() { } } +func (s *FieldSuite) TestIsInverse() { + cases := []struct { + tag string + expected bool + }{ + {"", false}, + {`inverse:"true"`, false}, + {`fk:"inverse"`, false}, + {`through:"inverse"`, false}, + {`fk:"foo,inverse"`, true}, + {`fk:",inverse"`, true}, + {`through:"foo,inverse"`, true}, + {`through:"foo:a:b,inverse"`, true}, + } + + for _, tt := range cases { + f := withTag(mkField("", ""), tt.tag) + f.Kind = Relationship + s.Equal(tt.expected, f.IsInverse(), tt.tag) + } +} + +func (s *FieldSuite) TestThroughTable() { + cases := []struct { + tag, expected string + }{ + {``, ""}, + {`through:"foo"`, "foo"}, + {`through:"foo,inverse"`, "foo"}, + {`through:"foo:a:b,inverse"`, "foo"}, + } + + for _, tt := range cases { + s.Equal(tt.expected, withTag(mkField("", ""), tt.tag).ThroughTable(), tt.tag) + } +} + +func (s *FieldSuite) TestLeftForeignKey() { + cases := []struct { + tag, expected string + }{ + {``, "bar_id"}, + {`through:"foo"`, "bar_id"}, + {`through:"foo,inverse"`, "bar_id"}, + {`through:"foo:a,inverse"`, "a"}, + {`through:"foo:a:b,inverse"`, "a"}, + } + + for _, tt := range cases { + f := withTag(mkField("", ""), tt.tag) + f.Model = &Model{Name: "Bar"} + s.Equal(tt.expected, f.LeftForeignKey(), tt.tag) + } +} + +func (s *FieldSuite) TestRightForeignKey() { + cases := []struct { + tag, expected string + }{ + {``, "foo_id"}, + {`through:"foo"`, "foo_id"}, + {`through:"foo,inverse"`, "foo_id"}, + {`through:"foo:a,inverse"`, "foo_id"}, + {`through:"foo:a:b,inverse"`, "b"}, + } + + for _, tt := range cases { + f := withTag(mkField("", ""), tt.tag) + f.Type = "Foo" + s.Equal(tt.expected, f.RightForeignKey(), tt.tag) + } +} + type ModelSuite struct { suite.Suite model *Model