Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[release-16.0] fix concurrency on stream execute engine primitives (#14586) #14590

Merged
merged 2 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ require (
github.com/nsf/jsondiff v0.0.0-20210926074059-1e845ec5d249
github.com/spyzhov/ajson v0.8.0
golang.org/x/exp v0.0.0-20230131160201-f062dba9d201
golang.org/x/sync v0.3.0
golang.org/x/tools/cmd/cover v0.1.0-deprecated
modernc.org/sqlite v1.20.3
)
Expand Down Expand Up @@ -195,7 +196,6 @@ require (
go4.org/intern v0.0.0-20220617035311-6925f38cc365 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20220617031537-928513b29760 // indirect
golang.org/x/exp/typeparams v0.0.0-20230131160201-f062dba9d201 // indirect
golang.org/x/sync v0.3.0 // indirect
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20231016165738-49dd2c1f3d0b // indirect
Expand Down
174 changes: 174 additions & 0 deletions go/sqltypes/parse_rows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*
Copyright 2023 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package sqltypes

import (
"fmt"
"io"
"reflect"
"strconv"
"strings"
"text/scanner"

querypb "vitess.io/vitess/go/vt/proto/query"
)

// ParseRows parses the output generated by fmt.Sprintf("#v", rows), and reifies the original []sqltypes.Row
// NOTE: This is not meant for production use!
func ParseRows(input string) ([]Row, error) {
type state int
const (
stInvalid state = iota
stInit
stBeginRow
stInRow
stInValue0
stInValue1
stInValue2
)

var (
scan scanner.Scanner
result []Row
row Row
vtype int32
st = stInit
)

scan.Init(strings.NewReader(input))

for tok := scan.Scan(); tok != scanner.EOF; tok = scan.Scan() {
var next state

switch st {
case stInit:
if tok == '[' {
next = stBeginRow
}
case stBeginRow:
switch tok {
case '[':
next = stInRow
case ']':
return result, nil
}
case stInRow:
switch tok {
case ']':
result = append(result, row)
row = nil
next = stBeginRow
case scanner.Ident:
ident := scan.TokenText()

if ident == "NULL" {
row = append(row, NULL)
continue
}

var ok bool
vtype, ok = querypb.Type_value[ident]
if !ok {
return nil, fmt.Errorf("unknown SQL type %q at %s", ident, scan.Position)
}
next = stInValue0
}
case stInValue0:
if tok == '(' {
next = stInValue1
}
case stInValue1:
literal := scan.TokenText()
switch tok {
case scanner.String:
var err error
literal, err = strconv.Unquote(literal)
if err != nil {
return nil, fmt.Errorf("failed to parse literal string at %s: %w", scan.Position, err)
}
fallthrough
case scanner.Int, scanner.Float:
row = append(row, MakeTrusted(Type(vtype), []byte(literal)))
next = stInValue2
}
case stInValue2:
if tok == ')' {
next = stInRow
}
}
if next == stInvalid {
return nil, fmt.Errorf("unexpected token '%s' at %s", scan.TokenText(), scan.Position)
}
st = next
}
return nil, io.ErrUnexpectedEOF
}

type RowMismatchError struct {
err error
want, got []Row
}

func (e *RowMismatchError) Error() string {
return fmt.Sprintf("results differ: %v\n\twant: %v\n\tgot: %v", e.err, e.want, e.got)
}

func RowsEquals(want, got []Row) error {
if len(want) != len(got) {
return &RowMismatchError{
err: fmt.Errorf("expected %d rows in result, got %d", len(want), len(got)),
want: want,
got: got,
}
}

var matched = make([]bool, len(want))
for _, aa := range want {
var ok bool
for i, bb := range got {
if matched[i] {
continue
}
if reflect.DeepEqual(aa, bb) {
matched[i] = true
ok = true
break
}
}
if !ok {
return &RowMismatchError{
err: fmt.Errorf("row %v is missing from result", aa),
want: want,
got: got,
}
}
}
for _, m := range matched {
if !m {
return fmt.Errorf("not all elements matched")
}
}
return nil
}

func RowsEqualsStr(wantStr string, got []Row) error {
want, err := ParseRows(wantStr)
if err != nil {
return fmt.Errorf("malformed row assertion: %w", err)
}
return RowsEquals(want, got)
}
6 changes: 5 additions & 1 deletion go/vt/vtgate/engine/distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package engine
import (
"context"
"fmt"
"sync"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -197,13 +198,16 @@ func (d *Distinct) TryExecute(ctx context.Context, vcursor VCursor, bindVars map

// TryStreamExecute implements the Primitive interface
func (d *Distinct) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
pt := newProbeTable(d.CheckCols)
var mu sync.Mutex

pt := newProbeTable(d.CheckCols)
err := vcursor.StreamExecutePrimitive(ctx, d.Source, bindVars, wantfields, func(input *sqltypes.Result) error {
result := &sqltypes.Result{
Fields: input.Fields,
InsertID: input.InsertID,
}
mu.Lock()
defer mu.Unlock()
for _, row := range input.Rows {
exists, err := pt.exists(row)
if err != nil {
Expand Down
53 changes: 53 additions & 0 deletions go/vt/vtgate/engine/distinct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,59 @@ func TestDistinct(t *testing.T) {
}
}

func TestDistinctStreamAsync(t *testing.T) {
distinct := &Distinct{
Source: &fakePrimitive{
results: sqltypes.MakeTestStreamingResults(sqltypes.MakeTestFields("myid|id|num|name", "varchar|int64|int64|varchar"),
"a|1|1|a",
"a|1|1|a",
"a|1|1|a",
"a|1|1|a",
"---",
"c|1|1|a",
"a|1|1|a",
"z|1|1|a",
"a|1|1|t",
"a|1|1|a",
"a|1|1|a",
"a|1|1|a",
"---",
"c|1|1|a",
"a|1|1|a",
"---",
"c|1|1|a",
"a|1|1|a",
"a|1|1|a",
"c|1|1|a",
"a|1|1|a",
"a|1|1|a",
"---",
"c|1|1|a",
"a|1|1|a",
),
async: true,
},
CheckCols: []CheckCol{
{Col: 0, Collation: collations.CollationUtf8mb4ID},
{Col: 1, Collation: collations.CollationBinaryID},
{Col: 2, Collation: collations.CollationBinaryID},
{Col: 3, Collation: collations.CollationUtf8mb4ID},
},
}

qr := &sqltypes.Result{}
err := distinct.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(result *sqltypes.Result) error {
qr.Rows = append(qr.Rows, result.Rows...)
return nil
})
require.NoError(t, err)
require.NoError(t, sqltypes.RowsEqualsStr(`
[[VARCHAR("c") INT64(1) INT64(1) VARCHAR("a")]
[VARCHAR("a") INT64(1) INT64(1) VARCHAR("a")]
[VARCHAR("z") INT64(1) INT64(1) VARCHAR("a")]
[VARCHAR("a") INT64(1) INT64(1) VARCHAR("t")]]`, qr.Rows))
}

func TestWeightStringFallBack(t *testing.T) {
offsetOne := 1
checkCols := []CheckCol{{
Expand Down
51 changes: 49 additions & 2 deletions go/vt/vtgate/engine/fake_primitive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import (
"strings"
"testing"

"vitess.io/vitess/go/sqltypes"
"golang.org/x/sync/errgroup"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
)

Expand All @@ -41,6 +42,8 @@ type fakePrimitive struct {
log []string

allResultsInOneCall bool

async bool
}

func (f *fakePrimitive) Inputs() []Primitive {
Expand Down Expand Up @@ -86,6 +89,13 @@ func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, b
return f.sendErr
}

if f.async {
return f.asyncCall(callback)
}
return f.syncCall(wantfields, callback)
}

func (f *fakePrimitive) syncCall(wantfields bool, callback func(*sqltypes.Result) error) error {
readMoreResults := true
for readMoreResults && f.curResult < len(f.results) {
readMoreResults = f.allResultsInOneCall
Expand Down Expand Up @@ -116,9 +126,46 @@ func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, b
}
}
}

return nil
}

func (f *fakePrimitive) asyncCall(callback func(*sqltypes.Result) error) error {
var g errgroup.Group
var fields []*querypb.Field
if len(f.results) > 0 {
fields = f.results[0].Fields
}
for _, res := range f.results {
qr := res
g.Go(func() error {
if qr == nil {
return f.sendErr
}
if err := callback(&sqltypes.Result{Fields: fields}); err != nil {
return err
}
result := &sqltypes.Result{}
for i := 0; i < len(qr.Rows); i++ {
result.Rows = append(result.Rows, qr.Rows[i])
// Send only two rows at a time.
if i%2 == 1 {
if err := callback(result); err != nil {
return err
}
result = &sqltypes.Result{}
}
}
if len(result.Rows) != 0 {
if err := callback(result); err != nil {
return err
}
}
return nil
})
}
return g.Wait()
}

func (f *fakePrimitive) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
f.log = append(f.log, fmt.Sprintf("GetFields %v", printBindVars(bindVars)))
return f.TryExecute(ctx, vcursor, bindVars, true /* wantfields */)
Expand Down
7 changes: 6 additions & 1 deletion go/vt/vtgate/engine/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package engine

import (
"context"
"sync"

"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/evalengine"
Expand Down Expand Up @@ -79,10 +80,14 @@ func (f *Filter) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[s

// TryStreamExecute satisfies the Primitive interface.
func (f *Filter) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
var mu sync.Mutex

env := evalengine.EnvWithBindVars(bindVars, vcursor.ConnCollation())
filter := func(results *sqltypes.Result) error {
var rows [][]sqltypes.Value
env.Fields = results.Fields

mu.Lock()
defer mu.Unlock()
for _, row := range results.Rows {
env.Row = row
evalResult, err := env.Evaluate(f.Predicate)
Expand Down