Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: fix naaj panic caused by wrong field types check #42482

Merged
merged 4 commits into from Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 30 additions & 15 deletions executor/hash_table.go
Expand Up @@ -153,7 +153,7 @@ func (c *hashRowContainer) GetMatchedRows(probeKey uint64, probeRow chunk.Row, h
}

func (c *hashRowContainer) GetAllMatchedRows(probeHCtx *hashContext, probeSideRow chunk.Row,
probeKeyNullBits *bitmap.ConcurrentBitmap, matched []chunk.Row, needCheckBuildRowPos, needCheckProbeRowPos []int) ([]chunk.Row, error) {
probeKeyNullBits *bitmap.ConcurrentBitmap, matched []chunk.Row, needCheckBuildColPos, needCheckProbeColPos []int, needCheckBuildTypes, needCheckProbeTypes []*types.FieldType) ([]chunk.Row, error) {
// for NAAJ probe row with null, we should match them with all build rows.
var (
ok bool
Expand All @@ -180,16 +180,20 @@ func (c *hashRowContainer) GetAllMatchedRows(probeHCtx *hashContext, probeSideRo
// else like
// (null, 1, 2), we should use the not-null probe bit to filter rows. Only fetch rows like
// ( ? , 1, 2), that exactly with value as 1 and 2 in the second and third join key column.
needCheckProbeRowPos = needCheckProbeRowPos[:0]
needCheckBuildRowPos = needCheckBuildRowPos[:0]
needCheckProbeColPos = needCheckProbeColPos[:0]
needCheckBuildColPos = needCheckBuildColPos[:0]
needCheckBuildTypes = needCheckBuildTypes[:0]
needCheckProbeTypes = needCheckProbeTypes[:0]
keyColLen := len(c.hCtx.naKeyColIdx)
for i := 0; i < keyColLen; i++ {
// since all bucket is from hash table (Not Null), so the buildSideNullBits check is eliminated.
if probeKeyNullBits.UnsafeIsSet(i) {
continue
}
needCheckBuildRowPos = append(needCheckBuildRowPos, c.hCtx.naKeyColIdx[i])
needCheckProbeRowPos = append(needCheckProbeRowPos, probeHCtx.naKeyColIdx[i])
needCheckBuildColPos = append(needCheckBuildColPos, c.hCtx.naKeyColIdx[i])
needCheckBuildTypes = append(needCheckBuildTypes, c.hCtx.allTypes[i])
needCheckProbeColPos = append(needCheckProbeColPos, probeHCtx.naKeyColIdx[i])
needCheckProbeTypes = append(needCheckProbeTypes, probeHCtx.allTypes[i])
}
}
var mayMatchedRow chunk.Row
Expand All @@ -200,7 +204,7 @@ func (c *hashRowContainer) GetAllMatchedRows(probeHCtx *hashContext, probeSideRo
}
if probeKeyNullBits != nil && len(probeHCtx.naKeyColIdx) > 1 {
// check the idxs-th value of the join columns.
ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, c.hCtx.allTypes, needCheckBuildRowPos, probeSideRow, probeHCtx.allTypes, needCheckProbeRowPos)
ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, needCheckBuildTypes, needCheckBuildColPos, probeSideRow, needCheckProbeTypes, needCheckProbeColPos)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -287,7 +291,7 @@ func (c *hashRowContainer) GetMatchedRowsAndPtrs(probeKey uint64, probeRow chunk
}

func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRow chunk.Row,
probeKeyNullBits *bitmap.ConcurrentBitmap, matched []chunk.Row, needCheckBuildRowPos, needCheckProbeRowPos []int) ([]chunk.Row, error) {
probeKeyNullBits *bitmap.ConcurrentBitmap, matched []chunk.Row, needCheckBuildColPos, needCheckProbeColPos []int, needCheckBuildTypes, needCheckProbeTypes []*types.FieldType) ([]chunk.Row, error) {
var (
ok bool
err error
Expand All @@ -306,8 +310,10 @@ func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRo
// case2: left side (probe side) don't have null
// left side key <1, 2>, actually we should fetch <1,null>, <null, 2>, <null, null> from the null bucket because
// case like <3,null> is obviously not matched with the probe key.
needCheckProbeRowPos = needCheckProbeRowPos[:0]
needCheckBuildRowPos = needCheckBuildRowPos[:0]
needCheckProbeColPos = needCheckProbeColPos[:0]
needCheckBuildColPos = needCheckBuildColPos[:0]
needCheckBuildTypes = needCheckBuildTypes[:0]
needCheckProbeTypes = needCheckProbeTypes[:0]
keyColLen := len(c.hCtx.naKeyColIdx)
if probeKeyNullBits != nil {
// when the probeKeyNullBits is not nil, it means the probe key has null values, where we should distinguish
Expand All @@ -325,11 +331,13 @@ func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRo
if probeKeyNullBits.UnsafeIsSet(i) || nullEntry.nullBitMap.UnsafeIsSet(i) {
continue
}
needCheckBuildRowPos = append(needCheckBuildRowPos, c.hCtx.naKeyColIdx[i])
needCheckProbeRowPos = append(needCheckProbeRowPos, probeHCtx.naKeyColIdx[i])
needCheckBuildColPos = append(needCheckBuildColPos, c.hCtx.naKeyColIdx[i])
needCheckBuildTypes = append(needCheckBuildTypes, c.hCtx.allTypes[i])
needCheckProbeColPos = append(needCheckProbeColPos, probeHCtx.naKeyColIdx[i])
needCheckProbeTypes = append(needCheckProbeTypes, probeHCtx.allTypes[i])
}
// check the idxs-th value of the join columns.
ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, c.hCtx.allTypes, needCheckBuildRowPos, probeSideRow, probeHCtx.allTypes, needCheckProbeRowPos)
ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, needCheckBuildTypes, needCheckBuildColPos, probeSideRow, needCheckProbeTypes, needCheckProbeColPos)
if err != nil {
return nil, err
}
Expand All @@ -346,11 +354,13 @@ func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRo
if nullEntry.nullBitMap.UnsafeIsSet(i) {
continue
}
needCheckBuildRowPos = append(needCheckBuildRowPos, c.hCtx.naKeyColIdx[i])
needCheckProbeRowPos = append(needCheckProbeRowPos, probeHCtx.naKeyColIdx[i])
needCheckBuildColPos = append(needCheckBuildColPos, c.hCtx.naKeyColIdx[i])
needCheckBuildTypes = append(needCheckBuildTypes, c.hCtx.allTypes[i])
needCheckProbeColPos = append(needCheckProbeColPos, probeHCtx.naKeyColIdx[i])
needCheckProbeTypes = append(needCheckProbeTypes, probeHCtx.allTypes[i])
}
// check the idxs-th value of the join columns.
ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, c.hCtx.allTypes, needCheckBuildRowPos, probeSideRow, probeHCtx.allTypes, needCheckProbeRowPos)
ok, err = codec.EqualChunkRow(c.sc, mayMatchedRow, needCheckBuildTypes, needCheckBuildColPos, probeSideRow, needCheckProbeTypes, needCheckProbeColPos)
if err != nil {
return nil, err
}
Expand All @@ -366,6 +376,11 @@ func (c *hashRowContainer) GetNullBucketRows(probeHCtx *hashContext, probeSideRo

// matchJoinKey checks if join keys of buildRow and probeRow are logically equal.
func (c *hashRowContainer) matchJoinKey(buildRow, probeRow chunk.Row, probeHCtx *hashContext) (ok bool, err error) {
if len(c.hCtx.naKeyColIdx) > 0 {
return codec.EqualChunkRow(c.sc,
buildRow, c.hCtx.allTypes, c.hCtx.naKeyColIdx,
probeRow, probeHCtx.allTypes, probeHCtx.naKeyColIdx)
}
return codec.EqualChunkRow(c.sc,
buildRow, c.hCtx.allTypes, c.hCtx.keyColIdx,
probeRow, probeHCtx.allTypes, probeHCtx.keyColIdx)
Expand Down
24 changes: 14 additions & 10 deletions executor/join.go
Expand Up @@ -98,8 +98,10 @@ type probeWorker struct {
rowIters *chunk.Iterator4Slice
rowContainerForProbe *hashRowContainer
// for every naaj probe worker, pre-allocate the int slice for store the join column index to check.
needCheckBuildRowPos []int
needCheckProbeRowPos []int
needCheckBuildColPos []int
needCheckProbeColPos []int
needCheckBuildTypes []*types.FieldType
needCheckProbeTypes []*types.FieldType
probeChkResourceCh chan *probeChkResource
joinChkResourceCh chan *chunk.Chunk
probeResultCh chan *chunk.Chunk
Expand Down Expand Up @@ -177,8 +179,10 @@ func (e *HashJoinExec) Close() error {
for _, w := range e.probeWorkers {
w.buildSideRows = nil
w.buildSideRowPtrs = nil
w.needCheckBuildRowPos = nil
w.needCheckProbeRowPos = nil
w.needCheckBuildColPos = nil
w.needCheckProbeColPos = nil
w.needCheckBuildTypes = nil
w.needCheckProbeTypes = nil
w.joinChkResourceCh = nil
}

Expand Down Expand Up @@ -605,7 +609,7 @@ func (w *probeWorker) joinNAALOSJMatchProbeSideRow2Chunk(probeKey uint64, probeK
}
}
// step2: match the null bucket secondly.
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos)
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes)
buildSideRows = w.buildSideRows
if err != nil {
joinResult.err = err
Expand Down Expand Up @@ -650,7 +654,7 @@ func (w *probeWorker) joinNAALOSJMatchProbeSideRow2Chunk(probeKey uint64, probeK
// case1: <?, null> NOT IN (empty set): ----------------------> result is <rhs, 1>.
// case2: <?, null> NOT IN (at least a valid inner row) ------------------> result is <rhs, null>.
// Step1: match null bucket (assumption that null bucket is quite smaller than all hash table bucket rows)
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos)
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes)
buildSideRows := w.buildSideRows
if err != nil {
joinResult.err = err
Expand Down Expand Up @@ -680,7 +684,7 @@ func (w *probeWorker) joinNAALOSJMatchProbeSideRow2Chunk(probeKey uint64, probeK
}
}
// Step2: match all hash table bucket build rows (use probeKeyNullBits to filter if any).
w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos)
w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes)
buildSideRows = w.buildSideRows
if err != nil {
joinResult.err = err
Expand Down Expand Up @@ -729,7 +733,7 @@ func (w *probeWorker) joinNAASJMatchProbeSideRow2Chunk(probeKey uint64, probeKey
if probeKeyNullBits == nil {
// step1: match null bucket first.
// need fetch the "valid" rows every time. (nullBits map check is necessary)
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos)
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes)
buildSideRows := w.buildSideRows
if err != nil {
joinResult.err = err
Expand Down Expand Up @@ -804,7 +808,7 @@ func (w *probeWorker) joinNAASJMatchProbeSideRow2Chunk(probeKey uint64, probeKey
// case1: <?, null> NOT IN (empty set): ----------------------> accept rhs row.
// case2: <?, null> NOT IN (at least a valid inner row) ------------------> unknown result, refuse rhs row.
// Step1: match null bucket (assumption that null bucket is quite smaller than all hash table bucket rows)
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos)
w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes)
buildSideRows := w.buildSideRows
if err != nil {
joinResult.err = err
Expand Down Expand Up @@ -834,7 +838,7 @@ func (w *probeWorker) joinNAASJMatchProbeSideRow2Chunk(probeKey uint64, probeKey
}
}
// Step2: match all hash table bucket build rows.
w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildRowPos, w.needCheckProbeRowPos)
w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes)
buildSideRows = w.buildSideRows
if err != nil {
joinResult.err = err
Expand Down
15 changes: 15 additions & 0 deletions executor/join_test.go
Expand Up @@ -83,3 +83,18 @@ func TestUsingAndNaturalJoinSchema(t *testing.T) {
tk.MustQuery(tt).Sort().Check(testkit.Rows(output[i].Res...))
}
}

func TestTiDBNAAJ(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("set @@session.tidb_enable_null_aware_anti_join=0;")
tk.MustExec("create table t(a decimal(40,0), b bigint(20) not null);")
tk.MustExec("insert into t values(7,8),(7,8),(3,4),(3,4),(9,2),(9,2),(2,0),(2,0),(0,4),(0,4),(8,8),(8,8),(6,1),(6,1),(NULL, 0),(NULL,0);")
tk.MustQuery("select ( table1 . a , table1 . b ) NOT IN ( SELECT 3 , 2 UNION SELECT 9, 2 ) AS field2 from t as table1 order by field2;").Check(testkit.Rows(
"0", "0", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1"))
tk.MustExec("set @@session.tidb_enable_null_aware_anti_join=1;")
tk.MustQuery("select ( table1 . a , table1 . b ) NOT IN ( SELECT 3 , 2 UNION SELECT 9, 2 ) AS field2 from t as table1 order by field2;").Check(testkit.Rows(
"0", "0", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1"))
}