Skip to content
This repository has been archived by the owner on Feb 15, 2023. It is now read-only.

Commit

Permalink
go livesql: marshal non driver.Value type filters
Browse files Browse the repository at this point in the history
  • Loading branch information
changpingc committed Sep 20, 2018
1 parent 69aa469 commit 3ff10eb
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 9 deletions.
10 changes: 5 additions & 5 deletions livesql/live.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (t *dbTracker) processBinlog(update *update) {
}
}

func (t *dbTracker) registerDependency(ctx context.Context, table string, tester sqlgen.Tester, filter sqlgen.Filter) error {
func (t *dbTracker) registerDependency(ctx context.Context, schema *sqlgen.Schema, table string, tester sqlgen.Tester, filter sqlgen.Filter) error {
r := &dbResource{
table: table,
tester: tester,
Expand All @@ -86,7 +86,7 @@ func (t *dbTracker) registerDependency(ctx context.Context, table string, tester
t.remove(r)
})

proto, err := filterToProto(table, filter)
proto, err := filterToProto(schema, table, filter)
if err != nil {
return err
}
Expand Down Expand Up @@ -149,7 +149,7 @@ func (ldb *LiveDB) query(ctx context.Context, query *sqlgen.BaseSelectQuery) ([]
// Register the dependency before we do the query to not miss any updates
// between querying and registering.
// Do not fail the query if this step fails.
_ = ldb.tracker.registerDependency(ctx, query.Table.Name, tester, query.Filter)
_ = ldb.tracker.registerDependency(ctx, ldb.Schema, query.Table.Name, tester, query.Filter)

// Perform the query.
// XXX: This will build the SQL string again... :(
Expand Down Expand Up @@ -211,7 +211,7 @@ func (ldb *LiveDB) Close() error {
}

func (ldb *LiveDB) AddDependency(ctx context.Context, proto *thunderpb.SQLFilter) error {
table, filter, err := filterFromProto(proto)
table, filter, err := filterFromProto(ldb.Schema, proto)
if err != nil {
return err
}
Expand All @@ -221,7 +221,7 @@ func (ldb *LiveDB) AddDependency(ctx context.Context, proto *thunderpb.SQLFilter
return err
}

if err := ldb.tracker.registerDependency(ctx, table, tester, filter); err != nil {
if err := ldb.tracker.registerDependency(ctx, ldb.Schema, table, tester, filter); err != nil {
return err
}
return nil
Expand Down
71 changes: 67 additions & 4 deletions livesql/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package livesql

import (
"database/sql/driver"
"errors"
"fmt"
"reflect"
"time"

"github.com/samsarahq/thunder/internal/fields"
"github.com/samsarahq/thunder/sqlgen"
"github.com/samsarahq/thunder/thunderpb"
)
Expand Down Expand Up @@ -62,26 +64,87 @@ func fieldToValue(field *thunderpb.Field) (driver.Value, error) {
}
}

func filterToProto(table string, filter sqlgen.Filter) (*thunderpb.SQLFilter, error) {
// filterToProto takes a sqlgen.Filter, runs Valuer on each filter value, and returns a thunderpb.SQLFilter.
func filterToProto(schema *sqlgen.Schema, tableName string, filter sqlgen.Filter) (*thunderpb.SQLFilter, error) {
table, ok := schema.ByName[tableName]
if !ok {
return nil, fmt.Errorf("unknown table: %s", tableName)
}

if filter == nil {
return &thunderpb.SQLFilter{Table: tableName}, nil
}

fields := make(map[string]*thunderpb.Field, len(filter))
for col, val := range filter {
column, ok := table.ColumnsByName[col]
if !ok {
return nil, fmt.Errorf("unknown column %s", col)
}

val, err := column.Descriptor.Valuer(reflect.ValueOf(val)).Value()
if err != nil {
return nil, err
}

field, err := valueToField(val)
if err != nil {
return nil, err
}
fields[col] = field
}
return &thunderpb.SQLFilter{Table: table, Fields: fields}, nil
return &thunderpb.SQLFilter{Table: tableName, Fields: fields}, nil
}

func filterFromProto(proto *thunderpb.SQLFilter) (string, sqlgen.Filter, error) {
// filterFromProto takes a thunderpb.SQLFilter, runs Scanner on each field value, and returns a sqlgen.Filter.
func filterFromProto(schema *sqlgen.Schema, proto *thunderpb.SQLFilter) (string, sqlgen.Filter, error) {
table, ok := schema.ByName[proto.Table]
if !ok {
return "", nil, fmt.Errorf("unknown table: %s", proto.Table)
}

scanners := table.Scanners.Get().([]interface{})
defer table.Scanners.Put(scanners)

filter := make(sqlgen.Filter, len(proto.Fields))
for col, field := range proto.Fields {
val, err := fieldToValue(field)
if err != nil {
return "", nil, err
}
filter[col] = val

column, ok := table.ColumnsByName[col]
if !ok {
return "", nil, fmt.Errorf("unknown column %s", col)
}

if !column.Descriptor.Ptr && val == nil {
return "", nil, errors.New("cannot unmarshal nil into non-pointer type")
}

scanner := scanners[column.Order].(*fields.Scanner)

// target is always a pointer.
var target, ptrptr reflect.Value
if column.Descriptor.Ptr {
// We need to hold onto this pointer-pointer in order to make the value addressable.
ptrptr = reflect.New(reflect.PtrTo(column.Descriptor.Type))
target = ptrptr.Elem()
} else {
target = reflect.New(column.Descriptor.Type)
}
scanner.Target(target)

if err := scanner.Scan(val); err != nil {
return "", nil, err
}

if column.Descriptor.Ptr {
filter[col] = target.Interface()
} else {
// Dereference pointer if column type is not a pointer.
filter[col] = target.Elem().Interface()
}
}
return proto.Table, filter, nil
}
107 changes: 107 additions & 0 deletions livesql/marshal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package livesql

import (
"testing"

"github.com/samsarahq/thunder/internal/testfixtures"
"github.com/samsarahq/thunder/sqlgen"
"github.com/stretchr/testify/assert"
)

func TestMarshal(t *testing.T) {
type user struct {
Id int64 `sql:",primary"`
Name *string
Uuid testfixtures.CustomType
Mood *testfixtures.CustomType
}

schema := sqlgen.NewSchema()
schema.MustRegisterType("users", sqlgen.AutoIncrement, user{})

one := int64(1)
foo := "foo"

cases := []struct {
name string
filter sqlgen.Filter
unmarshaled sqlgen.Filter
err bool
}{
{
name: "nil",
filter: nil,
unmarshaled: sqlgen.Filter{},
},
{
name: "empty",
filter: sqlgen.Filter{},
unmarshaled: sqlgen.Filter{},
},
{
name: "uuid",
filter: sqlgen.Filter{"uuid": testfixtures.CustomTypeFromString("foo")},
unmarshaled: sqlgen.Filter{"uuid": testfixtures.CustomTypeFromString("foo")},
},
{
name: "uuid from bytes",
filter: sqlgen.Filter{"uuid": []byte("foo")},
unmarshaled: sqlgen.Filter{"uuid": testfixtures.CustomTypeFromString("foo")},
},
{
name: "nil uuid",
filter: sqlgen.Filter{"mood": nil},
unmarshaled: sqlgen.Filter{"mood": (*testfixtures.CustomType)(nil)},
},
{
name: "id",
filter: sqlgen.Filter{"id": int64(1)},
unmarshaled: sqlgen.Filter{"id": int64(1)},
},
{
name: "id int32 to int64",
filter: sqlgen.Filter{"id": int32(1)},
unmarshaled: sqlgen.Filter{"id": int64(1)},
},
{
name: "id int64 ptr to int64",
filter: sqlgen.Filter{"id": &one},
unmarshaled: sqlgen.Filter{"id": int64(1)},
},
{
name: "string to string ptr",
filter: sqlgen.Filter{"name": "foo"},
unmarshaled: sqlgen.Filter{"name": &foo}},
{
name: "nil to string ptr",
filter: sqlgen.Filter{"name": nil},
unmarshaled: sqlgen.Filter{"name": (*string)(nil)},
},
{
name: "nil for int64",
filter: sqlgen.Filter{"id": nil},
err: true,
},
{
name: "string for int64",
filter: sqlgen.Filter{"id": ""},
err: true,
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
proto, err := filterToProto(schema, "users", c.filter)
assert.NoError(t, err)

table, filter, err := filterFromProto(schema, proto)
if c.err {
assert.NotNil(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, "users", table)
assert.Equal(t, c.unmarshaled, filter)
}
})
}
}

0 comments on commit 3ff10eb

Please sign in to comment.