Skip to content
This repository has been archived by the owner on Feb 15, 2023. It is now read-only.

Commit

Permalink
go sqlgen: implement read path
Browse files Browse the repository at this point in the history
  • Loading branch information
berfarah committed Aug 23, 2018
1 parent 9fa76f3 commit df7d436
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 162 deletions.
17 changes: 17 additions & 0 deletions fields/descriptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,20 @@ func (d Descriptor) ValidateSQLType() error {
}
return d.Scanner().Scan(val)
}

func (d Descriptor) copy(from, to reflect.Value, isValid bool) {
// Set non-pointer by setting reference
if !d.Ptr {
to.Set(from)
return
}

if !isValid {
return
}

// Set pointer by creating a new reference.
tmp := reflect.New(d.Type)
tmp.Elem().Set(from)
to.Set(tmp)
}
5 changes: 5 additions & 0 deletions fields/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ func (s Scanner) Interface() interface{} {
return nil
}

// CopyTo copies the scanner value to another reflect.Value. This is used for setting structs.
func (s *Scanner) CopyTo(to reflect.Value) {
s.copy(s.value, to, s.isValid)
}

// Scan satisfies the sql.Scanner interface.
// The src value will be one of the following:
// int64
Expand Down
41 changes: 41 additions & 0 deletions sqlgen/deprecated.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package sqlgen

import (
"database/sql/driver"
"fmt"
)

type NullBytes struct {
Bytes []byte
Valid bool
}

func (b *NullBytes) Scan(value interface{}) error {
if value == nil {
b.Bytes = nil
b.Valid = false
}
switch value := value.(type) {
case nil:
b.Bytes = nil
b.Valid = false
case []byte:
// copy value since the MySQL driver reuses buffers
b.Bytes = make([]byte, len(value))
copy(b.Bytes, value)
b.Valid = true
case string:
b.Bytes = []byte(value)
b.Valid = true
default:
return fmt.Errorf("cannot convert %v to bytes", value)
}
return nil
}

func (b *NullBytes) Value() (driver.Value, error) {
if !b.Valid {
return nil, nil
}
return b.Bytes, nil
}
187 changes: 25 additions & 162 deletions sqlgen/reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,14 @@ package sqlgen
import (
"bytes"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"reflect"
"sort"
"strings"
"sync"
"time"
"unicode"

"github.com/go-sql-driver/mysql"
"github.com/samsarahq/thunder/internal"
"github.com/samsarahq/thunder/fields"
)

type Filter map[string]interface{}
Expand Down Expand Up @@ -53,89 +49,43 @@ const (
UniqueId
)

type NullBytes struct {
Bytes []byte
Valid bool
}

func (b *NullBytes) Scan(value interface{}) error {
if value == nil {
b.Bytes = nil
b.Valid = false
}
switch value := value.(type) {
case nil:
b.Bytes = nil
b.Valid = false
case []byte:
// copy value since the MySQL driver reuses buffers
b.Bytes = make([]byte, len(value))
copy(b.Bytes, value)
b.Valid = true
case string:
b.Bytes = []byte(value)
b.Valid = true
default:
return fmt.Errorf("cannot convert %v to bytes", value)
}
return nil
}

func (b *NullBytes) Value() (driver.Value, error) {
if !b.Valid {
return nil, nil
}
return b.Bytes, nil
}

// Types should implement both the sql.Scanner and driver.Valuer interface.
var defaultScannableTypes = map[reflect.Type]func() Scannable{
// These types should not be pointer types; pointer types are handled
// automatically and are treated as optional fields.
reflect.TypeOf(string("")): func() Scannable { return new(sql.NullString) },
reflect.TypeOf(int64(0)): func() Scannable { return new(sql.NullInt64) },
reflect.TypeOf(int32(0)): func() Scannable { return new(sql.NullInt64) },
reflect.TypeOf(int16(0)): func() Scannable { return new(sql.NullInt64) },
reflect.TypeOf(bool(false)): func() Scannable { return new(sql.NullBool) },
reflect.TypeOf(float64(0)): func() Scannable { return new(sql.NullFloat64) },
reflect.TypeOf([]byte{}): func() Scannable { return new(NullBytes) },
reflect.TypeOf(time.Time{}): func() Scannable { return new(mysql.NullTime) },
}

