Skip to content
This repository has been archived by the owner on Feb 15, 2023. It is now read-only.

btuan/count #230

Merged
merged 5 commits into from
Apr 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

## [Unreleased]

### Changed

#### `sqlgen`
- Implemented a basic `(*sqlgen.DB).Count` receiver that wraps `SELECT COUNT(*)` functionality in SQL databases. ([#230](https://github.com/samsarahq/thunder/pull/230))


## [0.5.0] 2019-01-10

### Changed
Expand Down
38 changes: 38 additions & 0 deletions sqlgen/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,23 @@ func NewDB(conn *sql.DB, schema *Schema) *DB {
return db
}

func (db *DB) baseCount(ctx context.Context, query *baseCountQuery) (int64, error) {
countQuery, err := query.makeCountQuery()
if err != nil {
return 0, err
}

clause, args := countQuery.ToSQL()

var count int64
err = db.QueryExecer(ctx).QueryRowContext(ctx, clause, args...).Scan(&count)
if err != nil {
return 0, err
}

return count, err
}

func (db *DB) BaseQuery(ctx context.Context, query *BaseSelectQuery) ([]interface{}, error) {
if query.Options == nil && !db.HasTx(ctx) && batch.HasBatching(ctx) {
rows, err := db.batchFetch.Invoke(ctx, query)
Expand Down Expand Up @@ -123,6 +140,27 @@ func (db *DB) execWithTrace(ctx context.Context, query SQLQuery, operationName s
return db.QueryExecer(ctx).ExecContext(ctx, clause, args...)
}

// Count counts the number of relevant rows in a database, matching options in filter
//
// model should be a pointer to a struct, for example:
//
// count, err := db.Count(ctx, &User{}, &res, Filter{})
// if err != nil { ... }
//
func (db *DB) Count(ctx context.Context, model interface{}, filter Filter) (int64, error) {
query, err := db.Schema.makeCount(model, filter)
if err != nil {
return 0, err
}

count, err := db.baseCount(ctx, query)
if err != nil {
return 0, err
}

return count, nil
}

// Query fetches a collection of rows from the database
//
// result should be a pointer to a slice of pointers to structs, for example:
Expand Down
4 changes: 4 additions & 0 deletions sqlgen/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ func TestIntegrationBasic(t *testing.T) {
Mood: &mood,
},
}, users)

numBobs, err := db.Count(context.Background(), &User{}, Filter{"name": "Bob"})
assert.NoError(t, err)
assert.Equal(t, int64(1), numBobs)
}

// TestContextCancelBeforeRowsScan demonstrates we don't
Expand Down
22 changes: 22 additions & 0 deletions sqlgen/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,28 @@ type SQLQuery interface {
ToSQL() (string, []interface{})
}

type countQuery struct {
Table string
Where *SimpleWhere
}

// ToSQL builds a parameterized SELECT COUNT(*) FROM x ... statement
func (q *countQuery) ToSQL() (string, []interface{}) {
var buffer bytes.Buffer

buffer.WriteString("SELECT COUNT(*)")
buffer.WriteString(" FROM ")
buffer.WriteString(q.Table)

where, whereValues := q.Where.ToSQL()
if where != "" {
buffer.WriteString(" WHERE ")
buffer.WriteString(where)
}

return buffer.String(), whereValues
}

// SelectQuery represents a SELECT query
type SelectQuery struct {
Table string
Expand Down
18 changes: 18 additions & 0 deletions sqlgen/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,24 @@ func TestSimpleWhere(t *testing.T) {
}, "foo = ? AND bar = ?", []interface{}{1, 2}, t)
}

func TestCountQuery(t *testing.T) {
testQuery(&countQuery{
Table: "foo",
Where: &SimpleWhere{
Columns: []string{"bar"},
Values: []interface{}{3},
},
}, "SELECT COUNT(*) FROM foo WHERE bar = ?", []interface{}{3}, t)

testQuery(&countQuery{
Table: "foo2",
Where: &SimpleWhere{
Columns: []string{"baz"},
Values: []interface{}{"xyz"},
},
}, "SELECT COUNT(*) FROM foo2 WHERE baz = ?", []interface{}{"xyz"}, t)
}

func TestSelectQuery(t *testing.T) {
testQuery(&SelectQuery{
Table: "foo",
Expand Down
48 changes: 48 additions & 0 deletions sqlgen/reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,54 @@ func makeWhere(table *Table, filter Filter) (*SimpleWhere, error) {
}, nil
}

type baseCountQuery struct {
Table *Table
Filter Filter
}

func (b *baseCountQuery) makeCountQuery() (*countQuery, error) {
where, err := makeWhere(b.Table, b.Filter)
if err != nil {
return nil, err
}

return &countQuery{
Table: b.Table.Name,
Where: where,
}, nil
}

var errBadCountModelType = errors.New("count model value should be a pointer to a struct")

func checkCountModelTypeShape(typ reflect.Type) (reflect.Type, error) {
if typ.Kind() != reflect.Ptr {
return nil, errBadCountModelType
}
typ = typ.Elem()
if typ.Kind() != reflect.Struct {
return nil, errBadCountModelType
}
return typ, nil
}

func (s *Schema) makeCount(model interface{}, filter Filter) (*baseCountQuery, error) {
ptr := reflect.ValueOf(model)
typ, err := checkCountModelTypeShape(ptr.Type())
if err != nil {
return nil, err
}

table, err := s.get(typ)
if err != nil {
return nil, err
}

return &baseCountQuery{
Table: table,
Filter: filter,
}, nil
}

type BaseSelectQuery struct {
Table *Table
Filter Filter
Expand Down
41 changes: 41 additions & 0 deletions sqlgen/reflect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,47 @@ func TestMakeWhere(t *testing.T) {
}
}

func TestMakeCount(t *testing.T) {
s := NewSchema()
if err := s.RegisterType("users", AutoIncrement, user{}); err != nil {
t.Fatal(err)
}

var usr user
query, err := s.makeCount(&usr, Filter{"id": 10})
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(query, &baseCountQuery{
Table: s.ByName["users"],
Filter: Filter{"id": 10},
}) {
t.Error("bad count")
}

query, err = s.makeCount(&usr, Filter{})
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(query, &baseCountQuery{
Table: s.ByName["users"],
Filter: Filter{},
}) {
t.Error("bad count")
}

query, err = s.makeCount(&usr, Filter{"name": "bob", "age": 10})
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(query, &baseCountQuery{
Table: s.ByName["users"],
Filter: Filter{"name": "bob", "age": 10},
}) {
t.Error("bad count")
}
}

func TestMakeSelect(t *testing.T) {
s := NewSchema()
if err := s.RegisterType("users", AutoIncrement, user{}); err != nil {
Expand Down