Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix index push down #697

Merged
merged 8 commits into from May 8, 2019
Merged
Changes from 5 commits
Commits
File filter...
Filter file types
Jump to…
Jump to file or symbol
Failed to load files and symbols.

Always

Just for now

@@ -361,6 +361,7 @@
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
<version>3.0.0</version>
<executions>
<execution>
<phase>generate-sources</phase>
@@ -19,12 +19,12 @@ package com.pingcap.tispark.statistics

import com.google.common.primitives.UnsignedLong
import com.pingcap.tikv.expression.{ByItem, ColumnRef, ComparisonBinaryExpression, Constant}
import com.pingcap.tikv.key.{Key, TypedKey}
import com.pingcap.tikv.key.Key
import com.pingcap.tikv.meta.TiDAGRequest.PushDownType
import com.pingcap.tikv.meta.{TiDAGRequest, TiTableInfo, TiTimestamp}
import com.pingcap.tikv.meta.{TiColumnInfo, TiDAGRequest, TiIndexInfo, TiTableInfo, TiTimestamp}
import com.pingcap.tikv.row.Row
import com.pingcap.tikv.statistics._
import com.pingcap.tikv.types.{DataType, DataTypeFactory, MySQLType}
import com.pingcap.tikv.types.BytesType
import org.slf4j.LoggerFactory

