Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

db: preload embedded fields #393

Open
wants to merge 6 commits into
base: v3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 74 additions & 56 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,42 +34,42 @@
// package main
//
// import (
// "log"
// "log"
//
// "upper.io/db.v3/postgresql" // Imports the postgresql adapter.
// "upper.io/db.v3/postgresql" // Imports the postgresql adapter.
// )
//
// var settings = postgresql.ConnectionURL{
// Database: `booktown`,
// Host: `demo.upper.io`,
// User: `demouser`,
// Password: `demop4ss`,
// Database: `booktown`,
// Host: `demo.upper.io`,
// User: `demouser`,
// Password: `demop4ss`,
// }
//
// // Book represents a book.
// type Book struct {
// ID uint `db:"id"`
// Title string `db:"title"`
// AuthorID uint `db:"author_id"`
// SubjectID uint `db:"subject_id"`
// ID uint `db:"id"`
// Title string `db:"title"`
// AuthorID uint `db:"author_id"`
// SubjectID uint `db:"subject_id"`
// }
//
// func main() {
// sess, err := postgresql.Open(settings)
// if err != nil {
// log.Fatal(err)
// }
// defer sess.Close()
//
// var books []Book
// if err := sess.Collection("books").Find().OrderBy("title").All(&books); err != nil {
// log.Fatal(err)
// }
//
// log.Println("Books:")
// for _, book := range books {
// log.Printf("%q (ID: %d)\n", book.Title, book.ID)
// }
// sess, err := postgresql.Open(settings)
// if err != nil {
// log.Fatal(err)
// }
// defer sess.Close()
//
// var books []Book
// if err := sess.Collection("books").Find().OrderBy("title").All(&books); err != nil {
// log.Fatal(err)
// }
//
// log.Println("Books:")
// for _, book := range books {
// log.Printf("%q (ID: %d)\n", book.Title, book.ID)
// }
// }
//
// See more usage examples and documentation for users at
Expand Down Expand Up @@ -184,7 +184,7 @@ type Unmarshaler interface {
//
// // Where age equals 18.
// db.Cond{"age": 18}
// // // Where age is greater than or equal to 18.
// // // Where age is greater than or equal to 18.
// db.Cond{"age >=": 18}
//
// // Where id is in a list of ids.
Expand Down Expand Up @@ -241,6 +241,9 @@ func (c Cond) Empty() bool {
return true
}

// Relation represents a relation between columns or tables.
type Relation map[string]interface{}

type rawValue struct {
v string
a *[]interface{} // This may look ugly but allows us to use db.Raw() as keys for db.Cond{}.
Expand Down Expand Up @@ -429,17 +432,17 @@ func NewConstraint(key interface{}, value interface{}) Constraint {
//
// Examples:
//
// // MOD(29, 9)
// db.Func("MOD", 29, 9)
// // MOD(29, 9)
// db.Func("MOD", 29, 9)
//
// // CONCAT("foo", "bar")
// db.Func("CONCAT", "foo", "bar")
// // CONCAT("foo", "bar")
// db.Func("CONCAT", "foo", "bar")
//
// // NOW()
// db.Func("NOW")
// // NOW()
// db.Func("NOW")
//
// // RTRIM("Hello ")
// db.Func("RTRIM", "Hello ")
// // RTRIM("Hello ")
// db.Func("RTRIM", "Hello ")
func Func(name string, args ...interface{}) Function {
if len(args) == 1 {
if reflect.TypeOf(args[0]).Kind() == reflect.Slice {
Expand Down Expand Up @@ -471,20 +474,20 @@ func (f *dbFunc) Name() string {
//
// Examples:
//
// // name = "Peter" AND last_name = "Parker"
// db.And(
// db.Cond{"name": "Peter"},
// db.Cond{"last_name": "Parker "},
// )
//
// // (name = "Peter" OR name = "Mickey") AND last_name = "Mouse"
// db.And(
// db.Or(
// db.Cond{"name": "Peter"},
// db.Cond{"name": "Mickey"},
// ),
// db.Cond{"last_name": "Mouse"},
// )
// // name = "Peter" AND last_name = "Parker"
// db.And(
// db.Cond{"name": "Peter"},
// db.Cond{"last_name": "Parker "},
// )
//
// // (name = "Peter" OR name = "Mickey") AND last_name = "Mouse"
// db.And(
// db.Or(
// db.Cond{"name": "Peter"},
// db.Cond{"name": "Mickey"},
// ),
// db.Cond{"last_name": "Mouse"},
// )
func And(conds ...Compound) *Intersection {
return &Intersection{newCompound(conds...)}
}
Expand All @@ -494,11 +497,11 @@ func And(conds ...Compound) *Intersection {
//
// Example:
//
// // year = 2012 OR year = 1987
// db.Or(
// db.Cond{"year": 2012},
// db.Cond{"year": 1987},
// )
// // year = 2012 OR year = 1987
// db.Or(
// db.Cond{"year": 2012},
// db.Cond{"year": 1987},
// )
func Or(conds ...Compound) *Union {
return &Union{newCompound(defaultJoin(conds...)...)}
}
Expand All @@ -508,8 +511,8 @@ func Or(conds ...Compound) *Union {
//
// Example:
//
// // SOUNDEX('Hello')
// Raw("SOUNDEX('Hello')")
// // SOUNDEX('Hello')
// Raw("SOUNDEX('Hello')")
//
// Raw returns a value that satifies the db.RawValue interface.
func Raw(value string, args ...interface{}) RawValue {
Expand Down Expand Up @@ -659,6 +662,21 @@ type Result interface {
// or columns.
Group(...interface{}) Result

// Preload direct one-to-one relations with other tables. It takes a
// db.Relation argument, which is a map of all relations you want to preload.
//
// Example:
//
// q := publicationCollection.Find().Preload(db.Relation{
// "artist": artistCollection.Find(
// "artist.id = publication.author_id",
// ),
// })
//
// Preload returns all records from the left collection and only matching
// records on the right (left join).
Preload(Relation) Result

// Delete deletes all items within the result set. `Offset()` and `Limit()` are
// not honoured by `Delete()`.
Delete() error
Expand Down Expand Up @@ -747,7 +765,7 @@ type Result interface {
//
// You can define the pagination order and add constraints to your result:
//
// cursor = q.Where(...).OrderBy("id").Paginate(10).Cursor("id")
// cursor = q.Where(...).OrderBy("id").Paginate(10).Cursor("id")
// res = cursor.NextPage(lowerBound)
NextPage(cursorValue interface{}) Result

Expand Down
20 changes: 15 additions & 5 deletions internal/sqladapter/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ type BaseCollection interface {

// PrimaryKeys returns the table's primary keys.
PrimaryKeys() []string

// Columns returns the table's columns.
Columns() []string
}

type condsFilter interface {
Expand All @@ -65,7 +68,9 @@ type collection struct {
BaseCollection
PartialCollection

pk []string
primaryKeys []string
columns []string

err error
}

Expand All @@ -76,22 +81,27 @@ var (
// NewBaseCollection returns a collection with basic methods.
func NewBaseCollection(p PartialCollection) BaseCollection {
c := &collection{PartialCollection: p}
c.pk, c.err = c.Database().PrimaryKeys(c.Name())
c.primaryKeys, c.columns, c.err = c.Database().Columns(c.Name())
return c
}

// PrimaryKeys returns the collection's primary keys, if any.
func (c *collection) PrimaryKeys() []string {
return c.pk
return c.primaryKeys
}

// PrimaryKeys returns the collection's columns, if any.
func (c *collection) Columns() []string {
return c.columns
}

func (c *collection) filterConds(conds ...interface{}) []interface{} {
if tr, ok := c.PartialCollection.(condsFilter); ok {
return tr.FilterConds(conds...)
}
if len(conds) == 1 && len(c.pk) == 1 {
if len(conds) == 1 && len(c.primaryKeys) == 1 {
if id := conds[0]; IsKeyValue(id) {
conds[0] = db.Cond{c.pk[0]: id}
conds[0] = db.Cond{c.primaryKeys[0]: id}
}
}
return conds
Expand Down
4 changes: 2 additions & 2 deletions internal/sqladapter/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ type PartialDatabase interface {
// LookupName returns the name of the database.
LookupName() (string, error)

// PrimaryKeys returns all primary keys on the table.
PrimaryKeys(name string) ([]string, error)
// Columns returns all columns on the table.
Columns(name string) (primaryKeys []string, columns []string, err error)

// NewCollection allocates a new collection by name.
NewCollection(name string) db.Collection
Expand Down
79 changes: 77 additions & 2 deletions internal/sqladapter/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
package sqladapter

import (
"errors"
"fmt"
"sync"
"sync/atomic"

Expand All @@ -30,6 +32,10 @@ import (
"upper.io/db.v3/lib/sqlbuilder"
)

type hasColumns interface {
Columns(table string) ([]string, []string, error)
}

type Result struct {
builder sqlbuilder.SQLBuilder

Expand All @@ -42,6 +48,11 @@ type Result struct {
fn func(*result) error
}

type assocOne struct {
res db.Result
alias string
}

// result represents a delimited set of items bound by a condition.
type result struct {
table string
Expand All @@ -60,8 +71,12 @@ type result struct {
orderBy []interface{}
groupBy []interface{}
conds [][]interface{}

preloadOne []assocOne
}

type preloadAllFn func(db.Cond) db.Result

func filter(conds []interface{}) []interface{} {
return conds
}
Expand Down Expand Up @@ -298,6 +313,19 @@ func (r *Result) Update(values interface{}) error {
return r.setErr(err)
}

func (r *Result) Preload(relation db.Relation) db.Result {
return r.frame(func(res *result) error {
for key, val := range relation {
if finder, ok := val.(db.Result); ok {
res.preloadOne = append(res.preloadOne, assocOne{res: finder, alias: key})
continue
}
return fmt.Errorf("expecting a relation with db.Result value, got %T", val)
}
return nil
})
}

func (r *Result) TotalPages() (uint, error) {
query, err := r.buildPaginator()
if err != nil {
Expand Down Expand Up @@ -383,13 +411,60 @@ func (r *Result) buildPaginator() (sqlbuilder.Paginator, error) {
return nil, err
}

sel := r.SQLBuilder().Select(res.fields...).
From(res.table).
sel := r.SQLBuilder().SelectFrom(res.table).
Limit(res.limit).
Offset(res.offset).
GroupBy(res.groupBy...).
OrderBy(res.orderBy...)

if len(res.fields) > 0 {
sel = sel.Columns(res.fields...)
}

if len(res.preloadOne) > 0 {
sess, ok := r.SQLBuilder().(hasColumns)
if !ok {
return nil, errors.New("Could not create join")
}

columns := []interface{}{}

_, cs, err := sess.Columns(res.table)
if err != nil {
return nil, err
}

for _, c := range cs {
columns = append(columns, fmt.Sprintf("%s.%s AS %s", res.table, c, c))
}
sel = sel.Columns(columns...)

for _, assocOne := range res.preloadOne {

ff, err := assocOne.res.(*Result).fastForward()
if err != nil {
return nil, r.setErr(err)
}

ffConds := []interface{}{}
for i := range ff.conds {
ffConds = append(ffConds, filter(ff.conds[i])...)
}
sel = sel.LeftJoin(ff.table).On(ffConds...)

_, cs, err := sess.Columns(ff.table)
if err != nil {
return nil, r.setErr(err)
}

columns := []interface{}{}
for _, c := range cs {
columns = append(columns, fmt.Sprintf("%s.%s AS %s.%s", ff.table, c, assocOne.alias, c))
}
sel = sel.Columns(columns...)
}
}

for i := range res.conds {
sel = sel.And(filter(res.conds[i])...)
}
Expand Down
Loading