Skip to content
This repository has been archived by the owner on Feb 15, 2023. It is now read-only.

serialize dependencies into protobuf #165

Merged
merged 7 commits into from
Sep 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion graphql/end_to_end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func TestEndToEndAwaitAndCache(t *testing.T) {

user := schema.Object("User", User{})
user.FieldFunc("slow", func(ctx context.Context, u *User) *Slow {
reactive.AddDependency(ctx, u.resource)
reactive.AddDependency(ctx, u.resource, nil)
time.Sleep(100 * time.Millisecond)
return new(Slow)
})
Expand Down
33 changes: 29 additions & 4 deletions livesql/live.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/samsarahq/thunder/internal"
"github.com/samsarahq/thunder/reactive"
"github.com/samsarahq/thunder/sqlgen"
"github.com/samsarahq/thunder/thunderpb"
)

// dbResource tracks changes to a specific table matching a filter
Expand Down Expand Up @@ -75,7 +76,7 @@ func (t *dbTracker) processBinlog(update *update) {
}
}

func (t *dbTracker) registerDependency(ctx context.Context, table string, tester sqlgen.Tester) {
func (t *dbTracker) registerDependency(ctx context.Context, schema *sqlgen.Schema, table string, tester sqlgen.Tester, filter sqlgen.Filter) error {
r := &dbResource{
table: table,
tester: tester,
Expand All @@ -84,9 +85,16 @@ func (t *dbTracker) registerDependency(ctx context.Context, table string, tester
r.resource.Cleanup(func() {
t.remove(r)
})
reactive.AddDependency(ctx, r.resource)

proto, err := filterToProto(schema, table, filter)
if err != nil {
return err
}

reactive.AddDependency(ctx, r.resource, proto)

t.add(r)
return nil
}

// LiveDB is a SQL client that supports live updating queries.
Expand Down Expand Up @@ -120,7 +128,6 @@ func (ldb *LiveDB) query(ctx context.Context, query *sqlgen.BaseSelectQuery) ([]
if !reactive.HasRerunner(ctx) || ldb.HasTx(ctx) {
return ldb.DB.BaseQuery(ctx, query)
}

selectQuery, err := query.MakeSelectQuery()
if err != nil {
return nil, err
Expand All @@ -141,7 +148,8 @@ func (ldb *LiveDB) query(ctx context.Context, query *sqlgen.BaseSelectQuery) ([]

// Register the dependency before we do the query to not miss any updates
// between querying and registering.
ldb.tracker.registerDependency(ctx, query.Table.Name, tester)
// Do not fail the query if this step fails.
_ = ldb.tracker.registerDependency(ctx, ldb.Schema, query.Table.Name, tester, query.Filter)

// Perform the query.
// XXX: This will build the SQL string again... :(
Expand Down Expand Up @@ -201,3 +209,20 @@ func (ldb *LiveDB) QueryRow(ctx context.Context, result interface{}, filter sqlg
func (ldb *LiveDB) Close() error {
return ldb.Conn.Close()
}

func (ldb *LiveDB) AddDependency(ctx context.Context, proto *thunderpb.SQLFilter) error {
table, filter, err := filterFromProto(ldb.Schema, proto)
if err != nil {
return err
}

tester, err := ldb.Schema.MakeTester(table, filter)
if err != nil {
return err
}

if err := ldb.tracker.registerDependency(ctx, ldb.Schema, table, tester, filter); err != nil {
return err
}
return nil
}
150 changes: 150 additions & 0 deletions livesql/marshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package livesql

import (
"database/sql/driver"
"errors"
"fmt"
"reflect"
"time"

"github.com/samsarahq/thunder/internal/fields"
"github.com/samsarahq/thunder/sqlgen"
"github.com/samsarahq/thunder/thunderpb"
)

// valueToField converts a driver.Value into a thunderpb.Field.
// driver.Value must be one of the following:
// int64
// float64
// bool
// []byte
// string
// time.Time
func valueToField(value driver.Value) (*thunderpb.Field, error) {
switch column := value.(type) {
case nil:
return &thunderpb.Field{Kind: thunderpb.FieldKind_Null}, nil
case int64:
return &thunderpb.Field{Kind: thunderpb.FieldKind_Int, Int: column}, nil
case float64:
return &thunderpb.Field{Kind: thunderpb.FieldKind_Float64, Float64: column}, nil
case bool:
return &thunderpb.Field{Kind: thunderpb.FieldKind_Bool, Bool: column}, nil
case []byte:
return &thunderpb.Field{Kind: thunderpb.FieldKind_Bytes, Bytes: column}, nil
case string:
return &thunderpb.Field{Kind: thunderpb.FieldKind_String, String_: column}, nil
case time.Time:
return &thunderpb.Field{Kind: thunderpb.FieldKind_Time, Time: column}, nil
default:
return nil, fmt.Errorf("unknown type %s", reflect.TypeOf(column))
}
}

func fieldToValue(field *thunderpb.Field) (driver.Value, error) {
switch field.Kind {
case thunderpb.FieldKind_Null:
return nil, nil
case thunderpb.FieldKind_Bool:
return field.Bool, nil
case thunderpb.FieldKind_Int:
return field.Int, nil
case thunderpb.FieldKind_Uint:
return field.Uint, nil
case thunderpb.FieldKind_String:
return field.String_, nil // field.String is a function.
case thunderpb.FieldKind_Bytes:
return field.Bytes, nil
case thunderpb.FieldKind_Float64:
return field.Float64, nil
case thunderpb.FieldKind_Time:
return field.Time, nil
default:
return nil, fmt.Errorf("unknown kind %s", field.Kind.String())
}
}

// filterToProto takes a sqlgen.Filter, runs Valuer on each filter value, and returns a thunderpb.SQLFilter.
func filterToProto(schema *sqlgen.Schema, tableName string, filter sqlgen.Filter) (*thunderpb.SQLFilter, error) {
table, ok := schema.ByName[tableName]
if !ok {
return nil, fmt.Errorf("unknown table: %s", tableName)
}

if filter == nil {
return &thunderpb.SQLFilter{Table: tableName}, nil
}

fields := make(map[string]*thunderpb.Field, len(filter))
for col, val := range filter {
column, ok := table.ColumnsByName[col]
if !ok {
return nil, fmt.Errorf("unknown column %s", col)
}

val, err := column.Descriptor.Valuer(reflect.ValueOf(val)).Value()
if err != nil {
return nil, err
}

field, err := valueToField(val)
if err != nil {
return nil, err
}
fields[col] = field
}
return &thunderpb.SQLFilter{Table: tableName, Fields: fields}, nil
}

// filterFromProto takes a thunderpb.SQLFilter, runs Scanner on each field value, and returns a sqlgen.Filter.
func filterFromProto(schema *sqlgen.Schema, proto *thunderpb.SQLFilter) (string, sqlgen.Filter, error) {
table, ok := schema.ByName[proto.Table]
if !ok {
return "", nil, fmt.Errorf("unknown table: %s", proto.Table)
}

scanners := table.Scanners.Get().([]interface{})
defer table.Scanners.Put(scanners)

filter := make(sqlgen.Filter, len(proto.Fields))
for col, field := range proto.Fields {
val, err := fieldToValue(field)
if err != nil {
return "", nil, err
}

column, ok := table.ColumnsByName[col]
if !ok {
return "", nil, fmt.Errorf("unknown column %s", col)
}

if !column.Descriptor.Ptr && val == nil {
return "", nil, errors.New("cannot unmarshal nil into non-pointer type")
}

scanner := scanners[column.Order].(*fields.Scanner)

// target is always a pointer.
var target, ptrptr reflect.Value
if column.Descriptor.Ptr {
// We need to hold onto this pointer-pointer in order to make the value addressable.
ptrptr = reflect.New(reflect.PtrTo(column.Descriptor.Type))
target = ptrptr.Elem()
} else {
target = reflect.New(column.Descriptor.Type)
}
scanner.Target(target)

if err := scanner.Scan(val); err != nil {
return "", nil, err
}

if column.Descriptor.Ptr {
filter[col] = target.Interface()
} else {
// Dereference pointer if column type is not a pointer.
filter[col] = target.Elem().Interface()
}
}
return proto.Table, filter, nil
}
107 changes: 107 additions & 0 deletions livesql/marshal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package livesql

import (
"testing"

"github.com/samsarahq/thunder/internal/testfixtures"
"github.com/samsarahq/thunder/sqlgen"
"github.com/stretchr/testify/assert"
)

func TestMarshal(t *testing.T) {
type user struct {
Id int64 `sql:",primary"`
Name *string
Uuid testfixtures.CustomType
Mood *testfixtures.CustomType
}

schema := sqlgen.NewSchema()
schema.MustRegisterType("users", sqlgen.AutoIncrement, user{})

one := int64(1)
foo := "foo"

cases := []struct {
name string
filter sqlgen.Filter
unmarshaled sqlgen.Filter
err bool
}{
{
name: "nil",
filter: nil,
unmarshaled: sqlgen.Filter{},
},
{
name: "empty",
filter: sqlgen.Filter{},
unmarshaled: sqlgen.Filter{},
},
{
name: "uuid",
filter: sqlgen.Filter{"uuid": testfixtures.CustomTypeFromString("foo")},
unmarshaled: sqlgen.Filter{"uuid": testfixtures.CustomTypeFromString("foo")},
},
{
name: "uuid from bytes",
filter: sqlgen.Filter{"uuid": []byte("foo")},
unmarshaled: sqlgen.Filter{"uuid": testfixtures.CustomTypeFromString("foo")},
},
{
name: "nil uuid",
filter: sqlgen.Filter{"mood": nil},
unmarshaled: sqlgen.Filter{"mood": (*testfixtures.CustomType)(nil)},
},
{
name: "id",
filter: sqlgen.Filter{"id": int64(1)},
unmarshaled: sqlgen.Filter{"id": int64(1)},
},
{
name: "id int32 to int64",
filter: sqlgen.Filter{"id": int32(1)},
unmarshaled: sqlgen.Filter{"id": int64(1)},
},
{
name: "id int64 ptr to int64",
filter: sqlgen.Filter{"id": &one},
unmarshaled: sqlgen.Filter{"id": int64(1)},
},
{
name: "string to string ptr",
filter: sqlgen.Filter{"name": "foo"},
unmarshaled: sqlgen.Filter{"name": &foo}},
{
name: "nil to string ptr",
filter: sqlgen.Filter{"name": nil},
unmarshaled: sqlgen.Filter{"name": (*string)(nil)},
},
{
name: "nil for int64",
filter: sqlgen.Filter{"id": nil},
err: true,
},
{
name: "string for int64",
filter: sqlgen.Filter{"id": ""},
err: true,
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
proto, err := filterToProto(schema, "users", c.filter)
assert.NoError(t, err)

table, filter, err := filterFromProto(schema, proto)
if c.err {
assert.NotNil(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, "users", table)
assert.Equal(t, c.unmarshaled, filter)
}
})
}
}
Loading