forked from cockroachdb/cockroach
/
stream_merger.go
187 lines (175 loc) · 5.4 KB
/
stream_merger.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
// Copyright 2016 The Cockroach Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
// implied. See the License for the specific language governing
// permissions and limitations under the License.
//
// Author: Irfan Sharif (irfansharif@cockroachlabs.com)
package distsqlrun
import (
"github.com/cockroachdb/cockroach/pkg/sql/parser"
"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
"github.com/cockroachdb/cockroach/pkg/util/encoding"
"github.com/pkg/errors"
)
// We define a group to be a set of rows from a given source with the same
// group key, in this case the set of ordered columns. streamMerger emits
// batches of rows that are the cross-product of matching groups from each
// stream.
type streamMerger struct {
left streamGroupAccumulator
right streamGroupAccumulator
datumAlloc sqlbase.DatumAlloc
outputBuffer [][2]sqlbase.EncDatumRow
}
// computeBatch adds the cross-product of the next matching set of groups
// from each of streams to the output buffer.
func (sm *streamMerger) computeBatch() error {
sm.outputBuffer = sm.outputBuffer[:0]
lrow, err := sm.left.peekAtCurrentGroup()
if err != nil {
return err
}
rrow, err := sm.right.peekAtCurrentGroup()
if err != nil {
return err
}
if lrow == nil && rrow == nil {
return nil
}
cmp, err := CompareEncDatumRowForMerge(lrow, rrow, sm.left.ordering, sm.right.ordering, &sm.datumAlloc)
if err != nil {
return err
}
if cmp != 0 {
// lrow < rrow or rrow == nil, accumulate set of rows "equal" to lrow
// and emit (lrow, nil) tuples.
src := &sm.left
if cmp > 0 {
src = &sm.right
}
group, err := src.advanceGroup()
if err != nil {
return err
}
for _, r := range group {
var outputRow [2]sqlbase.EncDatumRow
if cmp < 0 {
outputRow[0] = r
outputRow[1] = nil
} else {
outputRow[0] = nil
outputRow[1] = r
}
sm.outputBuffer = append(sm.outputBuffer, outputRow)
}
return nil
}
// We found matching groups; we'll output the cross-product.
leftGroup, err := sm.left.advanceGroup()
if err != nil {
return err
}
rightGroup, err := sm.right.advanceGroup()
if err != nil {
return err
}
// TODO(andrei): if groups are large and we have a limit, we might want to
// stream through the leftGroup instead of accumulating it all.
for _, l := range leftGroup {
for _, r := range rightGroup {
sm.outputBuffer = append(sm.outputBuffer, [2]sqlbase.EncDatumRow{l, r})
}
}
return nil
}
// CompareEncDatumRowForMerge EncDatumRow compares two EncDatumRows for merging.
// When merging two streams and preserving the order (as in a MergeSort or
// a MergeJoin) compare the head of the streams, emitting the one that sorts
// first. It allows for the EncDatumRow to be nil if one of the streams is
// exhausted (and hence nil). CompareEncDatumRowForMerge returns 0 when both
// rows are nil, and a nil row is considered greater than any non-nil row.
// CompareEncDatumRowForMerge assumes that the two rows have the same columns
// in the same orders, but can handle different ordering directions. It takes
// a DatumAlloc which is used for decoding if any underlying EncDatum is not
// yet decoded.
func CompareEncDatumRowForMerge(
lhs, rhs sqlbase.EncDatumRow,
leftOrdering, rightOrdering sqlbase.ColumnOrdering,
da *sqlbase.DatumAlloc,
) (int, error) {
if lhs == nil && rhs == nil {
return 0, nil
}
if lhs == nil {
return 1, nil
}
if rhs == nil {
return -1, nil
}
if len(leftOrdering) != len(rightOrdering) {
return 0, errors.Errorf(
"cannot compare two EncDatumRow types that have different length ColumnOrderings",
)
}
// TODO(radu): plumb EvalContext
evalCtx := &parser.EvalContext{}
for i, ord := range leftOrdering {
lIdx := ord.ColIdx
rIdx := rightOrdering[i].ColIdx
cmp, err := lhs[lIdx].Compare(da, evalCtx, &rhs[rIdx])
if err != nil {
return 0, err
}
if cmp != 0 {
if leftOrdering[i].Direction == encoding.Descending {
cmp = -cmp
}
return cmp, nil
}
}
return 0, nil
}
func (sm *streamMerger) NextBatch() ([][2]sqlbase.EncDatumRow, error) {
if err := sm.computeBatch(); err != nil {
return nil, err
}
return sm.outputBuffer, nil
}
// makeStreamMerger creates a streamMerger, joining rows from leftSource with
// rows from rightSource.
//
// All metadata from the sources is forwarded to metadataSink.
func makeStreamMerger(
leftSource RowSource,
leftOrdering sqlbase.ColumnOrdering,
rightSource RowSource,
rightOrdering sqlbase.ColumnOrdering,
metadataSink RowReceiver,
) (streamMerger, error) {
if len(leftOrdering) != len(rightOrdering) {
return streamMerger{}, errors.Errorf(
"ordering lengths don't match: %d and %d", len(leftOrdering), len(rightOrdering))
}
for i, ord := range leftOrdering {
if ord.Direction != rightOrdering[i].Direction {
return streamMerger{}, errors.New("Ordering mismatch")
}
}
return streamMerger{
left: makeStreamGroupAccumulator(
MakeNoMetadataRowSource(leftSource, metadataSink),
leftOrdering),
right: makeStreamGroupAccumulator(
MakeNoMetadataRowSource(rightSource, metadataSink),
rightOrdering),
}, nil
}