Skip to content

Commit

Permalink
Fix the TableScan Bug while partition serde differs
Browse files Browse the repository at this point in the history
  • Loading branch information
chenghao-intel committed Jul 21, 2014
1 parent 40a24a7 commit 27540ba
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 114 deletions.
105 changes: 74 additions & 31 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,25 @@ import org.apache.hadoop.hive.ql.exec.Utilities
import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable}
import org.apache.hadoop.hive.ql.plan.TableDesc
import org.apache.hadoop.hive.serde2.Deserializer
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector

import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf}

import org.apache.spark.SerializableWritable
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD}

import org.apache.spark.sql.catalyst.expressions.{Attribute, Row, GenericMutableRow, Literal, Cast}
import org.apache.spark.sql.catalyst.types.DataType

/**
* A trait for subclasses that handle table scans.
*/
private[hive] sealed trait TableReader {
def makeRDDForTable(hiveTable: HiveTable): RDD[_]
def makeRDDForTable(hiveTable: HiveTable): RDD[Row]

def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[_]
def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[Row]
}


Expand All @@ -46,7 +51,8 @@ private[hive] sealed trait TableReader {
* data warehouse directory.
*/
private[hive]
class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveContext)
class HadoopTableReader(@transient attributes: Seq[Attribute],
@transient relation: MetastoreRelation, @transient sc: HiveContext)
extends TableReader {

// Choose the minimum number of splits. If mapred.map.tasks is set, then use that unless
Expand All @@ -63,10 +69,10 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon

def hiveConf = _broadcastedHiveConf.value.value

override def makeRDDForTable(hiveTable: HiveTable): RDD[_] =
override def makeRDDForTable(hiveTable: HiveTable): RDD[Row] =
makeRDDForTable(
hiveTable,
_tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]],
relation.tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]],
filterOpt = None)

/**
Expand All @@ -81,14 +87,14 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon
def makeRDDForTable(
hiveTable: HiveTable,
deserializerClass: Class[_ <: Deserializer],
filterOpt: Option[PathFilter]): RDD[_] = {
filterOpt: Option[PathFilter]): RDD[Row] = {

assert(!hiveTable.isPartitioned, """makeRDDForTable() cannot be called on a partitioned table,
since input formats may differ across partitions. Use makeRDDForTablePartitions() instead.""")

// Create local references to member variables, so that the entire `this` object won't be
// serialized in the closure below.
val tableDesc = _tableDesc
val tableDesc = relation.tableDesc
val broadcastedHiveConf = _broadcastedHiveConf

val tablePath = hiveTable.getPath
Expand All @@ -99,23 +105,20 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon
.asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]]
val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc)

val attrsWithIndex = attributes.zipWithIndex
val mutableRow = new GenericMutableRow(attrsWithIndex.length)
val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter =>
val hconf = broadcastedHiveConf.value.value
val deserializer = deserializerClass.newInstance()
deserializer.initialize(hconf, tableDesc.getProperties)

// Deserialize each Writable to get the row value.
iter.map {
case v: Writable => deserializer.deserialize(v)
case value =>
sys.error(s"Unable to deserialize non-Writable: $value of ${value.getClass.getName}")
}
HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow)
}

deserializedHadoopRDD
}

override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[_] = {
override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[Row] = {
val partitionToDeserializer = partitions.map(part =>
(part, part.getDeserializer.getClass.asInstanceOf[Class[Deserializer]])).toMap
makeRDDForPartitionedTable(partitionToDeserializer, filterOpt = None)
Expand All @@ -133,8 +136,7 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon
*/
def makeRDDForPartitionedTable(
partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]],
filterOpt: Option[PathFilter]): RDD[_] = {

filterOpt: Option[PathFilter]): RDD[Row] = {
val hivePartitionRDDs = partitionToDeserializer.map { case (partition, partDeserializer) =>
val partDesc = Utilities.getPartitionDesc(partition)
val partPath = partition.getPartitionPath
Expand All @@ -156,33 +158,42 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon
}

// Create local references so that the outer object isn't serialized.
val tableDesc = _tableDesc
val tableDesc = relation.tableDesc
val broadcastedHiveConf = _broadcastedHiveConf
val localDeserializer = partDeserializer
val mutableRow = new GenericMutableRow(attributes.length)

// split the attributes (output schema) into 2 categories:
// (partition keys, ordinal), (normal attributes, ordinal), the ordinal mean the
// index of the attribute in the output Row.
val (partitionKeys, attrs) = attributes.zipWithIndex.partition(attr => {
relation.partitionKeys.indexOf(attr._1) >= 0
})

def fillPartitionKeys(parts: Array[String], row: GenericMutableRow) = {
partitionKeys.foreach { case (attr, ordinal) =>
// get partition key ordinal for a given attribute
val partOridinal = relation.partitionKeys.indexOf(attr)
row(ordinal) = Cast(Literal(parts(partOridinal)), attr.dataType).eval(null)
}
}
// fill the partition key for the given MutableRow Object
fillPartitionKeys(partValues, mutableRow)

val hivePartitionRDD = createHadoopRdd(tableDesc, inputPathStr, ifc)
hivePartitionRDD.mapPartitions { iter =>
val hconf = broadcastedHiveConf.value.value
val rowWithPartArr = new Array[Object](2)

// The update and deserializer initialization are intentionally
// kept out of the below iter.map loop to save performance.
rowWithPartArr.update(1, partValues)
val deserializer = localDeserializer.newInstance()
deserializer.initialize(hconf, partProps)

// Map each tuple to a row object
iter.map { value =>
val deserializedRow = deserializer.deserialize(value)
rowWithPartArr.update(0, deserializedRow)
rowWithPartArr.asInstanceOf[Object]
}
// fill the non partition key attributes
HadoopTableReader.fillObject(iter, deserializer, attrs, mutableRow)
}
}.toSeq

