Skip to content

Commit

Permalink
Feature: Add With Inner Join Selection
Browse files Browse the repository at this point in the history
Enables to easily re-use existing structs for doing join queries
  • Loading branch information
OrCh3n committed Feb 18, 2024
2 parents 2d8c873 + 3b8f352 commit 6f5520a
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 9 deletions.
43 changes: 42 additions & 1 deletion select.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,50 @@ func WithSelectStar() SelectOption {
}
}

type JoinOp struct {
Table string
On exp.JoinCondition
}

// WithInnerJoinSelection returns a select option that inner joins the given table on the given column by tableName.column = otherTable.otherColumn,
// as well as selecting the columns from the given struct. each top-level struct field will be treated as a table and each field within that struct
// will be treated as a column.
func WithInnerJoinSelection[T any](op ...JoinOp) SelectOption {
return func(_ exp.IdentifierExpression, s *goqu.SelectDataset) *goqu.SelectDataset {
for _, j := range op {
s = s.InnerJoin(goqu.T(j.Table), j.On)
}
selectFields := make([]any, 0)
for _, c := range getSelectionFieldsFromSelectionStruct(new(T)) {
selectFields = append(selectFields, c)
}
return s.Select(selectFields...)
}
}

// WithLeftJoinSelection returns a select option that left joins the given table on the given column by tableName.column = otherTable.otherColumn,
// as well as selecting the columns from the given struct. each top-level struct field will be treated as a table and each field within that struct
// will be treated as a column.
func WithLeftJoinSelection[T any](op ...JoinOp) SelectOption {
return func(_ exp.IdentifierExpression, s *goqu.SelectDataset) *goqu.SelectDataset {
for _, j := range op {
s = s.LeftJoin(goqu.T(j.Table), j.On)
}
selectFields := make([]any, 0)
for _, c := range getSelectionFieldsFromSelectionStruct(new(T)) {
selectFields = append(selectFields, c)
}
return s.Select(selectFields...)
}
}