import scala.collection.JavaConversions._
@@ -69,35 +69,45 @@ object StatisticsHelper {
neededColIds: mutable.ArrayBuffer[Long],
histTable: TiTableInfo): StatisticsDTO = {
if (row.fieldCount() < 6) return null
assert(row.getLong(0) == table.getId, s"table id not match ${row.getLong(0)}!=${table.getId}")

This comment has been minimized.

Copy link
@birdstorm

birdstorm May 7, 2019

Author Member

I will remove this line of assert after all tests are passed.

val isIndex = row.getLong(1) > 0
val histID = row.getLong(2)
val distinct = row.getLong(3)
val nullCount = row.getLong(4)
val histVer = row.getUnsignedLong(5)
val cMSketch = if (checkColExists(histTable, "cm_sketch")) row.getBytes(6) else null
// get index/col info for StatisticsDTO
val indexInfos = table.getIndices
.filter { _.getId == histID }
var indexInfos: mutable.Buffer[TiIndexInfo] = mutable.Buffer.empty[TiIndexInfo]

val colInfos = table.getColumns
.filter { _.getId == histID }
var colInfos: mutable.Buffer[TiColumnInfo] = mutable.Buffer.empty[TiColumnInfo]

var needed = true

// we should only query those columns that user specified before
if (!loadAll && !neededColIds.contains(histID)) needed = false

var indexFlag = 1
var dataType: DataType = DataTypeFactory.of(MySQLType.TypeBlob)
// Columns info found
if (!isIndex && colInfos.nonEmpty) {
indexFlag = 0
dataType = colInfos.head.getType
} else if (!isIndex || indexInfos.isEmpty) {
logger.warn(
s"Cannot find histogram id $histID in table info ${table.getName} now. It may be deleted."
)
needed = false
val (indexFlag, dataType) = if (isIndex) {
indexInfos = table.getIndices.filter { _.getId == histID }
if (indexInfos.isEmpty) {
logger.warn(
s"Cannot find index histogram id $histID in table info ${table.getName}[${table.getId}] now. It may be deleted."
)
needed = false
(1, null)
} else {
(1, BytesType.BLOB)
}
} else {
colInfos = table.getColumns.filter { _.getId == histID }
if (colInfos.isEmpty) {
logger.warn(
s"Cannot find column histogram id $histID in table info ${table.getName}[${table.getId}] now. It may be deleted."
)
needed = false
(0, null)
} else {
(0, colInfos.head.getType)
}
}

if (needed) {
@@ -158,8 +168,8 @@ object StatisticsHelper {
var lowerBound: Key = null
var upperBound: Key = null
// all bounds are stored as blob in bucketTable currently, decode using blob type
lowerBound = TypedKey.toTypedKey(row.getBytes(6), DataTypeFactory.of(MySQLType.TypeBlob))
upperBound = TypedKey.toTypedKey(row.getBytes(7), DataTypeFactory.of(MySQLType.TypeBlob))
lowerBound = Key.toRawKey(row.getBytes(6))
upperBound = Key.toRawKey(row.getBytes(7))
totalCount += count
buckets += new Bucket(totalCount, repeats, lowerBound, upperBound)
}
@@ -62,6 +62,7 @@ private[statistics] case class StatisticsResult(histId: Long,
*/
object StatisticsManager {

private var session: TiSession = _
private var snapshot: Snapshot = _
private var catalog: Catalog = _
private var dbPrefix: String = _
@@ -150,7 +151,7 @@ object StatisticsManager {
// load count, modify_count, version info
loadMetaToTblStats(tblId, tblStatistic)
val req = StatisticsHelper
.buildHistogramsRequest(histTable, tblId, snapshot.getTimestamp)
.buildHistogramsRequest(histTable, tblId, session.getTimestamp)

val rows = readDAGRequest(req)
if (rows.isEmpty) return
@@ -197,7 +198,7 @@ object StatisticsManager {

private def loadMetaToTblStats(tableId: Long, tableStatistics: TableStatistics): Unit = {
val req =
StatisticsHelper.buildMetaRequest(metaTable, tableId, snapshot.getTimestamp)
StatisticsHelper.buildMetaRequest(metaTable, tableId, session.getTimestamp)

val rows = readDAGRequest(req)
if (rows.isEmpty) return
@@ -211,7 +212,7 @@ object StatisticsManager {
private def statisticsResultFromStorage(tableId: Long,
requests: Seq[StatisticsDTO]): Seq[StatisticsResult] = {
val req =
StatisticsHelper.buildBucketRequest(bucketTable, tableId, snapshot.getTimestamp)
StatisticsHelper.buildBucketRequest(bucketTable, tableId, session.getTimestamp)

val rows = readDAGRequest(req)
if (rows.isEmpty) return Nil
@@ -256,6 +257,7 @@ object StatisticsManager {
protected var initialized: Boolean = false

protected def initialize(tiSession: TiSession): Unit = {
session = tiSession
snapshot = tiSession.createSnapshot()
catalog = tiSession.getCatalog
dbPrefix = tiSession.getConf.getDBPrefix
@@ -363,7 +363,7 @@ case class TiStrategy(getOrCreateTiContext: SparkSession => TiContext)(sparkSess
if (dagRequest.hasIndex) {
// add the first index column so that the plan will contain at least one column.
val idxColumn = dagRequest.getIndexInfo.getIndexColumns.get(0)
dagRequest.addRequiredColumn(ColumnRef.create(idxColumn.getName))
dagRequest.addRequiredColumn(ColumnRef.create(idxColumn.getName, source.table))
} else {
// add a random column so that the plan will contain at least one column.
// if the table contains a primary key then use the PK instead.
@@ -372,7 +372,7 @@ case class TiStrategy(getOrCreateTiContext: SparkSession => TiContext)(sparkSess
case e if e.isPrimaryKey => e
}
.getOrElse(source.table.getColumn(0))
dagRequest.addRequiredColumn(ColumnRef.create(column.getName))
dagRequest.addRequiredColumn(ColumnRef.create(column.getName, source.table))
}
}

@@ -254,6 +254,26 @@ class IssueTestSuite extends BaseTiSparkSuite {
judge("select cast(count(1) as char(20)) from `tmp_empty_tbl`")
}

test("test push down filters when using index double read") {
def explainTestAndCollect(sql: String): Unit = {
val df = spark.sql(sql)
df.explain
df.collect.foreach(println)
}
explainTestAndCollect(
"select id_dt, tp_int, tp_double, tp_varchar from full_data_type_table_idx limit 10"
)
explainTestAndCollect(
"select id_dt, tp_int, tp_double, tp_varchar from full_data_type_table_idx where tp_int > 200 limit 10"
)
explainTestAndCollect(
"select id_dt, tp_int, tp_double, tp_varchar from full_data_type_table_idx where tp_int > 200 order by tp_varchar limit 10"
)
explainTestAndCollect(
"select max(tp_double) from full_data_type_table_idx where tp_int > 200 group by tp_bigint limit 10"
)
}

override def afterAll(): Unit =
try {
tidbStmt.execute("drop table if exists t")
@@ -72,7 +72,9 @@ class PrefixIndexTestSuite extends BaseTiSparkSuite {
refreshConnections()

spark.sql("select * from t1").show
runTest("select * from t1 where name = '借款策略集_网页'", skipJDBC = true)
explainAndRunTest("select * from t1 where name = '中文字符集_测试'", skipJDBC = true)
explainAndRunTest("select * from t1 where name < '中文字符集_测试'", skipJDBC = true)
explainAndRunTest("select * from t1 where name > '中文字符集_测试'", skipJDBC = true)
}

test("index double scan with predicate") {
@@ -81,18 +83,20 @@ class PrefixIndexTestSuite extends BaseTiSparkSuite {
"create table test_index(id bigint(20), c1 text default null, c2 int, c3 int, c4 int, KEY idx_c1(c1(10)))"
)
tidbStmt.execute("insert into test_index values(1, 'aairy', 10, 20, 30)")
tidbStmt.execute("insert into test_index values(1, 'dairy', 10, 20, 30)")
tidbStmt.execute("insert into test_index values(1, 'zairy', 10, 20, 30)")
tidbStmt.execute("insert into test_index values(2, 'dairy', 20, 30, 40)")
tidbStmt.execute("insert into test_index values(3, 'zairy', 30, 40, 50)")
refreshConnections() // refresh since we need to load data again
judge("select c2 from test_index where c1 > 'dairy'")
explainAndRunTest("select c1, c2 from test_index where c1 < 'dairy' and c2 > 20")
explainAndRunTest("select c1, c2 from test_index where c1 = 'dairy'")
explainAndRunTest("select c1, c2 from test_index where c1 > 'dairy'")
judge("select c2 from test_index where c1 < 'dairy'")
judge("select c2 from test_index where c1 = 'dairy'")
judge("select c2, c2 from test_index where c1 > 'dairy'")
judge("select c2, c2 from test_index where c1 < 'dairy'")
judge("select c2, c2 from test_index where c1 = 'dairy'")
judge("select max(c2) from test_index where c1 > 'dairy'")
judge("select max(c2) from test_index where c1 < 'dairy'")
judge("select max(c2) from test_index where c1 = 'dairy'")
explainAndRunTest("select max(c2) from test_index where c1 > 'dairy'")
explainAndRunTest("select max(c2) from test_index where c1 < 'dairy'")
explainAndRunTest("select max(c2) from test_index where c1 = 'dairy'")
}

override def afterAll(): Unit =
@@ -46,7 +46,7 @@ class StatisticsManagerSuite extends BaseTiSparkSuite {
StatisticsManager.loadStatisticsInfo(fDataIdxTbl)
}

ignore("Test fixed table size estimation") {
test("Test fixed table size estimation") {
tidbStmt.execute("DROP TABLE IF EXISTS `tb_fixed_float`")
tidbStmt.execute("DROP TABLE IF EXISTS `tb_fixed_int`")
tidbStmt.execute("DROP TABLE IF EXISTS `tb_fixed_time`")
@@ -102,7 +102,7 @@ class StatisticsManagerSuite extends BaseTiSparkSuite {
assert(timeBytes >= 19 * 2)
}

ignore("select count(1) from full_data_type_table_idx where tp_int = 2006469139 or tp_int < 0") {
test("select count(1) from full_data_type_table_idx where tp_int = 2006469139 or tp_int < 0") {
val indexes = fDataIdxTbl.getIndices
val idx = indexes.filter(_.getIndexColumns.asScala.exists(_.matchName("tp_int"))).head

@@ -115,7 +115,7 @@ class StatisticsManagerSuite extends BaseTiSparkSuite {
testSelectRowCount(expressions, idx, 46)
}

ignore(
test(
"select tp_int from full_data_type_table_idx where tp_int < 5390653 and tp_int > -46759812"
) {
val indexes = fDataIdxTbl.getIndices
@@ -144,7 +144,7 @@ class StatisticsManagerSuite extends BaseTiSparkSuite {
assert(rc == expectedCount)
}

val indexSelectionCases = Map(
private val indexSelectionCases = Map(
// double read case
"select tp_bigint, tp_real from full_data_type_table_idx where tp_int = 2333" -> "idx_tp_int",
"select * from full_data_type_table_idx where id_dt = 2333" -> "",
@@ -157,7 +157,7 @@ class StatisticsManagerSuite extends BaseTiSparkSuite {
indexSelectionCases.foreach((t: (String, String)) => {
val query = t._1
val idxName = t._2
ignore(query) {
test(query) {
val executedPlan = spark.sql(query).queryExecution.executedPlan
val usedIdxName = {
if (isDoubleRead(executedPlan)) {
@@ -168,6 +168,8 @@ class StatisticsManagerSuite extends BaseTiSparkSuite {
extractUsedIndex(coprocessorRDD)
}
}
spark.sql(query).show()
println(usedIdxName, idxName)
assert(usedIdxName.equals(idxName))
}
})
@@ -262,7 +262,7 @@ object SharedSQLContext extends Logging {

import com.pingcap.tispark.TiConfigConst._
sparkConf.set(PD_ADDRESSES, getOrElse(prop, PD_ADDRESSES, "127.0.0.1:2379"))
sparkConf.set(ALLOW_INDEX_READ, getFlag(prop, ALLOW_INDEX_READ).toString)
sparkConf.set(ALLOW_INDEX_READ, getFlagOrTrue(prop, ALLOW_INDEX_READ).toString)
sparkConf.set(ENABLE_AUTO_LOAD_STATISTICS, "true")
sparkConf.set("spark.sql.decimalOperations.allowPrecisionLoss", "false")
sparkConf.set(REQUEST_ISOLATION_LEVEL, SNAPSHOT_ISOLATION_LEVEL)
@@ -39,14 +39,14 @@ object Utils {
}
}

def getFlag(prop: Properties, key: String): Boolean = {
val jvmProp = System.getProperty(key)
if (jvmProp != null) {
jvmProp.equalsIgnoreCase("true")
} else {
Option(prop.getProperty(key)).getOrElse("false").equalsIgnoreCase("true")
}
}
private def getFlag(prop: Properties, key: String, defValue: String): Boolean =
getOrElse(prop, key, defValue).equalsIgnoreCase("true")

def getFlagOrFalse(prop: Properties, key: String): Boolean =
getFlag(prop, key, "false")

def getFlagOrTrue(prop: Properties, key: String): Boolean =
getFlag(prop, key, "true")

def getOrElse(prop: Properties, key: String, defValue: String): String = {
val jvmProp = System.getProperty(key)
@@ -83,21 +83,10 @@ public ByteString get(ByteString key) {
* @return a Iterator that contains all result from this select request.
*/
public Iterator<Row> tableRead(TiDAGRequest dagRequest) {
if (dagRequest.hasIndex()) {
Iterator<Long> iter =
getHandleIterator(
dagRequest,
RangeSplitter.newSplitter(session.getRegionManager())
.splitRangeByRegion(dagRequest.getRanges()),
session);
return new IndexScanIterator(this, dagRequest, iter);
} else {
return getRowIterator(
dagRequest,
RangeSplitter.newSplitter(session.getRegionManager())
.splitRangeByRegion(dagRequest.getRanges()),
session);
}
return tableRead(
dagRequest,
RangeSplitter.newSplitter(session.getRegionManager())
.splitRangeByRegion(dagRequest.getRanges()));
}

/**
@@ -237,7 +237,8 @@ protected Expr visit(ColumnRef node, Object context) {
Map<ColumnRef, Integer> colIdOffsetMap = (Map<ColumnRef, Integer>) context;
position =
requireNonNull(
colIdOffsetMap.get(node), "Required column position info is not in a valid context.");
colIdOffsetMap.get(node),
"Required column position info " + node.getName() + " is not in a valid context.");
}
Expr.Builder builder = Expr.newBuilder();
builder.setTp(ExprType.ColumnRef);
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.