// Even if we don't use any partitions, we still need an empty RDD
if (hivePartitionRDDs.size == 0) {
new EmptyRDD[Object](sc.sparkContext)
new EmptyRDD[Row](sc.sparkContext)
} else {
new UnionRDD(hivePartitionRDDs(0).context, hivePartitionRDDs)
}
Expand Down Expand Up @@ -225,10 +236,9 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon
// Only take the value (skip the key) because Hive works only with values.
rdd.map(_._2)
}

}

private[hive] object HadoopTableReader {
private[hive] object HadoopTableReader extends HiveInspectors {
/**
* Curried. After given an argument for 'path', the resulting JobConf => Unit closure is used to
* instantiate a HadoopRDD.
Expand All @@ -241,4 +251,37 @@ private[hive] object HadoopTableReader {
val bufferSize = System.getProperty("spark.buffer.size", "65536")
jobConf.set("io.file.buffer.size", bufferSize)
}

/**
* Transform the raw data(Writable object) into the Row object for an iterable input
* @param iter Iterable input which represented as Writable object
* @param deserializer Deserializer associated with the input writable object
* @param attrs Represents the row attribute names and its zero-based position in the MutableRow
* @param row reusable MutableRow object
*
* @return Iterable Row object that transformed from the given iterable input.
*/
def fillObject(iter: Iterator[Writable], deserializer: Deserializer,
attrs: Seq[(Attribute, Int)], row: GenericMutableRow): Iterator[Row] = {
val soi = deserializer.getObjectInspector().asInstanceOf[StructObjectInspector]
// get the field references according to the attributes(output of the reader) required
val fieldRefs = attrs.map { case (attr, idx) => (soi.getStructFieldRef(attr.name), idx) }

// Map each tuple to a row object
iter.map { value =>
val raw = deserializer.deserialize(value)
var idx = 0;
while (idx < fieldRefs.length) {
val fieldRef = fieldRefs(idx)._1
val fieldIdx = fieldRefs(idx)._2
val fieldValue = soi.getStructFieldData(raw, fieldRef)

row(fieldIdx) = unwrapData(fieldValue, fieldRef.getFieldObjectInspector())

idx += 1
}

row: Row
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types.{BooleanType, DataType}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.hive._
import org.apache.spark.util.MutablePair

/**
* :: DeveloperApi ::
Expand All @@ -50,8 +49,7 @@ case class HiveTableScan(
relation: MetastoreRelation,
partitionPruningPred: Option[Expression])(
@transient val context: HiveContext)
extends LeafNode
with HiveInspectors {
extends LeafNode {

require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned,
"Partition pruning predicates only supported for partitioned tables.")
Expand All @@ -67,42 +65,7 @@ case class HiveTableScan(
}

@transient
private[this] val hadoopReader = new HadoopTableReader(relation.tableDesc, context)

/**
* The hive object inspector for this table, which can be used to extract values from the
* serialized row representation.
*/
@transient
private[this] lazy val objectInspector =
relation.tableDesc.getDeserializer.getObjectInspector.asInstanceOf[StructObjectInspector]

/**
* Functions that extract the requested attributes from the hive output. Partitioned values are
* casted from string to its declared data type.
*/
@transient
protected lazy val attributeFunctions: Seq[(Any, Array[String]) => Any] = {
attributes.map { a =>
val ordinal = relation.partitionKeys.indexOf(a)
if (ordinal >= 0) {
val dataType = relation.partitionKeys(ordinal).dataType
(_: Any, partitionKeys: Array[String]) => {
castFromString(partitionKeys(ordinal), dataType)
}
} else {
val ref = objectInspector.getAllStructFieldRefs
.find(_.getFieldName == a.name)
.getOrElse(sys.error(s"Can't find attribute $a"))
val fieldObjectInspector = ref.getFieldObjectInspector

(row: Any, _: Array[String]) => {
val data = objectInspector.getStructFieldData(row, ref)
unwrapData(data, fieldObjectInspector)
}
}
}
}
private[this] val hadoopReader = new HadoopTableReader(attributes, relation, context)

private[this] def castFromString(value: String, dataType: DataType) = {
Cast(Literal(value), dataType).eval(null)
Expand All @@ -114,6 +77,7 @@ case class HiveTableScan(
val columnInternalNames = neededColumnIDs.map(HiveConf.getColumnInternalName(_)).mkString(",")

if (attributes.size == relation.output.size) {
// TODO what if duplicated attributes queried?
ColumnProjectionUtils.setFullyReadColumns(hiveConf)
} else {
ColumnProjectionUtils.appendReadColumnIDs(hiveConf, neededColumnIDs)
Expand All @@ -140,12 +104,6 @@ case class HiveTableScan(

addColumnMetadataToConf(context.hiveconf)

private def inputRdd = if (!relation.hiveQlTable.isPartitioned) {
hadoopReader.makeRDDForTable(relation.hiveQlTable)
} else {
hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions))
}

/**
* Prunes partitions not involve the query plan.
*
Expand All @@ -169,44 +127,10 @@ case class HiveTableScan(
}
}

override def execute() = {
inputRdd.mapPartitions { iterator =>
if (iterator.isEmpty) {
Iterator.empty
} else {
val mutableRow = new GenericMutableRow(attributes.length)
val mutablePair = new MutablePair[Any, Array[String]]()
val buffered = iterator.buffered

// NOTE (lian): Critical path of Hive table scan, unnecessary FP style code and pattern
// matching are avoided intentionally.
val rowsAndPartitionKeys = buffered.head match {
// With partition keys
case _: Array[Any] =>
buffered.map { case array: Array[Any] =>
val deserializedRow = array(0)
val partitionKeys = array(1).asInstanceOf[Array[String]]
mutablePair.update(deserializedRow, partitionKeys)
}

// Without partition keys
case _ =>
val emptyPartitionKeys = Array.empty[String]
buffered.map { deserializedRow =>
mutablePair.update(deserializedRow, emptyPartitionKeys)
}
}

rowsAndPartitionKeys.map { pair =>
var i = 0
while (i < attributes.length) {
mutableRow(i) = attributeFunctions(i)(pair._1, pair._2)
i += 1
}
mutableRow: Row
}
}
}
override def execute() = if (!relation.hiveQlTable.isPartitioned) {
hadoopReader.makeRDDForTable(relation.hiveQlTable)
} else {
hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions))
}

override def output = attributes
Expand Down

0 comments on commit 27540ba

Please sign in to comment.