Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix connection leak in vtworker #4585

Merged
merged 3 commits into from
Feb 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 6 additions & 7 deletions go/vt/worker/result_merger.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ import (
"fmt"
"io"

"github.com/golang/protobuf/proto"
"golang.org/x/net/context"
"vitess.io/vitess/go/vt/vterrors"

"github.com/golang/protobuf/proto"

"vitess.io/vitess/go/sqltypes"

querypb "vitess.io/vitess/go/vt/proto/query"
Expand All @@ -42,8 +41,9 @@ const ResultSizeRows = 64
// The output stream will be sorted by ascending primary key order.
// It implements the ResultReader interface.
type ResultMerger struct {
inputs []ResultReader
fields []*querypb.Field
inputs []ResultReader
allInputs []ResultReader
fields []*querypb.Field
// output is the buffer of merged rows. Once it's full, we'll return it in
// Next() (wrapped in a sqltypes.Result).
output [][]sqltypes.Value
Expand Down Expand Up @@ -92,6 +92,7 @@ func NewResultMerger(inputs []ResultReader, pkFieldCount int) (*ResultMerger, er

rm := &ResultMerger{
inputs: activeInputs,
allInputs: inputs,
fields: fields,
nextRowHeap: nextRowHeap,
}
Expand Down Expand Up @@ -180,13 +181,11 @@ func (rm *ResultMerger) Next() (*sqltypes.Result, error) {

// Close closes all inputs
func (rm *ResultMerger) Close(ctx context.Context) {
for _, i := range rm.inputs {
for _, i := range rm.allInputs {
i.Close(ctx)
}
}



func (rm *ResultMerger) deleteInput(deleteMe ResultReader) {
for i, input := range rm.inputs {
if input == deleteMe {
Expand Down
86 changes: 49 additions & 37 deletions go/vt/worker/result_merger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ type fakeResultReader struct {
// currentIndex is the current index within the current range.
currentIndex int
rowsReturned int
closed bool
}

// newFakeResultReader returns a new FakeResultReader.
Expand Down Expand Up @@ -113,6 +114,7 @@ func (f *fakeResultReader) Fields() []*querypb.Field {

// Close closes nothing
func (f *fakeResultReader) Close(ctx context.Context) {
f.closed = true
}

// Next returns the next fake result. It is part of the ResultReader interface.
Expand Down Expand Up @@ -303,53 +305,63 @@ func TestResultMerger(t *testing.T) {
}

for _, tc := range testcases {
t.Logf("checking testcase: %v", tc.desc)
pkFieldCount := 1
if tc.multiPk {
pkFieldCount = 2
}
rm, err := NewResultMerger(tc.inputs, pkFieldCount)
if err != nil {
t.Fatal(err)
}

// Consume all merged Results.
var got []*sqltypes.Result
for {
result, err := rm.Next()
t.Run(fmt.Sprintf("checking testcase: %v", tc.desc), func(inner *testing.T) {
pkFieldCount := 1
if tc.multiPk {
pkFieldCount = 2
}
rm, err := NewResultMerger(tc.inputs, pkFieldCount)
if err != nil {
if err == io.EOF {
break
} else {
t.Fatal(err)
inner.Fatal(err)
}

// Consume all merged Results.
var got []*sqltypes.Result
for {
result, err := rm.Next()
if err != nil {
if err == io.EOF {
break
} else {
inner.Fatal(err)
}
}
got = append(got, result)
}
got = append(got, result)
}

if !reflect.DeepEqual(got, tc.want) {
for i := range got {
if i == len(tc.want) {
// got has more Results than want. Avoid index out of range errors.
break
rm.Close(context.Background())

if !reflect.DeepEqual(got, tc.want) {
for i := range got {
if i == len(tc.want) {
// got has more Results than want. Avoid index out of range errors.
break
}
if got[i].RowsAffected != tc.want[i].RowsAffected {
inner.Logf("deviating RowsAffected value for Result at index: %v got = %v, want = %v", i, got[i].RowsAffected, tc.want[i].RowsAffected)
}
inner.Logf("deviating Rows for Result at index: %v got = %v, want = %v", i, got[i].Rows, tc.want[i].Rows)
}
if got[i].RowsAffected != tc.want[i].RowsAffected {
t.Logf("deviating RowsAffected value for Result at index: %v got = %v, want = %v", i, got[i].RowsAffected, tc.want[i].RowsAffected)
if len(tc.want)-len(got) > 0 {
for i := len(got); i < len(tc.want); i++ {
inner.Logf("missing Result in got: %v", tc.want[i].Rows)
}
}
t.Logf("deviating Rows for Result at index: %v got = %v, want = %v", i, got[i].Rows, tc.want[i].Rows)
}
if len(tc.want)-len(got) > 0 {
for i := len(got); i < len(tc.want); i++ {
t.Logf("missing Result in got: %v", tc.want[i].Rows)
if len(got)-len(tc.want) > 0 {
for i := len(tc.want); i < len(got); i++ {
inner.Logf("unnecessary extra Result in got: %v", got[i].Rows)
}
}
inner.Fatalf("ResultMerger testcase '%v' failed. See output above for different rows.", tc.desc)
}
if len(got)-len(tc.want) > 0 {
for i := len(tc.want); i < len(got); i++ {
t.Logf("unnecessary extra Result in got: %v", got[i].Rows)

for _, x := range tc.inputs {
fake := x.(*fakeResultReader)
if !fake.closed {
inner.Fatal("expected inputs to be closed by now")
}
}
t.Fatalf("ResultMerger testcase '%v' failed. See output above for different rows.", tc.desc)
}
})
}
}

Expand Down
30 changes: 26 additions & 4 deletions go/vt/worker/split_clone.go
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,13 @@ func mergeOrSingle(readers []ResultReader, td *tabletmanagerdatapb.TableDefiniti

func (scw *SplitCloneWorker) getSourceResultReader(ctx context.Context, td *tabletmanagerdatapb.TableDefinition, state StatusWorkerState, chunk chunk, txID int64) (ResultReader, error) {
sourceReaders := make([]ResultReader, len(scw.sourceShards))
var readers []ResultReader
defer func() {
for _, i := range readers {
i.Close(ctx)
}
}()

for shardIndex, si := range scw.sourceShards {
var sourceResultReader ResultReader
var err error
Expand Down Expand Up @@ -941,15 +948,26 @@ func (scw *SplitCloneWorker) getSourceResultReader(ctx context.Context, td *tabl
if err != nil {
return nil, fmt.Errorf("NewRestartableResultReader for source: %v failed: %v", tp.description(), err)
}
readers = append(readers, sourceResultReader)
}
// TODO: We could end up in a situation where some readers have been created but not all. In this situation, we would not close up all readers
sourceReaders[shardIndex] = sourceResultReader
}
return mergeOrSingle(sourceReaders, td)
resultReader, err := mergeOrSingle(sourceReaders, td)
if err == nil {
readers = readers[:0]
}
return resultReader, err
}

func (scw *SplitCloneWorker) getDestinationResultReader(ctx context.Context, td *tabletmanagerdatapb.TableDefinition, state StatusWorkerState, chunk chunk) (ResultReader, error) {
destReaders := make([]ResultReader, len(scw.destinationShards))
var readers []ResultReader
defer func() {
for _, i := range readers {
i.Close(ctx)
}
}()

for shardIndex, si := range scw.destinationShards {
tp := newShardTabletProvider(scw.tsc, scw.tabletTracker, si.Keyspace(), si.ShardName(), topodatapb.TabletType_MASTER)
destResultReader, err := NewRestartableResultReader(ctx, scw.wr.Logger(), tp, td, chunk, true /* allowMultipleRetries */)
Expand All @@ -958,7 +976,11 @@ func (scw *SplitCloneWorker) getDestinationResultReader(ctx context.Context, td
}
destReaders[shardIndex] = destResultReader
}
return mergeOrSingle(destReaders, td)
resultReader, err := mergeOrSingle(destReaders, td)
if err == nil {
readers = readers[:0]
}
return resultReader, err
}

func (scw *SplitCloneWorker) cloneAChunk(ctx context.Context, td *tabletmanagerdatapb.TableDefinition, tableIndex int, chunk chunk, processError func(string, ...interface{}), state StatusWorkerState, tableStatusList *tableStatusList, keyResolver keyspaceIDResolver, start time.Time, insertChannels []chan string, txID int64, statsCounters []*stats.CountersWithSingleLabel) {
Expand Down Expand Up @@ -1354,4 +1376,4 @@ func (scw *SplitCloneWorker) closeThrottlers() {
t.Close()
delete(scw.throttlers, keyspaceAndShard)
}
}
}