// BuildStruct constructs a struct value defined by table and based on scannables
func BuildStruct(table *Table, scannables []interface{}) interface{} {
// BuildStruct constructs a struct value defined by table and field values.
func BuildStruct(table *Table, scanners []*fields.Scanner) interface{} {
ptr := reflect.New(table.Type)
elem := ptr.Elem()

for i, column := range table.Columns {
value, _ := scannables[i].(driver.Valuer).Value()
// These values are all copies (as opposed to references) of database values.
// This means there's no funky business that can happen with the database re-using pointers.
value := scanners[i].Interface()
if value == nil {
continue
}

if column.Type.Kind() == reflect.Ptr {
ptr := reflect.New(column.Type.Elem())
ptr.Elem().Set(reflect.ValueOf(value).Convert(column.Type.Elem()))
elem.FieldByIndex(column.Index).Set(ptr)

} else {
elem.FieldByIndex(column.Index).Set(reflect.ValueOf(value).Convert(column.Type))
}
scanners[i].CopyTo(elem.FieldByIndex(column.Index))
}

return ptr.Interface()
}

// parseQueryRow parses a row from a sql.DB query into a struct
func parseQueryRow(table *Table, scanner *sql.Rows) (interface{}, error) {
scannables := table.Scannables.Get().([]interface{})
defer table.Scannables.Put(scannables)
// Pass fields which fulfill the interface `sql.Scanner` and coerce values.
values := make([]interface{}, len(table.Columns))
for i := range values {
values[i] = table.Columns[i].Descriptor.Scanner()
}

if err := scanner.Scan(scannables...); err != nil {
if err := scanner.Scan(values...); err != nil {
return nil, err
}

return BuildStruct(table, scannables), nil
scanners := make([]*fields.Scanner, len(table.Columns))
for i := range scanners {
scanners[i] = values[i].(*fields.Scanner)
}

return BuildStruct(table, scanners), nil
}

func CopySlice(result interface{}, rows []interface{}) error {
Expand Down Expand Up @@ -181,11 +131,10 @@ type Column struct {
Name string
Primary bool

Descriptor *fields.Descriptor

Index []int
Order int

Scannable func() Scannable
Type reflect.Type
}

type Table struct {
Expand All @@ -195,23 +144,6 @@ type Table struct {

Columns []*Column
ColumnsByName map[string]*Column

Scannables *sync.Pool
}

var scanIface = reflect.TypeOf((*Scannable)(nil)).Elem()

func coerceToScannable(scalarType reflect.Type) (func() Scannable, bool) {
if reflect.PtrTo(scalarType).Implements(scanIface) {
return func() Scannable {
return reflect.New(scalarType).Interface().(Scannable)
}, true
} else if scalarType.Implements(scanIface) {
return func() Scannable {
return reflect.New(scalarType).Elem().Interface().(Scannable)
}, true
}
return nil, false
}

func (s *Schema) buildDescriptor(table string, primaryKeyType PrimaryKeyType, typ reflect.Type) (*Table, error) {
Expand Down Expand Up @@ -258,17 +190,9 @@ func (s *Schema) buildDescriptor(table string, primaryKeyType PrimaryKeyType, ty
return nil, fmt.Errorf("bad type %s: duplicate column %s", typ, column)
}

scalarType := field.Type
if field.Type.Kind() == reflect.Ptr {
scalarType = field.Type.Elem()
}

scannable, ok := s.scalarTypes[scalarType]
if !ok {
scannable, ok = coerceToScannable(scalarType)
}
if !ok {
return nil, fmt.Errorf("bad type %s: field %s has unsupported type %s", typ, field.Name, field.Type)
d := fields.New(field.Type, tags[1:])
if err := d.ValidateSQLType(); err != nil {
return nil, fmt.Errorf("bad type %s: %s %v", typ, column, err)
}

descriptor := &Column{
Expand All @@ -278,8 +202,7 @@ func (s *Schema) buildDescriptor(table string, primaryKeyType PrimaryKeyType, ty
Index: field.Index,
Order: len(columns),

Scannable: scannable,
Type: field.Type,
Descriptor: d,
}

columns = append(columns, descriptor)
Expand All @@ -297,85 +220,25 @@ func (s *Schema) buildDescriptor(table string, primaryKeyType PrimaryKeyType, ty
return nil, fmt.Errorf("bad type %s: no primary key specified", typ)
}

scannables := &sync.Pool{
New: func() interface{} {
scannables := make([]interface{}, len(columns))
for i, column := range columns {
scannables[i] = column.Scannable()
}
return scannables
},
}

return &Table{
Name: table,
Type: typ,
PrimaryKeyType: primaryKeyType,

Columns: columns,
ColumnsByName: columnsByName,

Scannables: scannables,
}, nil
}

type Scannable interface {
sql.Scanner
driver.Valuer
}

type Schema struct {
ByName map[string]*Table
ByType map[reflect.Type]*Table

scalarTypes map[reflect.Type]func() Scannable
}

func NewSchema() *Schema {
scalarTypes := make(map[reflect.Type]func() Scannable)
for typ, scannable := range defaultScannableTypes {
scalarTypes[typ] = scannable
}

return &Schema{
ByName: make(map[string]*Table),
ByType: make(map[reflect.Type]*Table),

scalarTypes: scalarTypes,
}
}

func (s *Schema) RegisterCustomScalar(scalar interface{}, makeScannable func() Scannable) error {
scalarTyp := reflect.TypeOf(scalar)
if scalarTyp.Kind() == reflect.Ptr {
return fmt.Errorf("scalar type %v must not be a pointer", scalarTyp)
}
if _, ok := s.scalarTypes[scalarTyp]; ok {
return fmt.Errorf("duplicate scalar type %v", scalarTyp)
}
s.scalarTypes[scalarTyp] = makeScannable
return nil
}

func (s *Schema) MustRegisterCustomScalar(scalar interface{}, makeScannable func() Scannable) {
if err := s.RegisterCustomScalar(scalar, makeScannable); err != nil {
panic(err)
}
}

func (s *Schema) RegisterSimpleScalar(scalar interface{}) error {
typ := reflect.TypeOf(scalar)
for match, scannable := range defaultScannableTypes {
if internal.TypesIdenticalOrScalarAliases(typ, match) {
return s.RegisterCustomScalar(scalar, scannable)
}
}
return errors.New("unknown scalar")
}

func (s *Schema) MustRegisterSimpleScalar(scalar interface{}) {
if err := s.RegisterSimpleScalar(scalar); err != nil {
panic(err)
}
}

Expand Down

0 comments on commit df7d436

Please sign in to comment.