Skip to content

Commit

Permalink
Merge pull request #2801 from alainjobart/proto
Browse files Browse the repository at this point in the history
Fixing more proto equalities.
  • Loading branch information
alainjobart committed Apr 27, 2017
2 parents a5372cb + 42072f2 commit 9389cf5
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 17 deletions.
32 changes: 32 additions & 0 deletions go/sqltypes/bind_variables.go
@@ -0,0 +1,32 @@
package sqltypes

import (
"reflect"

"github.com/golang/protobuf/proto"
)

// BindVariablesEqual compares two maps of bind variables.
// For protobuf messages we have to use "proto.Equal".
func BindVariablesEqual(x, y map[string]interface{}) bool {
if len(x) != len(y) {
return false
}
for k := range x {
vx, vy := x[k], y[k]
if reflect.TypeOf(vx) != reflect.TypeOf(vy) {
return false
}
switch vx.(type) {
case proto.Message:
if !proto.Equal(vx.(proto.Message), vy.(proto.Message)) {
return false
}
default:
if !reflect.DeepEqual(vx, vy) {
return false
}
}
}
return true
}
2 changes: 1 addition & 1 deletion go/vt/binlog/binlogplayertest/player.go
Expand Up @@ -179,7 +179,7 @@ func testStreamTables(t *testing.T, bpc binlogplayer.Client) {
if se, err := stream.Recv(); err != nil {
t.Fatalf("got error: %v", err)
} else {
if !reflect.DeepEqual(*se, *testBinlogTransaction) {
if !proto.Equal(se, testBinlogTransaction) {
t.Errorf("got wrong result, got %v expected %v", *se, *testBinlogTransaction)
}
}
Expand Down
10 changes: 10 additions & 0 deletions go/vt/mysqlctl/tmutils/schema.go
Expand Up @@ -11,6 +11,7 @@ import (
"regexp"
"strings"

"github.com/golang/protobuf/proto"
"github.com/youtube/vitess/go/vt/concurrency"

tabletmanagerdatapb "github.com/youtube/vitess/go/vt/proto/tabletmanagerdata"
Expand Down Expand Up @@ -279,3 +280,12 @@ type SchemaChange struct {
BeforeSchema *tabletmanagerdatapb.SchemaDefinition
AfterSchema *tabletmanagerdatapb.SchemaDefinition
}

// Equal compares two SchemaChange objects.
func (s *SchemaChange) Equal(s2 *SchemaChange) bool {
return s.SQL == s2.SQL &&
s.Force == s2.Force &&
s.AllowReplication == s2.AllowReplication &&
proto.Equal(s.BeforeSchema, s2.BeforeSchema) &&
proto.Equal(s.AfterSchema, s2.AfterSchema)
}
3 changes: 2 additions & 1 deletion go/vt/topo/replication.go
Expand Up @@ -6,6 +6,7 @@ package topo

import (
log "github.com/golang/glog"
"github.com/golang/protobuf/proto"
"golang.org/x/net/context"

"github.com/youtube/vitess/go/trace"
Expand Down Expand Up @@ -104,7 +105,7 @@ func RemoveShardReplicationRecord(ctx context.Context, ts Server, cell, keyspace
err := ts.UpdateShardReplicationFields(ctx, cell, keyspace, shard, func(sr *topodatapb.ShardReplication) error {
nodes := make([]*topodatapb.ShardReplication_Node, 0, len(sr.Nodes))
for _, node := range sr.Nodes {
if *node.TabletAlias != *tabletAlias {
if !proto.Equal(node.TabletAlias, tabletAlias) {
nodes = append(nodes, node)
}
}
Expand Down
22 changes: 16 additions & 6 deletions go/vt/vtgate/vtgateconntest/client.go
Expand Up @@ -76,7 +76,7 @@ type queryExecute struct {

func (q *queryExecute) equal(q2 *queryExecute) bool {
return q.SQL == q2.SQL &&
reflect.DeepEqual(q.BindVariables, q2.BindVariables) &&
sqltypes.BindVariablesEqual(q.BindVariables, q2.BindVariables) &&
q.Keyspace == q2.Keyspace &&
q.TabletType == q2.TabletType &&
proto.Equal(q.Session, q2.Session) &&
Expand Down Expand Up @@ -166,7 +166,7 @@ type queryExecuteShards struct {

func (q *queryExecuteShards) equal(q2 *queryExecuteShards) bool {
return q.SQL == q2.SQL &&
reflect.DeepEqual(q.BindVariables, q2.BindVariables) &&
sqltypes.BindVariablesEqual(q.BindVariables, q2.BindVariables) &&
q.Keyspace == q2.Keyspace &&
reflect.DeepEqual(q.Shards, q2.Shards) &&
q.TabletType == q2.TabletType &&
Expand Down Expand Up @@ -223,7 +223,7 @@ type queryExecuteKeyspaceIds struct {

func (q *queryExecuteKeyspaceIds) equal(q2 *queryExecuteKeyspaceIds) bool {
return q.SQL == q2.SQL &&
reflect.DeepEqual(q.BindVariables, q2.BindVariables) &&
sqltypes.BindVariablesEqual(q.BindVariables, q2.BindVariables) &&
q.Keyspace == q2.Keyspace &&
reflect.DeepEqual(q.KeyspaceIds, q2.KeyspaceIds) &&
q.TabletType == q2.TabletType &&
Expand Down Expand Up @@ -279,7 +279,7 @@ type queryExecuteKeyRanges struct {

func (q *queryExecuteKeyRanges) equal(q2 *queryExecuteKeyRanges) bool {
if q.SQL != q2.SQL ||
!reflect.DeepEqual(q.BindVariables, q2.BindVariables) ||
!sqltypes.BindVariablesEqual(q.BindVariables, q2.BindVariables) ||
q.Keyspace != q2.Keyspace ||
len(q.KeyRanges) != len(q2.KeyRanges) ||
q.TabletType != q2.TabletType ||
Expand Down Expand Up @@ -344,7 +344,7 @@ type queryExecuteEntityIds struct {

func (q *queryExecuteEntityIds) equal(q2 *queryExecuteEntityIds) bool {
if q.SQL != q2.SQL ||
!reflect.DeepEqual(q.BindVariables, q2.BindVariables) ||
!sqltypes.BindVariablesEqual(q.BindVariables, q2.BindVariables) ||
q.Keyspace != q2.Keyspace ||
q.EntityColumnName != q2.EntityColumnName ||
len(q.EntityKeyspaceIDs) != len(q2.EntityKeyspaceIDs) ||
Expand Down Expand Up @@ -824,6 +824,16 @@ type querySplitQuery struct {
Algorithm querypb.SplitQueryRequest_Algorithm
}

func (q *querySplitQuery) equal(q2 *querySplitQuery) bool {
return q.Keyspace == q2.Keyspace &&
q.SQL == q2.SQL &&
sqltypes.BindVariablesEqual(q.BindVariables, q2.BindVariables) &&
reflect.DeepEqual(q.SplitColumns, q2.SplitColumns) &&
q.SplitCount == q2.SplitCount &&
q.NumRowsPerQueryPart == q2.NumRowsPerQueryPart &&
q.Algorithm == q2.Algorithm
}

// SplitQuery is part of the VTGateService interface
func (f *fakeVTGateService) SplitQuery(
ctx context.Context,
Expand All @@ -850,7 +860,7 @@ func (f *fakeVTGateService) SplitQuery(
NumRowsPerQueryPart: numRowsPerQueryPart,
Algorithm: algorithm,
}
if !reflect.DeepEqual(query, splitQueryRequest) {
if !query.equal(splitQueryRequest) {
f.t.Errorf("SplitQuery has wrong input: got %#v wanted %#v", query, splitQueryRequest)
}
return splitQueryResult, nil
Expand Down
35 changes: 32 additions & 3 deletions go/vt/vttablet/agentrpctest/test_agent_rpc.go
Expand Up @@ -15,6 +15,7 @@ import (

"golang.org/x/net/context"

"github.com/golang/protobuf/proto"
"github.com/youtube/vitess/go/sqltypes"
"github.com/youtube/vitess/go/vt/hook"
"github.com/youtube/vitess/go/vt/logutil"
Expand Down Expand Up @@ -59,10 +60,36 @@ func NewFakeRPCAgent(t *testing.T) tabletmanager.RPCAgent {
// for each possible method of the interface.
// This makes the implementations all in the same spot.

var protoMessage = reflect.TypeOf((*proto.Message)(nil)).Elem()

func compare(t *testing.T, name string, got, want interface{}) {
if !reflect.DeepEqual(got, want) {
t.Errorf("Unexpected %v: got %v expected %v", name, got, want)
typ := reflect.TypeOf(got)
if reflect.TypeOf(got) != reflect.TypeOf(want) {
goto fail
}
switch {
case typ.Implements(protoMessage):
if !proto.Equal(got.(proto.Message), want.(proto.Message)) {
goto fail
}
case typ.Kind() == reflect.Slice && typ.Elem().Implements(protoMessage):
vx, vy := reflect.ValueOf(got), reflect.ValueOf(want)
if vx.Len() != vy.Len() {
goto fail
}
for i := 0; i < vx.Len(); i++ {
if !proto.Equal(vx.Index(i).Interface().(proto.Message), vy.Index(i).Interface().(proto.Message)) {
goto fail
}
}
default:
if !reflect.DeepEqual(got, want) {
goto fail
}
}
return
fail:
t.Errorf("Unexpected %v:\ngot %#v\nwant %#v", name, got, want)
}

func compareBool(t *testing.T, name string, got bool) {
Expand Down Expand Up @@ -527,7 +554,9 @@ func (fra *fakeRPCAgent) ApplySchema(ctx context.Context, change *tmutils.Schema
if fra.panics {
panic(fmt.Errorf("test-triggered panic"))
}
compare(fra.t, "ApplySchema change", change, testSchemaChange)
if !change.Equal(testSchemaChange) {
fra.t.Errorf("Unexpected ApplySchema change:\ngot %#v\nwant %#v", change, testSchemaChange)
}
return testSchemaChangeResult[0], nil
}

Expand Down
8 changes: 4 additions & 4 deletions go/vt/vttablet/tabletconntest/fakequeryservice.go
Expand Up @@ -376,7 +376,7 @@ func (f *FakeQueryService) Execute(ctx context.Context, target *querypb.Target,
if sql != ExecuteQuery {
f.t.Errorf("invalid Execute.Query.Sql: got %v expected %v", sql, ExecuteQuery)
}
if !reflect.DeepEqual(bindVariables, ExecuteBindVars) {
if !sqltypes.BindVariablesEqual(bindVariables, ExecuteBindVars) {
f.t.Errorf("invalid Execute.BindVariables: got %v expected %v", bindVariables, ExecuteBindVars)
}
if !proto.Equal(options, TestExecuteOptions) {
Expand Down Expand Up @@ -432,7 +432,7 @@ func (f *FakeQueryService) StreamExecute(ctx context.Context, target *querypb.Ta
if sql != StreamExecuteQuery {
f.t.Errorf("invalid StreamExecute.Sql: got %v expected %v", sql, StreamExecuteQuery)
}
if !reflect.DeepEqual(bindVariables, StreamExecuteBindVars) {
if !sqltypes.BindVariablesEqual(bindVariables, StreamExecuteBindVars) {
f.t.Errorf("invalid StreamExecute.BindVariables: got %v expected %v", bindVariables, StreamExecuteBindVars)
}
if !proto.Equal(options, TestExecuteOptions) {
Expand Down Expand Up @@ -538,7 +538,7 @@ func (f *FakeQueryService) ExecuteBatch(ctx context.Context, target *querypb.Tar
if f.Panics {
panic(fmt.Errorf("test-triggered panic"))
}
if !reflect.DeepEqual(queries, ExecuteBatchQueries) {
if !querytypes.BoundQueriesEqual(queries, ExecuteBatchQueries) {
f.t.Errorf("invalid ExecuteBatch.Queries: got %v expected %v", queries, ExecuteBatchQueries)
}
if !proto.Equal(options, TestExecuteOptions) {
Expand Down Expand Up @@ -684,7 +684,7 @@ func (f *FakeQueryService) SplitQuery(
panic(fmt.Errorf("test-triggered panic"))
}
f.checkTargetCallerID(ctx, "SplitQuery", target)
if !reflect.DeepEqual(query, SplitQueryBoundQuery) {
if !querytypes.BoundQueryEqual(&query, &SplitQueryBoundQuery) {
f.t.Errorf("invalid SplitQuery.SplitQueryRequest.Query: got %v expected %v",
querytypes.QueryAsString(query.Sql, query.BindVariables), SplitQueryBoundQuery)
}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vttablet/tabletconntest/tabletconntest.go
Expand Up @@ -8,7 +8,6 @@ package tabletconntest

import (
"io"
"reflect"
"strings"
"testing"
"time"
Expand All @@ -21,6 +20,7 @@ import (
"github.com/youtube/vitess/go/vt/vterrors"
"github.com/youtube/vitess/go/vt/vttablet/queryservice"
"github.com/youtube/vitess/go/vt/vttablet/tabletconn"
"github.com/youtube/vitess/go/vt/vttablet/tabletserver/querytypes"

querypb "github.com/youtube/vitess/go/vt/proto/query"
topodatapb "github.com/youtube/vitess/go/vt/proto/topodata"
Expand Down Expand Up @@ -728,7 +728,7 @@ func testSplitQuery(t *testing.T, conn queryservice.QueryService, f *FakeQuerySe
if err != nil {
t.Fatalf("SplitQuery failed: %v", err)
}
if !reflect.DeepEqual(qsl, SplitQueryQuerySplitList) {
if !querytypes.QuerySplitsEqual(qsl, SplitQueryQuerySplitList) {
t.Errorf("Unexpected result from SplitQuery: got %v wanted %v", qsl, SplitQueryQuerySplitList)
}
}
Expand Down
21 changes: 21 additions & 0 deletions go/vt/vttablet/tabletserver/querytypes/bound_query.go
Expand Up @@ -9,6 +9,8 @@ package querytypes
import (
"bytes"
"fmt"

"github.com/youtube/vitess/go/sqltypes"
)

// This file defines the BoundQuery type.
Expand Down Expand Up @@ -55,3 +57,22 @@ func slimit(s string, max int) string {
}
return s
}

// BoundQueriesEqual compares two slices of BoundQuery objects.
func BoundQueriesEqual(x, y []BoundQuery) bool {
if len(x) != len(y) {
return false
}
for i := range x {
if !BoundQueryEqual(&x[i], &y[i]) {
return false
}
}
return true
}

// BoundQueryEqual compares two BoundQuery objects.
func BoundQueryEqual(x, y *BoundQuery) bool {
return x.Sql == y.Sql &&
sqltypes.BindVariablesEqual(x.BindVariables, y.BindVariables)
}
22 changes: 22 additions & 0 deletions go/vt/vttablet/tabletserver/querytypes/query_split.go
Expand Up @@ -4,6 +4,8 @@

package querytypes

import "github.com/youtube/vitess/go/sqltypes"

// This file defines QuerySplit

// QuerySplit represents a split of a query, used for MapReduce purposes.
Expand All @@ -17,3 +19,23 @@ type QuerySplit struct {
// RowCount is the approximate number of rows this query will return
RowCount int64
}

// Equal compares two QuerySplit objects.
func (q *QuerySplit) Equal(q2 *QuerySplit) bool {
return q.Sql == q2.Sql &&
sqltypes.BindVariablesEqual(q.BindVariables, q2.BindVariables) &&
q.RowCount == q2.RowCount
}

// QuerySplitsEqual compares two slices of QuerySplit objects.
func QuerySplitsEqual(x, y []QuerySplit) bool {
if len(x) != len(y) {
return false
}
for i := range x {
if !x[i].Equal(&y[i]) {
return false
}
}
return true
}

0 comments on commit 9389cf5

Please sign in to comment.