func BuildSelect[T any](tableName string, dst T, options ...SelectOption) (string, []any, error) {
table := goqu.T(tableName)
selectQuery := goqu.Select(getColumnsFromStruct(table, dst, skipSelect)...).From(table).WithDialect(defaultDialect)
structCols := make([]any, 0)
for _, c := range getColumnsFromStruct(table, dst, skipSelect) {
structCols = append(structCols, c)
}
selectQuery := goqu.Dialect(defaultDialect).Select(structCols...).From(table)
for _, o := range options {
selectQuery = o(table, selectQuery)
}
Expand Down
52 changes: 52 additions & 0 deletions select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package goqux_test
import (
"testing"

"github.com/doug-martin/goqu/v9"
"github.com/roneli/goqux"
"github.com/stretchr/testify/assert"
)
Expand All @@ -13,6 +14,27 @@ type selectModel struct {
fieldToSkip int `goqux:"skip_select"`
}

type joinModel struct {
T2 selectModel `db:"select_models"`
Table1 *Table1
}

type Table1 struct {
IntField int
StringField string `db:"cool_field"`
}

type doubleJoinModel struct {
T2 selectModel `db:"select_models"`
Table1 *Table1
T3 Table2 `db:"table_2"`
}

type Table2 struct {
IntField int
StringField string `db:"cool_field"`
}

func TestBuildSelect(t *testing.T) {
tableTests := []struct {
name string
Expand Down Expand Up @@ -64,6 +86,36 @@ func TestBuildSelect(t *testing.T) {
expectedQuery: `SELECT "select_models"."int_field" FROM "select_models" ORDER BY "select_models"."int_field" DESC`,
expectedArgs: []interface{}{},
},
{
name: "select_with_inner_join_selection",
dst: joinModel{},
options: []goqux.SelectOption{goqux.WithInnerJoinSelection[joinModel](goqux.JoinOp{
Table: "table_1",
On: goqu.On(goqux.Column("table_1", "int_field").Eq(goqux.Column("select_models", "int_field"))),
})},
expectedQuery: `SELECT "select_models"."int_field" AS "select_models.int_field", "table_1"."int_field" AS "table_1.int_field", "table_1"."cool_field" AS "table_1.cool_field" FROM "select_models" INNER JOIN "table_1" ON ("table_1"."int_field" = "select_models"."int_field")`,
},
{
name: "select_with_left_selection",
dst: joinModel{},
options: []goqux.SelectOption{goqux.WithLeftJoinSelection[joinModel](goqux.JoinOp{
Table: "table_1",
On: goqu.On(goqux.Column("table_1", "int_field").Eq(goqux.Column("select_models", "int_field"))),
})},
expectedQuery: `SELECT "select_models"."int_field" AS "select_models.int_field", "table_1"."int_field" AS "table_1.int_field", "table_1"."cool_field" AS "table_1.cool_field" FROM "select_models" LEFT JOIN "table_1" ON ("table_1"."int_field" = "select_models"."int_field")`,
},
{
name: "select_with_double_join_selection",
dst: doubleJoinModel{},
options: []goqux.SelectOption{goqux.WithInnerJoinSelection[doubleJoinModel](goqux.JoinOp{
Table: "table_1",
On: goqu.On(goqux.Column("table_1", "int_field").Eq(goqux.Column("select_models", "int_field"))),
}, goqux.JoinOp{
Table: "table_2",
On: goqu.On(goqux.Column("table_2", "int_field").Eq(goqux.Column("select_models", "int_field"))),
})},
expectedQuery: `SELECT "select_models"."int_field" AS "select_models.int_field", "table_1"."int_field" AS "table_1.int_field", "table_1"."cool_field" AS "table_1.cool_field", "table_2"."int_field" AS "table_2.int_field", "table_2"."cool_field" AS "table_2.cool_field" FROM "select_models" INNER JOIN "table_1" ON ("table_1"."int_field" = "select_models"."int_field") INNER JOIN "table_2" ON ("table_2"."int_field" = "select_models"."int_field")`,
},
}
for _, tableTest := range tableTests {
t.Run(tableTest.name, func(t *testing.T) {
Expand Down
41 changes: 36 additions & 5 deletions struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strings"
"time"

"github.com/doug-martin/goqu/v9"
_ "github.com/doug-martin/goqu/v9/dialect/postgres"
"github.com/doug-martin/goqu/v9/exp"
"github.com/iancoleman/strcase"
Expand Down Expand Up @@ -65,23 +66,24 @@ func encodeValues(v any, skipType string, skipZeroValues bool) map[string]SQLVal
return values
}

func getColumnsFromStruct(table exp.IdentifierExpression, s any, skipType string) []any {
func getColumnsFromStruct(table exp.IdentifierExpression, s any, skipType string) []exp.IdentifierExpression {
t := reflect.TypeOf(s)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
fields := reflect.VisibleFields(t)
var cols = make([]any, 0)
var cols = make([]exp.IdentifierExpression, 0)
for _, f := range fields {
if !f.IsExported() || strings.Contains(f.Tag.Get(tagName), skipType) {
continue
}
var colName string
if dbTag := f.Tag.Get(tagNameDb); dbTag != "" {
cols = append(cols, table.Col(cleanDbTag(dbTag)))
continue
colName = cleanDbTag(dbTag)
} else {
cols = append(cols, table.Col(strcase.ToSnake(f.Name)))
colName = strcase.ToSnake(f.Name)
}
cols = append(cols, table.Col(colName))
}
return cols
}
Expand All @@ -93,3 +95,32 @@ func cleanDbTag(tag string) string {

return tag
}

func getSelectionFieldsFromSelectionStruct(s interface{}) []exp.AliasedExpression {
cols := make([]exp.AliasedExpression, 0)
t := reflect.TypeOf(s)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
tableFields := reflect.VisibleFields(t)
for _, tf := range tableFields {
if !tf.IsExported() {
continue
}
if tf.Type.Kind() != reflect.Struct && !(tf.Type.Kind() == reflect.Ptr && tf.Type.Elem().Kind() == reflect.Struct) {
continue
}
tableName := strcase.ToSnake(tf.Name)
if dbTag := tf.Tag.Get(tagNameDb); dbTag != "" {
tableName = cleanDbTag(dbTag)
}
subTableColumns := getColumnsFromStruct(goqu.T(tableName), reflect.New(tf.Type).Elem().Interface(), skipSelect)
for _, c := range subTableColumns {
// SELECT "table"."column" AS "table.column" will make sure dbscan scans all the columns correctly
cc := c.GetCol()
cName := tableName + "." + cc.(string)
cols = append(cols, goqu.T(tableName).Col(cc).As(goqu.C(cName)))
}
}
return cols
}
71 changes: 68 additions & 3 deletions struct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"time"

"github.com/doug-martin/goqu/v9"
"github.com/doug-martin/goqu/v9/exp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -127,7 +128,7 @@ func TestGetColumnsFromStruct(t *testing.T) {
tableTests := []struct {
name string
model interface{}
expected []interface{}
expected []exp.IdentifierExpression
}{
{
name: "get_columns_from_struct",
Expand All @@ -137,15 +138,15 @@ func TestGetColumnsFromStruct(t *testing.T) {
FieldToSkip int `goqux:"skip_select"`
DbOsField int `db:"db_field"`
}{},
expected: []interface{}{goqu.T("table").Col("int_field"), goqu.T("table").Col("db_field")},
expected: []exp.IdentifierExpression{goqu.T("table").Col("int_field"), goqu.T("table").Col("db_field")},
},
{
name: "get_columns_from_struct_with_omitempty",
model: struct {
IntField int `db:"int_field,omitempty"` // should not be skipped
StringField string
}{},
expected: []interface{}{goqu.T("table").Col("int_field"), goqu.T("table").Col("string_field")},
expected: []exp.IdentifierExpression{goqu.T("table").Col("int_field"), goqu.T("table").Col("string_field")},
},
}
for _, tt := range tableTests {
Expand All @@ -155,3 +156,67 @@ func TestGetColumnsFromStruct(t *testing.T) {
})
}
}

type table1 struct {
IntField int
}

type table2 struct {
IntField int
StringField string `db:"cool_field"`
IgnoreField string `goqux:"skip_select"`
privateField string
}

func Test_getSelectionFieldsFromSelectionStruct(t *testing.T) {
tableTests := []struct {
name string
model interface{}
expected []exp.AliasedExpression
}{
{
name: "get_selection_fields_from_struct",
model: struct {
Table1 table1
Table2 table2 `db:"table_2"`
}{},
expected: []exp.AliasedExpression{
goqu.T("table_1").Col("int_field").As(goqu.C("table_1.int_field")),
goqu.T("table_2").Col("int_field").As(goqu.C("table_2.int_field")),
goqu.T("table_2").Col("cool_field").As(goqu.C("table_2.cool_field"))},
},
{
name: "get_selection_fields_from_struct_same_table",
model: struct {
Table1 table1
Table1Alias table1 `db:"table_alias"`
}{},
expected: []exp.AliasedExpression{
goqu.T("table_1").Col("int_field").As(goqu.C("table_1.int_field")),
goqu.T("table_alias").Col("int_field").As(goqu.C("table_alias.int_field")),
},
},
{
name: "get_selection_fields_from_struct_unexported_column",
model: struct {
t1Private table1
}{},
expected: []exp.AliasedExpression{},
},
{
name: "get_selection_fields_from_non_top_level_struct",
model: struct {
IntField int
FieldToSkip int `goqux:"skip_select"`
DbOsField int `db:"db_field"`
}{},
expected: []exp.AliasedExpression{},
},
}
for _, tt := range tableTests {
t.Run(tt.name, func(t *testing.T) {
columns := getSelectionFieldsFromSelectionStruct(tt.model)
assert.Equal(t, tt.expected, columns)
})
}
}

0 comments on commit 6f5520a

Please sign in to comment.