diff --git a/engine_test.go b/engine_test.go index 1b0d39cdc..b612a9169 100644 --- a/engine_test.go +++ b/engine_test.go @@ -1139,59 +1139,35 @@ var queries = []struct { []sql.Row{{float64(1)}, {float64(2)}, {float64(3)}}, }, { - "SELECT i, i2, s2 FROM mytable LEFT JOIN othertable ON i = i2", + "SELECT i, i2, s2 FROM mytable LEFT JOIN othertable ON i = i2 - 1", []sql.Row{ - {int64(1), int64(1), "third"}, - {int64(1), nil, nil}, - {int64(1), nil, nil}, - {int64(2), int64(2), "second"}, - {int64(2), nil, nil}, - {int64(2), nil, nil}, - {int64(3), int64(3), "first"}, - {int64(3), nil, nil}, + {int64(1), int64(2), "second"}, + {int64(2), int64(3), "first"}, {int64(3), nil, nil}, }, }, { - "SELECT i, i2, s2 FROM mytable RIGHT JOIN othertable ON i = i2", + "SELECT i, i2, s2 FROM mytable RIGHT JOIN othertable ON i = i2 - 1", []sql.Row{ - {int64(1), int64(1), "third"}, - {nil, int64(1), "third"}, {nil, int64(1), "third"}, - {int64(2), int64(2), "second"}, - {nil, int64(2), "second"}, - {nil, int64(2), "second"}, - {int64(3), int64(3), "first"}, - {nil, int64(3), "first"}, - {nil, int64(3), "first"}, + {int64(1), int64(2), "second"}, + {int64(2), int64(3), "first"}, }, }, { - "SELECT i, i2, s2 FROM mytable LEFT OUTER JOIN othertable ON i = i2", + "SELECT i, i2, s2 FROM mytable LEFT OUTER JOIN othertable ON i = i2 - 1", []sql.Row{ - {int64(1), int64(1), "third"}, - {int64(1), nil, nil}, - {int64(1), nil, nil}, - {int64(2), int64(2), "second"}, - {int64(2), nil, nil}, - {int64(2), nil, nil}, - {int64(3), int64(3), "first"}, - {int64(3), nil, nil}, + {int64(1), int64(2), "second"}, + {int64(2), int64(3), "first"}, {int64(3), nil, nil}, }, }, { - "SELECT i, i2, s2 FROM mytable RIGHT OUTER JOIN othertable ON i = i2", + "SELECT i, i2, s2 FROM mytable RIGHT OUTER JOIN othertable ON i = i2 - 1", []sql.Row{ - {int64(1), int64(1), "third"}, - {nil, int64(1), "third"}, {nil, int64(1), "third"}, - {int64(2), int64(2), "second"}, - {nil, int64(2), "second"}, - {nil, int64(2), "second"}, - {int64(3), int64(3), "first"}, - {nil, int64(3), "first"}, - {nil, int64(3), "first"}, + {int64(1), int64(2), "second"}, + {int64(2), int64(3), "first"}, }, }, { diff --git a/sql/plan/cross_join_test.go b/sql/plan/cross_join_test.go index 6198d44b7..fbccc20c1 100644 --- a/sql/plan/cross_join_test.go +++ b/sql/plan/cross_join_test.go @@ -4,9 +4,9 @@ import ( "io" "testing" - "github.com/stretchr/testify/require" "github.com/src-d/go-mysql-server/mem" "github.com/src-d/go-mysql-server/sql" + "github.com/stretchr/testify/require" ) var lSchema = sql.Schema{ @@ -62,12 +62,12 @@ func TestCrossJoin(t *testing.T) { require.Equal("col1_1", row[0]) require.Equal("col2_1", row[1]) - require.Equal(int32(1111), row[2]) - require.Equal(int64(2222), row[3]) + require.Equal(int32(1), row[2]) + require.Equal(int64(2), row[3]) require.Equal("col1_1", row[4]) require.Equal("col2_1", row[5]) - require.Equal(int32(1111), row[6]) - require.Equal(int64(2222), row[7]) + require.Equal(int32(1), row[6]) + require.Equal(int64(2), row[7]) row, err = iter.Next() require.NoError(err) @@ -75,12 +75,12 @@ func TestCrossJoin(t *testing.T) { require.Equal("col1_1", row[0]) require.Equal("col2_1", row[1]) - require.Equal(int32(1111), row[2]) - require.Equal(int64(2222), row[3]) + require.Equal(int32(1), row[2]) + require.Equal(int64(2), row[3]) require.Equal("col1_2", row[4]) require.Equal("col2_2", row[5]) - require.Equal(int32(3333), row[6]) - require.Equal(int64(4444), row[7]) + require.Equal(int32(3), row[6]) + require.Equal(int64(4), row[7]) for i := 0; i < 2; i++ { row, err = iter.Next() @@ -139,8 +139,8 @@ func insertData(t *testing.T, table *mem.Table) { require := require.New(t) rows := []sql.Row{ - sql.NewRow("col1_1", "col2_1", int32(1111), int64(2222)), - sql.NewRow("col1_2", "col2_2", int32(3333), int64(4444)), + sql.NewRow("col1_1", "col2_1", int32(1), int64(2)), + sql.NewRow("col1_2", "col2_2", int32(3), int64(4)), } for _, r := range rows { diff --git a/sql/plan/join.go b/sql/plan/join.go index 62d9deae4..1b3475908 100644 --- a/sql/plan/join.go +++ b/sql/plan/join.go @@ -363,12 +363,6 @@ func joinRowIter( "right": rightName, }) - l, err := left.RowIter(ctx) - if err != nil { - span.Finish() - return nil, err - } - var inMemorySession bool _, val := ctx.Get(inMemoryJoinSessionVar) if val != nil { @@ -380,13 +374,34 @@ func joinRowIter( mode = memoryMode } + if typ == rightJoin { + r, err := right.RowIter(ctx) + if err != nil { + span.Finish() + return nil, err + } + return sql.NewSpanIter(span, &joinIter{ + typ: typ, + primary: r, + secondaryProvider: left, + ctx: ctx, + cond: cond, + mode: mode, + }), nil + } + + l, err := left.RowIter(ctx) + if err != nil { + span.Finish() + return nil, err + } return sql.NewSpanIter(span, &joinIter{ - typ: typ, - l: l, - rp: right, - ctx: ctx, - cond: cond, - mode: mode, + typ: typ, + primary: l, + secondaryProvider: right, + ctx: ctx, + cond: cond, + mode: mode, }), nil } @@ -411,46 +426,49 @@ const ( // joinIter is a generic iterator for all join types. type joinIter struct { - typ joinType - l sql.RowIter - rp rowIterProvider - r sql.RowIter - ctx *sql.Context - cond sql.Expression + typ joinType + primary sql.RowIter + secondaryProvider rowIterProvider + secondary sql.RowIter + ctx *sql.Context + cond sql.Expression - leftRow sql.Row + primaryRow sql.Row + foundMatch bool + rowSize int // used to compute in-memory - mode joinMode - right []sql.Row - pos int + mode joinMode + secondaryRows []sql.Row + pos int } -func (i *joinIter) loadLeft() error { - if i.leftRow == nil { - r, err := i.l.Next() +func (i *joinIter) loadPrimary() error { + if i.primaryRow == nil { + r, err := i.primary.Next() if err != nil { return err } - i.leftRow = r + i.primaryRow = r + i.foundMatch = false } return nil } -func (i *joinIter) loadRightInMemory() error { - iter, err := i.rp.RowIter(i.ctx) +func (i *joinIter) loadSecondaryInMemory() error { + iter, err := i.secondaryProvider.RowIter(i.ctx) if err != nil { return err } - i.right, err = sql.RowIterToRows(iter) + i.secondaryRows, err = sql.RowIterToRows(iter) if err != nil { return err } - if len(i.right) == 0 { + if len(i.secondaryRows) == 0 { return io.EOF } @@ -476,40 +494,40 @@ func (i *joinIter) fitsInMemory() bool { return (ms.HeapInuse + ms.StackInuse) < maxMemory } -func (i *joinIter) loadRight() (row sql.Row, skip bool, err error) { +func (i *joinIter) loadSecondary() (row sql.Row, err error) { if i.mode == memoryMode { - if len(i.right) == 0 { - if err = i.loadRightInMemory(); err != nil { - return nil, false, err + if len(i.secondaryRows) == 0 { + if err = i.loadSecondaryInMemory(); err != nil { + return nil, err } } - if i.pos >= len(i.right) { - i.leftRow = nil + if i.pos >= len(i.secondaryRows) { + i.primaryRow = nil i.pos = 0 - return nil, true, nil + return nil, io.EOF } - row := i.right[i.pos] + row := i.secondaryRows[i.pos] i.pos++ - return row, false, nil + return row, nil } - if i.r == nil { + if i.secondary == nil { var iter sql.RowIter - iter, err = i.rp.RowIter(i.ctx) + iter, err = i.secondaryProvider.RowIter(i.ctx) if err != nil { - return nil, false, err + return nil, err } - i.r = iter + i.secondary = iter } - rightRow, err := i.r.Next() + rightRow, err := i.secondary.Next() if err != nil { if err == io.EOF { - i.r = nil - i.leftRow = nil + i.secondary = nil + i.primaryRow = nil // If we got to this point and the mode is still unknown it means // the right side fits in memory, so the mode changes to memory @@ -518,99 +536,94 @@ func (i *joinIter) loadRight() (row sql.Row, skip bool, err error) { i.mode = memoryMode } - return nil, true, nil + return nil, io.EOF } - return nil, false, err + return nil, err } if i.mode == unknownMode { if !i.fitsInMemory() { - i.right = nil + i.secondaryRows = nil i.mode = multipassMode } else { - i.right = append(i.right, rightRow) + i.secondaryRows = append(i.secondaryRows, rightRow) } } - return rightRow, false, err + return rightRow, nil } func (i *joinIter) Next() (sql.Row, error) { for { - if err := i.loadLeft(); err != nil { + if err := i.loadPrimary(); err != nil { return nil, err } - rightRow, skip, err := i.loadRight() + primary := i.primaryRow + secondary, err := i.loadSecondary() if err != nil { + if err == io.EOF { + if !i.foundMatch && (i.typ == leftJoin || i.typ == rightJoin) { + return i.buildRow(primary, nil), nil + } + continue + } return nil, err } - if skip { - continue - } - - row, err := i.buildRow(i.leftRow, rightRow) + row := i.buildRow(primary, secondary) + v, err := i.cond.Eval(i.ctx, row) if err != nil { return nil, err } - if row == nil { + if v == false { continue } + i.foundMatch = true return row, nil } } -// buildRow builds the resulting row using the rows from the left and right -// branches depending on the join type. The resulting node may be nil, in -// which case, that row must be skipped. -func (i *joinIter) buildRow(left, right sql.Row) (sql.Row, error) { - var row = make(sql.Row, len(left)+len(right)) - copy(row, left) - copy(row[len(left):], right) - - v, err := i.cond.Eval(i.ctx, row) - if err != nil { - return nil, err +// buildRow builds the resulting row using the rows from the primary and +// secondary branches depending on the join type. +func (i *joinIter) buildRow(primary, secondary sql.Row) sql.Row { + var row sql.Row + if i.rowSize > 0 { + row = make(sql.Row, i.rowSize) + } else { + row = make(sql.Row, len(primary)+len(secondary)) + i.rowSize = len(row) } - if v == false { - switch i.typ { - case leftJoin: - for j := len(left); j < len(row); j++ { - row[j] = nil - } - return row, nil - case rightJoin: - for j := 0; j < len(left); j++ { - row[j] = nil - } - return row, nil - default: - return nil, nil - } + switch i.typ { + case rightJoin: + copy(row, secondary) + copy(row[i.rowSize-len(primary):], primary) + default: + copy(row, primary) + copy(row[len(primary):], secondary) } - return row, nil + return row } func (i *joinIter) Close() (err error) { - i.right = nil + i.secondary = nil - if i.l != nil { - if err = i.l.Close(); err != nil { - if i.r != nil { - _ = i.r.Close() + if i.primary != nil { + if err = i.primary.Close(); err != nil { + if i.secondary != nil { + _ = i.secondary.Close() } return err } } - if i.r != nil { - err = i.r.Close() + if i.secondary != nil { + err = i.secondary.Close() } return err diff --git a/sql/plan/join_test.go b/sql/plan/join_test.go index 5c26011ba..c79bc80d5 100644 --- a/sql/plan/join_test.go +++ b/sql/plan/join_test.go @@ -4,10 +4,10 @@ import ( "fmt" "testing" - "github.com/stretchr/testify/require" "github.com/src-d/go-mysql-server/mem" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" ) func TestJoinSchema(t *testing.T) { @@ -87,8 +87,8 @@ func testInnerJoin(t *testing.T, ctx *sql.Context) { require.Len(rows, 2) require.Equal([]sql.Row{ - {"col1_1", "col2_1", int32(1111), int64(2222), "col1_1", "col2_1", int32(1111), int64(2222)}, - {"col1_2", "col2_2", int32(3333), int64(4444), "col1_2", "col2_2", int32(3333), int64(4444)}, + {"col1_1", "col2_1", int32(1), int64(2), "col1_1", "col2_1", int32(1), int64(2)}, + {"col1_2", "col2_2", int32(3), int64(4), "col1_2", "col2_2", int32(3), int64(4)}, }, rows) } func TestInnerJoinEmpty(t *testing.T) { @@ -233,16 +233,17 @@ func TestLeftJoin(t *testing.T) { NewResolvedTable(ltable), NewResolvedTable(rtable), expression.NewEquals( - expression.NewGetField(0, sql.Text, "lcol1", false), - expression.NewGetField(4, sql.Text, "rcol1", false), + expression.NewPlus( + expression.NewGetField(2, sql.Text, "lcol3", false), + expression.NewLiteral(int32(2), sql.Int32), + ), + expression.NewGetField(6, sql.Text, "rcol3", false), )) rows := collectRows(t, j) require.ElementsMatch([]sql.Row{ - {"col1_1", "col2_1", int32(1111), int64(2222), "col1_1", "col2_1", int32(1111), int64(2222)}, - {"col1_1", "col2_1", int32(1111), int64(2222), nil, nil, nil, nil}, - {"col1_2", "col2_2", int32(3333), int64(4444), "col1_2", "col2_2", int32(3333), int64(4444)}, - {"col1_2", "col2_2", int32(3333), int64(4444), nil, nil, nil, nil}, + {"col1_1", "col2_1", int32(1), int64(2), "col1_2", "col2_2", int32(3), int64(4)}, + {"col1_2", "col2_2", int32(3), int64(4), nil, nil, nil, nil}, }, rows) } @@ -258,15 +259,16 @@ func TestRightJoin(t *testing.T) { NewResolvedTable(ltable), NewResolvedTable(rtable), expression.NewEquals( - expression.NewGetField(0, sql.Text, "lcol1", false), - expression.NewGetField(4, sql.Text, "rcol1", false), + expression.NewPlus( + expression.NewGetField(2, sql.Text, "lcol3", false), + expression.NewLiteral(int32(2), sql.Int32), + ), + expression.NewGetField(6, sql.Text, "rcol3", false), )) rows := collectRows(t, j) require.ElementsMatch([]sql.Row{ - {"col1_1", "col2_1", int32(1111), int64(2222), "col1_1", "col2_1", int32(1111), int64(2222)}, - {nil, nil, nil, nil, "col1_1", "col2_1", int32(1111), int64(2222)}, - {"col1_2", "col2_2", int32(3333), int64(4444), "col1_2", "col2_2", int32(3333), int64(4444)}, - {nil, nil, nil, nil, "col1_2", "col2_2", int32(3333), int64(4444)}, + {nil, nil, nil, nil, "col1_1", "col2_1", int32(1), int64(2)}, + {"col1_1", "col2_1", int32(1), int64(2), "col1_2", "col2_2", int32(3), int64(4)}, }, rows) }