Skip to content

Commit

Permalink
apache#10 histogram can be used to calculate EqualTo
Browse files Browse the repository at this point in the history
  • Loading branch information
GH-Gloway committed Oct 18, 2017
1 parent 4aa79d6 commit 35090fa
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ case class Statistics(
sizeInBytes: BigInt,
rowCount: Option[BigInt] = None,
attributeStats: AttributeMap[ColumnStat] = AttributeMap(Nil),
histograms: AttributeMap[Histogram] = AttributeMap(Nil),
hints: HintInfo = HintInfo()) {

override def toString: String = "Statistics(" + simpleString + ")"
Expand Down Expand Up @@ -279,3 +280,30 @@ object ColumnStat extends Logging {
}

}

case class Histogram(
bucket: List[Double],
distinctCount: List[Long],
height: Double
) {
val min = bucket(0)
val max = bucket.last

def getInterval(point: Double): (Double, Long, Int) = {
val size = bucket.size
var start = 0
var end = size - 1
var index = 0
while (start <= end) {
index = (start + end) / 2
if (bucket(index) > point) {
end = index - 1
} else {
start = index + 1
}
}
if (start > end) {
(bucket(end), distinctCount(end), end)
} else (bucket(start), distinctCount(start), start)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.collection.mutable
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand All @@ -34,6 +34,8 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging

private val colStatsMap = new ColumnStatsMap(childStats.attributeStats)

private val histogramMap = new HistogramMap(childStats.histograms)

/**
* Returns an option of Statistics for a Filter logical plan node.
* For a given compound expression condition, this method computes filter selectivity
Expand Down Expand Up @@ -307,37 +309,62 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
attr: Attribute,
literal: Literal,
update: Boolean): Option[BigDecimal] = {
if (!colStatsMap.contains(attr)) {
if (!colStatsMap.contains(attr) && !histogramMap.contains(attr)) {
logDebug("[CBO] No statistics for " + attr)
return None
}
val colStat = colStatsMap(attr)
val ndv = colStat.distinctCount

// decide if the value is in [min, max] of the column.
// We currently don't store min/max for binary/string type.
// Hence, we assume it is in boundary for binary/string type.
val statsRange = Range(colStat.min, colStat.max, attr.dataType)
if (statsRange.contains(literal)) {
if (update) {
// We update ColumnStat structure after apply this equality predicate:
// Set distinctCount to 1, nullCount to 0, and min/max values (if exist) to the literal
// value.
val newStats = attr.dataType match {
case StringType | BinaryType =>
colStat.copy(distinctCount = 1, nullCount = 0)
case _ =>
colStat.copy(distinctCount = 1, min = Some(literal.value),
max = Some(literal.value), nullCount = 0)
if(histogramMap.contains(attr)) {
val histogram: Histogram = histogramMap(attr)
val statsRange = Range(Option(histogram.min), Option(histogram.max), attr.dataType)
if (statsRange.contains(literal)) {
if (update) {
var avglen = 4
attr.dataType match {
case _: DecimalType | DoubleType | LongType => avglen = 8
case _: FloatType | IntegerType => avglen = 4
case _: BinaryType => avglen = 1
// TODO : need to add more type
}
val newStats = ColumnStat(distinctCount = 1, min = Some(literal.value),
max = Some(literal.value), nullCount = 0, avgLen = avglen, maxLen = avglen)
colStatsMap.update(attr, newStats)
}
colStatsMap.update(attr, newStats)
}

Some(1.0 / BigDecimal(ndv))
val (inter, distinctCount, index) = histogram.getInterval(
EstimationUtils.toDecimal(literal.value, literal.dataType).toDouble)
if(distinctCount == 0) {
Some(0.0)
} else {
Some(1.0 / (distinctCount * (histogram.bucket.size - 1)))
}
} else Some(0.0)
} else {
Some(0.0)
}
val colStat = colStatsMap(attr)
val ndv = colStat.distinctCount
// decide if the value is in [min, max] of the column.
// We currently don't store min/max for binary/string type.
// Hence, we assume it is in boundary for binary/string type.
val statsRange = Range(colStat.min, colStat.max, attr.dataType)
if (statsRange.contains(literal)) {
if (update) {
// We update ColumnStat structure after apply this equality predicate:
// Set distinctCount to 1, nullCount to 0, and min/max values (if exist) to the literal
// value.
val newStats = attr.dataType match {
case StringType | BinaryType =>
colStat.copy(distinctCount = 1, nullCount = 0)
case _ =>
colStat.copy(distinctCount = 1, min = Some(literal.value),
max = Some(literal.value), nullCount = 0)
}
colStatsMap.update(attr, newStats)
}

Some(1.0 / BigDecimal(ndv))
} else {
Some(0.0)
}
}
}

/**
Expand Down Expand Up @@ -793,3 +820,23 @@ case class ColumnStatsMap(originalMap: AttributeMap[ColumnStat]) {
AttributeMap(newColumnStats.toSeq)
}
}

case class HistogramMap(originalMap: AttributeMap[Histogram]) {

/** This map maintains the latest column stats. */
private val updatedMap: mutable.Map[ExprId, (Attribute, Histogram)] = mutable.HashMap.empty

def contains(a: Attribute): Boolean = updatedMap.contains(a.exprId) || originalMap.contains(a)

def apply(a: Attribute): Histogram = {
if (updatedMap.contains(a.exprId)) {
updatedMap(a.exprId)._2
} else {
originalMap(a)
}
}

/** Updates column stats in updatedMap. */
def update(a: Attribute, stats: Histogram): Unit = updatedMap.update(a.exprId, a -> stats)

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.sql.Date
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.LeftOuter
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Join, Statistics}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
Expand All @@ -39,6 +39,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val attrInt = AttributeReference("cint", IntegerType)()
val colStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4)
val histogramStatInt = Histogram(List(1.0, 3.8333333333333335,
7.166666666666667, 10.0), List(3, 4, 3, 0), 3.3333333333333335)

