From 35090faf4ac38784899353105d4e72908b31623c Mon Sep 17 00:00:00 2001 From: GH-Gloway <30582875+GH-Gloway@users.noreply.github.com> Date: Wed, 18 Oct 2017 15:41:02 +0800 Subject: [PATCH] #10 histogram can be used to calculate EqualTo --- .../catalyst/plans/logical/Statistics.scala | 28 ++++++ .../statsEstimation/FilterEstimation.scala | 97 ++++++++++++++----- .../FilterEstimationSuite.scala | 20 +++- .../StatsEstimationTestBase.scala | 6 +- 4 files changed, 122 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index a64562b5dbd93..c39d8caf71a1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -52,6 +52,7 @@ case class Statistics( sizeInBytes: BigInt, rowCount: Option[BigInt] = None, attributeStats: AttributeMap[ColumnStat] = AttributeMap(Nil), + histograms: AttributeMap[Histogram] = AttributeMap(Nil), hints: HintInfo = HintInfo()) { override def toString: String = "Statistics(" + simpleString + ")" @@ -279,3 +280,30 @@ object ColumnStat extends Logging { } } + +case class Histogram( + bucket: List[Double], + distinctCount: List[Long], + height: Double + ) { + val min = bucket(0) + val max = bucket.last + + def getInterval(point: Double): (Double, Long, Int) = { + val size = bucket.size + var start = 0 + var end = size - 1 + var index = 0 + while (start <= end) { + index = (start + end) / 2 + if (bucket(index) > point) { + end = index - 1 + } else { + start = index + 1 + } + } + if (start > end) { + (bucket(end), distinctCount(end), end) + } else (bucket(start), distinctCount(start), start) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 7046b39d97e09..2a21246901fc1 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -34,6 +34,8 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging private val colStatsMap = new ColumnStatsMap(childStats.attributeStats) + private val histogramMap = new HistogramMap(childStats.histograms) + /** * Returns an option of Statistics for a Filter logical plan node. * For a given compound expression condition, this method computes filter selectivity @@ -307,37 +309,62 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging attr: Attribute, literal: Literal, update: Boolean): Option[BigDecimal] = { - if (!colStatsMap.contains(attr)) { + if (!colStatsMap.contains(attr) && !histogramMap.contains(attr)) { logDebug("[CBO] No statistics for " + attr) return None } - val colStat = colStatsMap(attr) - val ndv = colStat.distinctCount - // decide if the value is in [min, max] of the column. - // We currently don't store min/max for binary/string type. - // Hence, we assume it is in boundary for binary/string type. - val statsRange = Range(colStat.min, colStat.max, attr.dataType) - if (statsRange.contains(literal)) { - if (update) { - // We update ColumnStat structure after apply this equality predicate: - // Set distinctCount to 1, nullCount to 0, and min/max values (if exist) to the literal - // value. - val newStats = attr.dataType match { - case StringType | BinaryType => - colStat.copy(distinctCount = 1, nullCount = 0) - case _ => - colStat.copy(distinctCount = 1, min = Some(literal.value), - max = Some(literal.value), nullCount = 0) + if(histogramMap.contains(attr)) { + val histogram: Histogram = histogramMap(attr) + val statsRange = Range(Option(histogram.min), Option(histogram.max), attr.dataType) + if (statsRange.contains(literal)) { + if (update) { + var avglen = 4 + attr.dataType match { + case _: DecimalType | DoubleType | LongType => avglen = 8 + case _: FloatType | IntegerType => avglen = 4 + case _: BinaryType => avglen = 1 + // TODO : need to add more type + } + val newStats = ColumnStat(distinctCount = 1, min = Some(literal.value), + max = Some(literal.value), nullCount = 0, avgLen = avglen, maxLen = avglen) + colStatsMap.update(attr, newStats) } - colStatsMap.update(attr, newStats) - } - - Some(1.0 / BigDecimal(ndv)) + val (inter, distinctCount, index) = histogram.getInterval( + EstimationUtils.toDecimal(literal.value, literal.dataType).toDouble) + if(distinctCount == 0) { + Some(0.0) + } else { + Some(1.0 / (distinctCount * (histogram.bucket.size - 1))) + } + } else Some(0.0) } else { - Some(0.0) - } + val colStat = colStatsMap(attr) + val ndv = colStat.distinctCount + // decide if the value is in [min, max] of the column. + // We currently don't store min/max for binary/string type. + // Hence, we assume it is in boundary for binary/string type. + val statsRange = Range(colStat.min, colStat.max, attr.dataType) + if (statsRange.contains(literal)) { + if (update) { + // We update ColumnStat structure after apply this equality predicate: + // Set distinctCount to 1, nullCount to 0, and min/max values (if exist) to the literal + // value. + val newStats = attr.dataType match { + case StringType | BinaryType => + colStat.copy(distinctCount = 1, nullCount = 0) + case _ => + colStat.copy(distinctCount = 1, min = Some(literal.value), + max = Some(literal.value), nullCount = 0) + } + colStatsMap.update(attr, newStats) + } + Some(1.0 / BigDecimal(ndv)) + } else { + Some(0.0) + } + } } /** @@ -793,3 +820,23 @@ case class ColumnStatsMap(originalMap: AttributeMap[ColumnStat]) { AttributeMap(newColumnStats.toSeq) } } + +case class HistogramMap(originalMap: AttributeMap[Histogram]) { + + /** This map maintains the latest column stats. */ + private val updatedMap: mutable.Map[ExprId, (Attribute, Histogram)] = mutable.HashMap.empty + + def contains(a: Attribute): Boolean = updatedMap.contains(a.exprId) || originalMap.contains(a) + + def apply(a: Attribute): Histogram = { + if (updatedMap.contains(a.exprId)) { + updatedMap(a.exprId)._2 + } else { + originalMap(a) + } + } + + /** Updates column stats in updatedMap. */ + def update(a: Attribute, stats: Histogram): Unit = updatedMap.update(a.exprId, a -> stats) + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 2fa53a6466ef2..46bb1a2e11709 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -22,7 +22,7 @@ import java.sql.Date import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.LeftOuter -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Join, Statistics} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -39,6 +39,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val attrInt = AttributeReference("cint", IntegerType)() val colStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4) + val histogramStatInt = Histogram(List(1.0, 3.8333333333333335, + 7.166666666666667, 10.0), List(3, 4, 3, 0), 3.3333333333333335) // column cbool has only 2 distinct values val attrBool = AttributeReference("cbool", BooleanType)() @@ -103,6 +105,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { attrInt4 -> colStatInt4 )) + val histograms = AttributeMap(Seq( + attrInt -> histogramStatInt + )) + test("true") { validateEstimatedStats( Filter(TrueLiteral, childStatsTestPlan(Seq(attrInt), 10L)), @@ -186,6 +192,14 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 1) } + test("cint = 4") { + validateEstimatedStats( + Filter(EqualTo(attrInt, Literal(4)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(4), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 1) + } + test("cint <=> 2") { validateEstimatedStats( Filter(EqualNullSafe(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), @@ -234,6 +248,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 5) } + test("cint > 10") { // This is a corner case since max value is 10. validateEstimatedStats( @@ -582,7 +597,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { StatsTestPlan( outputList = outList, rowCount = tableRowCount, - attributeStats = AttributeMap(outList.map(a => a -> attributeMap(a)))) + attributeStats = AttributeMap(outList.map(a => a -> attributeMap(a))), + histograms = AttributeMap(outList.map(a => a -> histograms(a)))) } private def validateEstimatedStats( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index 263f4e18803d5..67a5d294f79c1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.statsEstimation import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, CBO_ENABLED} import org.apache.spark.sql.types.{IntegerType, StringType} @@ -53,11 +53,13 @@ case class StatsTestPlan( outputList: Seq[Attribute], rowCount: BigInt, attributeStats: AttributeMap[ColumnStat], + histograms: AttributeMap[Histogram] = AttributeMap(Nil), size: Option[BigInt] = None) extends LeafNode { override def output: Seq[Attribute] = outputList override def computeStats(conf: SQLConf): Statistics = Statistics( // If sizeInBytes is useless in testing, we just use a fake value sizeInBytes = size.getOrElse(Int.MaxValue), rowCount = Some(rowCount), - attributeStats = attributeStats) + attributeStats = attributeStats, + histograms = histograms) }