// column cbool has only 2 distinct values
val attrBool = AttributeReference("cbool", BooleanType)()
Expand Down Expand Up @@ -103,6 +105,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
attrInt4 -> colStatInt4
))

val histograms = AttributeMap(Seq(
attrInt -> histogramStatInt
))

test("true") {
validateEstimatedStats(
Filter(TrueLiteral, childStatsTestPlan(Seq(attrInt), 10L)),
Expand Down Expand Up @@ -186,6 +192,14 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
expectedRowCount = 1)
}

test("cint = 4") {
validateEstimatedStats(
Filter(EqualTo(attrInt, Literal(4)), childStatsTestPlan(Seq(attrInt), 10L)),
Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(4), max = Some(4),
nullCount = 0, avgLen = 4, maxLen = 4)),
expectedRowCount = 1)
}

test("cint <=> 2") {
validateEstimatedStats(
Filter(EqualNullSafe(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)),
Expand Down Expand Up @@ -234,6 +248,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
expectedRowCount = 5)
}


test("cint > 10") {
// This is a corner case since max value is 10.
validateEstimatedStats(
Expand Down Expand Up @@ -582,7 +597,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
StatsTestPlan(
outputList = outList,
rowCount = tableRowCount,
attributeStats = AttributeMap(outList.map(a => a -> attributeMap(a))))
attributeStats = AttributeMap(outList.map(a => a -> attributeMap(a))),
histograms = AttributeMap(outList.map(a => a -> histograms(a))))
}

private def validateEstimatedStats(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.statsEstimation

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, CBO_ENABLED}
import org.apache.spark.sql.types.{IntegerType, StringType}
Expand Down Expand Up @@ -53,11 +53,13 @@ case class StatsTestPlan(
outputList: Seq[Attribute],
rowCount: BigInt,
attributeStats: AttributeMap[ColumnStat],
histograms: AttributeMap[Histogram] = AttributeMap(Nil),
size: Option[BigInt] = None) extends LeafNode {
override def output: Seq[Attribute] = outputList
override def computeStats(conf: SQLConf): Statistics = Statistics(
// If sizeInBytes is useless in testing, we just use a fake value
sizeInBytes = size.getOrElse(Int.MaxValue),
rowCount = Some(rowCount),
attributeStats = attributeStats)
attributeStats = attributeStats,
histograms = histograms)
}

0 comments on commit 35090fa

Please sign in to comment.