+ val OVERWRITE, INSERT, NEW_ENTRY = Value
+ }
+
+ /**
+ * Factory/helper method used in LOAD and INSERT INTO/OVERWRITE analysis. Sets all necessary
+ * fields in the SparkLoadWork returned.
+ */
+ def apply(
+ db: Hive,
+ conf: HiveConf,
+ hiveTable: HiveTable,
+ partSpecOpt: Option[JavaMap[String, String]],
+ isOverwrite: Boolean): SparkLoadWork = {
+ val commandType = if (isOverwrite) {
+ SparkLoadWork.CommandTypes.OVERWRITE
+ } else {
+ SparkLoadWork.CommandTypes.INSERT
+ }
+ val cacheMode = CacheType.fromString(hiveTable.getProperty("shark.cache"))
+ val sparkLoadWork = new SparkLoadWork(
+ hiveTable.getDbName,
+ hiveTable.getTableName,
+ commandType,
+ cacheMode)
+ partSpecOpt.foreach(sparkLoadWork.addPartSpec(_))
+ if (commandType == SparkLoadWork.CommandTypes.INSERT) {
+ if (hiveTable.isPartitioned) {
+ partSpecOpt.foreach { partSpec =>
+ // None if the partition being updated doesn't exist yet.
+ val partitionOpt = Option(db.getPartition(hiveTable, partSpec, false /* forceCreate */))
+ sparkLoadWork.pathFilterOpt = partitionOpt.map(part =>
+ Utils.createSnapshotFilter(part.getPartitionPath, conf))
+ }
+ } else {
+ sparkLoadWork.pathFilterOpt = Some(Utils.createSnapshotFilter(hiveTable.getPath, conf))
+ }
+ }
+ sparkLoadWork
+ }
+}
+
+/**
+ * A Hive task to load data from disk into the Shark cache. Handles INSERT INTO/OVERWRITE,
+ * LOAD INTO/OVERWRITE, CACHE, and CTAS commands.
+ */
+private[shark]
+class SparkLoadTask extends HiveTask[SparkLoadWork] with Serializable with LogHelper {
+
+ override def execute(driveContext: DriverContext): Int = {
+ logDebug("Executing " + this.getClass.getName)
+
+ // Set the fair scheduler's pool using mapred.fairscheduler.pool if it is defined.
+ Option(conf.get("mapred.fairscheduler.pool")).foreach { pool =>
+ SharkEnv.sc.setLocalProperty("spark.scheduler.pool", pool)
+ }
+
+ val databaseName = work.databaseName
+ val tableName = work.tableName
+ // Set Spark's job description to be this query.
+ SharkEnv.sc.setJobGroup(
+ "shark.job",
+ s"Updating table $databaseName.$tableName for a(n) ${work.commandType}")
+ val hiveTable = Hive.get(conf).getTable(databaseName, tableName)
+ // Use HadoopTableReader to help with table scans. The `conf` passed is reused across HadoopRDD
+ // instantiations.
+ val hadoopReader = new HadoopTableReader(Utilities.getTableDesc(hiveTable), conf)
+ if (hiveTable.isPartitioned) {
+ loadPartitionedMemoryTable(
+ hiveTable,
+ work.partSpecs,
+ hadoopReader,
+ work.pathFilterOpt)
+ } else {
+ loadMemoryTable(
+ hiveTable,
+ hadoopReader,
+ work.pathFilterOpt)
+ }
+ // Success!
+ 0
+ }
+
+ /**
+ * Creates and materializes the in-memory, columnar RDD for a given input RDD.
+ *
+ * @param inputRdd A hadoop RDD, or a union of hadoop RDDs if the table is partitioned.
+ * @param serDeProps Properties used to initialize local ColumnarSerDe instantiations. This
+ * contains the output schema of the ColumnarSerDe and used to create its
+ * output object inspectors.
+ * @param broadcastedHiveConf Allows for sharing a Hive Configuration broadcast used to create
+ * the Hadoop `inputRdd`.
+ * @param inputOI Object inspector used to read rows from `inputRdd`.
+ * @param hivePartitionKeyOpt A defined Hive partition key if the RDD being loaded is part of a
+ * Hive-partitioned table.
+ */
+ private def materialize(
+ inputRdd: RDD[_],
+ serDeProps: Properties,
+ broadcastedHiveConf: Broadcast[SerializableWritable[HiveConf]],
+ inputOI: StructObjectInspector,
+ tableKey: String,
+ hivePartitionKeyOpt: Option[String]) = {
+ val statsAcc = SharkEnv.sc.accumulableCollection(ArrayBuffer[(Int, TablePartitionStats)]())
+ val tachyonWriter = if (work.cacheMode == CacheType.TACHYON) {
+ // Find the number of columns in the table schema using `serDeProps`.
+ val numColumns = serDeProps.getProperty(Constants.LIST_COLUMNS).split(',').size
+ // Use an additional row to store metadata (e.g. number of rows in each partition).
+ SharkEnv.tachyonUtil.createTableWriter(tableKey, hivePartitionKeyOpt, numColumns + 1)
+ } else {
+ null
+ }
+ val serializedOI = KryoSerializer.serialize(inputOI)
+ var transformedRdd = inputRdd.mapPartitionsWithIndex { case (partIndex, partIter) =>
+ val serde = new ColumnarSerDe
+ serde.initialize(broadcastedHiveConf.value.value, serDeProps)
+ val localInputOI = KryoSerializer.deserialize[ObjectInspector](serializedOI)
+ var builder: Writable = null
+ partIter.foreach { row =>
+ builder = serde.serialize(row.asInstanceOf[AnyRef], localInputOI)
+ }
+ if (builder == null) {
+ // Empty partition.
+ statsAcc += Tuple2(partIndex, new TablePartitionStats(Array.empty, 0))
+ Iterator(new TablePartition(0, Array()))
+ } else {
+ statsAcc += Tuple2(partIndex, builder.asInstanceOf[TablePartitionBuilder].stats)
+ Iterator(builder.asInstanceOf[TablePartitionBuilder].build())
+ }
+ }
+ // Run a job to materialize the RDD.
+ if (work.cacheMode == CacheType.TACHYON) {
+ // Put the table in Tachyon.
+ logInfo("Putting RDD for %s in Tachyon".format(tableKey))
+ if (work.commandType == SparkLoadWork.CommandTypes.OVERWRITE &&
+ SharkEnv.tachyonUtil.tableExists(tableKey, hivePartitionKeyOpt)) {
+ // For INSERT OVERWRITE, delete the old table or Hive partition directory, if it exists.
+ SharkEnv.tachyonUtil.dropTable(tableKey, hivePartitionKeyOpt)
+ }
+ tachyonWriter.createTable(ByteBuffer.allocate(0))
+ transformedRdd = transformedRdd.mapPartitionsWithIndex { case(part, iter) =>
+ val partition = iter.next()
+ partition.toTachyon.zipWithIndex.foreach { case(buf, column) =>
+ tachyonWriter.writeColumnPartition(column, part, buf)
+ }
+ Iterator(partition)
+ }
+ } else {
+ transformedRdd.persist(StorageLevel.MEMORY_AND_DISK)
+ }
+ transformedRdd.context.runJob(
+ transformedRdd, (iter: Iterator[TablePartition]) => iter.foreach(_ => Unit))
+ if (work.cacheMode == CacheType.TACHYON) {
+ tachyonWriter.updateMetadata(ByteBuffer.wrap(JavaSerializer.serialize(statsAcc.value.toMap)))
+ }
+ (transformedRdd, statsAcc.value)
+ }
+
+ /** Returns a MemoryTable for the given Hive table. */
+ private def getOrCreateMemoryTable(hiveTable: HiveTable): MemoryTable = {
+ val databaseName = hiveTable.getDbName
+ val tableName = hiveTable.getTableName
+ work.commandType match {
+ case SparkLoadWork.CommandTypes.NEW_ENTRY => {
+ // This is a new entry, e.g. we are caching a new table or partition.
+ // Create a new MemoryTable object and return that.
+ SharkEnv.memoryMetadataManager.createMemoryTable(databaseName, tableName, work.cacheMode)
+ }
+ case _ => {
+ // This is an existing entry (e.g. we are handling an INSERT or INSERT OVERWRITE).
+ // Get the MemoryTable object from the Shark metastore.
+ val tableOpt = SharkEnv.memoryMetadataManager.getTable(databaseName, tableName)
+ assert(tableOpt.exists(_.isInstanceOf[MemoryTable]),
+ "Memory table being updated cannot be found in the Shark metastore.")
+ tableOpt.get.asInstanceOf[MemoryTable]
+ }
+ }
+ }
+
+ /**
+ * Handles loading data from disk into the Shark cache for non-partitioned tables.
+ *
+ * @param hiveTable Hive metadata object representing the target table.
+ * @param hadoopReader Used to create a HadoopRDD from the table's data directory.
+ * @param pathFilterOpt Defined for INSERT update operations (e.g., INSERT INTO) and passed to
+ * hadoopReader#makeRDDForTable() to determine which new files should be read from the table's
+ * data directory - see the SparkLoadWork#apply() factory method for an example of how a
+ * path filter is created.
+ */
+ private def loadMemoryTable(
+ hiveTable: HiveTable,
+ hadoopReader: HadoopTableReader,
+ pathFilterOpt: Option[PathFilter]) {
+ val databaseName = hiveTable.getDbName
+ val tableName = hiveTable.getTableName
+ val tableSchema = hiveTable.getSchema
+ val serDe = hiveTable.getDeserializer
+ serDe.initialize(conf, tableSchema)
+ // Scan the Hive table's data directory.
+ val inputRDD = hadoopReader.makeRDDForTable(hiveTable, serDe.getClass, pathFilterOpt)
+ // Transform the HadoopRDD to an RDD[TablePartition].
+ val (tablePartitionRDD, tableStats) = materialize(
+ inputRDD,
+ tableSchema,
+ hadoopReader.broadcastedHiveConf,
+ serDe.getObjectInspector.asInstanceOf[StructObjectInspector],
+ MemoryMetadataManager.makeTableKey(databaseName, tableName),
+ hivePartitionKeyOpt = None)
+ if (work.cacheMode != CacheType.TACHYON) {
+ val memoryTable = getOrCreateMemoryTable(hiveTable)
+ work.commandType match {
+ case (SparkLoadWork.CommandTypes.OVERWRITE | SparkLoadWork.CommandTypes.NEW_ENTRY) =>
+ memoryTable.put(tablePartitionRDD, tableStats.toMap)
+ case SparkLoadWork.CommandTypes.INSERT => {
+ memoryTable.update(tablePartitionRDD, tableStats)
+ }
+ }
+ }
+ }
+
+ /**
+ * Returns the created (for CommandType.NEW_ENTRY) or fetched (for CommandType.INSERT or
+ * OVERWRITE) PartitionedMemoryTable corresponding to `partSpecs`.
+ *
+ * @param hiveTable The Hive Table.
+ * @param partSpecs A map of (partitioning column -> corresponding value) that uniquely
+ * identifies the partition being created or updated.
+ */
+ private def getOrCreatePartitionedMemoryTable(
+ hiveTable: HiveTable,
+ partSpecs: JavaMap[String, String]): PartitionedMemoryTable = {
+ val databaseName = hiveTable.getDbName
+ val tableName = hiveTable.getTableName
+ work.commandType match {
+ case SparkLoadWork.CommandTypes.NEW_ENTRY => {
+ SharkEnv.memoryMetadataManager.createPartitionedMemoryTable(
+ databaseName,
+ tableName,
+ work.cacheMode,
+ hiveTable.getParameters)
+ }
+ case _ => {
+ SharkEnv.memoryMetadataManager.getTable(databaseName, tableName) match {
+ case Some(table: PartitionedMemoryTable) => table
+ case _ => {
+ val tableOpt = SharkEnv.memoryMetadataManager.getTable(databaseName, tableName)
+ assert(tableOpt.exists(_.isInstanceOf[PartitionedMemoryTable]),
+ "Partitioned memory table being updated cannot be found in the Shark metastore.")
+ tableOpt.get.asInstanceOf[PartitionedMemoryTable]
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Handles loading data from disk into the Shark cache for non-partitioned tables.
+ *
+ * @param hiveTable Hive metadata object representing the target table.
+ * @param partSpecs Sequence of partition key specifications that contains either a single key,
+ * or all of the table's partition keys. This is because only one partition specficiation is
+ * allowed for each append or overwrite command, and new cache entries (i.e, for a CACHE
+ * comand) are full table scans.
+ * @param hadoopReader Used to create a HadoopRDD from each partition's data directory.
+ * @param pathFilterOpt Defined for INSERT update operations (e.g., INSERT INTO) and passed to
+ * hadoopReader#makeRDDForTable() to determine which new files should be read from the table
+ * partition's data directory - see the SparkLoadWork#apply() factory method for an example of
+ * how a path filter is created.
+ */
+ private def loadPartitionedMemoryTable(
+ hiveTable: HiveTable,
+ partSpecs: Seq[JavaMap[String, String]],
+ hadoopReader: HadoopTableReader,
+ pathFilterOpt: Option[PathFilter]) {
+ val databaseName = hiveTable.getDbName
+ val tableName = hiveTable.getTableName
+ val partCols = hiveTable.getPartCols.map(_.getName)
+
+ for (partSpec <- partSpecs) {
+ // Read, materialize, and store a columnar-backed RDD for `partSpec`.
+ val partitionKey = MemoryMetadataManager.makeHivePartitionKeyStr(partCols, partSpec)
+ val partition = db.getPartition(hiveTable, partSpec, false /* forceCreate */)
+ val partSerDe = partition.getDeserializer()
+ val partSchema = partition.getSchema
+ partSerDe.initialize(conf, partSchema)
+ // Get a UnionStructObjectInspector that unifies the two StructObjectInspectors for the table
+ // columns and the partition columns.
+ val unionOI = HiveUtils.makeUnionOIForPartitionedTable(partSchema, partSerDe)
+ // Create a HadoopRDD for the file scan.
+ val inputRDD = hadoopReader.makeRDDForPartitionedTable(
+ Map(partition -> partSerDe.getClass), pathFilterOpt)
+ val (tablePartitionRDD, tableStats) = materialize(
+ inputRDD,
+ SparkLoadTask.addPartitionInfoToSerDeProps(partCols, partition.getSchema),
+ hadoopReader.broadcastedHiveConf,
+ unionOI,
+ MemoryMetadataManager.makeTableKey(databaseName, tableName),
+ Some(partitionKey))
+ if (work.cacheMode != CacheType.TACHYON) {
+ // Handle appends or overwrites.
+ val partitionedTable = getOrCreatePartitionedMemoryTable(hiveTable, partSpec)
+ if (partitionedTable.containsPartition(partitionKey) &&
+ (work.commandType == SparkLoadWork.CommandTypes.INSERT)) {
+ partitionedTable.updatePartition(partitionKey, tablePartitionRDD, tableStats)
+ } else {
+ partitionedTable.putPartition(partitionKey, tablePartitionRDD, tableStats.toMap)
+ }
+ }
+ }
+ }
+
+ override def getType = StageType.MAPRED
+
+ override def getName = "MAPRED-LOAD-SPARK"
+
+ override def localizeMRTmpFilesImpl(ctx: Context) = Unit
+}
+
+
+object SparkLoadTask {
+
+ /**
+ * Returns a copy of `baseSerDeProps` with the names and types for the table's partitioning
+ * columns appended to respective row metadata properties.
+ */
+ private def addPartitionInfoToSerDeProps(
+ partCols: Seq[String],
+ baseSerDeProps: Properties): Properties = {
+ val serDeProps = new Properties(baseSerDeProps)
+
+ // Column names specified by the Constants.LIST_COLUMNS key are delimited by ",".
+ // E.g., for a table created from
+ // CREATE TABLE page_views(key INT, val BIGINT), PARTITIONED BY (dt STRING, country STRING),
+ // `columnNameProperties` will be "key,val". We want to append the "dt, country" partition
+ // column names to it, and reset the Constants.LIST_COLUMNS entry in the SerDe properties.
+ var columnNameProperties: String = serDeProps.getProperty(Constants.LIST_COLUMNS)
+ columnNameProperties += "," + partCols.mkString(",")
+ serDeProps.setProperty(Constants.LIST_COLUMNS, columnNameProperties)
+
+ // `None` if column types are missing. By default, Hive SerDeParameters initialized by the
+ // ColumnarSerDe will treat all columns as having string types.
+ // Column types specified by the Constants.LIST_COLUMN_TYPES key are delimited by ":"
+ // E.g., for the CREATE TABLE example above, if `columnTypeProperties` is defined, then it
+ // will be "int:bigint". Partition columns are strings, so "string:string" should be appended.
+ val columnTypePropertiesOpt = Option(serDeProps.getProperty(Constants.LIST_COLUMN_TYPES))
+ columnTypePropertiesOpt.foreach { columnTypeProperties =>
+ serDeProps.setProperty(Constants.LIST_COLUMN_TYPES,
+ columnTypeProperties + (":" + Constants.STRING_TYPE_NAME * partCols.size))
+ }
+ serDeProps
+ }
+}
diff --git a/src/main/scala/shark/execution/SparkTask.scala b/src/main/scala/shark/execution/SparkTask.scala
index 32241a47..f878ce0c 100755
--- a/src/main/scala/shark/execution/SparkTask.scala
+++ b/src/main/scala/shark/execution/SparkTask.scala
@@ -54,7 +54,7 @@ class SparkTask extends HiveTask[SparkWork] with Serializable with LogHelper {
def tableRdd: Option[TableRDD] = _tableRdd
override def execute(driverContext: DriverContext): Int = {
- logInfo("Executing " + this.getClass.getName)
+ logDebug("Executing " + this.getClass.getName)
val ctx = driverContext.getCtx()
@@ -86,17 +86,15 @@ class SparkTask extends HiveTask[SparkWork] with Serializable with LogHelper {
initializeTableScanTableDesc(tableScanOps)
- // Initialize the Hive query plan. This gives us all the object inspectors.
- initializeAllHiveOperators(terminalOp)
-
terminalOp.initializeMasterOnAll()
// Set Spark's job description to be this query.
- SharkEnv.sc.setJobDescription(work.pctx.getContext.getCmd)
+ SharkEnv.sc.setJobGroup("shark.job", work.pctx.getContext.getCmd)
- // Set the fair scheduler's pool.
- SharkEnv.sc.setLocalProperty("spark.scheduler.cluster.fair.pool",
- conf.get("mapred.fairscheduler.pool"))
+ // Set the fair scheduler's pool using mapred.fairscheduler.pool if it is defined.
+ Option(conf.get("mapred.fairscheduler.pool")).foreach { pool =>
+ SharkEnv.sc.setLocalProperty("spark.scheduler.pool", pool)
+ }
val sinkRdd = terminalOp.execute().asInstanceOf[RDD[Any]]
@@ -116,6 +114,7 @@ class SparkTask extends HiveTask[SparkWork] with Serializable with LogHelper {
// topToTable maps Hive's TableScanOperator to the Table object.
val topToTable: JHashMap[HiveTableScanOperator, Table] = work.pctx.getTopToTable()
+ val emptyPartnArray = new Array[Partition](0)
// Add table metadata to TableScanOperators
topOps.foreach { op =>
op.table = topToTable.get(op.hiveOp)
@@ -127,7 +126,8 @@ class SparkTask extends HiveTask[SparkWork] with Serializable with LogHelper {
work.pctx.getOpToPartPruner().get(op.hiveOp),
work.pctx.getConf(), "",
work.pctx.getPrunedPartitions())
- op.parts = ppl.getConfirmedPartns.toArray ++ ppl.getUnknownPartns.toArray
+ op.parts = ppl.getConfirmedPartns.toArray(emptyPartnArray) ++
+ ppl.getUnknownPartns.toArray(emptyPartnArray)
val allParts = op.parts ++ ppl.getDeniedPartns.toArray
if (allParts.size == 0) {
op.firstConfPartDesc = new PartitionDesc(op.tableDesc, null)
@@ -138,28 +138,6 @@ class SparkTask extends HiveTask[SparkWork] with Serializable with LogHelper {
}
}
- def initializeAllHiveOperators(terminalOp: TerminalOperator) {
- // Need to guarantee all parents are initialized before the child.
- val topOpList = new scala.collection.mutable.MutableList[HiveTopOperator]
- val queue = new scala.collection.mutable.Queue[Operator[_]]
- queue.enqueue(terminalOp)
-
- while (!queue.isEmpty) {
- val current = queue.dequeue()
- current match {
- case op: HiveTopOperator => topOpList += op
- case _ => Unit
- }
- queue ++= current.parentOperators
- }
-
- // Run the initialization. This guarantees that upstream operators are
- // initialized before downstream ones.
- topOpList.reverse.foreach { topOp =>
- topOp.initializeHiveTopOperator()
- }
- }
-
override def getType = StageType.MAPRED
override def getName = "MAPRED-SPARK"
diff --git a/src/main/scala/shark/execution/TableReader.scala b/src/main/scala/shark/execution/TableReader.scala
new file mode 100644
index 00000000..0687c7c8
--- /dev/null
+++ b/src/main/scala/shark/execution/TableReader.scala
@@ -0,0 +1,250 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package shark.execution
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
+import org.apache.hadoop.hive.metastore.api.Constants.META_TABLE_PARTITION_COLUMNS
+import org.apache.hadoop.hive.ql.exec.Utilities
+import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable}
+import org.apache.hadoop.hive.ql.plan.TableDesc
+
+import org.apache.spark.rdd.{EmptyRDD, RDD, UnionRDD}
+
+import shark.{LogHelper, SharkEnv}
+import shark.api.QueryExecutionException
+import shark.execution.serialization.JavaSerializer
+import shark.memstore2.{MemoryMetadataManager, Table, TablePartition, TablePartitionStats}
+import shark.tachyon.TachyonException
+
+
+/**
+ * A trait for subclasses that handle table scans. In Shark, there is one subclass for each
+ * type of table storage: HeapTableReader for Shark tables in Spark's block manager,
+ * TachyonTableReader for tables in Tachyon, and HadoopTableReader for Hive tables in a filesystem.
+ */
+trait TableReader extends LogHelper {
+
+ type PruningFunctionType = (RDD[_], collection.Map[Int, TablePartitionStats]) => RDD[_]
+
+ def makeRDDForTable(
+ hiveTable: HiveTable,
+ pruningFnOpt: Option[PruningFunctionType] = None
+ ): RDD[_]
+
+ def makeRDDForPartitionedTable(
+ partitions: Seq[HivePartition],
+ pruningFnOpt: Option[PruningFunctionType] = None
+ ): RDD[_]
+}
+
+/** Helper class for scanning tables stored in Tachyon. */
+class TachyonTableReader(@transient _tableDesc: TableDesc) extends TableReader {
+
+ // Split from 'databaseName.tableName'
+ private val _tableNameSplit = _tableDesc.getTableName.split('.')
+ private val _databaseName = _tableNameSplit(0)
+ private val _tableName = _tableNameSplit(1)
+
+ override def makeRDDForTable(
+ hiveTable: HiveTable,
+ pruningFnOpt: Option[PruningFunctionType] = None
+ ): RDD[_] = {
+ val tableKey = MemoryMetadataManager.makeTableKey(_databaseName, _tableName)
+ makeRDD(tableKey, hivePartitionKeyOpt = None, pruningFnOpt)
+ }
+
+ override def makeRDDForPartitionedTable(
+ partitions: Seq[HivePartition],
+ pruningFnOpt: Option[PruningFunctionType] = None): RDD[_] = {
+ val tableKey = MemoryMetadataManager.makeTableKey(_databaseName, _tableName)
+ val hivePartitionRDDs = partitions.map { hivePartition =>
+ val partDesc = Utilities.getPartitionDesc(hivePartition)
+ // Get partition field info
+ val partSpec = partDesc.getPartSpec()
+ val partProps = partDesc.getProperties()
+
+ val partColsDelimited = partProps.getProperty(META_TABLE_PARTITION_COLUMNS)
+ // Partitioning columns are delimited by "/"
+ val partCols = partColsDelimited.trim().split("/").toSeq
+ // 'partValues[i]' contains the value for the partitioning column at 'partCols[i]'.
+ val partValues = if (partSpec == null) {
+ Array.fill(partCols.size)(new String)
+ } else {
+ partCols.map(col => new String(partSpec.get(col))).toArray
+ }
+ val partitionKeyStr = MemoryMetadataManager.makeHivePartitionKeyStr(partCols, partSpec)
+ val hivePartitionRDD = makeRDD(tableKey, Some(partitionKeyStr), pruningFnOpt)
+ hivePartitionRDD.mapPartitions { iter =>
+ if (iter.hasNext) {
+ // Map each tuple to a row object
+ val rowWithPartArr = new Array[Object](2)
+ iter.map { value =>
+ rowWithPartArr.update(0, value.asInstanceOf[Object])
+ rowWithPartArr.update(1, partValues)
+ rowWithPartArr.asInstanceOf[Object]
+ }
+ } else {
+ Iterator.empty
+ }
+ }
+ }
+ if (hivePartitionRDDs.size > 0) {
+ new UnionRDD(hivePartitionRDDs.head.context, hivePartitionRDDs)
+ } else {
+ new EmptyRDD[Object](SharkEnv.sc)
+ }
+ }
+
+ private def makeRDD(
+ tableKey: String,
+ hivePartitionKeyOpt: Option[String],
+ pruningFnOpt: Option[PruningFunctionType]): RDD[Any] = {
+ // Check that the table is in Tachyon.
+ if (!SharkEnv.tachyonUtil.tableExists(tableKey, hivePartitionKeyOpt)) {
+ throw new TachyonException("Table " + tableKey + " does not exist in Tachyon")
+ }
+ val tableRDDsAndStats = SharkEnv.tachyonUtil.createRDD(tableKey, hivePartitionKeyOpt)
+ val prunedRDDs = if (pruningFnOpt.isDefined) {
+ val pruningFn = pruningFnOpt.get
+ tableRDDsAndStats.map(tableRDDWithStats =>
+ pruningFn(tableRDDWithStats._1, tableRDDWithStats._2).asInstanceOf[RDD[Any]])
+ } else {
+ tableRDDsAndStats.map(tableRDDAndStats => tableRDDAndStats._1.asInstanceOf[RDD[Any]])
+ }
+ val unionedRDD = if (prunedRDDs.isEmpty) {
+ new EmptyRDD[TablePartition](SharkEnv.sc)
+ } else {
+ new UnionRDD(SharkEnv.sc, prunedRDDs)
+ }
+ unionedRDD.asInstanceOf[RDD[Any]]
+ }
+
+}
+
+/** Helper class for scanning tables stored in Spark's block manager */
+class HeapTableReader(@transient _tableDesc: TableDesc) extends TableReader {
+
+ // Split from 'databaseName.tableName'
+ private val _tableNameSplit = _tableDesc.getTableName.split('.')
+ private val _databaseName = _tableNameSplit(0)
+ private val _tableName = _tableNameSplit(1)
+
+ /** Fetches and optionally prunes the RDD for `_tableName` from the Shark metastore. */
+ override def makeRDDForTable(
+ hiveTable: HiveTable,
+ pruningFnOpt: Option[PruningFunctionType] = None
+ ): RDD[_] = {
+ logInfo("Loading table %s.%s from Spark block manager".format(_databaseName, _tableName))
+ val tableOpt = SharkEnv.memoryMetadataManager.getMemoryTable(_databaseName, _tableName)
+ if (tableOpt.isEmpty) {
+ throwMissingTableException()
+ }
+
+ val table = tableOpt.get
+ val tableRdd = table.getRDD.get
+ val tableStats = table.getStats.get
+ // Prune if an applicable function is given.
+ pruningFnOpt.map(_(tableRdd, tableStats)).getOrElse(tableRdd)
+ }
+
+ /**
+ * Fetches an RDD from the Shark metastore for each partition key given. Returns a single, unioned
+ * RDD representing all of the specified partition keys.
+ *
+ * @param partitions A collection of Hive-partition metadata, such as partition columns and
+ * partition key specifications.
+ */
+ override def makeRDDForPartitionedTable(
+ partitions: Seq[HivePartition],
+ pruningFnOpt: Option[PruningFunctionType] = None
+ ): RDD[_] = {
+ val hivePartitionRDDs = partitions.map { partition =>
+ val partDesc = Utilities.getPartitionDesc(partition)
+ // Get partition field info
+ val partSpec = partDesc.getPartSpec()
+ val partProps = partDesc.getProperties()
+
+ val partColsDelimited = partProps.getProperty(META_TABLE_PARTITION_COLUMNS)
+ // Partitioning columns are delimited by "/"
+ val partCols = partColsDelimited.trim().split("/").toSeq
+ // 'partValues[i]' contains the value for the partitioning column at 'partCols[i]'.
+ val partValues = if (partSpec == null) {
+ Array.fill(partCols.size)(new String)
+ } else {
+ partCols.map(col => new String(partSpec.get(col))).toArray
+ }
+
+ val partitionKeyStr = MemoryMetadataManager.makeHivePartitionKeyStr(partCols, partSpec)
+ val hivePartitionedTableOpt = SharkEnv.memoryMetadataManager.getPartitionedTable(
+ _databaseName, _tableName)
+ if (hivePartitionedTableOpt.isEmpty) {
+ throwMissingTableException()
+ }
+ val hivePartitionedTable = hivePartitionedTableOpt.get
+
+ val rddAndStatsOpt = hivePartitionedTable.getPartitionAndStats(partitionKeyStr)
+ if (rddAndStatsOpt.isEmpty) {
+ throwMissingPartitionException(partitionKeyStr)
+ }
+ val (hivePartitionRDD, hivePartitionStats) = (rddAndStatsOpt.get._1, rddAndStatsOpt.get._2)
+ val prunedPartitionRDD = pruningFnOpt.map(_(hivePartitionRDD, hivePartitionStats))
+ .getOrElse(hivePartitionRDD)
+ prunedPartitionRDD.mapPartitions { iter =>
+ if (iter.hasNext) {
+ // Map each tuple to a row object
+ val rowWithPartArr = new Array[Object](2)
+ iter.map { value =>
+ rowWithPartArr.update(0, value.asInstanceOf[Object])
+ rowWithPartArr.update(1, partValues)
+ rowWithPartArr.asInstanceOf[Object]
+ }
+ } else {
+ Iterator.empty
+ }
+ }
+ }
+ if (hivePartitionRDDs.size > 0) {
+ new UnionRDD(hivePartitionRDDs.head.context, hivePartitionRDDs)
+ } else {
+ new EmptyRDD[Object](SharkEnv.sc)
+ }
+ }
+
+ /**
+ * Thrown if the table identified by the (_databaseName, _tableName) pair cannot be found in
+ * the Shark metastore.
+ */
+ private def throwMissingTableException() {
+ logError("""|Table %s.%s not found in block manager.
+ |Are you trying to access a cached table from a Shark session other than the one
+ |in which it was created?""".stripMargin.format(_databaseName, _tableName))
+ throw new QueryExecutionException("Cached table not found")
+ }
+
+ /**
+ * Thrown if the table partition identified by the (_databaseName, _tableName, partValues) tuple
+ * cannot be found in the Shark metastore.
+ */
+ private def throwMissingPartitionException(partValues: String) {
+ logError("""|Partition %s for table %s.%s not found in block manager.
+ |Are you trying to access a cached table from a Shark session other than the one in
+ |which it was created?""".stripMargin.format(partValues, _databaseName, _tableName))
+ throw new QueryExecutionException("Cached table partition not found")
+ }
+}
diff --git a/src/main/scala/shark/execution/TableScanOperator.scala b/src/main/scala/shark/execution/TableScanOperator.scala
index 9fda4702..eaba7e9b 100755
--- a/src/main/scala/shark/execution/TableScanOperator.scala
+++ b/src/main/scala/shark/execution/TableScanOperator.scala
@@ -18,40 +18,49 @@
package shark.execution
import java.util.{ArrayList, Arrays}
+
+import scala.collection.JavaConversions._
import scala.reflect.BeanProperty
-import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf}
+
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.metastore.api.Constants.META_TABLE_PARTITION_COLUMNS
import org.apache.hadoop.hive.ql.exec.{TableScanOperator => HiveTableScanOperator}
import org.apache.hadoop.hive.ql.exec.{MapSplitPruning, Utilities}
-import org.apache.hadoop.hive.ql.io.HiveInputFormat
import org.apache.hadoop.hive.ql.metadata.{Partition, Table}
-import org.apache.hadoop.hive.ql.plan.{PlanUtils, PartitionDesc, TableDesc}
+import org.apache.hadoop.hive.ql.plan.{PartitionDesc, TableDesc, TableScanDesc}
+import org.apache.hadoop.hive.serde.Constants
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory,
StructObjectInspector}
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
-import org.apache.hadoop.io.Writable
-import org.apache.spark.rdd.{PartitionPruningRDD, RDD, UnionRDD}
+import org.apache.spark.rdd.{PartitionPruningRDD, RDD}
-import shark.{SharkConfVars, SharkEnv, Utils}
-import shark.api.QueryExecutionException
+import shark.{LogHelper, SharkConfVars, SharkEnv}
import shark.execution.optimization.ColumnPruner
-import shark.execution.serialization.{XmlSerializer, JavaSerializer}
-import shark.memstore2.{CacheType, TablePartition, TablePartitionStats}
-import shark.tachyon.TachyonException
+import shark.memstore2.CacheType
+import shark.memstore2.CacheType._
+import shark.memstore2.{ColumnarSerDe, MemoryMetadataManager}
+import shark.memstore2.{TablePartition, TablePartitionStats}
+import shark.util.HiveUtils
/**
* The TableScanOperator is used for scanning any type of Shark or Hive table.
*/
-class TableScanOperator extends TopOperator[HiveTableScanOperator] with HiveTopOperator {
+class TableScanOperator extends TopOperator[TableScanDesc] {
+ // TODO(harvey): Try to use 'TableDesc' for execution and save 'Table' for analysis/planning.
+ // Decouple `Table` from TableReader and ColumnPruner.
@transient var table: Table = _
+ @transient var hiveOp: HiveTableScanOperator = _
+
// Metadata for Hive-partitions (i.e if the table was created from PARTITION BY). NULL if this
// table isn't Hive-partitioned. Set in SparkTask::initializeTableScanTableDesc().
- @transient var parts: Array[Object] = _
+ @transient var parts: Array[Partition] = _
+
+ // For convenience, a local copy of the HiveConf for this task.
+ @transient var localHConf: HiveConf = _
// PartitionDescs are used during planning in Hive. This reference to a single PartitionDesc
// is used to initialize partition ObjectInspectors.
@@ -62,283 +71,210 @@ class TableScanOperator extends TopOperator[HiveTableScanOperator] with HiveTopO
@BeanProperty var firstConfPartDesc: PartitionDesc = _
@BeanProperty var tableDesc: TableDesc = _
- @BeanProperty var localHconf: HiveConf = _
- /**
- * Initialize the hive TableScanOperator. This initialization propagates
- * downstream. When all Hive TableScanOperators are initialized, the entire
- * Hive query plan operators are initialized.
- */
- override def initializeHiveTopOperator() {
+ // True if table data is stored the Spark heap.
+ @BeanProperty var isInMemoryTableScan: Boolean = _
- val rowObjectInspector = {
- if (parts == null) {
- val serializer = tableDesc.getDeserializerClass().newInstance()
- serializer.initialize(hconf, tableDesc.getProperties)
- serializer.getObjectInspector()
- } else {
- val partProps = firstConfPartDesc.getProperties()
- val tableDeser = firstConfPartDesc.getDeserializerClass().newInstance()
- tableDeser.initialize(hconf, partProps)
- val partCols = partProps.getProperty(META_TABLE_PARTITION_COLUMNS)
- val partNames = new ArrayList[String]
- val partObjectInspectors = new ArrayList[ObjectInspector]
- partCols.trim().split("/").foreach{ key =>
- partNames.add(key)
- partObjectInspectors.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
- }
+ @BeanProperty var cacheMode: CacheType.CacheType = _
- // No need to lock this one (see SharkEnv.objectInspectorLock) because
- // this is called on the master only.
- val partObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector(
- partNames, partObjectInspectors)
- val oiList = Arrays.asList(
- tableDeser.getObjectInspector().asInstanceOf[StructObjectInspector],
- partObjectInspector.asInstanceOf[StructObjectInspector])
- // new oi is union of table + partition object inspectors
- ObjectInspectorFactory.getUnionStructObjectInspector(oiList)
- }
- }
- setInputObjectInspector(0, rowObjectInspector)
- super.initializeHiveTopOperator()
+ override def initializeOnMaster() {
+ // Create a local copy of the HiveConf that will be assigned job properties and, for disk reads,
+ // broadcasted to slaves.
+ localHConf = new HiveConf(super.hconf)
+ cacheMode = CacheType.fromString(
+ tableDesc.getProperties().get("shark.cache").asInstanceOf[String])
+ isInMemoryTableScan = SharkEnv.memoryMetadataManager.containsTable(
+ table.getDbName, table.getTableName)
}
- override def initializeOnMaster() {
- localHconf = super.hconf
+ override def outputObjectInspector() = {
+ if (parts == null) {
+ val serializer = if (isInMemoryTableScan || cacheMode == CacheType.TACHYON) {
+ new ColumnarSerDe
+ } else {
+ tableDesc.getDeserializerClass().newInstance()
+ }
+ serializer.initialize(hconf, tableDesc.getProperties)
+ serializer.getObjectInspector()
+ } else {
+ val partProps = firstConfPartDesc.getProperties()
+ val partSerDe = if (isInMemoryTableScan || cacheMode == CacheType.TACHYON) {
+ new ColumnarSerDe
+ } else {
+ firstConfPartDesc.getDeserializerClass().newInstance()
+ }
+ partSerDe.initialize(hconf, partProps)
+ HiveUtils.makeUnionOIForPartitionedTable(partProps, partSerDe)
+ }
}
override def execute(): RDD[_] = {
assert(parentOperators.size == 0)
- val tableKey: String = tableDesc.getTableName.split('.')(1)
+
+ val tableNameSplit = tableDesc.getTableName.split('.') // Split from 'databaseName.tableName'
+ val databaseName = tableNameSplit(0)
+ val tableName = tableNameSplit(1)
// There are three places we can load the table from.
- // 1. Tachyon table
- // 2. Spark heap (block manager)
+ // 1. Spark heap (block manager), accessed through the Shark MemoryMetadataManager
+ // 2. Tachyon table
// 3. Hive table on HDFS (or other Hadoop storage)
-
- val cacheMode = CacheType.fromString(
- tableDesc.getProperties().get("shark.cache").asInstanceOf[String])
- if (cacheMode == CacheType.heap) {
- // Table should be in Spark heap (block manager).
- val rdd = SharkEnv.memoryMetadataManager.get(tableKey).getOrElse {
- logError("""|Table %s not found in block manager.
- |Are you trying to access a cached table from a Shark session other than
- |the one in which it was created?""".stripMargin.format(tableKey))
- throw(new QueryExecutionException("Cached table not found"))
+ // TODO(harvey): Pruning Hive-partitioned, cached tables isn't supported yet.
+ if (isInMemoryTableScan || cacheMode == CacheType.TACHYON) {
+ if (isInMemoryTableScan) {
+ assert(cacheMode == CacheType.MEMORY || cacheMode == CacheType.MEMORY_ONLY,
+ "Table %s.%s is in Shark metastore, but its cacheMode (%s) indicates otherwise".
+ format(databaseName, tableName, cacheMode))
}
- logInfo("Loading table " + tableKey + " from Spark block manager")
- createPrunedRdd(tableKey, rdd)
- } else if (cacheMode == CacheType.tachyon) {
- // Table is in Tachyon.
- if (!SharkEnv.tachyonUtil.tableExists(tableKey)) {
- throw new TachyonException("Table " + tableKey + " does not exist in Tachyon")
+ val tableReader = if (cacheMode == CacheType.TACHYON) {
+ new TachyonTableReader(tableDesc)
+ } else {
+ new HeapTableReader(tableDesc)
}
- logInfo("Loading table " + tableKey + " from Tachyon.")
-
- var indexToStats: collection.Map[Int, TablePartitionStats] =
- SharkEnv.memoryMetadataManager.getStats(tableKey).getOrElse(null)
-
- if (indexToStats == null) {
- val statsByteBuffer = SharkEnv.tachyonUtil.getTableMetadata(tableKey)
- indexToStats = JavaSerializer.deserialize[collection.Map[Int, TablePartitionStats]](
- statsByteBuffer.array())
- logInfo("Loading table " + tableKey + " stats from Tachyon.")
- SharkEnv.memoryMetadataManager.putStats(tableKey, indexToStats)
+ if (table.isPartitioned) {
+ tableReader.makeRDDForPartitionedTable(parts, Some(createPrunedRdd _))
+ } else {
+ tableReader.makeRDDForTable(table, Some(createPrunedRdd _))
}
- createPrunedRdd(tableKey, SharkEnv.tachyonUtil.createRDD(tableKey))
} else {
// Table is a Hive table on HDFS (or other Hadoop storage).
- super.execute()
+ makeRDDFromHadoop()
}
}
- private def createPrunedRdd(tableKey: String, rdd: RDD[_]): RDD[_] = {
- // Stats used for map pruning.
- val indexToStats: collection.Map[Int, TablePartitionStats] =
- SharkEnv.memoryMetadataManager.getStats(tableKey).get
-
+ private def createPrunedRdd(
+ rdd: RDD[_],
+ indexToStats: collection.Map[Int, TablePartitionStats]): RDD[_] = {
// Run map pruning if the flag is set, there exists a filter predicate on
// the input table and we have statistics on the table.
val columnsUsed = new ColumnPruner(this, table).columnsUsed
- SharkEnv.tachyonUtil.pushDownColumnPruning(rdd, columnsUsed)
-
- val prunedRdd: RDD[_] =
- if (SharkConfVars.getBoolVar(localHconf, SharkConfVars.MAP_PRUNING) &&
- childOperators(0).isInstanceOf[FilterOperator] &&
- indexToStats.size == rdd.partitions.size) {
-
- val startTime = System.currentTimeMillis
- val printPruneDebug = SharkConfVars.getBoolVar(
- localHconf, SharkConfVars.MAP_PRUNING_PRINT_DEBUG)
-
- // Must initialize the condition evaluator in FilterOperator to get the
- // udfs and object inspectors set.
- val filterOp = childOperators(0).asInstanceOf[FilterOperator]
- filterOp.initializeOnSlave()
-
- def prunePartitionFunc(index: Int): Boolean = {
- if (printPruneDebug) {
- logInfo("\nPartition " + index + "\n" + indexToStats(index))
- }
- // Only test for pruning if we have stats on the column.
- val partitionStats = indexToStats(index)
- if (partitionStats != null && partitionStats.stats != null) {
- MapSplitPruning.test(partitionStats, filterOp.conditionEvaluator)
- } else {
- true
- }
- }
- // Do the pruning.
- val prunedRdd = PartitionPruningRDD.create(rdd, prunePartitionFunc)
- val timeTaken = System.currentTimeMillis - startTime
- logInfo("Map pruning %d partitions into %s partitions took %d ms".format(
- rdd.partitions.size, prunedRdd.partitions.size, timeTaken))
- prunedRdd
- } else {
- rdd
+ if (!table.isPartitioned && cacheMode == CacheType.TACHYON) {
+ SharkEnv.tachyonUtil.pushDownColumnPruning(rdd, columnsUsed)
+ }
+
+ val shouldPrune = SharkConfVars.getBoolVar(localHConf, SharkConfVars.MAP_PRUNING) &&
+ childOperators(0).isInstanceOf[FilterOperator] &&
+ indexToStats.size == rdd.partitions.size
+
+ val prunedRdd: RDD[_] = if (shouldPrune) {
+ val startTime = System.currentTimeMillis
+ val printPruneDebug = SharkConfVars.getBoolVar(
+ localHConf, SharkConfVars.MAP_PRUNING_PRINT_DEBUG)
+
+ // Must initialize the condition evaluator in FilterOperator to get the
+ // udfs and object inspectors set.
+ val filterOp = childOperators(0).asInstanceOf[FilterOperator]
+ filterOp.initializeOnSlave()
+
+ def prunePartitionFunc(index: Int): Boolean = {
+ if (printPruneDebug) {
+ logInfo("\nPartition " + index + "\n" + indexToStats(index))
+ }
+ // Only test for pruning if we have stats on the column.
+ val partitionStats = indexToStats(index)
+ if (partitionStats != null && partitionStats.stats != null) {
+ MapSplitPruning.test(partitionStats, filterOp.conditionEvaluator)
+ } else {
+ true
+ }
}
+ // Do the pruning.
+ val prunedRdd = PartitionPruningRDD.create(rdd, prunePartitionFunc)
+ val timeTaken = System.currentTimeMillis - startTime
+ logInfo("Map pruning %d partitions into %s partitions took %d ms".format(
+ rdd.partitions.size, prunedRdd.partitions.size, timeTaken))
+ prunedRdd
+ } else {
+ rdd
+ }
+
prunedRdd.mapPartitions { iter =>
if (iter.hasNext) {
- val tablePartition = iter.next.asInstanceOf[TablePartition]
+ val tablePartition1 = iter.next()
+ val tablePartition = tablePartition1.asInstanceOf[TablePartition]
tablePartition.prunedIterator(columnsUsed)
- //tablePartition.iterator
} else {
- Iterator()
+ Iterator.empty
}
}
}
/**
- * Create a RDD representing the table (with or without partitions).
+ * Create an RDD for a table stored in Hadoop.
*/
- override def preprocessRdd(rdd: RDD[_]): RDD[_] = {
+ def makeRDDFromHadoop(): RDD[_] = {
+ // Try to have the InputFormats filter predicates.
+ TableScanOperator.addFilterExprToConf(localHConf, hiveOp)
+
+ val hadoopReader = new HadoopTableReader(tableDesc, localHConf)
if (table.isPartitioned) {
- logInfo("Making %d Hive partitions".format(parts.size))
- makePartitionRDD(rdd)
+ logDebug("Making %d Hive partitions".format(parts.size))
+ // The returned RDD contains arrays of size two with the elements as
+ // (deserialized row, column partition value).
+ return hadoopReader.makeRDDForPartitionedTable(parts)
} else {
- val tablePath = table.getPath.toString
- val ifc = table.getInputFormatClass
- .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]]
- logInfo("Table input: %s".format(tablePath))
- createHadoopRdd(tablePath, ifc)
+ // The returned RDD contains deserialized row Objects.
+ return hadoopReader.makeRDDForTable(table)
}
}
- override def processPartition(index: Int, iter: Iterator[_]): Iterator[_] = {
- val deserializer = tableDesc.getDeserializerClass().newInstance()
- deserializer.initialize(localHconf, tableDesc.getProperties)
- iter.map { value =>
- value match {
- case rowWithPart: Array[Object] => rowWithPart
- case v: Writable => deserializer.deserialize(v)
- case _ => throw new RuntimeException("Failed to match " + value.toString)
- }
- }
- }
+ // All RDD processing is done in execute().
+ override def processPartition(split: Int, iter: Iterator[_]): Iterator[_] =
+ throw new UnsupportedOperationException("TableScanOperator.processPartition()")
+
+}
+
+
+object TableScanOperator extends LogHelper {
/**
- * Create an RDD for every partition column specified in the query. Note that for on-disk Hive
- * tables, a data directory is created for each partition corresponding to keys specified using
- * 'PARTITION BY'.
+ * Add filter expressions and column metadata to the HiveConf. This is meant to be called on the
+ * master - it's impractical to add filters during slave-local JobConf creation in HadoopRDD,
+ * since we would have to serialize the HiveTableScanOperator.
*/
- private def makePartitionRDD[T](rdd: RDD[T]): RDD[_] = {
- val partitions = parts
- val rdds = new Array[RDD[Any]](partitions.size)
-
- var i = 0
- partitions.foreach { part =>
- val partition = part.asInstanceOf[Partition]
- val partDesc = Utilities.getPartitionDesc(partition)
- val tablePath = partition.getPartitionPath.toString
-
- val ifc = partition.getInputFormatClass
- .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]]
- val parts = createHadoopRdd(tablePath, ifc)
-
- val serializedHconf = XmlSerializer.serialize(localHconf, localHconf)
- val partRDD = parts.mapPartitions { iter =>
- val hconf = XmlSerializer.deserialize(serializedHconf).asInstanceOf[HiveConf]
- val deserializer = partDesc.getDeserializerClass().newInstance()
- deserializer.initialize(hconf, partDesc.getProperties())
-
- // Get partition field info
- val partSpec = partDesc.getPartSpec()
- val partProps = partDesc.getProperties()
-
- val partCols = partProps.getProperty(META_TABLE_PARTITION_COLUMNS)
- // Partitioning keys are delimited by "/"
- val partKeys = partCols.trim().split("/")
- // 'partValues[i]' contains the value for the partitioning key at 'partKeys[i]'.
- val partValues = new ArrayList[String]
- partKeys.foreach { key =>
- if (partSpec == null) {
- partValues.add(new String)
- } else {
- partValues.add(new String(partSpec.get(key)))
- }
+ private def addFilterExprToConf(hiveConf: HiveConf, hiveTableScanOp: HiveTableScanOperator) {
+ val tableScanDesc = hiveTableScanOp.getConf()
+ if (tableScanDesc == null) return
+
+ val rowSchema = hiveTableScanOp.getSchema
+ if (rowSchema != null) {
+ // Add column names to the HiveConf.
+ val columnNames = new StringBuilder
+ for (columnInfo <- rowSchema.getSignature()) {
+ if (columnNames.length > 0) {
+ columnNames.append(",")
}
-
- val rowWithPartArr = new Array[Object](2)
- // Map each tuple to a row object
- iter.map { value =>
- val deserializedRow = deserializer.deserialize(value) // LazyStruct
- rowWithPartArr.update(0, deserializedRow)
- rowWithPartArr.update(1, partValues)
- rowWithPartArr.asInstanceOf[Object]
+ columnNames.append(columnInfo.getInternalName())
+ }
+ val columnNamesString = columnNames.toString()
+ hiveConf.set(Constants.LIST_COLUMNS, columnNamesString)
+
+ // Add column types to the HiveConf.
+ val columnTypes = new StringBuilder
+ for (columnInfo <- rowSchema.getSignature()) {
+ if (columnTypes.length > 0) {
+ columnTypes.append(",")
}
+ columnTypes.append(columnInfo.getType().getTypeName())
}
- rdds(i) = partRDD.asInstanceOf[RDD[Any]]
- i += 1
+ val columnTypesString = columnTypes.toString()
+ hiveConf.set(Constants.LIST_COLUMN_TYPES, columnTypesString)
}
- // Even if we don't use any partitions, we still need an empty RDD
- if (rdds.size == 0) {
- SharkEnv.sc.makeRDD(Seq[Object]())
- } else {
- new UnionRDD(rdds(0).context, rdds)
- }
- }
- private def createHadoopRdd(path: String, ifc: Class[InputFormat[Writable, Writable]])
- : RDD[Writable] = {
- val conf = new JobConf(localHconf)
- if (tableDesc != null) {
- Utilities.copyTableJobPropertiesToConf(tableDesc, conf)
- }
- new HiveInputFormat() {
- def doPushFilters() {
- pushFilters(conf, hiveOp)
- }
- }.doPushFilters()
- FileInputFormat.setInputPaths(conf, path)
- val bufferSize = System.getProperty("spark.buffer.size", "65536")
- conf.set("io.file.buffer.size", bufferSize)
-
- // Set s3/s3n credentials. Setting them in conf ensures the settings propagate
- // from Spark's master all the way to Spark's slaves.
- var s3varsSet = false
- val s3vars = Seq("fs.s3n.awsAccessKeyId", "fs.s3n.awsSecretAccessKey",
- "fs.s3.awsAccessKeyId", "fs.s3.awsSecretAccessKey").foreach { variableName =>
- if (localHconf.get(variableName) != null) {
- s3varsSet = true
- conf.set(variableName, localHconf.get(variableName))
- }
- }
+ // Push down predicate filters.
+ val filterExprNode = tableScanDesc.getFilterExpr()
+ if (filterExprNode != null) {
+ val filterText = filterExprNode.getExprString()
+ hiveConf.set(TableScanDesc.FILTER_TEXT_CONF_STR, filterText)
+ logDebug("Filter text: " + filterText)
- // If none of the s3 credentials are set in Hive conf, try use the environmental
- // variables for credentials.
- if (!s3varsSet) {
- Utils.setAwsCredentials(conf)
+ val filterExprNodeSerialized = Utilities.serializeExpression(filterExprNode)
+ hiveConf.set(TableScanDesc.FILTER_EXPR_CONF_STR, filterExprNodeSerialized)
+ logDebug("Filter expression: " + filterExprNodeSerialized)
}
-
- // Choose the minimum number of splits. If mapred.map.tasks is set, use that unless
- // it is smaller than what Spark suggests.
- val minSplits = math.max(localHconf.getInt("mapred.map.tasks", 1), SharkEnv.sc.defaultMinSplits)
- val rdd = SharkEnv.sc.hadoopRDD(conf, ifc, classOf[Writable], classOf[Writable], minSplits)
-
- // Only take the value (skip the key) because Hive works only with values.
- rdd.map(_._2)
}
+
}
diff --git a/src/main/scala/shark/execution/TerminalOperator.scala b/src/main/scala/shark/execution/TerminalOperator.scala
index 1a6400d7..7aa8afc8 100755
--- a/src/main/scala/shark/execution/TerminalOperator.scala
+++ b/src/main/scala/shark/execution/TerminalOperator.scala
@@ -23,6 +23,7 @@ import scala.reflect.BeanProperty
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.ql.exec.{FileSinkOperator => HiveFileSinkOperator}
+import org.apache.hadoop.hive.ql.plan.FileSinkDesc
/**
@@ -31,7 +32,7 @@ import org.apache.hadoop.hive.ql.exec.{FileSinkOperator => HiveFileSinkOperator}
* - cache query output
* - return query as RDD directly (without materializing it)
*/
-class TerminalOperator extends UnaryOperator[HiveFileSinkOperator] {
+class TerminalOperator extends UnaryOperator[FileSinkDesc] {
// Create a local copy of hconf and hiveSinkOp so we can XML serialize it.
@BeanProperty var localHiveOp: HiveFileSinkOperator = _
@@ -39,12 +40,12 @@ class TerminalOperator extends UnaryOperator[HiveFileSinkOperator] {
@BeanProperty val now = new Date()
override def initializeOnMaster() {
+ super.initializeOnMaster()
localHconf = super.hconf
// Set parent to null so we won't serialize the entire query plan.
- hiveOp.setParentOperators(null)
- hiveOp.setChildOperators(null)
- hiveOp.setInputObjInspectors(null)
- localHiveOp = hiveOp
+ localHiveOp.setParentOperators(null)
+ localHiveOp.setChildOperators(null)
+ localHiveOp.setInputObjInspectors(null)
}
override def initializeOnSlave() {
diff --git a/src/main/scala/shark/execution/UDTFOperator.scala b/src/main/scala/shark/execution/UDTFOperator.scala
index db59f9cc..5782f370 100755
--- a/src/main/scala/shark/execution/UDTFOperator.scala
+++ b/src/main/scala/shark/execution/UDTFOperator.scala
@@ -23,14 +23,14 @@ import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
import scala.reflect.BeanProperty
-import org.apache.hadoop.hive.ql.exec.{UDTFOperator => HiveUDTFOperator}
import org.apache.hadoop.hive.ql.plan.UDTFDesc
import org.apache.hadoop.hive.ql.udf.generic.Collector
-import org.apache.hadoop.hive.serde2.objectinspector.{ ObjectInspector,
- StandardStructObjectInspector, StructField, StructObjectInspector }
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector
+import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector
+import org.apache.hadoop.hive.serde2.objectinspector.StructField
-class UDTFOperator extends UnaryOperator[HiveUDTFOperator] {
+class UDTFOperator extends UnaryOperator[UDTFDesc] {
@BeanProperty var conf: UDTFDesc = _
@@ -38,9 +38,14 @@ class UDTFOperator extends UnaryOperator[HiveUDTFOperator] {
@transient var soi: StandardStructObjectInspector = _
@transient var inputFields: JavaList[_ <: StructField] = _
@transient var collector: UDTFCollector = _
+ @transient var outputObjInspector: ObjectInspector = _
override def initializeOnMaster() {
- conf = hiveOp.getConf()
+ super.initializeOnMaster()
+
+ conf = desc
+
+ initializeOnSlave()
}
override def initializeOnSlave() {
@@ -56,9 +61,11 @@ class UDTFOperator extends UnaryOperator[HiveUDTFOperator] {
}.toArray
objToSendToUDTF = new Array[java.lang.Object](inputFields.size)
- val udtfOutputOI = conf.getGenericUDTF().initialize(udtfInputOIs)
+ outputObjInspector = conf.getGenericUDTF().initialize(udtfInputOIs)
}
+ override def outputObjectInspector() = outputObjInspector
+
override def processPartition(split: Int, iter: Iterator[_]): Iterator[_] = {
iter.flatMap { row =>
explode(row)
diff --git a/src/main/scala/shark/execution/UnionOperator.scala b/src/main/scala/shark/execution/UnionOperator.scala
index 2e46a004..e332739e 100755
--- a/src/main/scala/shark/execution/UnionOperator.scala
+++ b/src/main/scala/shark/execution/UnionOperator.scala
@@ -19,15 +19,13 @@ package shark.execution
import java.util.{ArrayList, List => JavaList}
-import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
import scala.reflect.BeanProperty
-import org.apache.hadoop.hive.ql.exec.{UnionOperator => HiveUnionOperator}
+import org.apache.hadoop.hive.ql.plan.UnionDesc
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ReturnObjectInspectorResolver
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory
-import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector
import org.apache.hadoop.hive.serde2.objectinspector.StructField
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector
@@ -40,23 +38,29 @@ import shark.execution.serialization.OperatorSerializationWrapper
* A union operator. If the incoming data are of different type, the union
* operator transforms the incoming data into the same type.
*/
-class UnionOperator extends NaryOperator[HiveUnionOperator] {
+class UnionOperator extends NaryOperator[UnionDesc] {
- @transient var parentFields: ArrayBuffer[JavaList[_ <: StructField]] = _
- @transient var parentObjInspectors: ArrayBuffer[StructObjectInspector] = _
+ @transient var parentFields: Seq[JavaList[_ <: StructField]] = _
+ @transient var parentObjInspectors: Seq[StructObjectInspector] = _
@transient var columnTypeResolvers: Array[ReturnObjectInspectorResolver] = _
+ @transient var outputObjInspector: ObjectInspector = _
@BeanProperty var needsTransform: Array[Boolean] = _
@BeanProperty var numParents: Int = _
override def initializeOnMaster() {
+ super.initializeOnMaster()
numParents = parentOperators.size
- // Use reflection to get the needsTransform boolean array.
- val needsTransformField = hiveOp.getClass.getDeclaredField("needsTransform")
- needsTransformField.setAccessible(true)
- needsTransform = needsTransformField.get(hiveOp).asInstanceOf[Array[Boolean]]
-
+ // whether we need to do transformation for each parent
+ var parents = parentOperators.length
+ var outputOI = outputObjectInspector()
+ needsTransform = Array.tabulate[Boolean](objectInspectors.length) { i =>
+ // ObjectInspectors created by the ObjectInspectorFactory,
+ // which take the same ref if equals
+ objectInspectors(i) != outputOI
+ }
+
initializeOnSlave()
}
@@ -82,18 +86,20 @@ class UnionOperator extends NaryOperator[HiveUnionOperator] {
}
val outputFieldOIs = columnTypeResolvers.map(_.get())
- val outputObjInspector = ObjectInspectorFactory.getStandardStructObjectInspector(
- columnNames, outputFieldOIs.toList)
+ outputObjInspector = ObjectInspectorFactory.getStandardStructObjectInspector(
+ columnNames, outputFieldOIs.toList)
// whether we need to do transformation for each parent
// We reuse needsTransform from Hive because the comparison of object
// inspectors are hard once we send object inspectors over the wire.
needsTransform.zipWithIndex.filter(_._1).foreach { case(transform, p) =>
- logInfo("Union Operator needs to transform row from parent[%d] from %s to %s".format(
- p, objectInspectors(p), outputObjInspector))
+ logDebug("Union Operator needs to transform row from parent[%d] from %s to %s".format(
+ p, objectInspectors(p), outputObjInspector))
}
}
+ override def outputObjectInspector() = outputObjInspector
+
/**
* Override execute. The only thing we need to call is combineMultipleRdds().
*/
diff --git a/src/main/scala/shark/execution/optimization/ColumnPruner.scala b/src/main/scala/shark/execution/optimization/ColumnPruner.scala
index 4ab62194..38efb328 100644
--- a/src/main/scala/shark/execution/optimization/ColumnPruner.scala
+++ b/src/main/scala/shark/execution/optimization/ColumnPruner.scala
@@ -1,9 +1,26 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package shark.execution.optimization
import java.util.BitSet
import java.util.{List => JList}
-import scala.collection.JavaConversions.{asScalaBuffer, bufferAsJavaList, collectionAsScalaIterable}
+import scala.collection.JavaConversions.{asScalaBuffer, collectionAsScalaIterable}
import scala.collection.mutable.{Set, HashSet}
import org.apache.hadoop.hive.ql.exec.GroupByPreShuffleOperator
@@ -14,16 +31,16 @@ import org.apache.hadoop.hive.ql.plan.{FilterDesc, MapJoinDesc, ReduceSinkDesc}
import shark.execution.{FilterOperator, JoinOperator,
MapJoinOperator, Operator, ReduceSinkOperator,
SelectOperator, TopOperator}
-import shark.memstore2.{ColumnarStruct, TablePartitionIterator}
class ColumnPruner(@transient op: TopOperator[_], @transient tbl: Table) extends Serializable {
val columnsUsed: BitSet = {
val colsToKeep = computeColumnsToKeep()
- val allColumns = tbl.getAllCols().map(x => x.getName())
- var b = new BitSet()
- for (i <- Range(0, allColumns.size()) if (colsToKeep.contains(allColumns(i)))) {
+ // No need to prune partition columns - Hive does that for us.
+ val allColumns = tbl.getCols().map(x => x.getName())
+ val b = new BitSet()
+ for (i <- Range(0, allColumns.size) if colsToKeep.contains(allColumns(i))) {
b.set(i, true)
}
b
@@ -38,11 +55,15 @@ class ColumnPruner(@transient op: TopOperator[_], @transient tbl: Table) extends
/**
* Computes the column names that are referenced in the Query
*/
- private def computeColumnsToKeep(op: Operator[_],
- cols: HashSet[String], parentOp: Operator[_] = null): Unit = {
+ private def computeColumnsToKeep(
+ op: Operator[_],
+ cols: HashSet[String],
+ parentOp: Operator[_] = null) {
+
def nullGuard[T](s: JList[T]): Seq[T] = {
if (s == null) Seq[T]() else s
}
+
op match {
case selOp: SelectOperator => {
val cnf:SelectDesc = selOp.getConf
@@ -67,7 +88,7 @@ class ColumnPruner(@transient op: TopOperator[_], @transient tbl: Table) extends
if (cnf != null) {
val keyEvals = nullGuard(cnf.getKeyCols)
val valEvals = nullGuard(cnf.getValueCols)
- val evals = (HashSet() ++ keyEvals ++ valEvals)
+ val evals = HashSet() ++ keyEvals ++ valEvals
cols ++= evals.flatMap(x => nullGuard(x.getCols))
}
}
@@ -76,7 +97,7 @@ class ColumnPruner(@transient op: TopOperator[_], @transient tbl: Table) extends
if (cnf != null) {
val keyEvals = cnf.getKeys.values
val valEvals = cnf.getExprs.values
- val evals = (HashSet() ++ keyEvals ++ valEvals)
+ val evals = HashSet() ++ keyEvals ++ valEvals
cols ++= evals.flatMap(x => x).flatMap(x => nullGuard(x.getCols))
}
}
diff --git a/src/main/scala/shark/execution/package.scala b/src/main/scala/shark/execution/package.scala
index e99b4766..f8251c8a 100755
--- a/src/main/scala/shark/execution/package.scala
+++ b/src/main/scala/shark/execution/package.scala
@@ -17,17 +17,17 @@
package shark
+import scala.language.implicitConversions
+
import shark.execution.serialization.KryoSerializationWrapper
import shark.execution.serialization.OperatorSerializationWrapper
-
package object execution {
- type HiveOperator = org.apache.hadoop.hive.ql.exec.Operator[_]
+ type HiveDesc = java.io.Serializable // XXXDesc in Hive is the subclass of Serializable
- implicit def opSerWrapper2op[T <: Operator[_ <: HiveOperator]](
+ implicit def opSerWrapper2op[T <: Operator[_ <: HiveDesc]](
wrapper: OperatorSerializationWrapper[T]): T = wrapper.value
implicit def kryoWrapper2object[T](wrapper: KryoSerializationWrapper[T]): T = wrapper.value
}
-
diff --git a/src/main/scala/shark/execution/serialization/HiveConfSerializer.scala b/src/main/scala/shark/execution/serialization/HiveConfSerializer.scala
deleted file mode 100644
index db612e4c..00000000
--- a/src/main/scala/shark/execution/serialization/HiveConfSerializer.scala
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Copyright (C) 2012 The Regents of The University California.
- * All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package shark.execution.serialization
-
-import java.io.ByteArrayInputStream
-import java.io.ByteArrayOutputStream
-import java.io.DataInputStream
-import java.io.DataOutputStream
-
-import com.ning.compress.lzf.LZFEncoder
-import com.ning.compress.lzf.LZFDecoder
-
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.hive.conf.HiveConf
-import org.apache.hadoop.io.Text
-
-
-object HiveConfSerializer {
-
- def serialize(hConf: HiveConf): Array[Byte] = {
- val os = new ByteArrayOutputStream
- val dos = new DataOutputStream(os)
- val auxJars = hConf.getAuxJars()
- Text.writeString(dos, if(auxJars == null) "" else auxJars)
- hConf.write(dos)
- LZFEncoder.encode(os.toByteArray())
- }
-
- def deserialize(b: Array[Byte]): HiveConf = {
- val is = new ByteArrayInputStream(LZFDecoder.decode(b))
- val dis = new DataInputStream(is)
- val auxJars = Text.readString(dis)
- val conf = new HiveConf
- conf.readFields(dis)
- if(auxJars.equals("").unary_!)
- conf.setAuxJars(auxJars)
- conf
- }
-}
diff --git a/src/main/scala/shark/execution/serialization/HiveStructDeserializer.scala b/src/main/scala/shark/execution/serialization/HiveStructDeserializer.scala
index 9589a1e9..2a54fbf3 100644
--- a/src/main/scala/shark/execution/serialization/HiveStructDeserializer.scala
+++ b/src/main/scala/shark/execution/serialization/HiveStructDeserializer.scala
@@ -23,8 +23,6 @@ package org.apache.hadoop.hive.serde2.binarysortable
import java.io.IOException
import java.util.{ArrayList => JArrayList}
-import scala.collection.JavaConversions._
-
import org.apache.hadoop.hive.serde2.SerDeException
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector
import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoUtils}
diff --git a/src/main/scala/shark/execution/serialization/HiveStructSerializer.scala b/src/main/scala/shark/execution/serialization/HiveStructSerializer.scala
index 5b99ba09..1f5544fb 100644
--- a/src/main/scala/shark/execution/serialization/HiveStructSerializer.scala
+++ b/src/main/scala/shark/execution/serialization/HiveStructSerializer.scala
@@ -22,8 +22,6 @@ package org.apache.hadoop.hive.serde2.binarysortable
import java.util.{List => JList}
-import scala.collection.JavaConversions._
-
import org.apache.hadoop.hive.serde2.objectinspector.{StructField, StructObjectInspector}
diff --git a/src/main/scala/shark/execution/serialization/JavaSerializer.scala b/src/main/scala/shark/execution/serialization/JavaSerializer.scala
index a98cb95c..df6ab31d 100644
--- a/src/main/scala/shark/execution/serialization/JavaSerializer.scala
+++ b/src/main/scala/shark/execution/serialization/JavaSerializer.scala
@@ -19,11 +19,12 @@ package shark.execution.serialization
import java.nio.ByteBuffer
+import org.apache.spark.SparkEnv
import org.apache.spark.serializer.{JavaSerializer => SparkJavaSerializer}
object JavaSerializer {
- @transient val ser = new SparkJavaSerializer
+ @transient val ser = new SparkJavaSerializer(SparkEnv.get.conf)
def serialize[T](o: T): Array[Byte] = {
ser.newInstance().serialize(o).array()
diff --git a/src/main/scala/shark/execution/serialization/KryoSerializer.scala b/src/main/scala/shark/execution/serialization/KryoSerializer.scala
index c4764979..0532fbcc 100644
--- a/src/main/scala/shark/execution/serialization/KryoSerializer.scala
+++ b/src/main/scala/shark/execution/serialization/KryoSerializer.scala
@@ -19,8 +19,10 @@ package shark.execution.serialization
import java.nio.ByteBuffer
+import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.serializer.{KryoSerializer => SparkKryoSerializer}
+import shark.SharkContext
/**
* Java object serialization using Kryo. This is much more efficient, but Kryo
@@ -29,7 +31,10 @@ import org.apache.spark.serializer.{KryoSerializer => SparkKryoSerializer}
*/
object KryoSerializer {
- @transient val ser = new SparkKryoSerializer
+ @transient lazy val ser: SparkKryoSerializer = {
+ val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
+ new SparkKryoSerializer(sparkConf)
+ }
def serialize[T](o: T): Array[Byte] = {
ser.newInstance().serialize(o).array()
diff --git a/src/main/scala/shark/execution/serialization/OperatorSerializationWrapper.scala b/src/main/scala/shark/execution/serialization/OperatorSerializationWrapper.scala
index 858ce182..19e383c4 100644
--- a/src/main/scala/shark/execution/serialization/OperatorSerializationWrapper.scala
+++ b/src/main/scala/shark/execution/serialization/OperatorSerializationWrapper.scala
@@ -17,7 +17,7 @@
package shark.execution.serialization
-import shark.execution.HiveOperator
+import shark.execution.HiveDesc
import shark.execution.Operator
@@ -28,7 +28,7 @@ import shark.execution.Operator
*
* Use OperatorSerializationWrapper(operator) to create a wrapper.
*/
-class OperatorSerializationWrapper[T <: Operator[_ <: HiveOperator]]
+class OperatorSerializationWrapper[T <: Operator[_ <: HiveDesc]]
extends Serializable with shark.LogHelper {
/** The operator we are going to serialize. */
@@ -69,9 +69,9 @@ class OperatorSerializationWrapper[T <: Operator[_ <: HiveOperator]]
object OperatorSerializationWrapper {
- def apply[T <: Operator[_ <: HiveOperator]](value: T): OperatorSerializationWrapper[T] = {
+ def apply[T <: Operator[_ <: HiveDesc]](value: T): OperatorSerializationWrapper[T] = {
val wrapper = new OperatorSerializationWrapper[T]
wrapper.value = value
wrapper
}
-}
\ No newline at end of file
+}
diff --git a/src/main/scala/shark/execution/serialization/ShuffleSerializer.scala b/src/main/scala/shark/execution/serialization/ShuffleSerializer.scala
index b2a2d014..e4eba584 100644
--- a/src/main/scala/shark/execution/serialization/ShuffleSerializer.scala
+++ b/src/main/scala/shark/execution/serialization/ShuffleSerializer.scala
@@ -22,7 +22,9 @@ import java.nio.ByteBuffer
import org.apache.hadoop.io.BytesWritable
-import org.apache.spark.serializer.{DeserializationStream, Serializer, SerializerInstance, SerializationStream}
+import org.apache.spark.SparkConf
+import org.apache.spark.serializer.DeserializationStream
+import org.apache.spark.serializer.{SerializationStream, Serializer, SerializerInstance}
import shark.execution.{ReduceKey, ReduceKeyReduceSide}
@@ -47,7 +49,11 @@ import shark.execution.{ReduceKey, ReduceKeyReduceSide}
* into a hash table. We want to reduce the size of the hash table. Having the BytesWritable wrapper
* would increase the size of the hash table by another 16 bytes per key-value pair.
*/
-class ShuffleSerializer extends Serializer {
+class ShuffleSerializer(conf: SparkConf) extends Serializer {
+
+ // A no-arg constructor since conf is not needed in this serializer.
+ def this() = this(null)
+
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance
}
diff --git a/src/main/scala/shark/execution/serialization/XmlSerializer.scala b/src/main/scala/shark/execution/serialization/XmlSerializer.scala
index 4c63efab..a533c812 100644
--- a/src/main/scala/shark/execution/serialization/XmlSerializer.scala
+++ b/src/main/scala/shark/execution/serialization/XmlSerializer.scala
@@ -17,8 +17,8 @@
package shark.execution.serialization
-import java.beans.{XMLDecoder, XMLEncoder, PersistenceDelegate}
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectOutput, ObjectInput}
+import java.beans.{XMLDecoder, XMLEncoder}
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import com.ning.compress.lzf.{LZFEncoder, LZFDecoder}
@@ -28,7 +28,7 @@ import org.apache.hadoop.hive.ql.exec.Utilities.EnumDelegate
import org.apache.hadoop.hive.ql.plan.GroupByDesc
import org.apache.hadoop.hive.ql.plan.PlanUtils.ExpressionTypes
-import shark.{SharkConfVars, SharkEnvSlave}
+import shark.SharkConfVars
/**
diff --git a/src/main/scala/shark/memstore2/CachePolicy.scala b/src/main/scala/shark/memstore2/CachePolicy.scala
new file mode 100644
index 00000000..27e29ff6
--- /dev/null
+++ b/src/main/scala/shark/memstore2/CachePolicy.scala
@@ -0,0 +1,227 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package shark.memstore2
+
+import java.util.concurrent.ConcurrentHashMap
+import java.util.LinkedHashMap
+import java.util.Map.Entry
+
+import scala.collection.JavaConversions._
+
+
+/**
+ * An general interface for pluggable cache eviction policies in Shark.
+ * One example of usage is to control persistance levels of RDDs that represent a table's
+ * Hive-partitions.
+ */
+trait CachePolicy[K, V] {
+
+ protected var _loadFunc: (K => V) = _
+
+ protected var _evictionFunc: (K, V) => Unit = _
+
+ protected var _maxSize: Int = -1
+
+ def initialize(
+ strArgs: Array[String],
+ fallbackMaxSize: Int,
+ loadFunc: K => V,
+ evictionFunc: (K, V) => Unit) {
+ _loadFunc = loadFunc
+ _evictionFunc = evictionFunc
+
+ // By default, only initialize the `maxSize` from user specifications.
+ strArgs.size match {
+ case 0 => _maxSize = fallbackMaxSize
+ case 1 => _maxSize = strArgs.head.toInt
+ case _ =>
+ throw new Exception("Accpted format: %s(maxSize: Int)".format(this.getClass.getName))
+ }
+ require(maxSize > 0, "Size given to cache eviction policy must be > 1")
+ }
+
+ def notifyPut(key: K, value: V): Unit
+
+ def notifyRemove(key: K): Unit
+
+ def notifyGet(key: K): Unit
+
+ def keysOfCachedEntries: Seq[K]
+
+ def maxSize: Int = _maxSize
+
+ // TODO(harvey): Call this in Shark's handling of ALTER TABLE TBLPROPERTIES.
+ def maxSize_= (newMaxSize: Int) = _maxSize = newMaxSize
+
+ def hitRate: Double
+
+ def evictionCount: Long
+}
+
+
+object CachePolicy {
+
+ def instantiateWithUserSpecs[K, V](
+ str: String,
+ fallbackMaxSize: Int,
+ loadFunc: K => V,
+ evictionFunc: (K, V) => Unit): CachePolicy[K, V] = {
+ val firstParenPos = str.indexOf('(')
+ if (firstParenPos == -1) {
+ val policy = Class.forName(str).newInstance.asInstanceOf[CachePolicy[K, V]]
+ policy.initialize(Array.empty[String], fallbackMaxSize, loadFunc, evictionFunc)
+ return policy
+ } else {
+ val classStr = str.slice(0, firstParenPos)
+ val strArgs = str.substring(firstParenPos + 1, str.lastIndexOf(')')).split(',')
+ val policy = Class.forName(classStr).newInstance.asInstanceOf[CachePolicy[K, V]]
+ policy.initialize(strArgs, fallbackMaxSize, loadFunc, evictionFunc)
+ return policy
+ }
+ }
+}
+
+
+/**
+ * A cache that never evicts entries.
+ */
+class CacheAllPolicy[K, V] extends CachePolicy[K, V] {
+
+ // Track the entries in the cache, so that keysOfCachedEntries() returns a valid result.
+ var cache = new ConcurrentHashMap[K, V]()
+
+ override def notifyPut(key: K, value: V) = cache.put(key, value)
+
+ override def notifyRemove(key: K) = cache.remove(key)
+
+ override def notifyGet(key: K) = Unit
+
+ override def keysOfCachedEntries: Seq[K] = cache.keySet.toSeq
+
+ override def hitRate = 1.0
+
+ override def evictionCount = 0L
+}
+
+
+class LRUCachePolicy[K, V] extends LinkedMapBasedPolicy[K, V] {
+
+ override def initialize(
+ strArgs: Array[String],
+ fallbackMaxSize: Int,
+ loadFunc: K => V,
+ evictionFunc: (K, V) => Unit) {
+ super.initialize(strArgs, fallbackMaxSize, loadFunc, evictionFunc)
+ _cache = new LinkedMapCache(true /* evictUsingAccessOrder */)
+ }
+
+}
+
+
+class FIFOCachePolicy[K, V] extends LinkedMapBasedPolicy[K, V] {
+
+ override def initialize(
+ strArgs: Array[String],
+ fallbackMaxSize: Int,
+ loadFunc: K => V,
+ evictionFunc: (K, V) => Unit) {
+ super.initialize(strArgs, fallbackMaxSize, loadFunc, evictionFunc)
+ _cache = new LinkedMapCache()
+ }
+
+}
+
+
+sealed abstract class LinkedMapBasedPolicy[K, V] extends CachePolicy[K, V] {
+
+ class LinkedMapCache(evictUsingAccessOrder: Boolean = false)
+ extends LinkedHashMap[K, V](maxSize, 0.75F, evictUsingAccessOrder) {
+
+ override def removeEldestEntry(eldest: Entry[K, V]): Boolean = {
+ val shouldRemove = (size() > maxSize)
+ if (shouldRemove) {
+ _evictionFunc(eldest.getKey, eldest.getValue)
+ _evictionCount += 1
+ }
+ return shouldRemove
+ }
+ }
+
+ protected var _cache: LinkedMapCache = _
+ protected var _isInitialized = false
+ protected var _hitCount: Long = 0L
+ protected var _missCount: Long = 0L
+ protected var _evictionCount: Long = 0L
+
+ override def initialize(
+ strArgs: Array[String],
+ fallbackMaxSize: Int,
+ loadFunc: K => V,
+ evictionFunc: (K, V) => Unit) {
+ super.initialize(strArgs, fallbackMaxSize, loadFunc, evictionFunc)
+ _isInitialized = true
+ }
+
+ override def notifyPut(key: K, value: V): Unit = {
+ assert(_isInitialized, "Must initialize() %s.".format(this.getClass.getName))
+ this.synchronized {
+ val oldValue = _cache.put(key, value)
+ if (oldValue != null) {
+ _evictionFunc(key, oldValue)
+ _evictionCount += 1
+ }
+ }
+ }
+
+ override def notifyRemove(key: K): Unit = {
+ assert(_isInitialized, "Must initialize() %s.".format(this.getClass.getName))
+ this.synchronized { _cache.remove(key) }
+ }
+
+ override def notifyGet(key: K): Unit = {
+ assert(_isInitialized, "Must initialize() %s.".format(this.getClass.getName))
+ this.synchronized {
+ if (_cache.contains(key)) {
+ _cache.get(key)
+ _hitCount += 1L
+ } else {
+ val loadedValue = _loadFunc(key)
+ _cache.put(key, loadedValue)
+ _missCount += 1L
+ }
+ }
+ }
+
+ override def keysOfCachedEntries: Seq[K] = {
+ assert(_isInitialized, "Must initialize() LRUCachePolicy.")
+ this.synchronized {
+ return _cache.keySet.toSeq
+ }
+ }
+
+ override def hitRate: Double = {
+ this.synchronized {
+ val requestCount = _missCount + _hitCount
+ val rate = if (requestCount == 0L) 1.0 else (_hitCount.toDouble / requestCount)
+ return rate
+ }
+ }
+
+ override def evictionCount = _evictionCount
+
+}
diff --git a/src/main/scala/shark/memstore2/CacheType.scala b/src/main/scala/shark/memstore2/CacheType.scala
index 13115415..ed1e1735 100644
--- a/src/main/scala/shark/memstore2/CacheType.scala
+++ b/src/main/scala/shark/memstore2/CacheType.scala
@@ -17,28 +17,50 @@
package shark.memstore2
+import shark.LogHelper
-object CacheType extends Enumeration {
+/*
+ * Enumerations and static helper functions for caches supported by Shark.
+ */
+object CacheType extends Enumeration with LogHelper {
+
+ /*
+ * The CacheTypes:
+ * - MEMORY: Stored in memory and on disk (i.e., cache is write-through). Persistent across Shark
+ * sessions. By default, all such tables are reloaded into memory on restart.
+ * - MEMORY_ONLY: Stored only in memory and dropped at the end of each Shark session.
+ * - TACHYON: A distributed storage system that manages an in-memory cache for sharing files and
+ RDDs across cluster frameworks.
+ * - NONE: Stored on disk (e.g., HDFS) and managed by Hive.
+ */
type CacheType = Value
- val none, heap, tachyon = Value
+ val MEMORY, MEMORY_ONLY, TACHYON, NONE = Value
- def shouldCache(c: CacheType): Boolean = (c != none)
+ def shouldCache(c: CacheType): Boolean = (c != NONE)
/** Get the cache type object from a string representation. */
def fromString(name: String): CacheType = {
- if (name == null || name == "") {
- none
+ if (name == null || name == "" || name.toLowerCase == "false") {
+ NONE
} else if (name.toLowerCase == "true") {
- heap
+ MEMORY
} else {
try {
- withName(name.toLowerCase)
+ if (name.toUpperCase == "HEAP") {
+ // Interpret 'HEAP' as 'MEMORY' to ensure backwards compatibility with Shark 0.8.0.
+ logWarning("The 'HEAP' cache type name is deprecated. Use 'MEMORY' instead.")
+ MEMORY
+ } else {
+ // Try to use Scala's Enumeration::withName() to interpret 'name'.
+ withName(name.toUpperCase)
+ }
} catch {
case e: java.util.NoSuchElementException => throw new InvalidCacheTypeException(name)
}
}
}
- class InvalidCacheTypeException(name: String) extends Exception("Invalid cache type " + name)
+ class InvalidCacheTypeException(name: String)
+ extends Exception("Invalid string representation of cache type: '%s'".format(name))
}
diff --git a/src/main/scala/shark/memstore2/ColumnarSerDe.scala b/src/main/scala/shark/memstore2/ColumnarSerDe.scala
index 4c8bef76..79c6f282 100644
--- a/src/main/scala/shark/memstore2/ColumnarSerDe.scala
+++ b/src/main/scala/shark/memstore2/ColumnarSerDe.scala
@@ -51,7 +51,7 @@ class ColumnarSerDe extends SerDe with LogHelper {
objectInspector = ColumnarStructObjectInspector(serDeParams)
// This null check is needed because Hive's SemanticAnalyzer.genFileSinkPlan() creates
- // an instance of the table's StructObjectInspector by creating an instance SerDe, which
+ // an instance of the table's StructObjectInspector by creating an instance of SerDe, which
// it initializes by passing a 'null' argument for 'conf'.
if (conf != null) {
var partitionSize = {
diff --git a/src/main/scala/shark/memstore2/ColumnarStructObjectInspector.scala b/src/main/scala/shark/memstore2/ColumnarStructObjectInspector.scala
index 02f799fe..67a99612 100644
--- a/src/main/scala/shark/memstore2/ColumnarStructObjectInspector.scala
+++ b/src/main/scala/shark/memstore2/ColumnarStructObjectInspector.scala
@@ -27,8 +27,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo
-import shark.{SharkConfVars}
-
class ColumnarStructObjectInspector(fields: JList[StructField]) extends StructObjectInspector {
@@ -60,7 +58,8 @@ object ColumnarStructObjectInspector {
for (i <- 0 until columnNames.size) {
val typeInfo = columnTypes.get(i)
val fieldOI = typeInfo.getCategory match {
- case Category.PRIMITIVE => PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(
+ case Category.PRIMITIVE =>
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(
typeInfo.asInstanceOf[PrimitiveTypeInfo].getPrimitiveCategory)
case _ => LazyFactory.createLazyObjectInspector(
typeInfo, serDeParams.getSeparators(), 1, serDeParams.getNullSequence(),
diff --git a/src/main/scala/shark/memstore2/LazySimpleSerDeWrapper.scala b/src/main/scala/shark/memstore2/LazySimpleSerDeWrapper.scala
new file mode 100644
index 00000000..2211d557
--- /dev/null
+++ b/src/main/scala/shark/memstore2/LazySimpleSerDeWrapper.scala
@@ -0,0 +1,49 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package shark.memstore2
+
+import java.util.{List => JList, Properties}
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.hive.serde2.{SerDe, SerDeStats}
+import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector
+import org.apache.hadoop.io.Writable
+
+
+class LazySimpleSerDeWrapper extends SerDe {
+
+ val _lazySimpleSerDe = new LazySimpleSerDe()
+
+ override def initialize(conf: Configuration, tbl: Properties) {
+ _lazySimpleSerDe.initialize(conf, tbl)
+ }
+
+ override def deserialize(blob: Writable): Object = _lazySimpleSerDe.deserialize(blob)
+
+ override def getSerDeStats(): SerDeStats = _lazySimpleSerDe.getSerDeStats()
+
+ override def getObjectInspector: ObjectInspector = _lazySimpleSerDe.getObjectInspector
+
+ override def getSerializedClass: Class[_ <: Writable] = _lazySimpleSerDe.getSerializedClass
+
+ override def serialize(obj: Object, objInspector: ObjectInspector): Writable = {
+ _lazySimpleSerDe.serialize(obj, objInspector)
+ }
+
+}
diff --git a/src/main/scala/shark/memstore2/MemoryMetadataManager.scala b/src/main/scala/shark/memstore2/MemoryMetadataManager.scala
index c180dd40..9d5ce7ab 100755
--- a/src/main/scala/shark/memstore2/MemoryMetadataManager.scala
+++ b/src/main/scala/shark/memstore2/MemoryMetadataManager.scala
@@ -17,104 +17,209 @@
package shark.memstore2
+import java.util.{HashMap=> JavaHashMap, Map => JavaMap}
import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConversions._
-import scala.collection.mutable.ConcurrentMap
+import scala.collection.concurrent
+
+import org.apache.hadoop.hive.ql.metadata.Hive
import org.apache.spark.rdd.{RDD, UnionRDD}
-import org.apache.spark.storage.StorageLevel
-import shark.SharkConfVars
+import shark.{LogHelper, SharkEnv}
+import shark.execution.RDDUtils
+import shark.util.HiveUtils
-class MemoryMetadataManager {
+class MemoryMetadataManager extends LogHelper {
- private val _keyToRdd: ConcurrentMap[String, RDD[_]] =
- new ConcurrentHashMap[String, RDD[_]]()
+ // Set of tables, from databaseName.tableName to Table object.
+ private val _tables: concurrent.Map[String, Table] =
+ new ConcurrentHashMap[String, Table]()
- private val _keyToStats: ConcurrentMap[String, collection.Map[Int, TablePartitionStats]] =
- new ConcurrentHashMap[String, collection.Map[Int, TablePartitionStats]]
+ def isHivePartitioned(databaseName: String, tableName: String): Boolean = {
+ val tableKey = MemoryMetadataManager.makeTableKey(databaseName, tableName)
+ _tables.get(tableKey) match {
+ case Some(table) => table.isInstanceOf[PartitionedMemoryTable]
+ case None => false
+ }
+ }
- def contains(key: String) = _keyToRdd.contains(key.toLowerCase)
+ def containsTable(databaseName: String, tableName: String): Boolean = {
+ _tables.contains(MemoryMetadataManager.makeTableKey(databaseName, tableName))
+ }
- def put(key: String, rdd: RDD[_]) {
- _keyToRdd(key.toLowerCase) = rdd
+ def createMemoryTable(
+ databaseName: String,
+ tableName: String,
+ cacheMode: CacheType.CacheType): MemoryTable = {
+ val tableKey = MemoryMetadataManager.makeTableKey(databaseName, tableName)
+ val newTable = new MemoryTable(databaseName, tableName, cacheMode)
+ _tables.put(tableKey, newTable)
+ newTable
}
- def get(key: String): Option[RDD[_]] = _keyToRdd.get(key.toLowerCase)
+ def createPartitionedMemoryTable(
+ databaseName: String,
+ tableName: String,
+ cacheMode: CacheType.CacheType,
+ tblProps: JavaMap[String, String]
+ ): PartitionedMemoryTable = {
+ val tableKey = MemoryMetadataManager.makeTableKey(databaseName, tableName)
+ val newTable = new PartitionedMemoryTable(databaseName, tableName, cacheMode)
+ // Determine the cache policy to use and read any user-specified cache settings.
+ val cachePolicyStr = tblProps.getOrElse(SharkTblProperties.CACHE_POLICY.varname,
+ SharkTblProperties.CACHE_POLICY.defaultVal)
+ val maxCacheSize = tblProps.getOrElse(SharkTblProperties.MAX_PARTITION_CACHE_SIZE.varname,
+ SharkTblProperties.MAX_PARTITION_CACHE_SIZE.defaultVal).toInt
+ newTable.setPartitionCachePolicy(cachePolicyStr, maxCacheSize)
+
+ _tables.put(tableKey, newTable)
+ newTable
+ }
- def putStats(key: String, stats: collection.Map[Int, TablePartitionStats]) {
- _keyToStats.put(key.toLowerCase, stats)
+ def getTable(databaseName: String, tableName: String): Option[Table] = {
+ _tables.get(MemoryMetadataManager.makeTableKey(databaseName, tableName))
}
- def getStats(key: String): Option[collection.Map[Int, TablePartitionStats]] = {
- _keyToStats.get(key.toLowerCase)
+ def getMemoryTable(databaseName: String, tableName: String): Option[MemoryTable] = {
+ val tableKey = MemoryMetadataManager.makeTableKey(databaseName, tableName)
+ val tableOpt = _tables.get(tableKey)
+ if (tableOpt.isDefined) {
+ assert(tableOpt.get.isInstanceOf[MemoryTable],
+ "getMemoryTable() called for a partitioned table.")
+ }
+ tableOpt.asInstanceOf[Option[MemoryTable]]
}
- /**
- * Find all keys that are strings. Used to drop tables after exiting.
- */
- def getAllKeyStrings(): Seq[String] = {
- _keyToRdd.keys.collect { case k: String => k } toSeq
+ def getPartitionedTable(
+ databaseName: String,
+ tableName: String): Option[PartitionedMemoryTable] = {
+ val tableKey = MemoryMetadataManager.makeTableKey(databaseName, tableName)
+ val tableOpt = _tables.get(tableKey)
+ if (tableOpt.isDefined) {
+ assert(tableOpt.get.isInstanceOf[PartitionedMemoryTable],
+ "getPartitionedTable() called for a non-partitioned table.")
+ }
+ tableOpt.asInstanceOf[Option[PartitionedMemoryTable]]
+ }
+
+ def renameTable(databaseName: String, oldName: String, newName: String) {
+ if (containsTable(databaseName, oldName)) {
+ val oldTableKey = MemoryMetadataManager.makeTableKey(databaseName, oldName)
+ val newTableKey = MemoryMetadataManager.makeTableKey(databaseName, newName)
+
+ val tableValueEntry = _tables.remove(oldTableKey).get
+ tableValueEntry.tableName = newTableKey
+
+ _tables.put(newTableKey, tableValueEntry)
+ }
}
/**
- * Used to drop an RDD from the Spark in-memory cache and/or disk. All metadata
- * (e.g. entry in '_keyToStats') about the RDD that's tracked by Shark is deleted as well.
+ * Used to drop a table from Spark in-memory cache and/or disk. All metadata is deleted as well.
+ *
+ * Note that this is always used in conjunction with a dropTableFromMemory() for handling
+ *'shark.cache' property changes in an ALTER TABLE command, or to finish off a DROP TABLE command
+ * after the table has been deleted from the Hive metastore.
*
- * @param key Used to fetch the an RDD value from '_keyToRDD'.
- * @return Option::isEmpty() is true if there is no RDD value corresponding to 'key' in
- * '_keyToRDD'. Otherwise, returns a reference to the RDD that was unpersist()'ed.
+ * @return Option::isEmpty() is true of there is no MemoryTable (and RDD) corresponding to 'key'
+ * in _keyToMemoryTable. For tables that are Hive-partitioned, the RDD returned will be a
+ * UnionRDD comprising RDDs that back the table's Hive-partitions.
*/
- def unpersist(key: String): Option[RDD[_]] = {
- def unpersistRDD(rdd: RDD[_]): Unit = {
- rdd match {
- case u: UnionRDD[_] => {
- // Recursively unpersist() all RDDs that compose the UnionRDD.
- u.unpersist()
- u.rdds.foreach {
- r => unpersistRDD(r)
- }
+ def removeTable(
+ databaseName: String,
+ tableName: String): Option[RDD[_]] = {
+ val tableKey = MemoryMetadataManager.makeTableKey(databaseName, tableName)
+ val tableValueOpt: Option[Table] = _tables.remove(tableKey)
+ tableValueOpt.flatMap(tableValue => MemoryMetadataManager.unpersistRDDsForTable(tableValue))
+ }
+
+ def shutdown() {
+ val db = Hive.get()
+ for (table <- _tables.values) {
+ table.cacheMode match {
+ case CacheType.MEMORY => {
+ dropTableFromMemory(db, table.databaseName, table.tableName)
+ }
+ case CacheType.MEMORY_ONLY => HiveUtils.dropTableInHive(table.tableName, db.getConf)
+ case _ => {
+ // No need to handle Hive or Tachyon tables, which are persistent and managed by their
+ // respective systems.
+ Unit
}
- case r => r.unpersist()
}
}
- // Remove RDD's entry from Shark metadata. This also fetches a reference to the RDD object
- // corresponding to the argument for 'key'.
- val rddValue = _keyToRdd.remove(key.toLowerCase())
- _keyToStats.remove(key)
- // Unpersist the RDD using the nested helper fn above.
- rddValue match {
- case Some(rdd) => unpersistRDD(rdd)
- case None => Unit
+ }
+
+ /**
+ * Drops a table from the Shark cache. However, Shark properties needed for table recovery
+ * (see TableRecovery#reloadRdds()) won't be removed.
+ * After this method completes, the table can still be scanned from disk.
+ */
+ def dropTableFromMemory(
+ db: Hive,
+ databaseName: String,
+ tableName: String) {
+ getTable(databaseName, tableName).foreach { sharkTable =>
+ db.setCurrentDatabase(databaseName)
+ val hiveTable = db.getTable(databaseName, tableName)
+ // Refresh the Hive `db`.
+ db.alterTable(tableName, hiveTable)
+ // Unpersist the table's RDDs from memory.
+ removeTable(databaseName, tableName)
}
- rddValue
}
}
object MemoryMetadataManager {
- /** Return a StorageLevel corresponding to its String name. */
- def getStorageLevelFromString(s: String): StorageLevel = {
- if (s == null || s == "") {
- getStorageLevelFromString(SharkConfVars.STORAGE_LEVEL.defaultVal)
- } else {
- s.toUpperCase match {
- case "NONE" => StorageLevel.NONE
- case "DISK_ONLY" => StorageLevel.DISK_ONLY
- case "DISK_ONLY_2" => StorageLevel.DISK_ONLY_2
- case "MEMORY_ONLY" => StorageLevel.MEMORY_ONLY
- case "MEMORY_ONLY_2" => StorageLevel.MEMORY_ONLY_2
- case "MEMORY_ONLY_SER" => StorageLevel.MEMORY_ONLY_SER
- case "MEMORY_ONLY_SER_2" => StorageLevel.MEMORY_ONLY_SER_2
- case "MEMORY_AND_DISK" => StorageLevel.MEMORY_AND_DISK
- case "MEMORY_AND_DISK_2" => StorageLevel.MEMORY_AND_DISK_2
- case "MEMORY_AND_DISK_SER" => StorageLevel.MEMORY_AND_DISK_SER
- case "MEMORY_AND_DISK_SER_2" => StorageLevel.MEMORY_AND_DISK_SER_2
- case _ => throw new IllegalArgumentException("Unrecognized storage level: " + s)
+ def unpersistRDDsForTable(table: Table): Option[RDD[_]] = {
+ table match {
+ case partitionedTable: PartitionedMemoryTable => {
+ // unpersist() all RDDs for all Hive-partitions.
+ val unpersistedRDDs = partitionedTable.keyToPartitions.values.map(rdd =>
+ RDDUtils.unpersistRDD(rdd)).asInstanceOf[Seq[RDD[Any]]]
+ if (unpersistedRDDs.size > 0) {
+ val unionedRDD = new UnionRDD(unpersistedRDDs.head.context, unpersistedRDDs)
+ Some(unionedRDD)
+ } else {
+ None
+ }
}
+ case memoryTable: MemoryTable => Some(RDDUtils.unpersistRDD(memoryTable.getRDD.get))
+ }
+ }
+
+ // Returns a key of the form "databaseName.tableName" that uniquely identifies a Shark table.
+ // For example, it's used to track a table's RDDs in MemoryMetadataManager and table paths in the
+ // Tachyon table warehouse.
+ def makeTableKey(databaseName: String, tableName: String): String = {
+ (databaseName + '.' + tableName).toLowerCase
+ }
+
+ /**
+ * Return a representation of the partition key in the string format:
+ * 'col1=value1/col2=value2/.../colN=valueN'
+ */
+ def makeHivePartitionKeyStr(
+ partitionCols: Seq[String],
+ partColToValue: JavaMap[String, String]): String = {
+ partitionCols.map(col => "%s=%s".format(col, partColToValue(col))).mkString("/")
+ }
+
+ /**
+ * Returns a (partition column name -> value) mapping by parsing a `keyStr` of the format
+ * 'col1=value1/col2=value2/.../colN=valueN', created by makeHivePartitionKeyStr() above.
+ */
+ def parseHivePartitionKeyStr(keyStr: String): JavaMap[String, String] = {
+ val partitionSpec = new JavaHashMap[String, String]()
+ for (pair <- keyStr.split("/")) {
+ val pairSplit = pair.split("=")
+ partitionSpec.put(pairSplit(0), pairSplit(1))
}
+ partitionSpec
}
}
diff --git a/src/main/scala/shark/memstore2/MemoryTable.scala b/src/main/scala/shark/memstore2/MemoryTable.scala
new file mode 100644
index 00000000..1a971d4c
--- /dev/null
+++ b/src/main/scala/shark/memstore2/MemoryTable.scala
@@ -0,0 +1,87 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package shark.memstore2
+
+import org.apache.spark.rdd.RDD
+
+import scala.collection.mutable.{Buffer, HashMap}
+
+import shark.execution.RDDUtils
+
+
+/**
+ * A metadata container for a table in Shark that's backed by an RDD.
+ */
+private[shark] class MemoryTable(
+ databaseName: String,
+ tableName: String,
+ cacheMode: CacheType.CacheType)
+ extends Table(databaseName, tableName, cacheMode) {
+
+ private var _rddValueOpt: Option[RDDValue] = None
+
+ /**
+ * Sets the RDD and stats fields the `_rddValueOpt`. Used for INSERT/LOAD OVERWRITE.
+ * @param newRDD The table's data.
+ * @param newStats Stats for each TablePartition in `newRDD`.
+ * @return The previous (RDD, stats) pair for this table.
+ */
+ def put(
+ newRDD: RDD[TablePartition],
+ newStats: collection.Map[Int, TablePartitionStats] = new HashMap[Int, TablePartitionStats]()
+ ): Option[(RDD[TablePartition], collection.Map[Int, TablePartitionStats])] = {
+ val prevRDDAndStatsOpt = _rddValueOpt.map(_.toTuple)
+ if (_rddValueOpt.isDefined) {
+ _rddValueOpt.foreach { rddValue =>
+ rddValue.rdd = newRDD
+ rddValue.stats = newStats
+ }
+ } else {
+ _rddValueOpt = Some(new RDDValue(newRDD, newStats))
+ }
+ prevRDDAndStatsOpt
+ }
+
+ /**
+ * Used for append operations, such as INSERT and LOAD INTO.
+ *
+ * @param newRDD Data to append to the table.
+ * @param newStats Stats for each TablePartition in `newRDD`.
+ * @return The previous (RDD, stats) pair for this table.
+ */
+ def update(
+ newRDD: RDD[TablePartition],
+ newStats: Buffer[(Int, TablePartitionStats)]
+ ): Option[(RDD[TablePartition], collection.Map[Int, TablePartitionStats])] = {
+ val prevRDDAndStatsOpt = _rddValueOpt.map(_.toTuple)
+ if (_rddValueOpt.isDefined) {
+ val (prevRDD, prevStats) = (prevRDDAndStatsOpt.get._1, prevRDDAndStatsOpt.get._2)
+ val updatedRDDValue = _rddValueOpt.get
+ updatedRDDValue.rdd = RDDUtils.unionAndFlatten(prevRDD, newRDD)
+ updatedRDDValue.stats = Table.mergeStats(newStats, prevStats).toMap
+ } else {
+ put(newRDD, newStats.toMap)
+ }
+ prevRDDAndStatsOpt
+ }
+
+ def getRDD = _rddValueOpt.map(_.rdd)
+
+ def getStats = _rddValueOpt.map(_.stats)
+
+}
diff --git a/src/main/scala/shark/memstore2/PartitionedMemoryTable.scala b/src/main/scala/shark/memstore2/PartitionedMemoryTable.scala
new file mode 100644
index 00000000..b6bd8ae6
--- /dev/null
+++ b/src/main/scala/shark/memstore2/PartitionedMemoryTable.scala
@@ -0,0 +1,151 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package shark.memstore2
+
+import java.util.concurrent.{ConcurrentHashMap => ConcurrentJavaHashMap}
+
+import scala.collection.JavaConversions._
+import scala.collection.concurrent
+import scala.collection.mutable.{Buffer, HashMap}
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+
+import shark.execution.RDDUtils
+
+
+/**
+ * A metadata container for partitioned Shark table backed by RDDs.
+ *
+ * Note that a Hive-partition of a table is different from an RDD partition. Each Hive-partition
+ * is stored as a subdirectory of the table subdirectory in the warehouse directory
+ * (e.g. '/user/hive/warehouse'). So, every Hive-Partition is loaded into Shark as an RDD.
+ */
+private[shark]
+class PartitionedMemoryTable(
+ databaseName: String,
+ tableName: String,
+ cacheMode: CacheType.CacheType)
+ extends Table(databaseName, tableName, cacheMode) {
+
+ // A map from the Hive-partition key to the RDD that contains contents of that partition.
+ // The conventional string format for the partition key, 'col1=value1/col2=value2/...', can be
+ // computed using MemoryMetadataManager#makeHivePartitionKeyStr().
+ private val _keyToPartitions: concurrent.Map[String, RDDValue] =
+ new ConcurrentJavaHashMap[String, RDDValue]()
+
+ // The eviction policy for this table's cached Hive-partitions. An example of how this
+ // can be set from the CLI:
+ // `TBLPROPERTIES("shark.partition.cachePolicy", "LRUCachePolicy")`.
+ // If 'None', then all partitions will be put in memory.
+ //
+ // Since RDDValue is mutable, entries maintained by a CachePolicy's underlying data structure,
+ // such as the LinkedHashMap for LRUCachePolicy, can be updated without causing an eviction.
+ // The value entires for a single key in
+ // `_keyToPartitions` and `_cachePolicy` will reference the same RDDValue object.
+ private var _cachePolicy: CachePolicy[String, RDDValue] = _
+
+ def containsPartition(partitionKey: String): Boolean = _keyToPartitions.contains(partitionKey)
+
+ def getPartition(partitionKey: String): Option[RDD[TablePartition]] = {
+ getPartitionAndStats(partitionKey).map(_._1)
+ }
+
+ def getStats(partitionKey: String): Option[collection.Map[Int, TablePartitionStats]] = {
+ getPartitionAndStats(partitionKey).map(_._2)
+ }
+
+ def getPartitionAndStats(
+ partitionKey: String
+ ): Option[(RDD[TablePartition], collection.Map[Int, TablePartitionStats])] = {
+ val rddValueOpt: Option[RDDValue] = _keyToPartitions.get(partitionKey)
+ if (rddValueOpt.isDefined) _cachePolicy.notifyGet(partitionKey)
+ rddValueOpt.map(_.toTuple)
+ }
+
+ def putPartition(
+ partitionKey: String,
+ newRDD: RDD[TablePartition],
+ newStats: collection.Map[Int, TablePartitionStats] = new HashMap[Int, TablePartitionStats]()
+ ): Option[(RDD[TablePartition], collection.Map[Int, TablePartitionStats])] = {
+ val rddValueOpt = _keyToPartitions.get(partitionKey)
+ val prevRDDAndStats = rddValueOpt.map(_.toTuple)
+ val newRDDValue = new RDDValue(newRDD, newStats)
+ _keyToPartitions.put(partitionKey, newRDDValue)
+ _cachePolicy.notifyPut(partitionKey, newRDDValue)
+ prevRDDAndStats
+ }
+
+ def updatePartition(
+ partitionKey: String,
+ newRDD: RDD[TablePartition],
+ newStats: Buffer[(Int, TablePartitionStats)]
+ ): Option[(RDD[TablePartition], collection.Map[Int, TablePartitionStats])] = {
+ val prevRDDAndStatsOpt = getPartitionAndStats(partitionKey)
+ if (prevRDDAndStatsOpt.isDefined) {
+ val (prevRDD, prevStats) = (prevRDDAndStatsOpt.get._1, prevRDDAndStatsOpt.get._2)
+ // This is an update of an old value, so update the RDDValue's `rdd` entry.
+ // Don't notify the `_cachePolicy`. Assumes that getPartition() has already been called to
+ // obtain the value of the previous RDD.
+ // An RDD update refers to the RDD created from an INSERT.
+ val updatedRDDValue = _keyToPartitions.get(partitionKey).get
+ updatedRDDValue.rdd = RDDUtils.unionAndFlatten(prevRDD, newRDD)
+ updatedRDDValue.stats = Table.mergeStats(newStats, prevStats).toMap
+ } else {
+ // No previous RDDValue entry currently exists for `partitionKey`, so add one.
+ putPartition(partitionKey, newRDD, newStats.toMap)
+ }
+ prevRDDAndStatsOpt
+ }
+
+ def removePartition(
+ partitionKey: String
+ ): Option[(RDD[TablePartition], collection.Map[Int, TablePartitionStats])] = {
+ val rddRemoved = _keyToPartitions.remove(partitionKey)
+ if (rddRemoved.isDefined) {
+ _cachePolicy.notifyRemove(partitionKey)
+ }
+ rddRemoved.map(_.toTuple)
+ }
+
+ /** Returns an immutable view of (partition key -> RDD) mappings to external callers */
+ def keyToPartitions: collection.immutable.Map[String, RDD[TablePartition]] = {
+ _keyToPartitions.mapValues(_.rdd).toMap
+ }
+
+ def setPartitionCachePolicy(cachePolicyStr: String, fallbackMaxSize: Int) {
+ // The loadFunc will upgrade the persistence level of the RDD to the preferred storage level.
+ val loadFunc: String => RDDValue = (partitionKey: String) => {
+ val rddValue = _keyToPartitions.get(partitionKey).get
+ if (cacheMode == CacheType.MEMORY) {
+ rddValue.rdd.persist(StorageLevel.MEMORY_AND_DISK)
+ }
+ rddValue
+ }
+ // The evictionFunc will unpersist the RDD.
+ val evictionFunc: (String, RDDValue) => Unit = (partitionKey, rddValue) => {
+ RDDUtils.unpersistRDD(rddValue.rdd)
+ }
+ val newPolicy = CachePolicy.instantiateWithUserSpecs[String, RDDValue](
+ cachePolicyStr, fallbackMaxSize, loadFunc, evictionFunc)
+ _cachePolicy = newPolicy
+ }
+
+ def cachePolicy: CachePolicy[String, RDDValue] = _cachePolicy
+
+}
diff --git a/src/main/scala/shark/memstore2/SharkTblProperties.scala b/src/main/scala/shark/memstore2/SharkTblProperties.scala
new file mode 100644
index 00000000..befc91d1
--- /dev/null
+++ b/src/main/scala/shark/memstore2/SharkTblProperties.scala
@@ -0,0 +1,68 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package shark.memstore2
+
+import java.util.{Map => JavaMap}
+
+
+/**
+ * Collection of static fields and helpers for table properties (i.e., from A
+ * CREATE TABLE TBLPROPERTIES( ... ) used by Shark.
+ */
+object SharkTblProperties {
+
+ case class TableProperty(varname: String, defaultVal: String)
+
+ // Class name of the default cache policy used to manage partition evictions for cached,
+ // Hive-partitioned tables.
+ val CACHE_POLICY = new TableProperty("shark.cache.policy", "shark.memstore2.CacheAllPolicy")
+
+ // Maximum size - in terms of the number of objects - of the cache specified by the
+ // "shark.cache.partition.cachePolicy" property above.
+ val MAX_PARTITION_CACHE_SIZE = new TableProperty("shark.cache.policy.maxSize", "10")
+
+ // Default value for the "shark.cache" table property
+ val CACHE_FLAG = new TableProperty("shark.cache", "true")
+
+ def getOrSetDefault(tblProps: JavaMap[String, String], variable: TableProperty): String = {
+ if (!tblProps.containsKey(variable.varname)) {
+ tblProps.put(variable.varname, variable.defaultVal)
+ }
+ tblProps.get(variable.varname)
+ }
+
+ /**
+ * Returns value for the `variable` table property. If a value isn't present in `tblProps`, then
+ * the default for `variable` will be returned.
+ */
+ def initializeWithDefaults(
+ tblProps: JavaMap[String, String],
+ isPartitioned: Boolean = false): JavaMap[String, String] = {
+ tblProps.put(CACHE_FLAG.varname, CACHE_FLAG.defaultVal)
+ if (isPartitioned) {
+ tblProps.put(CACHE_POLICY.varname, CACHE_POLICY.defaultVal)
+ }
+ tblProps
+ }
+
+ def removeSharkProperties(tblProps: JavaMap[String, String]) {
+ tblProps.remove(CACHE_FLAG.varname)
+ tblProps.remove(CACHE_POLICY.varname)
+ tblProps.remove(MAX_PARTITION_CACHE_SIZE.varname)
+ }
+}
diff --git a/src/main/scala/shark/memstore2/Table.scala b/src/main/scala/shark/memstore2/Table.scala
new file mode 100644
index 00000000..ae7f451f
--- /dev/null
+++ b/src/main/scala/shark/memstore2/Table.scala
@@ -0,0 +1,66 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package shark.memstore2
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.rdd.RDD
+
+import scala.collection.mutable.Buffer
+
+
+/**
+ * A container for table metadata managed by Shark and Spark. Subclasses are responsible for
+ * how RDDs are set, stored, and accessed.
+ *
+ * @param databaseName Namespace for this table.
+ * @param tableName Name of this table.
+ * @param cacheMode Type of memory storage used for the table (e.g., the Spark block manager).
+ */
+private[shark] abstract class Table(
+ var databaseName: String,
+ var tableName: String,
+ var cacheMode: CacheType.CacheType) {
+
+ /**
+ * A mutable wrapper for an RDD and stats for its partitions.
+ */
+ class RDDValue(
+ var rdd: RDD[TablePartition],
+ var stats: collection.Map[Int, TablePartitionStats]) {
+
+ def toTuple = (rdd, stats)
+ }
+}
+
+object Table {
+
+ /**
+ * Merges contents of `otherStatsMaps` into `targetStatsMap`.
+ */
+ def mergeStats(
+ targetStatsMap: Buffer[(Int, TablePartitionStats)],
+ otherStatsMap: Iterable[(Int, TablePartitionStats)]
+ ): Buffer[(Int, TablePartitionStats)] = {
+ val targetStatsMapSize = targetStatsMap.size
+ for ((otherIndex, tableStats) <- otherStatsMap) {
+ targetStatsMap.append((otherIndex + targetStatsMapSize, tableStats))
+ }
+ targetStatsMap
+ }
+}
diff --git a/src/main/scala/shark/memstore2/TablePartition.scala b/src/main/scala/shark/memstore2/TablePartition.scala
index 61235e85..ba8370a7 100644
--- a/src/main/scala/shark/memstore2/TablePartition.scala
+++ b/src/main/scala/shark/memstore2/TablePartition.scala
@@ -60,8 +60,6 @@ class TablePartition(private var _numRows: Long, private var _columns: Array[Byt
buffer
}
- // TODO: Add column pruning to TablePartition for creating a TablePartitionIterator.
-
/**
* Return an iterator for the partition.
*/
@@ -76,9 +74,9 @@ class TablePartition(private var _numRows: Long, private var _columns: Array[Byt
def prunedIterator(columnsUsed: BitSet) = {
val columnIterators: Array[ColumnIterator] = _columns.map {
case buffer: ByteBuffer =>
- val iter = ColumnIterator.newIterator(buffer)
- iter
+ ColumnIterator.newIterator(buffer)
case _ =>
+ // The buffer might be null if it is pruned in Tachyon.
null
}
new TablePartitionIterator(_numRows, columnIterators, columnsUsed)
diff --git a/src/main/scala/shark/memstore2/TablePartitionBuilder.scala b/src/main/scala/shark/memstore2/TablePartitionBuilder.scala
index cdd2843d..8614c070 100644
--- a/src/main/scala/shark/memstore2/TablePartitionBuilder.scala
+++ b/src/main/scala/shark/memstore2/TablePartitionBuilder.scala
@@ -18,10 +18,10 @@
package shark.memstore2
import java.io.{DataInput, DataOutput}
-import java.util.{List => JList}
+
+import scala.collection.JavaConversions._
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector
-import org.apache.hadoop.hive.serde2.objectinspector.StructField
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector
import org.apache.hadoop.io.Writable
@@ -33,19 +33,22 @@ import shark.memstore2.column.ColumnBuilder
* partition of data into columnar format and to generate a TablePartition.
*/
class TablePartitionBuilder(
- oi: StructObjectInspector,
+ ois: Seq[ObjectInspector],
initialColumnSize: Int,
- shouldCompress: Boolean = true)
+ shouldCompress: Boolean)
extends Writable {
- var numRows: Long = 0
- val fields: JList[_ <: StructField] = oi.getAllStructFieldRefs
+ def this(oi: StructObjectInspector, initialColumnSize: Int, shouldCompress: Boolean = true) = {
+ this(oi.getAllStructFieldRefs.map(_.getFieldObjectInspector), initialColumnSize, shouldCompress)
+ }
+
+ private var numRows: Long = 0
- val columnBuilders = Array.tabulate[ColumnBuilder[_]](fields.size) { i =>
- val columnBuilder = ColumnBuilder.create(fields.get(i).getFieldObjectInspector, shouldCompress)
+ private val columnBuilders: Array[ColumnBuilder[_]] = ois.map { oi =>
+ val columnBuilder = ColumnBuilder.create(oi, shouldCompress)
columnBuilder.initialize(initialColumnSize)
columnBuilder
- }
+ }.toArray
def incrementRowCount() {
numRows += 1
@@ -57,7 +60,7 @@ class TablePartitionBuilder(
def stats: TablePartitionStats = new TablePartitionStats(columnBuilders.map(_.stats), numRows)
- def build: TablePartition = new TablePartition(numRows, columnBuilders.map(_.build))
+ def build(): TablePartition = new TablePartition(numRows, columnBuilders.map(_.build()))
// We don't use these, but want to maintain Writable interface for SerDe
override def write(out: DataOutput) {}
diff --git a/src/main/scala/shark/memstore2/TablePartitionIterator.scala b/src/main/scala/shark/memstore2/TablePartitionIterator.scala
index 71aabd7c..947cdd22 100644
--- a/src/main/scala/shark/memstore2/TablePartitionIterator.scala
+++ b/src/main/scala/shark/memstore2/TablePartitionIterator.scala
@@ -17,7 +17,6 @@
package shark.memstore2
-import java.nio.ByteBuffer
import java.util.BitSet
import shark.memstore2.column.ColumnIterator
@@ -45,13 +44,13 @@ class TablePartitionIterator(
private var _position: Long = 0
- def hasNext(): Boolean = _position < numRows
+ def hasNext: Boolean = _position < numRows
def next(): ColumnarStruct = {
_position += 1
var i = columnUsed.nextSetBit(0)
while (i > -1) {
- columnIterators(i).next
+ columnIterators(i).next()
i = columnUsed.nextSetBit(i + 1)
}
_struct
diff --git a/src/main/scala/shark/memstore2/TableRecovery.scala b/src/main/scala/shark/memstore2/TableRecovery.scala
new file mode 100644
index 00000000..adf61061
--- /dev/null
+++ b/src/main/scala/shark/memstore2/TableRecovery.scala
@@ -0,0 +1,66 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package shark.memstore2
+
+import java.util.{HashMap => JavaHashMap}
+
+import scala.collection.JavaConversions.asScalaBuffer
+
+import org.apache.hadoop.hive.ql.metadata.Hive
+import org.apache.hadoop.hive.ql.session.SessionState
+
+import shark.{LogHelper, SharkEnv}
+import shark.util.QueryRewriteUtils
+
+/**
+ * Singleton used to reload RDDs upon server restarts.
+ */
+object TableRecovery extends LogHelper {
+
+ val db = Hive.get()
+
+ /**
+ * Loads any cached tables with MEMORY as its `shark.cache` property.
+ * @param cmdRunner The runner that is responsible for taking a cached table query and
+ * a) Creating the table metadata in Hive Meta Store
+ * b) Loading the table as an RDD in memory
+ * @see SharkServer for an example usage.
+ * @param console Optional SessionState.LogHelper used, if present, to log information about
+ the tables that get reloaded.
+ */
+ def reloadRdds(cmdRunner: String => Unit, console: Option[SessionState.LogHelper] = None) {
+ // Filter for tables that should be reloaded into the cache.
+ val currentDbName = db.getCurrentDatabase()
+ for (databaseName <- db.getAllDatabases(); tableName <- db.getAllTables(databaseName)) {
+ val hiveTable = db.getTable(databaseName, tableName)
+ val tblProps = hiveTable.getParameters
+ val cacheMode = CacheType.fromString(tblProps.get(SharkTblProperties.CACHE_FLAG.varname))
+ if (cacheMode == CacheType.MEMORY) {
+ val logMessage = "Reloading %s.%s into memory.".format(databaseName, tableName)
+ if (console.isDefined) {
+ console.get.printInfo(logMessage)
+ } else {
+ logInfo(logMessage)
+ }
+ val cmd = QueryRewriteUtils.cacheToAlterTable("CACHE %s".format(tableName))
+ cmdRunner(cmd)
+ }
+ }
+ db.setCurrentDatabase(currentDbName)
+ }
+}
diff --git a/src/main/scala/shark/memstore2/column/ColumnBuilder.scala b/src/main/scala/shark/memstore2/column/ColumnBuilder.scala
index 375ec244..84988be3 100644
--- a/src/main/scala/shark/memstore2/column/ColumnBuilder.scala
+++ b/src/main/scala/shark/memstore2/column/ColumnBuilder.scala
@@ -61,12 +61,12 @@ trait ColumnBuilder[T] {
_buffer.order(ByteOrder.nativeOrder())
_buffer.putInt(t.typeID)
}
-
+
protected def growIfNeeded(orig: ByteBuffer, size: Int): ByteBuffer = {
val capacity = orig.capacity()
if (orig.remaining() < size) {
- //grow in steps of initial size
- var additionalSize = capacity/8 + 1
+ // grow in steps of initial size
+ val additionalSize = capacity / 8 + 1
var newSize = capacity + additionalSize
if (additionalSize < size) {
newSize = capacity + size
@@ -82,7 +82,7 @@ trait ColumnBuilder[T] {
}
}
-class DefaultColumnBuilder[T](val stats: ColumnStats[T], val t: ColumnType[T, _])
+class DefaultColumnBuilder[T](val stats: ColumnStats[T], val t: ColumnType[T, _])
extends CompressedColumnBuilder[T] with NullableColumnBuilder[T]{}
@@ -105,7 +105,7 @@ trait CompressedColumnBuilder[T] extends ColumnBuilder[T] {
override def build() = {
val b = super.build()
-
+
if (compressionSchemes.isEmpty) {
new NoCompression().compress(b, t)
} else {
@@ -136,16 +136,16 @@ object ColumnBuilder {
case PrimitiveCategory.BYTE => new ByteColumnBuilder
case PrimitiveCategory.TIMESTAMP => new TimestampColumnBuilder
case PrimitiveCategory.BINARY => new BinaryColumnBuilder
-
+
// TODO: add decimal column.
- case _ => throw new Exception(
+ case _ => throw new MemoryStoreException(
"Invalid primitive object inspector category" + columnOi.getCategory)
}
}
case _ => new GenericColumnBuilder(columnOi)
}
if (shouldCompress) {
- v.compressionSchemes = Seq(new RLE())
+ v.compressionSchemes = Seq(new RLE(), new BooleanBitSetCompression())
}
v
}
diff --git a/src/main/scala/shark/memstore2/column/ColumnBuilders.scala b/src/main/scala/shark/memstore2/column/ColumnBuilders.scala
index 593f8685..6cee1359 100644
--- a/src/main/scala/shark/memstore2/column/ColumnBuilders.scala
+++ b/src/main/scala/shark/memstore2/column/ColumnBuilders.scala
@@ -1,3 +1,20 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package shark.memstore2.column
import java.nio.ByteBuffer
@@ -12,18 +29,6 @@ import shark.execution.serialization.KryoSerializer
import shark.memstore2.column.ColumnStats._
-class GenericColumnBuilder(oi: ObjectInspector)
- extends DefaultColumnBuilder[ByteStream.Output](new NoOpStats(), GENERIC) {
-
- override def initialize(initialSize: Int):ByteBuffer = {
- val buffer = super.initialize(initialSize)
- val objectInspectorSerialized = KryoSerializer.serialize(oi)
- buffer.putInt(objectInspectorSerialized.size)
- buffer.put(objectInspectorSerialized)
- buffer
- }
-}
-
class BooleanColumnBuilder extends DefaultColumnBuilder[Boolean](new BooleanColumnStats(), BOOLEAN)
class IntColumnBuilder extends DefaultColumnBuilder[Int](new IntColumnStats(), INT)
@@ -45,4 +50,20 @@ class TimestampColumnBuilder
class BinaryColumnBuilder extends DefaultColumnBuilder[BytesWritable](new NoOpStats(), BINARY)
-class VoidColumnBuilder extends DefaultColumnBuilder[Void](new NoOpStats(), VOID)
\ No newline at end of file
+class VoidColumnBuilder extends DefaultColumnBuilder[Void](new NoOpStats(), VOID)
+
+/**
+ * Generic columns that we can serialize, including maps, structs, and other complex types.
+ */
+class GenericColumnBuilder(oi: ObjectInspector)
+ extends DefaultColumnBuilder[ByteStream.Output](new NoOpStats(), GENERIC) {
+
+ // Complex data types cannot be null. Override the initialize in NullableColumnBuilder.
+ override def initialize(initialSize: Int): ByteBuffer = {
+ val buffer = super.initialize(initialSize)
+ val objectInspectorSerialized = KryoSerializer.serialize(oi)
+ buffer.putInt(objectInspectorSerialized.size)
+ buffer.put(objectInspectorSerialized)
+ buffer
+ }
+}
diff --git a/src/main/scala/shark/memstore2/column/ColumnIterator.scala b/src/main/scala/shark/memstore2/column/ColumnIterator.scala
index 5c9b267c..404e456b 100644
--- a/src/main/scala/shark/memstore2/column/ColumnIterator.scala
+++ b/src/main/scala/shark/memstore2/column/ColumnIterator.scala
@@ -17,28 +17,30 @@
package shark.memstore2.column
-import java.nio.ByteBuffer
-import java.nio.ByteOrder
+import scala.language.implicitConversions
+import java.nio.{ByteBuffer, ByteOrder}
trait ColumnIterator {
- private var _initialized = false
-
+ init()
+
def init() {}
- def next() {
- if (!_initialized) {
- init()
- _initialized = true
- }
- computeNext()
- }
+ /**
+ * Produces the next element of this iterator.
+ */
+ def next()
- def computeNext(): Unit
+ /**
+ * Tests whether this iterator can provide another element.
+ */
+ def hasNext: Boolean
- // Should be implemented as a read-only operation by the ColumnIterator
- // Can be called any number of times
+ /**
+ * Return the current element. The operation should have no side-effect, i.e. it can be invoked
+ * multiple times returning the same value.
+ */
def current: Object
}
@@ -49,25 +51,27 @@ abstract class DefaultColumnIterator[T, V](val buffer: ByteBuffer, val columnTyp
object Implicits {
implicit def intToCompressionType(i: Int): CompressionType = i match {
- case -1 => DefaultCompressionType
- case 0 => RLECompressionType
- case 1 => DictionaryCompressionType
- case _ => throw new UnsupportedOperationException("Compression Type " + i)
+ case DefaultCompressionType.typeID => DefaultCompressionType
+ case RLECompressionType.typeID => RLECompressionType
+ case DictionaryCompressionType.typeID => DictionaryCompressionType
+ case BooleanBitSetCompressionType.typeID => BooleanBitSetCompressionType
+ case _ => throw new MemoryStoreException("Unknown compression type " + i)
}
implicit def intToColumnType(i: Int): ColumnType[_, _] = i match {
- case 0 => INT
- case 1 => LONG
- case 2 => FLOAT
- case 3 => DOUBLE
- case 4 => BOOLEAN
- case 5 => BYTE
- case 6 => SHORT
- case 7 => VOID
- case 8 => STRING
- case 9 => TIMESTAMP
- case 10 => BINARY
- case 11 => GENERIC
+ case INT.typeID => INT
+ case LONG.typeID => LONG
+ case FLOAT.typeID => FLOAT
+ case DOUBLE.typeID => DOUBLE
+ case BOOLEAN.typeID => BOOLEAN
+ case BYTE.typeID => BYTE
+ case SHORT.typeID => SHORT
+ case VOID.typeID => VOID
+ case STRING.typeID => STRING
+ case TIMESTAMP.typeID => TIMESTAMP
+ case BINARY.typeID => BINARY
+ case GENERIC.typeID => GENERIC
+ case _ => throw new MemoryStoreException("Unknown column type " + i)
}
}
@@ -76,9 +80,14 @@ object ColumnIterator {
import shark.memstore2.column.Implicits._
def newIterator(b: ByteBuffer): ColumnIterator = {
+ new NullableColumnIterator(b.duplicate().order(ByteOrder.nativeOrder()))
+ }
+
+ def newNonNullIterator(b: ByteBuffer): ColumnIterator = {
+ // The first 4 bytes in the buffer indicates the column type.
val buffer = b.duplicate().order(ByteOrder.nativeOrder())
val columnType: ColumnType[_, _] = buffer.getInt()
- val v = columnType match {
+ columnType match {
case INT => new IntColumnIterator(buffer)
case LONG => new LongColumnIterator(buffer)
case FLOAT => new FloatColumnIterator(buffer)
@@ -92,6 +101,5 @@ object ColumnIterator {
case TIMESTAMP => new TimestampColumnIterator(buffer)
case GENERIC => new GenericColumnIterator(buffer)
}
- new NullableColumnIterator(v, buffer)
}
}
diff --git a/src/main/scala/shark/memstore2/column/ColumnIterators.scala b/src/main/scala/shark/memstore2/column/ColumnIterators.scala
index be9902b5..3060b5b1 100644
--- a/src/main/scala/shark/memstore2/column/ColumnIterators.scala
+++ b/src/main/scala/shark/memstore2/column/ColumnIterators.scala
@@ -1,3 +1,20 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package shark.memstore2.column
import java.nio.ByteBuffer
@@ -32,9 +49,9 @@ class BinaryColumnIterator(buffer: ByteBuffer) extends DefaultColumnIterator(buf
class StringColumnIterator(buffer: ByteBuffer) extends DefaultColumnIterator(buffer, STRING)
class GenericColumnIterator(buffer: ByteBuffer) extends DefaultColumnIterator(buffer, GENERIC) {
-
+
private var _obj: LazyObject[_] = _
-
+
override def init() {
super.init()
val oiSize = buffer.getInt()
@@ -43,7 +60,7 @@ class GenericColumnIterator(buffer: ByteBuffer) extends DefaultColumnIterator(bu
val oi = KryoSerializer.deserialize[ObjectInspector](oiSerialized)
_obj = LazyFactory.createLazyObject(oi)
}
-
+
override def current = {
val v = super.current.asInstanceOf[ByteArrayRef]
_obj.init(v, 0, v.getData().length)
diff --git a/src/main/scala/shark/memstore2/column/ColumnStats.scala b/src/main/scala/shark/memstore2/column/ColumnStats.scala
index 31270fa3..dce811d5 100644
--- a/src/main/scala/shark/memstore2/column/ColumnStats.scala
+++ b/src/main/scala/shark/memstore2/column/ColumnStats.scala
@@ -25,7 +25,8 @@ import org.apache.hadoop.io.Text
/**
- * Column level statistics, including range (min, max).
+ * Column level statistics, including range (min, max). We expect null values to be taken care
+ * of outside of the ColumnStats, so none of these stats should take null values.
*/
sealed trait ColumnStats[@specialized(Boolean, Byte, Short, Int, Long, Float, Double) T]
extends Serializable {
@@ -35,7 +36,6 @@ sealed trait ColumnStats[@specialized(Boolean, Byte, Short, Int, Long, Float, Do
protected def _min: T
protected def _max: T
-
def min: T = _min
def max: T = _max
@@ -67,27 +67,29 @@ object ColumnStats {
class BooleanColumnStats extends ColumnStats[Boolean] {
protected var _max = false
protected var _min = true
+
override def append(v: Boolean) {
if (v) _max = v
else _min = v
}
+
def :=(v: Any): Boolean = {
v match {
- case u:Boolean => _min <= u && _max >= u
+ case u: Boolean => _min <= u && _max >= u
case _ => true
}
}
def :>(v: Any): Boolean = {
v match {
- case u:Boolean => _max > u
+ case u: Boolean => _max > u
case _ => true
}
}
def :<(v: Any): Boolean = {
v match {
- case u:Boolean => _min < u
+ case u: Boolean => _min < u
case _ => true
}
}
@@ -97,6 +99,7 @@ object ColumnStats {
class ByteColumnStats extends ColumnStats[Byte] {
protected var _max = Byte.MinValue
protected var _min = Byte.MaxValue
+
override def append(v: Byte) {
if (v > _max) _max = v
if (v < _min) _min = v
@@ -104,21 +107,21 @@ object ColumnStats {
def :=(v: Any): Boolean = {
v match {
- case u:Byte => _min <= u && _max >= u
+ case u: Byte => _min <= u && _max >= u
case _ => true
}
}
def :>(v: Any): Boolean = {
v match {
- case u:Byte => _max > u
+ case u: Byte => _max > u
case _ => true
}
}
def :<(v: Any): Boolean = {
v match {
- case u:Byte => _min < u
+ case u: Byte => _min < u
case _ => true
}
}
@@ -127,27 +130,29 @@ object ColumnStats {
class ShortColumnStats extends ColumnStats[Short] {
protected var _max = Short.MinValue
protected var _min = Short.MaxValue
+
override def append(v: Short) {
if (v > _max) _max = v
if (v < _min) _min = v
}
+
def :=(v: Any): Boolean = {
v match {
- case u:Short => _min <= u && _max >= u
+ case u: Short => _min <= u && _max >= u
case _ => true
}
}
def :>(v: Any): Boolean = {
v match {
- case u:Short => _max > u
+ case u: Short => _max > u
case _ => true
}
}
def :<(v: Any): Boolean = {
v match {
- case u:Short => _min < u
+ case u: Short => _min < u
case _ => true
}
}
@@ -184,14 +189,14 @@ object ColumnStats {
def :>(v: Any): Boolean = {
v match {
- case u:Int => _max > u
+ case u: Int => _max > u
case _ => true
}
}
def :<(v: Any): Boolean = {
v match {
- case u:Int => _min < u
+ case u: Int => _min < u
case _ => true
}
}
@@ -228,27 +233,29 @@ object ColumnStats {
class LongColumnStats extends ColumnStats[Long] {
protected var _max = Long.MinValue
protected var _min = Long.MaxValue
+
override def append(v: Long) {
if (v > _max) _max = v
if (v < _min) _min = v
}
+
def :=(v: Any): Boolean = {
v match {
- case u:Long => _min <= u && _max >= u
+ case u: Long => _min <= u && _max >= u
case _ => true
}
}
def :>(v: Any): Boolean = {
v match {
- case u:Long => _max > u
+ case u: Long => _max > u
case _ => true
}
}
def :<(v: Any): Boolean = {
v match {
- case u:Long => _min < u
+ case u: Long => _min < u
case _ => true
}
}
@@ -257,20 +264,22 @@ object ColumnStats {
class FloatColumnStats extends ColumnStats[Float] {
protected var _max = Float.MinValue
protected var _min = Float.MaxValue
+
override def append(v: Float) {
if (v > _max) _max = v
if (v < _min) _min = v
}
+
def :=(v: Any): Boolean = {
v match {
- case u:Float => _min <= u && _max >= u
+ case u: Float => _min <= u && _max >= u
case _ => true
}
}
def :>(v: Any): Boolean = {
v match {
- case u:Float => _max > u
+ case u: Float => _max > u
case _ => true
}
}
@@ -286,10 +295,12 @@ object ColumnStats {
class DoubleColumnStats extends ColumnStats[Double] {
protected var _max = Double.MinValue
protected var _min = Double.MaxValue
+
override def append(v: Double) {
if (v > _max) _max = v
if (v < _min) _min = v
}
+
def :=(v: Any): Boolean = {
v match {
case u:Double => _min <= u && _max >= u
@@ -315,10 +326,12 @@ object ColumnStats {
class TimestampColumnStats extends ColumnStats[Timestamp] {
protected var _max = new Timestamp(0)
protected var _min = new Timestamp(Long.MaxValue)
+
override def append(v: Timestamp) {
if (v.compareTo(_max) > 0) _max = v
if (v.compareTo(_min) < 0) _min = v
}
+
def :=(v: Any): Boolean = {
v match {
case u: Timestamp => _min.compareTo(u) <=0 && _max.compareTo(u) >= 0
@@ -345,8 +358,12 @@ object ColumnStats {
// Note: this is not Java serializable because Text is not Java serializable.
protected var _max: Text = null
protected var _min: Text = null
-
+
def :=(v: Any): Boolean = {
+ if (_max eq null) {
+ // This partition doesn't contain any non-null strings in this column. Return false.
+ return false
+ }
v match {
case u: Text => _min.compareTo(u) <= 0 && _max.compareTo(u) >= 0
case u: String => this := new Text(u)
@@ -355,6 +372,10 @@ object ColumnStats {
}
def :>(v: Any): Boolean = {
+ if (_max eq null) {
+ // This partition doesn't contain any non-null strings in this column. Return false.
+ return false
+ }
v match {
case u: Text => _max.compareTo(u) > 0
case u: String => this :> new Text(u)
@@ -363,14 +384,19 @@ object ColumnStats {
}
def :<(v: Any): Boolean = {
+ if (_max eq null) {
+ // This partition doesn't contain any non-null strings in this column. Return false.
+ return false
+ }
v match {
- case u:Text => _min.compareTo(u) < 0
+ case u: Text => _min.compareTo(u) < 0
case u: String => this :< new Text(u)
case _ => true
}
}
override def append(v: Text) {
+ assert(!(v eq null))
// Need to make a copy of Text since Text is not immutable and we reuse
// the same Text object in serializer to mitigate frequent GC.
if (_max == null) {
@@ -382,7 +408,7 @@ object ColumnStats {
_min = new Text(v)
} else if (v.compareTo(_min) < 0) {
_min.set(v.getBytes(), 0, v.getLength())
- }
+ }
}
override def readExternal(in: ObjectInput) {
diff --git a/src/main/scala/shark/memstore2/column/ColumnType.scala b/src/main/scala/shark/memstore2/column/ColumnType.scala
index 068efe42..4ca62a19 100644
--- a/src/main/scala/shark/memstore2/column/ColumnType.scala
+++ b/src/main/scala/shark/memstore2/column/ColumnType.scala
@@ -20,39 +20,85 @@ package shark.memstore2.column
import java.nio.ByteBuffer
import java.sql.Timestamp
+import scala.reflect.ClassTag
+
import org.apache.hadoop.hive.serde2.ByteStream
import org.apache.hadoop.hive.serde2.`lazy`.{ByteArrayRef, LazyBinary}
-import org.apache.hadoop.hive.serde2.io.{TimestampWritable, ShortWritable, ByteWritable, DoubleWritable}
+import org.apache.hadoop.hive.serde2.io.ByteWritable
+import org.apache.hadoop.hive.serde2.io.DoubleWritable
+import org.apache.hadoop.hive.serde2.io.ShortWritable
+import org.apache.hadoop.hive.serde2.io.TimestampWritable
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector
import org.apache.hadoop.hive.serde2.objectinspector.primitive._
import org.apache.hadoop.io._
-abstract class ColumnType[T, V](val typeID: Int, val defaultSize: Int) {
-
- def extract(currentPos: Int, buffer: ByteBuffer): T
-
+/**
+ * @param typeID A unique ID representing the type.
+ * @param defaultSize Default size in bytes for one element of type T (e.g. Int = 4).
+ * @tparam T Scala data type for the column.
+ * @tparam V Writable data type for the column.
+ */
+sealed abstract class ColumnType[T : ClassTag, V : ClassTag](
+ val typeID: Int, val defaultSize: Int) {
+
+ /**
+ * Scala ClassTag. Can be used to create primitive arrays and hash tables.
+ */
+ def scalaTag = implicitly[ClassTag[T]]
+
+ /**
+ * Scala ClassTag. Can be used to create primitive arrays and hash tables.
+ */
+ def writableScalaTag = implicitly[ClassTag[V]]
+
+ /**
+ * Extract a value out of the buffer at the buffer's current position.
+ */
+ def extract(buffer: ByteBuffer): T
+
+ /**
+ * Append the given value v of type T into the given ByteBuffer.
+ */
def append(v: T, buffer: ByteBuffer)
+ /**
+ * Return the Scala data representation of the given object, using an object inspector.
+ */
def get(o: Object, oi: ObjectInspector): T
- def actualSize(v: T) = defaultSize
-
- def extractInto(currentPos: Int, buffer: ByteBuffer, writable: V)
-
+ /**
+ * Return the size of the value. This is used to calculate the size of variable length types
+ * such as byte arrays and strings.
+ */
+ def actualSize(v: T): Int = defaultSize
+
+ /**
+ * Extract a value out of the buffer at the buffer's current position, and put it in the writable
+ * object. This is used as an optimization to reduce the temporary objects created, since the
+ * writable object can be reused.
+ */
+ def extractInto(buffer: ByteBuffer, writable: V)
+
+ /**
+ * Create a new writable object corresponding to this type.
+ */
def newWritable(): V
+ /**
+ * Create a duplicated copy of the value.
+ */
def clone(v: T): T = v
}
object INT extends ColumnType[Int, IntWritable](0, 4) {
- override def append(v: Int, buffer: ByteBuffer) = {
+ override def append(v: Int, buffer: ByteBuffer) {
buffer.putInt(v)
}
- override def extract(currentPos: Int, buffer: ByteBuffer) = {
+ override def extract(buffer: ByteBuffer) = {
buffer.getInt()
}
@@ -60,8 +106,8 @@ object INT extends ColumnType[Int, IntWritable](0, 4) {
oi.asInstanceOf[IntObjectInspector].get(o)
}
- override def extractInto(currentPos: Int, buffer: ByteBuffer, writable: IntWritable) = {
- writable.set(extract(currentPos, buffer))
+ override def extractInto(buffer: ByteBuffer, writable: IntWritable) {
+ writable.set(extract(buffer))
}
override def newWritable() = new IntWritable
@@ -70,11 +116,11 @@ object INT extends ColumnType[Int, IntWritable](0, 4) {
object LONG extends ColumnType[Long, LongWritable](1, 8) {
- override def append(v: Long, buffer: ByteBuffer) = {
+ override def append(v: Long, buffer: ByteBuffer) {
buffer.putLong(v)
}
- override def extract(currentPos: Int, buffer: ByteBuffer) = {
+ override def extract(buffer: ByteBuffer) = {
buffer.getLong()
}
@@ -82,21 +128,21 @@ object LONG extends ColumnType[Long, LongWritable](1, 8) {
oi.asInstanceOf[LongObjectInspector].get(o)
}
- def extractInto(currentPos: Int, buffer: ByteBuffer, writable: LongWritable) = {
- writable.set(extract(currentPos, buffer))
+ override def extractInto(buffer: ByteBuffer, writable: LongWritable) {
+ writable.set(extract(buffer))
}
- def newWritable() = new LongWritable
+ override def newWritable() = new LongWritable
}
object FLOAT extends ColumnType[Float, FloatWritable](2, 4) {
- override def append(v: Float, buffer: ByteBuffer) = {
+ override def append(v: Float, buffer: ByteBuffer) {
buffer.putFloat(v)
}
- override def extract(currentPos: Int, buffer: ByteBuffer) = {
+ override def extract(buffer: ByteBuffer) = {
buffer.getFloat()
}
@@ -104,42 +150,43 @@ object FLOAT extends ColumnType[Float, FloatWritable](2, 4) {
oi.asInstanceOf[FloatObjectInspector].get(o)
}
- def extractInto(currentPos: Int, buffer: ByteBuffer, writable: FloatWritable) = {
- writable.set(extract(currentPos, buffer))
+ override def extractInto(buffer: ByteBuffer, writable: FloatWritable) {
+ writable.set(extract(buffer))
}
- def newWritable() = new FloatWritable
+ override def newWritable() = new FloatWritable
}
object DOUBLE extends ColumnType[Double, DoubleWritable](3, 8) {
- override def append(v: Double, buffer: ByteBuffer) = {
+ override def append(v: Double, buffer: ByteBuffer) {
buffer.putDouble(v)
}
- override def extract(currentPos: Int, buffer: ByteBuffer) = {
+ override def extract(buffer: ByteBuffer) = {
buffer.getDouble()
}
+
override def get(o: Object, oi: ObjectInspector): Double = {
oi.asInstanceOf[DoubleObjectInspector].get(o)
}
- def extractInto(currentPos: Int, buffer: ByteBuffer, writable: DoubleWritable) = {
- writable.set(extract(currentPos, buffer))
+ override def extractInto(buffer: ByteBuffer, writable: DoubleWritable) {
+ writable.set(extract(buffer))
}
- def newWritable() = new DoubleWritable
+ override def newWritable() = new DoubleWritable
}
object BOOLEAN extends ColumnType[Boolean, BooleanWritable](4, 1) {
- override def append(v: Boolean, buffer: ByteBuffer) = {
+ override def append(v: Boolean, buffer: ByteBuffer) {
buffer.put(if (v) 1.toByte else 0.toByte)
}
- override def extract(currentPos: Int, buffer: ByteBuffer) = {
+ override def extract(buffer: ByteBuffer) = {
if (buffer.get() == 1) true else false
}
@@ -147,42 +194,42 @@ object BOOLEAN extends ColumnType[Boolean, BooleanWritable](4, 1) {
oi.asInstanceOf[BooleanObjectInspector].get(o)
}
- def extractInto(currentPos: Int, buffer: ByteBuffer, writable: BooleanWritable) = {
- writable.set(extract(currentPos, buffer))
+ override def extractInto(buffer: ByteBuffer, writable: BooleanWritable) {
+ writable.set(extract(buffer))
}
- def newWritable() = new BooleanWritable
+ override def newWritable() = new BooleanWritable
}
object BYTE extends ColumnType[Byte, ByteWritable](5, 1) {
- override def append(v: Byte, buffer: ByteBuffer) = {
+ override def append(v: Byte, buffer: ByteBuffer) {
buffer.put(v)
}
- override def extract(currentPos: Int, buffer: ByteBuffer) = {
+ override def extract(buffer: ByteBuffer) = {
buffer.get()
}
override def get(o: Object, oi: ObjectInspector): Byte = {
oi.asInstanceOf[ByteObjectInspector].get(o)
}
- def extractInto(currentPos: Int, buffer: ByteBuffer, writable: ByteWritable) = {
- writable.set(extract(currentPos, buffer))
+ override def extractInto(buffer: ByteBuffer, writable: ByteWritable) {
+ writable.set(extract(buffer))
}
- def newWritable() = new ByteWritable
+ override def newWritable() = new ByteWritable
}
object SHORT extends ColumnType[Short, ShortWritable](6, 2) {
- override def append(v: Short, buffer: ByteBuffer) = {
+ override def append(v: Short, buffer: ByteBuffer) {
buffer.putShort(v)
}
- override def extract(currentPos: Int, buffer: ByteBuffer) = {
+ override def extract(buffer: ByteBuffer) = {
buffer.getShort()
}
@@ -190,8 +237,8 @@ object SHORT extends ColumnType[Short, ShortWritable](6, 2) {
oi.asInstanceOf[ShortObjectInspector].get(o)
}
- def extractInto(currentPos: Int, buffer: ByteBuffer, writable: ShortWritable) = {
- writable.set(extract(currentPos, buffer))
+ def extractInto(buffer: ByteBuffer, writable: ShortWritable) {
+ writable.set(extract(buffer))
}
def newWritable() = new ShortWritable
@@ -200,15 +247,15 @@ object SHORT extends ColumnType[Short, ShortWritable](6, 2) {
object VOID extends ColumnType[Void, NullWritable](7, 0) {
- override def append(v: Void, buffer: ByteBuffer) = {}
+ override def append(v: Void, buffer: ByteBuffer) {}
- override def extract(currentPos: Int, buffer: ByteBuffer) = {
+ override def extract(buffer: ByteBuffer) = {
throw new UnsupportedOperationException()
}
override def get(o: Object, oi: ObjectInspector) = null
- override def extractInto(currentPos: Int, buffer: ByteBuffer, writable: NullWritable) = {}
+ override def extractInto(buffer: ByteBuffer, writable: NullWritable) {}
override def newWritable() = NullWritable.get
}
@@ -234,9 +281,9 @@ object STRING extends ColumnType[Text, Text](8, 8) {
buffer.put(v.getBytes(), 0, length)
}
- override def extract(currentPos: Int, buffer: ByteBuffer) = {
+ override def extract(buffer: ByteBuffer) = {
val t = new Text()
- extractInto(currentPos, buffer, t)
+ extractInto(buffer, t)
t
}
@@ -246,7 +293,7 @@ object STRING extends ColumnType[Text, Text](8, 8) {
override def actualSize(v: Text) = v.getLength() + 4
- def extractInto(currentPos: Int, buffer: ByteBuffer, writable: Text) = {
+ override def extractInto(buffer: ByteBuffer, writable: Text) {
val length = buffer.getInt()
var b = _bytesFld.get(writable).asInstanceOf[Array[Byte]]
if (b == null || b.length < length) {
@@ -257,7 +304,8 @@ object STRING extends ColumnType[Text, Text](8, 8) {
_lengthFld.set(writable, length)
}
- def newWritable() = new Text
+ override def newWritable() = new Text
+
override def clone(v: Text) = {
val t = new Text()
t.set(v)
@@ -268,12 +316,12 @@ object STRING extends ColumnType[Text, Text](8, 8) {
object TIMESTAMP extends ColumnType[Timestamp, TimestampWritable](9, 12) {
- override def append(v: Timestamp, buffer: ByteBuffer) = {
+ override def append(v: Timestamp, buffer: ByteBuffer) {
buffer.putLong(v.getTime())
buffer.putInt(v.getNanos())
}
- override def extract(currentPos: Int, buffer: ByteBuffer) = {
+ override def extract(buffer: ByteBuffer) = {
val ts = new Timestamp(0)
ts.setTime(buffer.getLong())
ts.setNanos(buffer.getInt())
@@ -284,11 +332,11 @@ object TIMESTAMP extends ColumnType[Timestamp, TimestampWritable](9, 12) {
oi.asInstanceOf[TimestampObjectInspector].getPrimitiveJavaObject(o)
}
- def extractInto(currentPos: Int, buffer: ByteBuffer, writable: TimestampWritable) = {
- writable.set(extract(currentPos, buffer))
+ override def extractInto(buffer: ByteBuffer, writable: TimestampWritable) {
+ writable.set(extract(buffer))
}
- def newWritable() = new TimestampWritable
+ override def newWritable() = new TimestampWritable
}
@@ -306,13 +354,13 @@ object BINARY extends ColumnType[BytesWritable, BytesWritable](10, 16) {
f
}
- override def append(v: BytesWritable, buffer: ByteBuffer) = {
+ override def append(v: BytesWritable, buffer: ByteBuffer) {
val length = v.getLength()
buffer.putInt(length)
buffer.put(v.getBytes(), 0, length)
}
- override def extract(currentPos: Int, buffer: ByteBuffer) = {
+ override def extract(buffer: ByteBuffer) = {
throw new UnsupportedOperationException()
}
@@ -324,7 +372,7 @@ object BINARY extends ColumnType[BytesWritable, BytesWritable](10, 16) {
}
}
- def extractInto(currentPos: Int, buffer: ByteBuffer, writable: BytesWritable) = {
+ override def extractInto(buffer: ByteBuffer, writable: BytesWritable) {
val length = buffer.getInt()
var b = _bytesFld.get(writable).asInstanceOf[Array[Byte]]
if (b == null || b.length < length) {
@@ -335,7 +383,7 @@ object BINARY extends ColumnType[BytesWritable, BytesWritable](10, 16) {
_lengthFld.set(writable, length)
}
- def newWritable() = new BytesWritable
+ override def newWritable() = new BytesWritable
override def actualSize(v: BytesWritable) = v.getLength() + 4
}
@@ -349,7 +397,7 @@ object GENERIC extends ColumnType[ByteStream.Output, ByteArrayRef](11, 16) {
buffer.put(v.getData(), 0, length)
}
- override def extract(currentPos: Int, buffer: ByteBuffer) = {
+ override def extract(buffer: ByteBuffer) = {
throw new UnsupportedOperationException()
}
@@ -357,12 +405,14 @@ object GENERIC extends ColumnType[ByteStream.Output, ByteArrayRef](11, 16) {
o.asInstanceOf[ByteStream.Output]
}
- def extractInto(currentPos: Int, buffer: ByteBuffer, writable: ByteArrayRef) = {
+ override def extractInto(buffer: ByteBuffer, writable: ByteArrayRef) {
val length = buffer.getInt()
val a = new Array[Byte](length)
buffer.get(a, 0, length)
writable.setData(a)
}
- def newWritable() = new ByteArrayRef
+ override def newWritable() = new ByteArrayRef
+
+ override def actualSize(v: ByteStream.Output): Int = v.getCount() + 4
}
diff --git a/src/main/scala/shark/memstore2/column/CompressedColumnIterator.scala b/src/main/scala/shark/memstore2/column/CompressedColumnIterator.scala
index 7b4e5ab8..5d74a61c 100644
--- a/src/main/scala/shark/memstore2/column/CompressedColumnIterator.scala
+++ b/src/main/scala/shark/memstore2/column/CompressedColumnIterator.scala
@@ -1,8 +1,25 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package shark.memstore2.column
import java.nio.ByteBuffer
-import scala.collection.mutable.{Map, HashMap}
+import org.apache.hadoop.io.BooleanWritable
import shark.memstore2.column.Implicits._
@@ -11,9 +28,8 @@ import shark.memstore2.column.Implicits._
* The first element of the buffer at the point of initialization
* is expected to be the type of compression indicator.
*/
-trait CompressedColumnIterator extends ColumnIterator{
+trait CompressedColumnIterator extends ColumnIterator {
- private var _compressionType: CompressionType = _
private var _decoder: Iterator[_] = _
private var _current: Any = _
@@ -22,21 +38,25 @@ trait CompressedColumnIterator extends ColumnIterator{
def columnType: ColumnType[_,_]
override def init() {
- _compressionType = buffer.getInt()
- _decoder = _compressionType match {
+ val compressionType: CompressionType = buffer.getInt()
+ _decoder = compressionType match {
case DefaultCompressionType => new DefaultDecoder(buffer, columnType)
case RLECompressionType => new RLDecoder(buffer, columnType)
case DictionaryCompressionType => new DictDecoder(buffer, columnType)
+ case BooleanBitSetCompressionType => new BooleanBitSetDecoder(buffer, columnType)
case _ => throw new UnsupportedOperationException()
}
}
- override def computeNext() {
+ override def next() {
+ // TODO: can we remove the if branch?
if (_decoder.hasNext) {
_current = _decoder.next()
}
}
-
+
+ override def hasNext = _decoder.hasNext
+
override def current = _current.asInstanceOf[Object]
}
@@ -50,7 +70,7 @@ class DefaultDecoder[V](buffer: ByteBuffer, columnType: ColumnType[_, V]) extend
override def hasNext = buffer.hasRemaining()
override def next(): V = {
- columnType.extractInto(buffer.position(), buffer, _current)
+ columnType.extractInto(buffer, _current)
_current
}
}
@@ -59,17 +79,17 @@ class DefaultDecoder[V](buffer: ByteBuffer, columnType: ColumnType[_, V]) extend
* Run Length Decoder, decodes data compressed in RLE format of [element, length]
*/
class RLDecoder[V](buffer: ByteBuffer, columnType: ColumnType[_, V]) extends Iterator[V] {
-
+
private var _run: Int = _
private var _count: Int = 0
private val _current: V = columnType.newWritable()
override def hasNext = buffer.hasRemaining()
-
+
override def next(): V = {
if (_count == _run) {
//next run
- columnType.extractInto(buffer.position(), buffer, _current)
+ columnType.extractInto(buffer, _current)
_run = buffer.getInt()
_count = 1
} else {
@@ -79,26 +99,62 @@ class RLDecoder[V](buffer: ByteBuffer, columnType: ColumnType[_, V]) extends Ite
}
}
-class DictDecoder[V] (buffer:ByteBuffer, columnType: ColumnType[_, V]) extends Iterator[V] {
+/**
+ * Dictionary encoding compression.
+ */
+class DictDecoder[V](buffer: ByteBuffer, columnType: ColumnType[_, V]) extends Iterator[V] {
- private val _dictionary: Map[Int, V] = {
+ // Dictionary in the form of an array. The index is the encoded value, and the value is the
+ // decompressed value.
+ private val _dictionary: Array[V] = {
val size = buffer.getInt()
- val d = new HashMap[Int, V]()
+ val arr = columnType.writableScalaTag.newArray(size)
var count = 0
while (count < size) {
- //read text, followed by index
- val text = columnType.extract(buffer.position(), buffer)
- val index = buffer.getInt()
- d.put(index, text.asInstanceOf[V])
- count+= 1
+ val writable = columnType.newWritable()
+ columnType.extractInto(buffer, writable)
+ arr(count) = writable.asInstanceOf[V]
+ count += 1
}
- d
+ arr
}
override def hasNext = buffer.hasRemaining()
-
+
+ override def next(): V = {
+ val index = buffer.getShort().toInt
+ _dictionary(index)
+ }
+}
+
+/**
+ * Boolean BitSet encoding.
+ */
+class BooleanBitSetDecoder[V](
+ buffer: ByteBuffer,
+ columnType: ColumnType[_, V],
+ var _pos: Int,
+ var _uncompressedSize: Int,
+ var _curValue: Long,
+ var _writable: BooleanWritable
+ ) extends Iterator[V] {
+
+ def this(buffer: ByteBuffer, columnType: ColumnType[_, V])
+ = this(buffer, columnType, 0, buffer.getInt(), 0, new BooleanWritable())
+
+ override def hasNext = _pos < _uncompressedSize
+
override def next(): V = {
- val index = buffer.getInt()
- _dictionary.get(index).get
+ val offset = _pos % BooleanBitSetCompression.BOOLEANS_PER_LONG
+
+ if (offset == 0) {
+ _curValue = buffer.getLong()
+ }
+
+ val retval: Boolean = (_curValue & (1 << offset)) != 0
+ _pos += 1
+ _writable.set(retval)
+ _writable.asInstanceOf[V]
}
}
+
diff --git a/src/main/scala/shark/memstore2/column/CompressionAlgorithm.scala b/src/main/scala/shark/memstore2/column/CompressionAlgorithm.scala
index a26d2ff5..5db74dee 100644
--- a/src/main/scala/shark/memstore2/column/CompressionAlgorithm.scala
+++ b/src/main/scala/shark/memstore2/column/CompressionAlgorithm.scala
@@ -1,9 +1,26 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package shark.memstore2.column
-import java.nio.ByteBuffer
-import java.nio.ByteOrder
+import java.nio.{ByteBuffer, ByteOrder}
+
import scala.annotation.tailrec
-import scala.collection.mutable.HashMap
+import scala.collection.mutable.{ArrayBuffer, HashMap}
/**
* API for Compression
@@ -12,15 +29,35 @@ trait CompressionAlgorithm {
def compressionType: CompressionType
+ /**
+ * Tests whether the compression algorithm supports a specific column type.
+ */
def supportsType(t: ColumnType[_, _]): Boolean
+ /**
+ * Collect a value so we can update the compression ratio for this compression algorithm.
+ */
def gatherStatsForCompressibility[T](v: T, t: ColumnType[T, _])
/**
* Return compression ratio between 0 and 1, smaller score imply higher compressibility.
+ * This is used to pick the compression algorithm to apply at runtime.
+ */
+ def compressionRatio: Double = compressedSize.toDouble / uncompressedSize.toDouble
+
+ /**
+ * The uncompressed size of the input data.
*/
- def compressionRatio: Double
+ def uncompressedSize: Int
+ /**
+ * Estimation of the data size once compressed.
+ */
+ def compressedSize: Int
+
+ /**
+ * Compress the given buffer and return the compressed data as a new buffer.
+ */
def compress[T](b: ByteBuffer, t: ColumnType[T, _]): ByteBuffer
}
@@ -30,19 +67,28 @@ case class CompressionType(typeID: Int)
object DefaultCompressionType extends CompressionType(-1)
object RLECompressionType extends CompressionType(0)
+
object DictionaryCompressionType extends CompressionType(1)
-object RLEVariantCompressionType extends CompressionType(2)
+object BooleanBitSetCompressionType extends CompressionType(2)
+/**
+ * An no-op compression.
+ */
class NoCompression extends CompressionAlgorithm {
+
override def compressionType = DefaultCompressionType
override def supportsType(t: ColumnType[_,_]) = true
- override def gatherStatsForCompressibility[T](v: T, t: ColumnType[T,_]) = {}
+ override def gatherStatsForCompressibility[T](v: T, t: ColumnType[T,_]) {}
override def compressionRatio: Double = 1.0
+ override def uncompressedSize: Int = 0
+
+ override def compressedSize: Int = 0
+
override def compress[T](b: ByteBuffer, t: ColumnType[T, _]) = {
val len = b.limit()
val newBuffer = ByteBuffer.allocate(len + 4)
@@ -57,52 +103,56 @@ class NoCompression extends CompressionAlgorithm {
}
/**
- * Implements Run Length Encoding
+ * Run-length encoding for columns with a lot of repeated values.
*/
class RLE extends CompressionAlgorithm {
- private var _total: Int = 0
+ private var _uncompressedSize: Int = 0
+ private var _compressedSize: Int = 0
+
+ // Previous element, used to track how many runs and the run lengths.
private var _prev: Any = _
+ // Current run length.
private var _run: Int = 0
- private var _size: Int = 0
override def compressionType = RLECompressionType
override def supportsType(t: ColumnType[_, _]) = {
t match {
- case INT | STRING | SHORT | BYTE | BOOLEAN => true
+ case LONG | INT | STRING | SHORT | BYTE | BOOLEAN => true
case _ => false
}
}
- override def gatherStatsForCompressibility[T](v: T, t: ColumnType[T,_]) = {
+ override def gatherStatsForCompressibility[T](v: T, t: ColumnType[T,_]) {
val s = t.actualSize(v)
if (_prev == null) {
+ // This is the very first run.
_prev = t.clone(v)
_run = 1
+ _compressedSize += s + 4
} else {
if (_prev.equals(v)) {
+ // Add one to the current run's length.
_run += 1
} else {
- // flush run into size
- _size += (t.actualSize(_prev.asInstanceOf[T]) + 4)
+ // Start a new run. Update the current run length.
+ _compressedSize += s + 4
_prev = t.clone(v)
_run = 1
}
}
- _total += s
+ _uncompressedSize += s
}
+ override def uncompressedSize: Int = _uncompressedSize
+
// Note that we don't actually track the size of the last run into account to simplify the
// logic a little bit.
- override def compressionRatio = _size / (_total + 0.0)
+ override def compressedSize: Int = _compressedSize
- override def compress[T](b: ByteBuffer, t: ColumnType[T,_]) = {
- // Add the size of the last run to the _size
- if (_prev != null) {
- _size += t.actualSize(_prev.asInstanceOf[T]) + 4
- }
-
- val compressedBuffer = ByteBuffer.allocate(_size + 4 + 4)
+ override def compress[T](b: ByteBuffer, t: ColumnType[T,_]): ByteBuffer = {
+ // Leave 4 extra bytes for column type and another 4 for compression type.
+ val compressedBuffer = ByteBuffer.allocate(4 + 4 + _compressedSize)
compressedBuffer.order(ByteOrder.nativeOrder())
compressedBuffer.putInt(b.getInt())
compressedBuffer.putInt(compressionType.typeID)
@@ -112,7 +162,7 @@ class RLE extends CompressionAlgorithm {
}
@tailrec private final def encode[T](currentBuffer: ByteBuffer,
- compressedBuffer: ByteBuffer, currentRun: (T, Int), t: ColumnType[T,_]) {
+ compressedBuffer: ByteBuffer, currentRun: (T, Int), t: ColumnType[T,_]) {
def writeOutRun() {
t.append(currentRun._1, compressedBuffer)
compressedBuffer.putInt(currentRun._2)
@@ -121,7 +171,7 @@ class RLE extends CompressionAlgorithm {
writeOutRun()
return
}
- val elem = t.extract(currentBuffer.position(), currentBuffer)
+ val elem = t.extract(currentBuffer)
val newRun =
if (currentRun == null) {
(elem, 1)
@@ -137,88 +187,175 @@ class RLE extends CompressionAlgorithm {
}
}
+/**
+ * Dictionary encoding for columns with small cardinality. This algorithm encodes values into
+ * short integers (2 byte each). It can support up to 32k distinct values.
+ */
class DictionaryEncoding extends CompressionAlgorithm {
- private val MAX_DICT_SIZE = 4000
- private val _dictionary = new HashMap[Any, Int]()
- private var _dictionarySize = 0
- private var _totalSize = 0
+ // 32K unique values allowed
+ private val MAX_DICT_SIZE = Short.MaxValue - 1
+
+ // The dictionary that maps a value to the encoded short integer.
+ private var _dictionary = new HashMap[Any, Short]()
+
+ // The reverse mapping of _dictionary, i.e. mapping encoded integer to the value itself.
+ private var _values = new ArrayBuffer[Any](1024)
+
+ // We use a short integer to store the dictionary index, which takes 2 bytes.
+ private val indexSize = 2
+
+ // Size of the dictionary, in bytes. Initialize the dictionary size to 4 since we use an int
+ // to store the number of elements in the dictionary.
+ private var _dictionarySize = 4
+
+ // Size of the input, uncompressed, in bytes. Note that we only count until the dictionary
+ // overflows.
+ private var _uncompressedSize = 0
+
+ // Total number of elements.
private var _count = 0
- private var _index = 0
+
+ // If the number of distinct elements is too large, we discard the use of dictionary
+ // encoding and set the overflow flag to true.
private var _overflow = false
override def compressionType = DictionaryCompressionType
override def supportsType(t: ColumnType[_, _]) = t match {
- case STRING => true
+ case STRING | LONG | INT => true
case _ => false
}
- private def encode[T](v: T, t: ColumnType[T, _], sizeFunc:T => Int): Int = {
- _count += 1
- val size = sizeFunc(v)
- _totalSize += size
- if (_dictionary.size < MAX_DICT_SIZE) {
- val s = t.clone(v)
- _dictionary.get(s) match {
- case Some(index) => index
- case None => {
- _dictionary.put(s, _index)
- _index += 1
- _dictionarySize += (size + 4)
- _index
+ override def gatherStatsForCompressibility[T](v: T, t: ColumnType[T, _]) {
+ // Use this function to build up a dictionary.
+ if (!_overflow) {
+ val size = t.actualSize(v)
+ _count += 1
+ _uncompressedSize += size
+
+ if (!_dictionary.contains(v)) {
+ // The dictionary doesn't contain the value. Add the value to the dictionary if we haven't
+ // overflown yet.
+ if (_dictionary.size < MAX_DICT_SIZE) {
+ val clone = t.clone(v)
+ _values.append(clone)
+ _dictionary.put(clone, _dictionary.size.toShort)
+ _dictionarySize += size
+ } else {
+ // Overflown. Release the dictionary immediately to lower memory pressure.
+ _overflow = true
+ _dictionary = null
+ _values = null
}
}
- } else {
- _overflow = true
- -1
}
}
- override def gatherStatsForCompressibility[T](v: T, t: ColumnType[T, _]) = {
- //need an estimate of the # of uniques so we can build an appropriate
- //dictionary if needed. More precisely, we only need a lower bound
- //on # of uniques.
- val size = t.actualSize(v)
- encode(v, t, { _:T => size})
- }
+ override def uncompressedSize: Int = _uncompressedSize
/**
- * return score between 0 and 1, smaller score imply higher compressibility.
+ * Return the compressed data size if encoded with dictionary encoding. If the dictionary
+ * cardinality (i.e. the number of distinct elements) is bigger than 32K, we return an
+ * a really large number.
*/
- override def compressionRatio: Double = {
- if (_overflow) 1.0 else (_count*4 + dictionarySize) / (_totalSize + 0.0)
+ override def compressedSize: Int = {
+ // Total compressed size =
+ // size of the dictionary +
+ // the number of elements * dictionary encoded size (short)
+ if (_overflow) Int.MaxValue else _dictionarySize + _count * indexSize
}
- private def writeDictionary[T](compressedBuffer: ByteBuffer, t: ColumnType[T, _]) {
- //store dictionary size
+ override def compress[T](b: ByteBuffer, t: ColumnType[T, _]): ByteBuffer = {
+ if (_overflow) {
+ throw new MemoryStoreException(
+ "Dictionary encoding should not be used because we have overflown the dictionary.")
+ }
+
+ // Create a new buffer and store the compression type and column type.
+ // Leave 4 extra bytes for column type and another 4 for compression type.
+ val compressedBuffer = ByteBuffer.allocate(4 + 4 + compressedSize)
+ compressedBuffer.order(ByteOrder.nativeOrder())
+ compressedBuffer.putInt(b.getInt())
+ compressedBuffer.putInt(compressionType.typeID)
+
+ // Write out the dictionary.
compressedBuffer.putInt(_dictionary.size)
- //store the dictionary
- _dictionary.foreach { x =>
- t.append(x._1.asInstanceOf[T], compressedBuffer)
- compressedBuffer.putInt(x._2)
+ _values.foreach { v =>
+ t.append(v.asInstanceOf[T], compressedBuffer)
}
+
+ // Write out the encoded values, each is represented by a short integer.
+ while (b.hasRemaining()) {
+ val v = t.extract(b)
+ compressedBuffer.putShort(_dictionary(v))
+ }
+
+ // Rewind the compressed buffer and return it.
+ compressedBuffer.rewind()
+ compressedBuffer
}
+}
- private def dictionarySize = _dictionarySize + 4
+/**
+* BitSet compression for Boolean values.
+*/
+object BooleanBitSetCompression {
+ val BOOLEANS_PER_LONG : Short = 64
+}
- override def compress[T](b: ByteBuffer, t: ColumnType[T, _]): ByteBuffer = {
- //build a dictionary of given size
- val compressedBuffer = ByteBuffer.allocate(_count*4 + dictionarySize + 4 + 4)
+class BooleanBitSetCompression extends CompressionAlgorithm {
+
+ private var _uncompressedSize = 0
+
+ override def compressionType = BooleanBitSetCompressionType
+
+ override def supportsType(t: ColumnType[_, _]) = {
+ t match {
+ case BOOLEAN => true
+ case _ => false
+ }
+ }
+
+ override def gatherStatsForCompressibility[T](v: T, t: ColumnType[T,_]) {
+ val s = t.actualSize(v)
+ _uncompressedSize += s
+ }
+
+ // Booleans are encoded into Longs; in addition, we need one int to store the number of
+ // Booleans contained in the compressed buffer.
+ override def compressedSize: Int = {
+ math.ceil(_uncompressedSize.toFloat / BooleanBitSetCompression.BOOLEANS_PER_LONG).toInt * 8 + 4
+ }
+
+ override def uncompressedSize: Int = _uncompressedSize
+
+ override def compress[T](b: ByteBuffer, t: ColumnType[T,_]): ByteBuffer = {
+ // Leave 4 extra bytes for column type, another 4 for compression type.
+ val compressedBuffer = ByteBuffer.allocate(4 + 4 + compressedSize)
compressedBuffer.order(ByteOrder.nativeOrder())
compressedBuffer.putInt(b.getInt())
compressedBuffer.putInt(compressionType.typeID)
- //store dictionary size
- writeDictionary(compressedBuffer, t)
- //traverse the original buffer
- while (b.hasRemaining()) {
- val v = t.extract(b.position(), b)
- _dictionary.get(v).map { index =>
- compressedBuffer.putInt(index)
+ compressedBuffer.putInt(b.remaining())
+
+ var cur: Long = 0
+ var pos: Int = 0
+ var offset: Int = 0
+
+ while (b.hasRemaining) {
+ offset = pos % BooleanBitSetCompression.BOOLEANS_PER_LONG
+ val elem = t.extract(b).asInstanceOf[Boolean]
+
+ if (elem) {
+ cur = (cur | (1 << offset)).toLong
}
-
+ if (offset == BooleanBitSetCompression.BOOLEANS_PER_LONG - 1 || !b.hasRemaining) {
+ compressedBuffer.putLong(cur)
+ cur = 0
+ }
+ pos += 1
}
compressedBuffer.rewind()
compressedBuffer
}
-}
\ No newline at end of file
+}
diff --git a/src/main/scala/shark/memstore2/column/MemoryStoreException.scala b/src/main/scala/shark/memstore2/column/MemoryStoreException.scala
new file mode 100644
index 00000000..5db2631d
--- /dev/null
+++ b/src/main/scala/shark/memstore2/column/MemoryStoreException.scala
@@ -0,0 +1,21 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package shark.memstore2.column
+
+
+class MemoryStoreException(message: String) extends Exception(message)
diff --git a/src/main/scala/shark/memstore2/column/NullableColumnBuilder.scala b/src/main/scala/shark/memstore2/column/NullableColumnBuilder.scala
index 2b544f4e..2d79fd87 100644
--- a/src/main/scala/shark/memstore2/column/NullableColumnBuilder.scala
+++ b/src/main/scala/shark/memstore2/column/NullableColumnBuilder.scala
@@ -1,3 +1,20 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package shark.memstore2.column
import java.nio.ByteBuffer
@@ -7,19 +24,20 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector
/**
- * Builds a nullable column. The byte buffer of a nullable column contains
- * the column type, followed by the null count and the index of nulls, followed
- * finally by the non nulls.
+ * Builds a nullable column. The byte buffer of a nullable column contains:
+ * - 4 bytes for the null count (number of nulls)
+ * - positions for each null, in ascending order
+ * - the non-null data (column data type, compression type, data...)
*/
trait NullableColumnBuilder[T] extends ColumnBuilder[T] {
private var _nulls: ByteBuffer = _
-
+
private var _pos: Int = _
- private var _nullCount:Int = _
+ private var _nullCount: Int = _
override def initialize(initialSize: Int): ByteBuffer = {
- _nulls = ByteBuffer.allocate(1024)
+ _nulls = ByteBuffer.allocate(1024)
_nulls.order(ByteOrder.nativeOrder())
_pos = 0
_nullCount = 0
@@ -38,19 +56,16 @@ trait NullableColumnBuilder[T] extends ColumnBuilder[T] {
}
override def build(): ByteBuffer = {
- val b = super.build()
- if (_pos == 0) {
- b
- } else {
- val v = _nulls.position()
- _nulls.limit(v)
- _nulls.rewind()
- val newBuffer = ByteBuffer.allocate(b.limit + v + 4)
- newBuffer.order(ByteOrder.nativeOrder())
- val colType= b.getInt()
- newBuffer.putInt(colType).putInt(_nullCount).put(_nulls).put(b)
- newBuffer.rewind()
- newBuffer
- }
+ val nonNulls = super.build()
+ val nullDataLen = _nulls.position()
+ _nulls.limit(nullDataLen)
+ _nulls.rewind()
+
+ // 4 bytes for null count + null positions + non nulls
+ val newBuffer = ByteBuffer.allocate(4 + nullDataLen + nonNulls.limit)
+ newBuffer.order(ByteOrder.nativeOrder())
+ newBuffer.putInt(_nullCount).put(_nulls).put(nonNulls)
+ newBuffer.rewind()
+ newBuffer
}
-}
\ No newline at end of file
+}
diff --git a/src/main/scala/shark/memstore2/column/NullableColumnIterator.scala b/src/main/scala/shark/memstore2/column/NullableColumnIterator.scala
index 66a3adfa..49e0eb20 100644
--- a/src/main/scala/shark/memstore2/column/NullableColumnIterator.scala
+++ b/src/main/scala/shark/memstore2/column/NullableColumnIterator.scala
@@ -1,3 +1,20 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package shark.memstore2.column
import java.nio.ByteBuffer
@@ -9,26 +26,30 @@ import java.nio.ByteOrder
* Reading of non nulls is delegated by setting the buffer position to the first
* non null.
*/
-class NullableColumnIterator(delegate: ColumnIterator, buffer: ByteBuffer) extends ColumnIterator {
+class NullableColumnIterator(buffer: ByteBuffer) extends ColumnIterator {
private var _d: ByteBuffer = _
private var _nullCount: Int = _
private var _nulls = 0
private var _isNull = false
- private var _currentNullIndex:Int = _
+ private var _currentNullIndex: Int = _
private var _pos = 0
+ private var _delegate: ColumnIterator = _
+
override def init() {
_d = buffer.duplicate()
_d.order(ByteOrder.nativeOrder())
_nullCount = _d.getInt()
- buffer.position(buffer.position() + _nullCount * 4 + 4)
_currentNullIndex = if (_nullCount > 0) _d.getInt() else Integer.MAX_VALUE
_pos = 0
- delegate.init()
+
+ // Move the buffer position to the non-null region.
+ buffer.position(buffer.position() + 4 + _nullCount * 4)
+ _delegate = ColumnIterator.newNonNullIterator(buffer)
}
- override def computeNext() {
+ override def next() {
if (_pos == _currentNullIndex) {
_nulls += 1
if (_nulls < _nullCount) {
@@ -37,12 +58,12 @@ class NullableColumnIterator(delegate: ColumnIterator, buffer: ByteBuffer) exten
_isNull = true
} else {
_isNull = false
- delegate.computeNext()
+ _delegate.next()
}
_pos += 1
}
-
- def current: Object = {
- if (_isNull) null else delegate.current
- }
+
+ override def hasNext: Boolean = (_nulls < _nullCount) || _delegate.hasNext
+
+ def current: Object = if (_isNull) null else _delegate.current
}
diff --git a/src/main/scala/shark/parse/QueryBlock.scala b/src/main/scala/shark/parse/QueryBlock.scala
new file mode 100644
index 00000000..4d79f12a
--- /dev/null
+++ b/src/main/scala/shark/parse/QueryBlock.scala
@@ -0,0 +1,47 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package shark.parse
+
+import org.apache.hadoop.hive.ql.parse.{QB => HiveQueryBlock}
+import org.apache.hadoop.hive.ql.plan.CreateTableDesc
+import org.apache.hadoop.hive.ql.plan.TableDesc
+
+import shark.memstore2.CacheType
+import shark.memstore2.CacheType._
+
+
+/**
+ * A container for flags and table metadata. Used in SharkSemanticAnalyzer while parsing
+ * and analyzing ASTs (e.g. in SharkSemanticAnalyzer#analyzeCreateTable()).
+ */
+class QueryBlock(outerID: String, alias: String, isSubQuery: Boolean)
+ extends HiveQueryBlock(outerID, alias, isSubQuery) {
+
+ // The CacheType for the table that will be created from CREATE TABLE/CTAS, or updated for an
+ // INSERT.
+ var cacheMode = CacheType.NONE
+
+ // Descriptor for the table being updated by an INSERT.
+ var targetTableDesc: TableDesc = _
+
+ // Hive's QB uses `tableDesc` to refer to the CreateTableDesc. A direct `createTableDesc`
+ // makes it easier to differentiate from `_targetTableDesc`.
+ def createTableDesc: CreateTableDesc = super.getTableDesc
+
+ def createTableDesc_= (desc: CreateTableDesc) = super.setTableDesc(desc)
+}
diff --git a/src/main/scala/shark/parse/SharkDDLSemanticAnalyzer.scala b/src/main/scala/shark/parse/SharkDDLSemanticAnalyzer.scala
index a43a4975..3e5f69b2 100644
--- a/src/main/scala/shark/parse/SharkDDLSemanticAnalyzer.scala
+++ b/src/main/scala/shark/parse/SharkDDLSemanticAnalyzer.scala
@@ -1,24 +1,186 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package shark.parse
+import java.util.{HashMap => JavaHashMap}
+
+import scala.collection.JavaConversions._
+
import org.apache.hadoop.hive.conf.HiveConf
-import org.apache.hadoop.hive.ql.parse.{ASTNode, BaseSemanticAnalyzer, DDLSemanticAnalyzer, HiveParser}
+import org.apache.hadoop.hive.ql.exec.TaskFactory
+import org.apache.hadoop.hive.ql.parse.ASTNode
+import org.apache.hadoop.hive.ql.parse.BaseSemanticAnalyzer
+import org.apache.hadoop.hive.ql.parse.DDLSemanticAnalyzer
+import org.apache.hadoop.hive.ql.parse.HiveParser
+import org.apache.hadoop.hive.ql.parse.SemanticException
+import org.apache.hadoop.hive.ql.plan.{AlterTableDesc, DDLWork}
import org.apache.spark.rdd.{UnionRDD, RDD}
import shark.{LogHelper, SharkEnv}
+import shark.execution.{SharkDDLWork, SparkLoadWork}
+import shark.memstore2.{CacheType, MemoryMetadataManager, SharkTblProperties}
class SharkDDLSemanticAnalyzer(conf: HiveConf) extends DDLSemanticAnalyzer(conf) with LogHelper {
- override def analyzeInternal(node: ASTNode): Unit = {
- super.analyzeInternal(node)
- //handle drop table query
- if (node.getToken().getType() == HiveParser.TOK_DROPTABLE) {
- SharkEnv.unpersist(getTableName(node))
+ override def analyzeInternal(ast: ASTNode): Unit = {
+ super.analyzeInternal(ast)
+
+ ast.getToken.getType match {
+ case HiveParser.TOK_ALTERTABLE_ADDPARTS => {
+ analyzeAlterTableAddParts(ast)
+ }
+ case HiveParser.TOK_ALTERTABLE_DROPPARTS => {
+ analyzeDropTableOrDropParts(ast)
+ }
+ case HiveParser.TOK_ALTERTABLE_RENAME => {
+ analyzeAlterTableRename(ast)
+ }
+ case HiveParser.TOK_ALTERTABLE_PROPERTIES => {
+ analyzeAlterTableProperties(ast)
+ }
+ case HiveParser.TOK_DROPTABLE => {
+ analyzeDropTableOrDropParts(ast)
+ }
+ case _ => Unit
+ }
+ }
+
+ /**
+ * Handle table property changes.
+ * How Shark-specific changes are handled:
+ * - "shark.cache":
+ * If the value evaluated by CacheType#shouldCache() is `true`, then create a SparkLoadTask to
+ * load the Hive table into memory.
+ * Set it as a dependent of the Hive DDLTask. A SharkDDLTask counterpart isn't created because
+ * the HadoopRDD creation and transformation isn't a direct Shark metastore operation
+ * (unlike the other cases handled in SharkDDLSemantiAnalyzer). *
+ * If 'false', then create a SharkDDLTask that will delete the table entry in the Shark
+ * metastore.
+ */
+ def analyzeAlterTableProperties(ast: ASTNode) {
+ val databaseName = db.getCurrentDatabase()
+ val tableName = getTableName(ast)
+ val hiveTable = db.getTable(databaseName, tableName)
+ val newTblProps = getAlterTblDesc().getProps
+ val oldTblProps = hiveTable.getParameters
+
+ val oldCacheMode = CacheType.fromString(oldTblProps.get(SharkTblProperties.CACHE_FLAG.varname))
+ val newCacheMode = CacheType.fromString(newTblProps.get(SharkTblProperties.CACHE_FLAG.varname))
+ if ((oldCacheMode == CacheType.TACHYON && newCacheMode != CacheType.TACHYON) ||
+ (oldCacheMode == CacheType.MEMORY_ONLY && newCacheMode != CacheType.MEMORY_ONLY)) {
+ throw new SemanticException("""Table %s.%s's 'shark.cache' table property is %s. Only changes
+ from "'MEMORY' and 'NONE' are supported. Tables stored in TACHYON and MEMORY_ONLY must be
+ "dropped.""".format(databaseName, tableName, oldCacheMode))
+ } else if (newCacheMode == CacheType.MEMORY) {
+ // The table should be cached (and is not already cached).
+ val partSpecsOpt = if (hiveTable.isPartitioned) {
+ val columnNames = hiveTable.getPartCols.map(_.getName)
+ val partSpecs = db.getPartitions(hiveTable).map { partition =>
+ val partSpec = new JavaHashMap[String, String]()
+ val values = partition.getValues()
+ columnNames.zipWithIndex.map { case(name, index) => partSpec.put(name, values(index)) }
+ partSpec
+ }
+ Some(partSpecs)
+ } else {
+ None
+ }
+ newTblProps.put(SharkTblProperties.CACHE_FLAG.varname, newCacheMode.toString)
+ val sparkLoadWork = new SparkLoadWork(
+ databaseName,
+ tableName,
+ SparkLoadWork.CommandTypes.NEW_ENTRY,
+ newCacheMode)
+ partSpecsOpt.foreach(partSpecs => sparkLoadWork.partSpecs = partSpecs)
+ rootTasks.head.addDependentTask(TaskFactory.get(sparkLoadWork, conf))
+ } else if (newCacheMode == CacheType.NONE) {
+ // Uncache the table.
+ SharkEnv.memoryMetadataManager.dropTableFromMemory(db, databaseName, tableName)
}
}
+ def analyzeDropTableOrDropParts(ast: ASTNode) {
+ val databaseName = db.getCurrentDatabase()
+ val tableName = getTableName(ast)
+ val hiveTableOpt = Option(db.getTable(databaseName, tableName, false /* throwException */))
+ // `hiveTableOpt` can be NONE for a DROP TABLE IF EXISTS command on a nonexistent table.
+ hiveTableOpt.foreach { hiveTable =>
+ val cacheMode = CacheType.fromString(
+ hiveTable.getProperty(SharkTblProperties.CACHE_FLAG.varname))
+ // Create a SharkDDLTask only if the table is cached.
+ if (CacheType.shouldCache(cacheMode)) {
+ // Hive's DDLSemanticAnalyzer#analyzeInternal() will only populate rootTasks with DDLTasks
+ // and DDLWorks that contain DropTableDesc objects.
+ for (ddlTask <- rootTasks) {
+ val dropTableDesc = ddlTask.getWork.asInstanceOf[DDLWork].getDropTblDesc
+ val sharkDDLWork = new SharkDDLWork(dropTableDesc)
+ sharkDDLWork.cacheMode = cacheMode
+ ddlTask.addDependentTask(TaskFactory.get(sharkDDLWork, conf))
+ }
+ }
+ }
+ }
+
+ def analyzeAlterTableAddParts(ast: ASTNode) {
+ val databaseName = db.getCurrentDatabase()
+ val tableName = getTableName(ast)
+ val hiveTable = db.getTable(databaseName, tableName)
+ val cacheMode = CacheType.fromString(
+ hiveTable.getProperty(SharkTblProperties.CACHE_FLAG.varname))
+ // Create a SharkDDLTask only if the table is cached.
+ if (CacheType.shouldCache(cacheMode)) {
+ // Hive's DDLSemanticAnalyzer#analyzeInternal() will only populate rootTasks with DDLTasks
+ // and DDLWorks that contain AddPartitionDesc objects.
+ for (ddlTask <- rootTasks) {
+ val addPartitionDesc = ddlTask.getWork.asInstanceOf[DDLWork].getAddPartitionDesc
+ val sharkDDLWork = new SharkDDLWork(addPartitionDesc)
+ sharkDDLWork.cacheMode = cacheMode
+ ddlTask.addDependentTask(TaskFactory.get(sharkDDLWork, conf))
+ }
+ }
+ }
+
+ private def analyzeAlterTableRename(astNode: ASTNode) {
+ val databaseName = db.getCurrentDatabase()
+ val oldTableName = getTableName(astNode)
+ val hiveTable = db.getTable(databaseName, oldTableName)
+ val cacheMode = CacheType.fromString(hiveTable.getProperty(SharkTblProperties.CACHE_FLAG.varname))
+ if (CacheType.shouldCache(cacheMode)) {
+ val alterTableDesc = getAlterTblDesc()
+ val sharkDDLWork = new SharkDDLWork(alterTableDesc)
+ sharkDDLWork.cacheMode = cacheMode
+ rootTasks.head.addDependentTask(TaskFactory.get(sharkDDLWork, conf))
+ }
+ }
+
+ private def getAlterTblDesc(): AlterTableDesc = {
+ // Hive's DDLSemanticAnalyzer#analyzeInternal() will only populate rootTasks with a DDLTask
+ // and DDLWork that contains an AlterTableDesc.
+ assert(rootTasks.size == 1)
+ val ddlTask = rootTasks.head
+ val ddlWork = ddlTask.getWork
+ assert(ddlWork.isInstanceOf[DDLWork])
+ ddlWork.asInstanceOf[DDLWork].getAlterTblDesc
+ }
+
private def getTableName(node: ASTNode): String = {
BaseSemanticAnalyzer.getUnescapedName(node.getChild(0).asInstanceOf[ASTNode])
}
-}
\ No newline at end of file
+}
diff --git a/src/main/scala/shark/parse/SharkExplainSemanticAnalyzer.scala b/src/main/scala/shark/parse/SharkExplainSemanticAnalyzer.scala
index c8d69322..e139ac27 100755
--- a/src/main/scala/shark/parse/SharkExplainSemanticAnalyzer.scala
+++ b/src/main/scala/shark/parse/SharkExplainSemanticAnalyzer.scala
@@ -19,13 +19,11 @@ package shark.parse
import java.io.Serializable
import java.util.ArrayList
-import java.util.List
import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.ql.exec._
import org.apache.hadoop.hive.ql.parse._
-import org.apache.hadoop.hive.ql.plan.ExplainWork
import shark.execution.SharkExplainWork
diff --git a/src/main/scala/shark/parse/SharkLoadSemanticAnalyzer.scala b/src/main/scala/shark/parse/SharkLoadSemanticAnalyzer.scala
new file mode 100644
index 00000000..fc32dbd7
--- /dev/null
+++ b/src/main/scala/shark/parse/SharkLoadSemanticAnalyzer.scala
@@ -0,0 +1,89 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package shark.parse
+
+import scala.collection.JavaConversions._
+
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.ql.exec.{CopyTask, MoveTask, TaskFactory}
+import org.apache.hadoop.hive.ql.metadata.{Partition, Table => HiveTable}
+import org.apache.hadoop.hive.ql.parse.{ASTNode, BaseSemanticAnalyzer, LoadSemanticAnalyzer}
+import org.apache.hadoop.hive.ql.plan._
+
+import shark.{LogHelper, SharkEnv}
+import shark.execution.SparkLoadWork
+import shark.memstore2.{CacheType, SharkTblProperties}
+
+
+class SharkLoadSemanticAnalyzer(conf: HiveConf) extends LoadSemanticAnalyzer(conf) {
+
+ override def analyzeInternal(ast: ASTNode): Unit = {
+ // Delegate to the LoadSemanticAnalyzer parent for error checking the source path formatting.
+ super.analyzeInternal(ast)
+
+ // Children of the AST root created for a LOAD DATA [LOCAL] INPATH ... statement are, in order:
+ // 1. node containing the path specified by INPATH.
+ // 2. internal TOK_TABNAME node that contains the table's name.
+ // 3. (optional) node representing the LOCAL modifier.
+ val tableASTNode = ast.getChild(1).asInstanceOf[ASTNode]
+ val tableName = getTableName(tableASTNode)
+ val hiveTable = db.getTable(tableName)
+ val cacheMode = CacheType.fromString(
+ hiveTable.getProperty(SharkTblProperties.CACHE_FLAG.varname))
+
+ if (CacheType.shouldCache(cacheMode)) {
+ // Find the arguments needed to instantiate a SparkLoadWork.
+ val tableSpec = new BaseSemanticAnalyzer.tableSpec(db, conf, tableASTNode)
+ val hiveTable = tableSpec.tableHandle
+ val moveTask = getMoveTask()
+ val partSpecOpt = Option(tableSpec.getPartSpec)
+ val sparkLoadWork = SparkLoadWork(
+ db,
+ conf,
+ hiveTable,
+ partSpecOpt,
+ isOverwrite = moveTask.getWork.getLoadTableWork.getReplace)
+
+ // Create a SparkLoadTask that will read from the table's data directory. Make it a dependent
+ // task of the LoadTask so that it's executed only if the LoadTask executes successfully.
+ moveTask.addDependentTask(TaskFactory.get(sparkLoadWork, conf))
+ }
+ }
+
+ private def getMoveTask(): MoveTask = {
+ assert(rootTasks.size == 1)
+
+ // If the execution is local, then the root task is a CopyTask with a MoveTask child.
+ // Otherwise, the root is a MoveTask.
+ var rootTask = rootTasks.head
+ val moveTask = if (rootTask.isInstanceOf[CopyTask]) {
+ val firstChildTask = rootTask.getChildTasks.head
+ assert(firstChildTask.isInstanceOf[MoveTask])
+ firstChildTask
+ } else {
+ rootTask
+ }
+
+ // In Hive, LoadTableDesc is referred to as LoadTableWork ...
+ moveTask.asInstanceOf[MoveTask]
+ }
+
+ private def getTableName(node: ASTNode): String = {
+ BaseSemanticAnalyzer.getUnescapedName(node.getChild(0).asInstanceOf[ASTNode])
+ }
+}
diff --git a/src/main/scala/shark/parse/SharkSemanticAnalyzer.scala b/src/main/scala/shark/parse/SharkSemanticAnalyzer.scala
index 199102bc..f3a7b49a 100755
--- a/src/main/scala/shark/parse/SharkSemanticAnalyzer.scala
+++ b/src/main/scala/shark/parse/SharkSemanticAnalyzer.scala
@@ -17,30 +17,33 @@
package shark.parse
-import java.lang.reflect.Method
import java.util.ArrayList
import java.util.{List => JavaList}
+import java.util.{Map => JavaMap}
import scala.collection.JavaConversions._
import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.conf.HiveConf
-import org.apache.hadoop.hive.metastore.api.{FieldSchema, MetaException}
import org.apache.hadoop.hive.metastore.Warehouse
-import org.apache.hadoop.hive.ql.exec.{DDLTask, FetchTask, MoveTask, TaskFactory}
+import org.apache.hadoop.hive.metastore.api.{FieldSchema, MetaException}
+import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
+import org.apache.hadoop.hive.ql.exec.{DDLTask, FetchTask}
import org.apache.hadoop.hive.ql.exec.{FileSinkOperator => HiveFileSinkOperator}
-import org.apache.hadoop.hive.ql.metadata.HiveException
+import org.apache.hadoop.hive.ql.exec.MoveTask
+import org.apache.hadoop.hive.ql.exec.{Operator => HiveOperator}
+import org.apache.hadoop.hive.ql.exec.TaskFactory
+import org.apache.hadoop.hive.ql.metadata.{Hive, HiveException}
import org.apache.hadoop.hive.ql.optimizer.Optimizer
import org.apache.hadoop.hive.ql.parse._
import org.apache.hadoop.hive.ql.plan._
import org.apache.hadoop.hive.ql.session.SessionState
-import org.apache.spark.storage.StorageLevel
-
-import shark.{CachedTableRecovery, LogHelper, SharkConfVars, SharkEnv, Utils}
-import shark.execution.{HiveOperator, Operator, OperatorFactory, RDDUtils, ReduceSinkOperator,
- SparkWork, TerminalOperator}
-import shark.memstore2.{CacheType, ColumnarSerDe, MemoryMetadataManager}
+import shark.{LogHelper, SharkConfVars, SharkEnv, Utils}
+import shark.execution.{HiveDesc, Operator, OperatorFactory, RDDUtils, ReduceSinkOperator}
+import shark.execution.{SharkDDLWork, SparkLoadWork, SparkWork, TerminalOperator}
+import shark.memstore2.{CacheType, ColumnarSerDe, LazySimpleSerDeWrapper, MemoryMetadataManager}
+import shark.memstore2.{MemoryTable, PartitionedMemoryTable, SharkTblProperties, TableRecovery}
/**
@@ -60,73 +63,67 @@ class SharkSemanticAnalyzer(conf: HiveConf) extends SemanticAnalyzer(conf) with
override def getResultSchema() = _resSchema
/**
- * Override SemanticAnalyzer.analyzeInternal to handle CTAS caching.
+ * Override SemanticAnalyzer.analyzeInternal to handle CTAS caching and INSERT updates.
+ *
+ * Unified views:
+ * For CTAS and INSERT INTO/OVERWRITE the generated Shark query plan matches the one
+ * created if the target table were not cached. Disk => memory loading is done by a
+ * SparkLoadTask that executes _after_ all other tasks (SparkTask, Hive MoveTasks) finish
+ * executing. For INSERT INTO, the SparkLoadTask will be able to determine, using a path filter
+ * based on a snapshot of the table/partition data directory taken in genMapRedTasks(), new files
+ * that should be loaded into the cache. For CTAS, a path filter isn't used - everything in the
+ * data directory is loaded into the cache.
+ *
+ * Non-unified views (i.e., the cached table content is memory-only):
+ * The query plan's FileSinkOperator is replaced by a MemoryStoreSinkOperator. The
+ * MemoryStoreSinkOperator creates a new table (or partition) entry in the Shark metastore
+ * for CTAS, and creates UnionRDDs for INSERT INTO commands.
*/
override def analyzeInternal(ast: ASTNode): Unit = {
reset()
- val qb = new QB(null, null, false)
+ val qb = new QueryBlock(null, null, false)
val pctx = getParseContext()
pctx.setQB(qb)
pctx.setParseTree(ast)
init(pctx)
+ // The ASTNode that will be analyzed by SemanticAnalzyer#doPhase1().
var child: ASTNode = ast
- logInfo("Starting Shark Semantic Analysis")
+ logDebug("Starting Shark Semantic Analysis")
//TODO: can probably reuse Hive code for this
- // analyze create table command
- var cacheMode = CacheType.none
- var isCTAS = false
var shouldReset = false
- if (ast.getToken().getType() == HiveParser.TOK_CREATETABLE) {
+ val astTokenType = ast.getToken().getType()
+ if (astTokenType == HiveParser.TOK_CREATEVIEW || astTokenType == HiveParser.TOK_ANALYZE) {
+ // Delegate create view and analyze to Hive.
super.analyzeInternal(ast)
- for (ch <- ast.getChildren) {
- ch.asInstanceOf[ASTNode].getToken.getType match {
- case HiveParser.TOK_QUERY => {
- isCTAS = true
- child = ch.asInstanceOf[ASTNode]
- }
- case _ =>
- Unit
- }
- }
-
- // If the table descriptor can be null if the CTAS has an
- // "if not exists" condition.
- val td = getParseContext.getQB.getTableDesc
- if (!isCTAS || td == null) {
- return
- } else {
- val checkTableName = SharkConfVars.getBoolVar(conf, SharkConfVars.CHECK_TABLENAME_FLAG)
- val cacheType = CacheType.fromString(td.getTblProps().get("shark.cache"))
- if (cacheType == CacheType.heap ||
- (td.getTableName.endsWith("_cached") && checkTableName)) {
- cacheMode = CacheType.heap
- td.getTblProps().put("shark.cache", cacheMode.toString)
- } else if (cacheType == CacheType.tachyon ||
- (td.getTableName.endsWith("_tachyon") && checkTableName)) {
- cacheMode = CacheType.tachyon
- td.getTblProps().put("shark.cache", cacheMode.toString)
+ return
+ } else if (astTokenType == HiveParser.TOK_CREATETABLE) {
+ // Use Hive to do a first analysis pass.
+ super.analyzeInternal(ast)
+ // Do post-Hive analysis of the CREATE TABLE (e.g detect caching mode).
+ analyzeCreateTable(ast, qb) match {
+ case Some(queryStmtASTNode) => {
+ // Set the 'child' to reference the SELECT statement root node, with is a
+ // HiveParer.HIVE_QUERY.
+ child = queryStmtASTNode
+ // Hive's super.analyzeInternal() might generate MapReduce tasks. Avoid executing those
+ // tasks by reset()-ing some Hive SemanticAnalyzer state after doPhase1() is called below.
+ shouldReset = true
}
-
- if (CacheType.shouldCache(cacheMode)) {
- td.setSerName(classOf[ColumnarSerDe].getName)
+ case None => {
+ // Done with semantic analysis if the CREATE TABLE statement isn't a CTAS.
+ return
}
-
- qb.setTableDesc(td)
- shouldReset = true
}
} else {
SessionState.get().setCommandType(HiveOperation.QUERY)
}
- // Delegate create view and analyze to Hive.
- val astTokenType = ast.getToken().getType()
- if (astTokenType == HiveParser.TOK_CREATEVIEW || astTokenType == HiveParser.TOK_ANALYZE) {
- return super.analyzeInternal(ast)
- }
+ // Invariant: At this point, the command will execute a query (i.e., its AST contains a
+ // HiveParser.TOK_QUERY node).
// Continue analyzing from the child ASTNode.
if (!doPhase1(child, qb, initPhase1Ctx())) {
@@ -136,12 +133,14 @@ class SharkSemanticAnalyzer(conf: HiveConf) extends SemanticAnalyzer(conf) with
// Used to protect against recursive views in getMetaData().
SharkSemanticAnalyzer.viewsExpandedField.set(this, new ArrayList[String]())
- logInfo("Completed phase 1 of Shark Semantic Analysis")
+ logDebug("Completed phase 1 of Shark Semantic Analysis")
getMetaData(qb)
- logInfo("Completed getting MetaData in Shark Semantic Analysis")
+ logDebug("Completed getting MetaData in Shark Semantic Analysis")
// Reset makes sure we don't run the mapred jobs generated by Hive.
- if (shouldReset) reset()
+ if (shouldReset) {
+ reset()
+ }
// Save the result schema derived from the sink operator produced
// by genPlan. This has the correct column names, which clients
@@ -169,61 +168,89 @@ class SharkSemanticAnalyzer(conf: HiveConf) extends SemanticAnalyzer(conf) with
// TODO: clean the following code. It's too messy to understand...
val terminalOpSeq = {
- if (qb.getParseInfo.isInsertToTable && !qb.isCTAS) {
+ val qbParseInfo = qb.getParseInfo
+ if (qbParseInfo.isInsertToTable && !qb.isCTAS) {
+ // Handle INSERT. There can be multiple Hive sink operators if the single command comprises
+ // multiple INSERTs.
hiveSinkOps.map { hiveSinkOp =>
- val tableName = hiveSinkOp.asInstanceOf[HiveFileSinkOperator].getConf().getTableInfo()
- .getTableName()
-
+ val tableDesc = hiveSinkOp.asInstanceOf[HiveFileSinkOperator].getConf().getTableInfo()
+ val tableName = tableDesc.getTableName
if (tableName == null || tableName == "") {
// If table name is empty, it is an INSERT (OVERWRITE) DIRECTORY.
OperatorFactory.createSharkFileOutputPlan(hiveSinkOp)
} else {
// Otherwise, check if we are inserting into a table that was cached.
- val cachedTableName = tableName.split('.')(1) // Ignore the database name
- SharkEnv.memoryMetadataManager.get(cachedTableName) match {
- case Some(rdd) => {
- if (hiveSinkOps.size == 1) {
- // If useUnionRDD is false, the sink op is for INSERT OVERWRITE.
- val useUnionRDD = qb.getParseInfo.isInsertIntoTable(cachedTableName)
- val storageLevel = RDDUtils.getStorageLevelOfCachedTable(rdd)
+ val tableNameSplit = tableName.split('.') // Split from 'databaseName.tableName'
+ val cachedTableName = tableNameSplit(1)
+ val databaseName = tableNameSplit(0)
+ val hiveTable = Hive.get().getTable(databaseName, tableName)
+ val cacheMode = CacheType.fromString(
+ hiveTable.getProperty(SharkTblProperties.CACHE_FLAG.varname))
+ if (CacheType.shouldCache(cacheMode)) {
+ if (hiveSinkOps.size == 1) {
+ // INSERT INTO or OVERWRITE update on a cached table.
+ qb.targetTableDesc = tableDesc
+ // If isInsertInto is true, the sink op is for INSERT INTO.
+ val isInsertInto = qbParseInfo.isInsertIntoTable(cachedTableName)
+ val isPartitioned = hiveTable.isPartitioned
+ var hivePartitionKeyOpt = if (isPartitioned) {
+ Some(SharkSemanticAnalyzer.getHivePartitionKey(qb))
+ } else {
+ None
+ }
+ if (cacheMode == CacheType.MEMORY) {
+ // The table being updated is stored in memory and backed by disk, a
+ // SparkLoadTask will be created by the genMapRedTasks() call below. Set fields
+ // in `qb` that will be needed.
+ qb.cacheMode = cacheMode
+ qb.targetTableDesc = tableDesc
+ OperatorFactory.createSharkFileOutputPlan(hiveSinkOp)
+ } else {
OperatorFactory.createSharkMemoryStoreOutputPlan(
hiveSinkOp,
cachedTableName,
- storageLevel,
- _resSchema.size, // numColumns
- cacheMode == CacheType.tachyon, // use tachyon
- useUnionRDD)
- } else {
- throw new SemanticException(
- "Shark does not support updating cached table(s) with multiple INSERTs")
+ databaseName,
+ _resSchema.size, /* numColumns */
+ hivePartitionKeyOpt,
+ cacheMode,
+ isInsertInto)
}
+ } else {
+ throw new SemanticException(
+ "Shark does not support updating cached table(s) with multiple INSERTs")
}
- case None => OperatorFactory.createSharkFileOutputPlan(hiveSinkOp)
+ } else {
+ OperatorFactory.createSharkFileOutputPlan(hiveSinkOp)
}
}
}
} else if (hiveSinkOps.size == 1) {
- // For a single output, we have the option of choosing the output
- // destination (e.g. CTAS with table property "shark.cache" = "true").
Seq {
- if (qb.isCTAS && qb.getTableDesc != null && CacheType.shouldCache(cacheMode)) {
- val storageLevel = MemoryMetadataManager.getStorageLevelFromString(
- qb.getTableDesc().getTblProps.get("shark.cache.storageLevel"))
- qb.getTableDesc().getTblProps().put(CachedTableRecovery.QUERY_STRING, ctx.getCmd())
- OperatorFactory.createSharkMemoryStoreOutputPlan(
- hiveSinkOps.head,
- qb.getTableDesc.getTableName,
- storageLevel,
- _resSchema.size, // numColumns
- cacheMode == CacheType.tachyon, // use tachyon
- false)
+ // For a single output, we have the option of choosing the output
+ // destination (e.g. CTAS with table property "shark.cache" = "true").
+ if (qb.isCTAS && qb.createTableDesc != null && CacheType.shouldCache(qb.cacheMode)) {
+ // The table being created from CTAS should be cached.
+ val tblProps = qb.createTableDesc.getTblProps
+ if (qb.cacheMode == CacheType.MEMORY) {
+ // Save the preferred storage level, since it's needed to create a SparkLoadTask in
+ // genMapRedTasks().
+ OperatorFactory.createSharkFileOutputPlan(hiveSinkOps.head)
+ } else {
+ OperatorFactory.createSharkMemoryStoreOutputPlan(
+ hiveSinkOps.head,
+ qb.createTableDesc.getTableName,
+ qb.createTableDesc.getDatabaseName,
+ numColumns = _resSchema.size,
+ hivePartitionKeyOpt = None,
+ qb.cacheMode,
+ isInsertInto = false)
+ }
} else if (pctx.getContext().asInstanceOf[QueryContext].useTableRddSink && !qb.isCTAS) {
OperatorFactory.createSharkRddOutputPlan(hiveSinkOps.head)
} else {
OperatorFactory.createSharkFileOutputPlan(hiveSinkOps.head)
}
}
-
// A hack for the query plan dashboard to get the query plan. This was
// done for SIGMOD demo. Turn it off by default.
//shark.dashboard.QueryPlanDashboardHandler.terminalOperator = terminalOp
@@ -237,15 +264,14 @@ class SharkSemanticAnalyzer(conf: HiveConf) extends SemanticAnalyzer(conf) with
SharkSemanticAnalyzer.breakHivePlanByStages(terminalOpSeq)
genMapRedTasks(qb, pctx, terminalOpSeq)
- logInfo("Completed plan generation")
+ logDebug("Completed plan generation")
}
/**
* Generate tasks for executing the query, including the SparkTask to do the
* select, the MoveTask for updates, and the DDLTask for CTAS.
*/
- def genMapRedTasks(qb: QB, pctx: ParseContext, terminalOps: Seq[TerminalOperator]) {
-
+ def genMapRedTasks(qb: QueryBlock, pctx: ParseContext, terminalOps: Seq[TerminalOperator]) {
// Create the spark task.
terminalOps.foreach { terminalOp =>
val task = TaskFactory.get(new SparkWork(pctx, terminalOp, _resSchema), conf)
@@ -253,6 +279,7 @@ class SharkSemanticAnalyzer(conf: HiveConf) extends SemanticAnalyzer(conf) with
}
if (qb.getIsQuery) {
+ // Note: CTAS isn't considered a query - it's handled in the 'else' block below.
// Configure FetchTask (used for fetching results to CLIDriver).
val loadWork = getParseContext.getLoadFileWork.get(0)
val cols = loadWork.getColumns
@@ -268,9 +295,10 @@ class SharkSemanticAnalyzer(conf: HiveConf) extends SemanticAnalyzer(conf) with
setFetchTask(fetchTask)
} else {
- // Configure MoveTasks for table updates (e.g. CTAS, INSERT).
+ // Configure MoveTasks for CTAS, INSERT.
val mvTasks = new ArrayList[MoveTask]()
+ // For CTAS, `fileWork` contains a single LoadFileDesc (called "LoadFileWork" in Hive).
val fileWork = getParseContext.getLoadFileWork
val tableWork = getParseContext.getLoadTableWork
tableWork.foreach { ltd =>
@@ -280,13 +308,14 @@ class SharkSemanticAnalyzer(conf: HiveConf) extends SemanticAnalyzer(conf) with
fileWork.foreach { lfd =>
if (qb.isCTAS) {
+ // For CTAS, `lfd.targetDir` references the data directory of the table being created.
var location = qb.getTableDesc.getLocation
if (location == null) {
try {
- val dumpTable = db.newTable(qb.getTableDesc.getTableName)
+ val tableToCreate = db.newTable(qb.getTableDesc.getTableName)
val wh = new Warehouse(conf)
- location = wh.getTablePath(db.getDatabase(dumpTable.getDbName()), dumpTable
- .getTableName()).toString;
+ location = wh.getTablePath(db.getDatabase(tableToCreate.getDbName()), tableToCreate
+ .getTableName()).toString;
} catch {
case e: HiveException => throw new SemanticException(e)
case e: MetaException => throw new SemanticException(e)
@@ -299,9 +328,13 @@ class SharkSemanticAnalyzer(conf: HiveConf) extends SemanticAnalyzer(conf) with
new MoveWork(null, null, null, lfd, false), conf).asInstanceOf[MoveTask])
}
- // The move task depends on all root tasks. In the case of multi outputs,
+ // The move task depends on all root tasks. In the case of multiple outputs,
// the moves are only started once all outputs are executed.
- val hiveFileSinkOp = terminalOps.head.hiveOp
+ // Note: For a CTAS for a memory-only cached table, a MoveTask is still added as a child of
+ // the main SparkTask. However, there no effects from its execution, since the SELECT query
+ // output is piped to Shark's in-memory columnar storage builder, instead of a Hive tmp
+ // directory.
+ // TODO(harvey): Don't create a MoveTask in this case.
mvTasks.foreach { moveTask =>
rootTasks.foreach { rootTask =>
rootTask.addDependentTask(moveTask)
@@ -321,6 +354,44 @@ class SharkSemanticAnalyzer(conf: HiveConf) extends SemanticAnalyzer(conf) with
}
*/
}
+
+ if (qb.cacheMode == CacheType.MEMORY) {
+ // Create a SparkLoadTask used to scan and load disk contents into the cache.
+ val sparkLoadWork = if (qb.isCTAS) {
+ // For cached tables, Shark-specific table properties should be set in
+ // analyzeCreateTable().
+ val tblProps = qb.createTableDesc.getTblProps
+
+ // No need to create a filter, since the entire table data directory should be loaded, nor
+ // pass partition specifications, since partitioned tables can't be created from CTAS.
+ val sparkLoadWork = new SparkLoadWork(
+ qb.createTableDesc.getDatabaseName,
+ qb.createTableDesc.getTableName,
+ SparkLoadWork.CommandTypes.NEW_ENTRY,
+ qb.cacheMode)
+ sparkLoadWork
+ } else {
+ // Split from 'databaseName.tableName'
+ val tableNameSplit = qb.targetTableDesc.getTableName.split('.')
+ val databaseName = tableNameSplit(0)
+ val cachedTableName = tableNameSplit(1)
+ val hiveTable = db.getTable(databaseName, cachedTableName)
+ // None if the table isn't partitioned, or if the partition specified doesn't exist.
+ val partSpecOpt = Option(qb.getMetaData.getDestPartitionForAlias(
+ qb.getParseInfo.getClauseNamesForDest.head)).map(_.getSpec)
+ SparkLoadWork(
+ db,
+ conf,
+ hiveTable,
+ partSpecOpt,
+ isOverwrite = !qb.getParseInfo.isInsertIntoTable(cachedTableName))
+ }
+ // Add a SparkLoadTask as a dependent of all MoveTasks, so that when executed, the table's
+ // (or table partition's) data directory will already contain updates that should be
+ // loaded into memory.
+ val sparkLoadTask = TaskFactory.get(sparkLoadWork, conf)
+ mvTasks.foreach(_.addDependentTask(sparkLoadTask))
+ }
}
// For CTAS, generate a DDL task to create the table. This task should be a
@@ -344,11 +415,108 @@ class SharkSemanticAnalyzer(conf: HiveConf) extends SemanticAnalyzer(conf) with
rootTasks.head.addDependentTask(crtTblTask)
}
}
+
+ def analyzeCreateTable(rootAST: ASTNode, queryBlock: QueryBlock): Option[ASTNode] = {
+ // If we detect that the CREATE TABLE is part of a CTAS, then this is set to the root node of
+ // the query command (i.e., the root node of the SELECT statement).
+ var queryStmtASTNode: Option[ASTNode] = None
+
+ // TODO(harvey): We might be able to reuse the QB passed into this method, as long as it was
+ // created after the super.analyzeInternal() call. That QB and the createTableDesc
+ // should have everything (e.g. isCTAS(), partCols). Note that the QB might not be
+ // accessible from getParseContext(), since the SemanticAnalyzer#analyzeInternal()
+ // doesn't set (this.qb = qb) for a non-CTAS.
+ // True if the command is a CREATE TABLE, but not a CTAS.
+ var isRegularCreateTable = true
+ var isHivePartitioned = false
+
+ for (ch <- rootAST.getChildren) {
+ ch.asInstanceOf[ASTNode].getToken.getType match {
+ case HiveParser.TOK_QUERY => {
+ isRegularCreateTable = false
+ queryStmtASTNode = Some(ch.asInstanceOf[ASTNode])
+ }
+ case _ => Unit
+ }
+ }
+
+ var ddlTasks: Seq[DDLTask] = Nil
+ val createTableDesc = if (isRegularCreateTable) {
+ // Unfortunately, we have to comb the root tasks because for CREATE TABLE,
+ // SemanticAnalyzer#analyzeCreateTable() does't set the CreateTableDesc in its QB.
+ ddlTasks = rootTasks.filter(_.isInstanceOf[DDLTask]).asInstanceOf[Seq[DDLTask]]
+ if (ddlTasks.isEmpty) null else ddlTasks.head.getWork.getCreateTblDesc
+ } else {
+ getParseContext.getQB.getTableDesc
+ }
+
+ // Update the QueryBlock passed into this method.
+ // TODO(harvey): Remove once the TODO above is fixed.
+ queryBlock.setTableDesc(createTableDesc)
+
+ // 'createTableDesc' is NULL if there is an IF NOT EXISTS condition and the target table
+ // already exists.
+ if (createTableDesc != null) {
+ val tableName = createTableDesc.getTableName
+ val checkTableName = SharkConfVars.getBoolVar(conf, SharkConfVars.CHECK_TABLENAME_FLAG)
+ // Note that the CreateTableDesc's table properties are Java Maps, but the TableDesc's table
+ // properties, which are used during execution, are Java Properties.
+ val createTableProperties: JavaMap[String, String] = createTableDesc.getTblProps()
+
+ // There are two cases that will enable caching:
+ // 1) Table name includes "_cached" or "_tachyon".
+ // 2) The "shark.cache" table property is "true", or the string representation of a supported
+ // cache mode (memory, memory-only, Tachyon).
+ var cacheMode = CacheType.fromString(
+ createTableProperties.get(SharkTblProperties.CACHE_FLAG.varname))
+ if (checkTableName) {
+ if (tableName.endsWith("_cached")) {
+ cacheMode = CacheType.MEMORY
+ } else if (tableName.endsWith("_tachyon")) {
+ cacheMode = CacheType.TACHYON
+ }
+ }
+
+ // Continue planning based on the 'cacheMode' read.
+ val shouldCache = CacheType.shouldCache(cacheMode)
+ if (shouldCache) {
+ if (cacheMode == CacheType.MEMORY_ONLY || cacheMode == CacheType.TACHYON) {
+ val serDeName = createTableDesc.getSerName
+ if (serDeName == null || serDeName == classOf[LazySimpleSerDe].getName) {
+ // Hive's SemanticAnalyzer optimizes based on checks for LazySimpleSerDe, which causes
+ // casting exceptions for cached table scans during runtime. Use a simple SerDe wrapper
+ // to guard against these optimizations.
+ createTableDesc.setSerName(classOf[LazySimpleSerDeWrapper].getName)
+ }
+ }
+ createTableProperties.put(SharkTblProperties.CACHE_FLAG.varname, cacheMode.toString)
+ }
+
+ // For CTAS ('isRegularCreateTable' is false), the MemoryStoreSinkOperator creates a new
+ // table metadata entry in the MemoryMetadataManager. The SparkTask that encloses the
+ // MemoryStoreSinkOperator will have a child Hive DDLTask, which creates a new table metadata
+ // entry in the Hive metastore. See genMapRedTasks() for SparkTask creation.
+ if (isRegularCreateTable && shouldCache) {
+ // In Hive, a CREATE TABLE command is handled by a DDLTask, created by
+ // SemanticAnalyzer#analyzeCreateTable(), in 'rootTasks'. The DDL tasks' execution succeeds
+ // only if the CREATE TABLE is valid. So, hook a SharkDDLTask as a child of the Hive DDLTask
+ // so that Shark metadata is updated only if the Hive task execution is successful.
+ val hiveDDLTask = ddlTasks.head;
+ val sharkDDLWork = new SharkDDLWork(createTableDesc)
+ sharkDDLWork.cacheMode = cacheMode
+ hiveDDLTask.addDependentTask(TaskFactory.get(sharkDDLWork, conf))
+ }
+
+ queryBlock.cacheMode = cacheMode
+ queryBlock.setTableDesc(createTableDesc)
+ }
+ queryStmtASTNode
+ }
+
}
object SharkSemanticAnalyzer extends LogHelper {
-
/**
* The reflection object used to invoke convertRowSchemaToViewSchema.
*/
@@ -363,13 +531,22 @@ object SharkSemanticAnalyzer extends LogHelper {
private val viewsExpandedField = classOf[SemanticAnalyzer].getDeclaredField("viewsExpanded")
viewsExpandedField.setAccessible(true)
+ private def getHivePartitionKey(qb: QB): String = {
+ val selectClauseKey = qb.getParseInfo.getClauseNamesForDest.head
+ val destPartition = qb.getMetaData.getDestPartitionForAlias(selectClauseKey)
+ val partitionColumns = destPartition.getTable.getPartCols.map(_.getName)
+ val partitionColumnToValue = destPartition.getSpec
+ MemoryMetadataManager.makeHivePartitionKeyStr(partitionColumns, partitionColumnToValue)
+ }
+
/**
* Given a Hive top operator (e.g. TableScanOperator), find all the file sink
* operators (aka file output operator).
*/
- private def findAllHiveFileSinkOperators(op: HiveOperator): Seq[HiveOperator] = {
+ private def findAllHiveFileSinkOperators(op: HiveOperator[_<: HiveDesc])
+ : Seq[HiveOperator[_<: HiveDesc]] = {
if (op.getChildOperators() == null || op.getChildOperators().size() == 0) {
- Seq[HiveOperator](op)
+ Seq[HiveOperator[_<: HiveDesc]](op)
} else {
op.getChildOperators().flatMap(findAllHiveFileSinkOperators(_)).distinct
}
@@ -384,7 +561,7 @@ object SharkSemanticAnalyzer extends LogHelper {
*/
private def breakHivePlanByStages(terminalOps: Seq[TerminalOperator]) = {
val reduceSinks = new scala.collection.mutable.HashSet[ReduceSinkOperator]
- val queue = new scala.collection.mutable.Queue[Operator[_]]
+ val queue = new scala.collection.mutable.Queue[Operator[_ <: HiveDesc]]
queue ++= terminalOps
while (!queue.isEmpty) {
@@ -399,15 +576,5 @@ object SharkSemanticAnalyzer extends LogHelper {
}
logDebug("Found %d ReduceSinkOperator's.".format(reduceSinks.size))
-
- reduceSinks.foreach { op =>
- val hiveOp = op.asInstanceOf[Operator[HiveOperator]].hiveOp
- if (hiveOp.getChildOperators() != null) {
- hiveOp.getChildOperators().foreach { child =>
- logDebug("Removing child %s from %s".format(child, hiveOp))
- hiveOp.removeChild(child)
- }
- }
- }
}
}
diff --git a/src/main/scala/shark/parse/SharkSemanticAnalyzerFactory.scala b/src/main/scala/shark/parse/SharkSemanticAnalyzerFactory.scala
index 91215988..721ce115 100755
--- a/src/main/scala/shark/parse/SharkSemanticAnalyzerFactory.scala
+++ b/src/main/scala/shark/parse/SharkSemanticAnalyzerFactory.scala
@@ -19,7 +19,7 @@ package shark.parse
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.ql.parse.{ASTNode, BaseSemanticAnalyzer, DDLSemanticAnalyzer,
- SemanticAnalyzerFactory, ExplainSemanticAnalyzer, SemanticAnalyzer}
+ ExplainSemanticAnalyzer, LoadSemanticAnalyzer, SemanticAnalyzerFactory, SemanticAnalyzer}
import shark.SharkConfVars
@@ -30,18 +30,19 @@ object SharkSemanticAnalyzerFactory {
* Return a semantic analyzer for the given ASTNode.
*/
def get(conf: HiveConf, tree:ASTNode): BaseSemanticAnalyzer = {
- val baseSem = SemanticAnalyzerFactory.get(conf, tree)
-
- if (baseSem.isInstanceOf[SemanticAnalyzer]) {
- new SharkSemanticAnalyzer(conf)
- } else if (baseSem.isInstanceOf[ExplainSemanticAnalyzer] &&
- SharkConfVars.getVar(conf, SharkConfVars.EXPLAIN_MODE) == "shark") {
- new SharkExplainSemanticAnalyzer(conf)
- } else if (baseSem.isInstanceOf[DDLSemanticAnalyzer]) {
- new SharkDDLSemanticAnalyzer(conf)
- } else {
- baseSem
+ val explainMode = SharkConfVars.getVar(conf, SharkConfVars.EXPLAIN_MODE) == "shark"
+
+ SemanticAnalyzerFactory.get(conf, tree) match {
+ case _: SemanticAnalyzer =>
+ new SharkSemanticAnalyzer(conf)
+ case _: ExplainSemanticAnalyzer if explainMode =>
+ new SharkExplainSemanticAnalyzer(conf)
+ case _: DDLSemanticAnalyzer =>
+ new SharkDDLSemanticAnalyzer(conf)
+ case _: LoadSemanticAnalyzer =>
+ new SharkLoadSemanticAnalyzer(conf)
+ case sem: BaseSemanticAnalyzer =>
+ sem
}
}
}
-
diff --git a/src/main/scala/shark/repl/Main.scala b/src/main/scala/shark/repl/Main.scala
index 1fa22da5..890a74ef 100755
--- a/src/main/scala/shark/repl/Main.scala
+++ b/src/main/scala/shark/repl/Main.scala
@@ -17,11 +17,21 @@
package shark.repl
+import org.apache.hadoop.hive.common.LogUtils
+import org.apache.hadoop.hive.common.LogUtils.LogInitializationException
+
+
/**
* Shark's REPL entry point.
*/
object Main {
+ try {
+ LogUtils.initHiveLog4j()
+ } catch {
+ case e: LogInitializationException => // Ignore the error.
+ }
+
private var _interp: SharkILoop = null
def interp = _interp
diff --git a/src/main/scala/shark/tachyon/TachyonUtil.scala b/src/main/scala/shark/tachyon/TachyonUtil.scala
index 3a50eead..25207d91 100644
--- a/src/main/scala/shark/tachyon/TachyonUtil.scala
+++ b/src/main/scala/shark/tachyon/TachyonUtil.scala
@@ -22,8 +22,7 @@ import java.util.BitSet
import org.apache.spark.rdd.RDD
-import shark.memstore2.TablePartition
-
+import shark.memstore2.{TablePartition, TablePartitionStats}
/**
@@ -32,17 +31,27 @@ import shark.memstore2.TablePartition
* even without Tachyon jars.
*/
abstract class TachyonUtil {
+
def pushDownColumnPruning(rdd: RDD[_], columnUsed: BitSet): Boolean
def tachyonEnabled(): Boolean
- def tableExists(tableName: String): Boolean
+ def tableExists(tableKey: String, hivePartitionKeyOpt: Option[String]): Boolean
+
+ def dropTable(tableKey: String, hivePartitionKeyOpt: Option[String]): Boolean
- def dropTable(tableName: String): Boolean
+ def createDirectory(tableKey: String, hivePartitionKeyOpt: Option[String]): Boolean
- def getTableMetadata(tableName: String): ByteBuffer
+ def renameDirectory(oldName: String, newName: String): Boolean
- def createRDD(tableName: String): RDD[TablePartition]
+ def createRDD(
+ tableKey: String,
+ hivePartitionKeyOpt: Option[String]
+ ): Seq[(RDD[TablePartition], collection.Map[Int, TablePartitionStats])]
- def createTableWriter(tableName: String, numColumns: Int): TachyonTableWriter
+ def createTableWriter(
+ tableKey: String,
+ hivePartitionKey: Option[String],
+ numColumns: Int
+ ): TachyonTableWriter
}
diff --git a/src/main/scala/shark/tgf/TGF.scala b/src/main/scala/shark/tgf/TGF.scala
new file mode 100644
index 00000000..b57d4053
--- /dev/null
+++ b/src/main/scala/shark/tgf/TGF.scala
@@ -0,0 +1,303 @@
+/*
+ * Copyright (C) 2013 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package shark.tgf
+
+import java.sql.Timestamp
+import java.util.Date
+
+import scala.language.implicitConversions
+import scala.reflect.{classTag, ClassTag}
+import scala.util.parsing.combinator._
+
+import org.apache.spark.rdd.RDD
+
+import shark.api._
+import shark.SharkContext
+import java.lang.reflect.Method
+
+/**
+ * This object is responsible for handling TGF (Table Generating Function) commands.
+ *
+ * {{{
+ * -- TGF Commands --
+ * GENERATE tgfname(param1, param2, ... , param_n)
+ * GENERATE tgfname(param1, param2, ... , param_n) AS tablename
+ * }}}
+ *
+ * Parameters can either be of primitive types, eg int, or of type RDD[Product]. TGF.execute()
+ * will use reflection looking for an object of name "tgfname", invoking apply() with the
+ * primitive values. If the type of a parameter to apply() is RDD[Product], it will assume the
+ * parameter is the name of a table, which it will turn into an RDD before invoking apply().
+ *
+ * For example, "GENERATE MyObj(25, emp)" will invoke
+ * MyObj.apply(25, sc.sql2rdd("select * from emp"))
+ * , assuming the TGF object (MyObj) has an apply function that takes an int and an RDD[Product].
+ *
+ * The "as" version of the command saves the output in a new table named "tablename",
+ * whereas the other version returns a ResultSet.
+ *
+ * -- Defining TGF objects --
+ * TGF objects need to have an apply() function and take an arbitrary number of either primitive
+ * or RDD[Product] typed parameters. The apply() function should either return an RDD[Product]
+ * or RDDSchema. When the former case is used, the returned table's schema and column names need
+ * to be defined through a Java annotation called @Schema. Here is a short example:
+ * {{{
+ * object MyTGF1 {
+ * \@Schema(spec = "name string, age int")
+ * def apply(table1: RDD[(String, String, Int)]): RDD[Product] = {
+ * // code that manipulates table1 and returns a new RDD of tuples
+ * }
+ * }
+ * }}}
+ *
+ * Sometimes, the TGF dynamically determines the number or types of columns returned. In this case,
+ * the TGF can use the RDDSchema return type instead of Java annotations. RDDSchema simply contains
+ * a schema string and an RDD of results. For example:
+ * {{{
+ * object MyTGF2 {
+ * \@Schema(spec = "name string, age int")
+ * def apply(table1: RDD[(String, String, Int)]): RDD[Product] = {
+ * // code that manipulates table1 and creates a result rdd
+ * return RDDSchema(rdd.asInstanceOf[RDD[Seq[_]]], "name string, age int")
+ * }
+ * }
+ * }}}
+ *
+ * Sometimes the TGF needs to internally make SQL calls. For that, it needs access to a
+ * SharkContext object. Therefore,
+ * {{{
+ * def apply(sc: SharkContext, table1: RDD[(String, String, Int)]): RDD[Product] = {
+ * // code that can use sc, for example by calling sc.sql2rdd()
+ * // code that manipulates table1 and returns a new RDD of tuples
+ * }
+ * }}}
+ */
+
+object TGF {
+ private val parser = new TGFParser
+
+ /**
+ * Executes a TGF command and gives back the ResultSet.
+ * Mainly to be used from SharkContext (e.g. runSql())
+ *
+ * @param sql TGF command, e.g. "GENERATE name(params) AS tablename"
+ * @param sc SharkContext
+ * @return ResultSet containing the results of the command
+ */
+ def execute(sql: String, sc: SharkContext): ResultSet = {
+ val ast = parser.parseAll(parser.tgf, sql).getOrElse(
+ throw new QueryExecutionException("TGF parse error: "+ sql))
+
+ val (tableNameOpt, tgfName, params) = ast match {
+ case (tgfName, params) =>
+ (None, tgfName.asInstanceOf[String], params.asInstanceOf[List[String]])
+ case (tableName, tgfName, params) =>
+ (Some(tableName.asInstanceOf[String]), tgfName.asInstanceOf[String],
+ params.asInstanceOf[List[String]])
+ }
+
+ val obj = reflectInvoke(tgfName, params, sc)
+ val (rdd, schema) = getSchema(obj, tgfName)
+
+ val (sharkSchema, resultArr) = tableNameOpt match {
+ case Some(tableName) => // materialize results
+ val helper = new RDDTableFunctions(rdd, schema.map { case (_, tpe) => toClassTag(tpe) })
+ helper.saveAsTable(tableName, schema.map{ case (name, _) => name})
+ (Array[ColumnDesc](), Array[Array[Object]]())
+ case None => // return results
+ val newSchema = schema.map { case (name, tpe) =>
+ new ColumnDesc(name, DataTypes.fromClassTag(toClassTag(tpe)))
+ }
+ val res = rdd.collect().map{p => p.map( _.asInstanceOf[Object] ).toArray}
+ (newSchema.toArray, res)
+ }
+ new ResultSet(sharkSchema, resultArr)
+ }
+
+ private def getMethod(tgfName: String, methodName: String): Option[Method] = {
+ val tgfClazz = try {
+ Thread.currentThread().getContextClassLoader.loadClass(tgfName)
+ } catch {
+ case ex: ClassNotFoundException =>
+ throw new QueryExecutionException("Couldn't find TGF class: " + tgfName)
+ }
+
+ val methods = tgfClazz.getDeclaredMethods.filter(_.getName == methodName)
+ if (methods.isEmpty) None else Some(methods(0))
+ }
+
+ private def getSchema(tgfOutput: Object, tgfName: String): (RDD[Seq[_]], Seq[(String,String)]) = {
+ tgfOutput match {
+ case rddSchema: RDDSchema =>
+ val schema = parser.parseAll(parser.schema, rddSchema.schema)
+
+ (rddSchema.rdd, schema.get)
+ case rdd: RDD[Product] =>
+ val applyMethod = getMethod(tgfName, "apply")
+ if (applyMethod == None) {
+ throw new QueryExecutionException("TGF lacking apply() method")
+ }
+
+ val annotations = applyMethod.get.getAnnotation(classOf[Schema])
+ if (annotations == null || annotations.spec() == null) {
+ throw new QueryExecutionException("No schema annotation found for TGF")
+ }
+
+ // TODO: How can we compare schema with None?
+ val schema = parser.parseAll(parser.schema, annotations.spec())
+ if (schema.isEmpty) {
+ throw new QueryExecutionException(
+ "Error parsing TGF schema annotation (@Schema(spec=...)")
+ }
+
+ (rdd.map(_.productIterator.toList), schema.get)
+ case _ =>
+ throw new QueryExecutionException("TGF output needs to be of type RDD or RDDSchema")
+ }
+ }
+
+ private def reflectInvoke(tgfName: String, paramStrs: Seq[String], sc: SharkContext) = {
+
+ val applyMethodOpt = getMethod(tgfName, "apply")
+ if (applyMethodOpt.isEmpty) {
+ throw new QueryExecutionException("TGF " + tgfName + " needs to implement apply()")
+ }
+
+ val applyMethod = applyMethodOpt.get
+
+ val typeNames: Seq[String] = applyMethod.getParameterTypes.toList.map(_.toString)
+
+ val augParams =
+ if (!typeNames.isEmpty && typeNames.head.startsWith("class shark.SharkContext")) {
+ Seq("sc") ++ paramStrs
+ } else {
+ paramStrs
+ }
+
+ if (augParams.length != typeNames.length) {
+ throw new QueryExecutionException("Expecting " + typeNames.length +
+ " parameters to " + tgfName + ", got " + augParams.length)
+ }
+
+ val params = (augParams.toList zip typeNames.toList).map {
+ case (param: String, tpe: String) if tpe.startsWith("class shark.SharkContext") =>
+ sc
+ case (param: String, tpe: String) if tpe.startsWith("class org.apache.spark.rdd.RDD") =>
+ tableRdd(sc, param)
+ case (param: String, tpe: String) if tpe.startsWith("long") =>
+ param.toLong
+ case (param: String, tpe: String) if tpe.startsWith("int") =>
+ param.toInt
+ case (param: String, tpe: String) if tpe.startsWith("double") =>
+ param.toDouble
+ case (param: String, tpe: String) if tpe.startsWith("float") =>
+ param.toFloat
+ case (param: String, tpe: String) if tpe.startsWith("class java.lang.String") ||
+ tpe.startsWith("class String") =>
+ param.stripPrefix("\"").stripSuffix("\"")
+ case (param: String, tpe: String) =>
+ throw new QueryExecutionException(s"Expected TGF parameter type: $tpe ($param)")
+ }
+
+ applyMethod.invoke(null, params.asInstanceOf[List[Object]] : _*)
+ }
+
+ private def toClassTag(tpe: String): ClassTag[_] = {
+ if (tpe == "boolean") classTag[Boolean]
+ else if (tpe == "tinyint") classTag[Byte]
+ else if (tpe == "smallint") classTag[Short]
+ else if (tpe == "int") classTag[Integer]
+ else if (tpe == "bigint") classTag[Long]
+ else if (tpe == "float") classTag[Float]
+ else if (tpe == "double") classTag[Double]
+ else if (tpe == "string") classTag[String]
+ else if (tpe == "timestamp") classTag[Timestamp]
+ else if (tpe == "date") classTag[Date]
+ else {
+ throw new QueryExecutionException("Unknown column type specified in schema (" + tpe + ")")
+ }
+ }
+
+ def tableRdd(sc: SharkContext, tableName: String): RDD[_] = {
+ val rdd = sc.sql2rdd("SELECT * FROM " + tableName)
+ rdd.schema.size match {
+ case 2 => new TableRDD2(rdd, Seq())
+ case 3 => new TableRDD3(rdd, Seq())
+ case 4 => new TableRDD4(rdd, Seq())
+ case 5 => new TableRDD5(rdd, Seq())
+ case 6 => new TableRDD6(rdd, Seq())
+ case 7 => new TableRDD7(rdd, Seq())
+ case 8 => new TableRDD8(rdd, Seq())
+ case 9 => new TableRDD9(rdd, Seq())
+ case 10 => new TableRDD10(rdd, Seq())
+ case 11 => new TableRDD11(rdd, Seq())
+ case 12 => new TableRDD12(rdd, Seq())
+ case 13 => new TableRDD13(rdd, Seq())
+ case 14 => new TableRDD14(rdd, Seq())
+ case 15 => new TableRDD15(rdd, Seq())
+ case 16 => new TableRDD16(rdd, Seq())
+ case 17 => new TableRDD17(rdd, Seq())
+ case 18 => new TableRDD18(rdd, Seq())
+ case 19 => new TableRDD19(rdd, Seq())
+ case 20 => new TableRDD20(rdd, Seq())
+ case 21 => new TableRDD21(rdd, Seq())
+ case 22 => new TableRDD22(rdd, Seq())
+ case _ => new TableSeqRDD(rdd)
+ }
+ }
+}
+
+case class RDDSchema(rdd: RDD[Seq[_]], schema: String)
+
+private class TGFParser extends JavaTokenParsers {
+
+ // Code to enable case-insensitive modifiers to strings, e.g.
+ // "Berkeley".ci will match "berkeley"
+ class MyString(str: String) {
+ def ci: Parser[String] = ("(?i)" + str).r
+ }
+
+ implicit def stringToRichString(str: String): MyString = new MyString(str)
+
+ def tgf: Parser[Any] = saveTgf | basicTgf
+
+ /**
+ * @return Tuple2 containing a TGF method name and a List of parameters as strings
+ */
+ def basicTgf: Parser[(String, List[String])] = {
+ ("GENERATE".ci ~> methodName) ~ (("(" ~> repsep(param, ",")) <~ ")") ^^
+ { case id1 ~ x => (id1, x.asInstanceOf[List[String]]) }
+ }
+
+ /**
+ * @return Tuple3 containing a table name, TGF method name and a List of parameters as strings
+ */
+ def saveTgf: Parser[(String, String, List[String])] = {
+ (("GENERATE".ci ~> methodName) ~ (("(" ~> repsep(param, ",")) <~ ")")) ~ (("AS".ci) ~>
+ ident) ^^ { case id1 ~ x ~ id2 => (id2, id1, x.asInstanceOf[List[String]]) }
+ }
+
+ def schema: Parser[Seq[(String,String)]] = repsep(nameType, ",")
+
+ def nameType: Parser[(String,String)] = ident ~ ident ^^ { case name~tpe => Tuple2(name, tpe) }
+
+ def param: Parser[Any] = stringLiteral | floatingPointNumber | decimalNumber | ident |
+ failure("Expected a string, number, or identifier as parameters in TGF")
+
+ def methodName: Parser[String] = """[a-zA-Z_][\w\.]*""".r
+}
diff --git a/src/main/scala/shark/util/BloomFilter.scala b/src/main/scala/shark/util/BloomFilter.scala
index 6a26b9e5..3a798d28 100644
--- a/src/main/scala/shark/util/BloomFilter.scala
+++ b/src/main/scala/shark/util/BloomFilter.scala
@@ -1,9 +1,27 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package shark.util
import java.util.BitSet
import java.nio.charset.Charset
-import scala.math._
-import com.google.common.primitives.Bytes
+
+import scala.math.{ceil, log}
+
import com.google.common.primitives.Ints
import com.google.common.primitives.Longs
@@ -16,13 +34,12 @@ import com.google.common.primitives.Longs
* @param expectedSize is the number of elements to be contained in the filter.
* @param numHashes is the number of hash functions.
* @author Ram Sriharsha (harshars at yahoo-inc dot com)
- * @date 07/07/2013
*/
class BloomFilter(numBitsPerElement: Double, expectedSize: Int, numHashes: Int)
- extends AnyRef with Serializable{
+ extends AnyRef with Serializable {
val SEED = System.getProperty("shark.bloomfilter.seed","1234567890").toInt
- val bitSetSize = ceil(numBitsPerElement * expectedSize).toInt
+ val bitSetSize = math.ceil(numBitsPerElement * expectedSize).toInt
val bitSet = new BitSet(bitSetSize)
/**
@@ -51,7 +68,7 @@ class BloomFilter(numBitsPerElement: Double, expectedSize: Int, numHashes: Int)
* Optimization to allow reusing the same input buffer by specifying
* the length of the buffer that contains the bytes to be hashed.
* @param data is the bytes to be hashed.
- * @param length is the length of the buffer to examine.
+ * @param len is the length of the buffer to examine.
*/
def add(data: Array[Byte], len: Int) {
val hashes = hash(data, numHashes, len)
@@ -96,9 +113,9 @@ class BloomFilter(numBitsPerElement: Double, expectedSize: Int, numHashes: Int)
* Optimization to allow reusing the same input buffer by specifying
* the length of the buffer that contains the bytes to be hashed.
* @param data is the bytes to be hashed.
- * @param length is the length of the buffer to examine.
+ * @param len is the length of the buffer to examine.
* @return true with some false positive probability and false if the
- * bytes is not contained in the bloom filter.
+ * bytes is not contained in the bloom filter.
*/
def contains(data: Array[Byte], len: Int): Boolean = {
!hash(data,numHashes, len).exists {
@@ -119,14 +136,17 @@ class BloomFilter(numBitsPerElement: Double, expectedSize: Int, numHashes: Int)
MurmurHash3_x86_128.hash(data, SEED + i, len, results)
a(i) = results(0).abs
var j = i + 1
- if (j < n)
+ if (j < n) {
a(j) = results(1).abs
+ }
j += 1
- if (j < n)
+ if (j < n) {
a(j) = results(2).abs
+ }
j += 1
- if (j < n)
+ if (j < n) {
a(j) = results(3).abs
+ }
i += 1
}
a
@@ -139,4 +159,4 @@ object BloomFilter {
def numHashes(fpp: Double, expectedSize: Int) = ceil(-(log(fpp) / log(2))).toInt
-}
\ No newline at end of file
+}
diff --git a/src/main/scala/shark/util/HiveUtils.scala b/src/main/scala/shark/util/HiveUtils.scala
new file mode 100644
index 00000000..46465993
--- /dev/null
+++ b/src/main/scala/shark/util/HiveUtils.scala
@@ -0,0 +1,142 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package shark.util
+
+import java.util.{Arrays => JArrays, ArrayList => JArrayList}
+import java.util.{HashMap => JHashMap, HashSet => JHashSet}
+import java.util.Properties
+
+import scala.reflect.ClassTag
+import scala.collection.JavaConversions._
+
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.metastore.api.Constants.META_TABLE_PARTITION_COLUMNS
+import org.apache.hadoop.hive.metastore.api.FieldSchema
+import org.apache.hadoop.hive.serde2.Deserializer
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector
+import org.apache.hadoop.hive.serde2.objectinspector.UnionStructObjectInspector
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
+import org.apache.hadoop.hive.ql.exec.DDLTask
+import org.apache.hadoop.hive.ql.hooks.{ReadEntity, WriteEntity}
+import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, DDLWork, DropTableDesc}
+
+import shark.api.{DataType, DataTypes}
+import shark.memstore2.SharkTblProperties
+
+
+private[shark] object HiveUtils {
+
+ def getJavaPrimitiveObjectInspector(c: ClassTag[_]): PrimitiveObjectInspector = {
+ getJavaPrimitiveObjectInspector(DataTypes.fromClassTag(c))
+ }
+
+ def getJavaPrimitiveObjectInspector(t: DataType): PrimitiveObjectInspector = t match {
+ case DataTypes.BOOLEAN => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector
+ case DataTypes.TINYINT => PrimitiveObjectInspectorFactory.javaByteObjectInspector
+ case DataTypes.SMALLINT => PrimitiveObjectInspectorFactory.javaShortObjectInspector
+ case DataTypes.INT => PrimitiveObjectInspectorFactory.javaIntObjectInspector
+ case DataTypes.BIGINT => PrimitiveObjectInspectorFactory.javaLongObjectInspector
+ case DataTypes.FLOAT => PrimitiveObjectInspectorFactory.javaFloatObjectInspector
+ case DataTypes.DOUBLE => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector
+ case DataTypes.TIMESTAMP => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector
+ case DataTypes.STRING => PrimitiveObjectInspectorFactory.javaStringObjectInspector
+ }
+
+ /**
+ * Return a UnionStructObjectInspector that combines the StructObjectInspectors for the table
+ * schema and the partition columns, which are virtual in Hive.
+ */
+ def makeUnionOIForPartitionedTable(
+ partProps: Properties,
+ partSerDe: Deserializer): UnionStructObjectInspector = {
+ val partCols = partProps.getProperty(META_TABLE_PARTITION_COLUMNS)
+ val partColNames = new JArrayList[String]
+ val partColObjectInspectors = new JArrayList[ObjectInspector]
+ partCols.trim().split("/").foreach { colName =>
+ partColNames.add(colName)
+ partColObjectInspectors.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
+ }
+
+ val partColObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector(
+ partColNames, partColObjectInspectors)
+ val oiList = JArrays.asList(
+ partSerDe.getObjectInspector.asInstanceOf[StructObjectInspector],
+ partColObjectInspector.asInstanceOf[StructObjectInspector])
+ // New oi is union of table + partition object inspectors
+ ObjectInspectorFactory.getUnionStructObjectInspector(oiList)
+ }
+
+ /**
+ * Execute the create table DDL operation against Hive's metastore.
+ */
+ def createTableInHive(
+ tableName: String,
+ columnNames: Seq[String],
+ columnTypes: Seq[ClassTag[_]],
+ hiveConf: HiveConf = new HiveConf): Boolean = {
+ val schema = columnNames.zip(columnTypes).map { case (colName, classTag) =>
+ new FieldSchema(colName, DataTypes.fromClassTag(classTag).hiveName, "")
+ }
+
+ // Setup the create table descriptor with necessary information.
+ val createTableDesc = new CreateTableDesc()
+ createTableDesc.setTableName(tableName)
+ createTableDesc.setCols(new JArrayList[FieldSchema](schema))
+ createTableDesc.setTblProps(
+ SharkTblProperties.initializeWithDefaults(new JHashMap[String, String]()))
+ createTableDesc.setInputFormat("org.apache.hadoop.mapred.TextInputFormat")
+ createTableDesc.setOutputFormat("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")
+ createTableDesc.setSerName(classOf[shark.memstore2.ColumnarSerDe].getName)
+ createTableDesc.setNumBuckets(-1)
+
+ // Execute the create table against the Hive metastore.
+ val work = new DDLWork(new JHashSet[ReadEntity], new JHashSet[WriteEntity], createTableDesc)
+ val taskExecutionStatus = executeDDLTaskDirectly(work, hiveConf)
+ taskExecutionStatus == 0
+ }
+
+ def dropTableInHive(tableName: String, hiveConf: HiveConf = new HiveConf): Boolean = {
+ // Setup the drop table descriptor with necessary information.
+ val dropTblDesc = new DropTableDesc(
+ tableName,
+ false /* expectView */,
+ false /* ifExists */,
+ false /* stringPartitionColumns */)
+
+ // Execute the drop table against the metastore.
+ val work = new DDLWork(new JHashSet[ReadEntity], new JHashSet[WriteEntity], dropTblDesc)
+ val taskExecutionStatus = executeDDLTaskDirectly(work, hiveConf)
+ taskExecutionStatus == 0
+ }
+
+ /**
+ * Creates a DDLTask from the DDLWork given, and directly calls DDLTask#execute(). Returns 0 if
+ * the create table command is executed successfully.
+ * This is safe to use for all DDL commands except for AlterTableTypes.ARCHIVE, which actually
+ * requires the DriverContext created in Hive Driver#execute().
+ */
+ def executeDDLTaskDirectly(ddlWork: DDLWork, hiveConf: HiveConf): Int = {
+ val task = new DDLTask()
+ task.initialize(hiveConf, null /* queryPlan */, null /* ctx: DriverContext */)
+ task.setWork(ddlWork)
+ task.execute(null /* driverContext */)
+ }
+}
diff --git a/src/main/scala/shark/util/MurmurHash3_x86_128.scala b/src/main/scala/shark/util/MurmurHash3_x86_128.scala
index 5dcc6068..ff230ee5 100644
--- a/src/main/scala/shark/util/MurmurHash3_x86_128.scala
+++ b/src/main/scala/shark/util/MurmurHash3_x86_128.scala
@@ -1,7 +1,23 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package shark.util
import java.lang.Integer.{ rotateLeft => rotl }
-import scala.math._
/**
* The MurmurHash3_x86_128(...) is a fast, non-cryptographic, 128-bit hash
@@ -109,7 +125,7 @@ object MurmurHash3_x86_128 {
* @param seed is the seed for the murmurhash algorithm.
* @param length is the length of the buffer to use for hashing.
* @param results is the output buffer to store the four ints that are returned,
- * should have size atleast 4.
+ * should have size at least 4.
*/
@inline final def hash(data: Array[Byte], seed: Int, length: Int,
results: Array[Int]): Unit = {
@@ -177,18 +193,18 @@ object MurmurHash3_x86_128 {
* @param rem is the remainder of the byte array to examine.
*/
@inline final def getInt(data: Array[Byte], index: Int, rem: Int): Int = {
- rem match {
+ rem match {
case 3 => data(index) << 24 |
- (data(index + 1) & 0xFF) << 16 |
- (data(index + 2) & 0xFF) << 8
+ (data(index + 1) & 0xFF) << 16 |
+ (data(index + 2) & 0xFF) << 8
case 2 => data(index) << 24 |
- (data(index + 1) & 0xFF) << 16
+ (data(index + 1) & 0xFF) << 16
case 1 => data(index) << 24
case 0 => 0
case _ => data(index) << 24 |
- (data(index + 1) & 0xFF) << 16 |
- (data(index + 2) & 0xFF) << 8 |
- (data(index + 3) & 0xFF)
+ (data(index + 1) & 0xFF) << 16 |
+ (data(index + 2) & 0xFF) << 8 |
+ (data(index + 3) & 0xFF)
}
}
-}
\ No newline at end of file
+}
diff --git a/src/main/scala/shark/util/QueryRewriteUtils.scala b/src/main/scala/shark/util/QueryRewriteUtils.scala
new file mode 100644
index 00000000..8d44f8a8
--- /dev/null
+++ b/src/main/scala/shark/util/QueryRewriteUtils.scala
@@ -0,0 +1,48 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package shark.util
+
+import org.apache.hadoop.hive.ql.parse.SemanticException
+
+import shark.memstore2.SharkTblProperties
+
+
+object QueryRewriteUtils {
+
+ def cacheToAlterTable(cmd: String): String = {
+ val cmdSplit = cmd.split(' ')
+ if (cmdSplit.size == 2) {
+ val tableName = cmdSplit(1)
+ "ALTER TABLE %s SET TBLPROPERTIES ('shark.cache' = 'true')".format(tableName)
+ } else {
+ throw new SemanticException(
+ s"CACHE accepts a single table name: 'CACHE
' (received command: '$cmd')")
+ }
+ }
+
+ def uncacheToAlterTable(cmd: String): String = {
+ val cmdSplit = cmd.split(' ')
+ if (cmdSplit.size == 2) {
+ val tableName = cmdSplit(1)
+ "ALTER TABLE %s SET TBLPROPERTIES ('shark.cache' = 'false')".format(tableName)
+ } else {
+ throw new SemanticException(
+ s"UNCACHE accepts a single table name: 'UNCACHE ' (received command: '$cmd')")
+ }
+ }
+}
diff --git a/src/tachyon_disabled/scala/shark/tachyon/TachyonUtilImpl.scala b/src/tachyon_disabled/scala/shark/tachyon/TachyonUtilImpl.scala
index 3f1d2eba..dbdf1ff6 100644
--- a/src/tachyon_disabled/scala/shark/tachyon/TachyonUtilImpl.scala
+++ b/src/tachyon_disabled/scala/shark/tachyon/TachyonUtilImpl.scala
@@ -22,35 +22,51 @@ import java.util.BitSet
import org.apache.spark.rdd.RDD
-import shark.memstore2.TablePartition
-
+import shark.memstore2.{Table, TablePartition, TablePartitionStats}
class TachyonUtilImpl(val master: String, val warehousePath: String) extends TachyonUtil {
+
override def pushDownColumnPruning(rdd: RDD[_], columnUsed: BitSet): Boolean = false
override def tachyonEnabled(): Boolean = false
- override def tableExists(tableName: String): Boolean = {
+ override def tableExists(tableKey: String, hivePartitionKeyOpt: Option[String]): Boolean = {
+ throw new UnsupportedOperationException(
+ "This version of Shark is not compiled with Tachyon support.")
+ }
+
+ override def dropTable(tableKey: String, hivePartitionKeyOpt: Option[String]): Boolean = {
throw new UnsupportedOperationException(
"This version of Shark is not compiled with Tachyon support.")
}
- override def dropTable(tableName: String): Boolean = {
+ override def createDirectory(
+ tableKey: String,
+ hivePartitionKeyOpt: Option[String]): Boolean = {
throw new UnsupportedOperationException(
"This version of Shark is not compiled with Tachyon support.")
}
- override def getTableMetadata(tableName: String): ByteBuffer = {
+ override def renameDirectory(
+ oldName: String,
+ newName: String): Boolean = {
throw new UnsupportedOperationException(
"This version of Shark is not compiled with Tachyon support.")
}
- override def createRDD(tableName: String): RDD[TablePartition] = {
+ override def createRDD(
+ tableKey: String,
+ hivePartitionKeyOpt: Option[String]
+ ): Seq[(RDD[TablePartition], collection.Map[Int, TablePartitionStats])] = {
throw new UnsupportedOperationException(
"This version of Shark is not compiled with Tachyon support.")
}
- override def createTableWriter(tableName: String, numColumns: Int): TachyonTableWriter = {
+ override def createTableWriter(
+ tableKey: String,
+ hivePartitionKeyOpt: Option[String],
+ numColumns: Int
+ ): TachyonTableWriter = {
throw new UnsupportedOperationException(
"This version of Shark is not compiled with Tachyon support.")
}
diff --git a/src/tachyon_enabled/scala/shark/tachyon/TachyonUtilImpl.scala b/src/tachyon_enabled/scala/shark/tachyon/TachyonUtilImpl.scala
index 32f27dee..8e4eab8d 100644
--- a/src/tachyon_enabled/scala/shark/tachyon/TachyonUtilImpl.scala
+++ b/src/tachyon_enabled/scala/shark/tachyon/TachyonUtilImpl.scala
@@ -19,67 +19,127 @@ package shark.tachyon
import java.nio.ByteBuffer
import java.util.BitSet
+import java.util.concurrent.{ConcurrentHashMap => ConcurrentJavaHashMap}
-import scala.collection.JavaConverters._
+import scala.collection.JavaConversions._
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{EmptyRDD, RDD, UnionRDD}
import tachyon.client.TachyonFS
import tachyon.client.table.{RawTable, RawColumn}
-import shark.SharkEnv
-import shark.memstore2.TablePartition
+import shark.{LogHelper, SharkEnv}
+import shark.execution.serialization.JavaSerializer
+import shark.memstore2.{MemoryMetadataManager, TablePartition, TablePartitionStats}
/**
* An abstraction for the Tachyon APIs.
*/
-class TachyonUtilImpl(val master: String, val warehousePath: String) extends TachyonUtil {
+class TachyonUtilImpl(
+ val master: String,
+ val warehousePath: String)
+ extends TachyonUtil
+ with LogHelper {
+
+ private val INSERT_FILE_PREFIX = "insert_"
+
+ private val _fileNameMappings = new ConcurrentJavaHashMap[String, Int]()
val client = if (master != null && master != "") TachyonFS.get(master) else null
+ private def getUniqueFilePath(parentDirectory: String): String = {
+ val parentDirectoryLower = parentDirectory.toLowerCase
+ val currentInsertNum = if (_fileNameMappings.containsKey(parentDirectoryLower)) {
+ _fileNameMappings.get(parentDirectoryLower)
+ } else {
+ 0
+ }
+ var nextInsertNum = currentInsertNum + 1
+ var filePath = parentDirectoryLower + "/" + INSERT_FILE_PREFIX
+ // Make sure there aren't file conflicts. This could occur if the directory was created in a
+ // previous Shark session.
+ while (client.exist(filePath + nextInsertNum)) {
+ nextInsertNum = nextInsertNum + 1
+ }
+ _fileNameMappings.put(parentDirectoryLower, nextInsertNum)
+ filePath + nextInsertNum
+ }
+
if (master != null && warehousePath == null) {
throw new TachyonException("TACHYON_MASTER is set. However, TACHYON_WAREHOUSE_PATH is not.")
}
- def getPath(tableName: String): String = warehousePath + "/" + tableName
+ private def getPath(tableKey: String, hivePartitionKeyOpt: Option[String]): String = {
+ val hivePartitionKey = if (hivePartitionKeyOpt.isDefined) {
+ "/" + hivePartitionKeyOpt.get
+ } else {
+ ""
+ }
+ warehousePath + "/" + tableKey + hivePartitionKey
+ }
override def pushDownColumnPruning(rdd: RDD[_], columnUsed: BitSet): Boolean = {
- if (rdd.isInstanceOf[TachyonTableRDD]) {
+ val isTachyonTableRdd = rdd.isInstanceOf[TachyonTableRDD]
+ if (isTachyonTableRdd) {
rdd.asInstanceOf[TachyonTableRDD].setColumnUsed(columnUsed)
- true
- } else {
- false
}
+ isTachyonTableRdd
}
+ override def tachyonEnabled(): Boolean =
+ (master != null && warehousePath != null && client.isConnected)
- override def tachyonEnabled(): Boolean = (master != null && warehousePath != null)
-
- override def tableExists(tableName: String): Boolean = {
- client.exist(getPath(tableName))
+ override def tableExists(tableKey: String, hivePartitionKeyOpt: Option[String]): Boolean = {
+ client.exist(getPath(tableKey, hivePartitionKeyOpt))
}
- override def dropTable(tableName: String): Boolean = {
+ override def dropTable(tableKey: String, hivePartitionKeyOpt: Option[String]): Boolean = {
// The second parameter (true) means recursive deletion.
- client.delete(getPath(tableName), true)
+ client.delete(getPath(tableKey, hivePartitionKeyOpt), true)
}
- override def getTableMetadata(tableName: String): ByteBuffer = {
- if (!tableExists(tableName)) {
- throw new TachyonException("Table " + tableName + " does not exist in Tachyon")
- }
- client.getRawTable(getPath(tableName)).getMetadata()
+ override def createDirectory(
+ tableKey: String,
+ hivePartitionKeyOpt: Option[String]): Boolean = {
+ client.mkdir(getPath(tableKey, hivePartitionKeyOpt))
+ }
+
+ override def renameDirectory(
+ oldTableKey: String,
+ newTableKey: String): Boolean = {
+ val oldPath = getPath(oldTableKey, hivePartitionKeyOpt = None)
+ val newPath = getPath(newTableKey, hivePartitionKeyOpt = None)
+ client.rename(oldPath, newPath)
}
- override def createRDD(tableName: String): RDD[TablePartition] = {
- new TachyonTableRDD(getPath(tableName), SharkEnv.sc)
+ override def createRDD(
+ tableKey: String,
+ hivePartitionKeyOpt: Option[String]
+ ): Seq[(RDD[TablePartition], collection.Map[Int, TablePartitionStats])] = {
+ // Create a TachyonTableRDD for each raw table file in the directory.
+ val tableDirectory = getPath(tableKey, hivePartitionKeyOpt)
+ val files = client.ls(tableDirectory, false /* recursive */)
+ // The first path is just "{tableDirectory}/", so ignore it.
+ val rawTableFiles = files.subList(1, files.size)
+ val tableRDDsAndStats = rawTableFiles.map { filePath =>
+ val serializedMetadata = client.getRawTable(client.getFileId(filePath)).getMetadata
+ val indexToStats = JavaSerializer.deserialize[collection.Map[Int, TablePartitionStats]](
+ serializedMetadata.array())
+ (new TachyonTableRDD(filePath, SharkEnv.sc), indexToStats)
+ }
+ tableRDDsAndStats
}
- override def createTableWriter(tableName: String, numColumns: Int): TachyonTableWriter = {
+ override def createTableWriter(
+ tableKey: String,
+ hivePartitionKeyOpt: Option[String],
+ numColumns: Int): TachyonTableWriter = {
if (!client.exist(warehousePath)) {
client.mkdir(warehousePath)
}
- new TachyonTableWriterImpl(getPath(tableName), numColumns)
+ val parentDirectory = getPath(tableKey, hivePartitionKeyOpt)
+ val filePath = getUniqueFilePath(parentDirectory)
+ new TachyonTableWriterImpl(filePath, numColumns)
}
}
diff --git a/src/test/java/shark/JavaAPISuite.java b/src/test/java/shark/JavaAPISuite.java
index 01f6fe58..49b0d2e8 100644
--- a/src/test/java/shark/JavaAPISuite.java
+++ b/src/test/java/shark/JavaAPISuite.java
@@ -48,13 +48,9 @@ public static void oneTimeSetUp() {
// Intentionally leaving this here since SBT doesn't seem to display junit tests well ...
System.out.println("running JavaAPISuite ================================================");
- sc = SharkEnv.initWithJavaSharkContext("JavaAPISuite", "local");
-
- sc.sql("set javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=" +
- METASTORE_PATH + ";create=true");
- sc.sql("set hive.metastore.warehouse.dir=" + WAREHOUSE_PATH);
-
- sc.sql("set shark.test.data.path=" + TestUtils$.MODULE$.dataFilePath());
+ // Check if the SharkEnv's SharkContext has already been initialized. If so, use that to
+ // instantiate a JavaSharkContext.
+ sc = SharkRunner.initWithJava();
// test
sc.sql("drop table if exists test_java");
diff --git a/src/test/scala/shark/ColumnStatsSQLSuite.scala b/src/test/scala/shark/ColumnStatsSQLSuite.scala
new file mode 100644
index 00000000..f0aa5931
--- /dev/null
+++ b/src/test/scala/shark/ColumnStatsSQLSuite.scala
@@ -0,0 +1,129 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package shark
+
+import org.apache.hadoop.io.BytesWritable
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.FunSuite
+
+import org.apache.hadoop.hive.metastore.MetaStoreUtils.DEFAULT_DATABASE_NAME
+
+import org.apache.spark.rdd.RDD
+
+import shark.memstore2.MemoryMetadataManager
+
+
+class ColumnStatsSQLSuite extends FunSuite with BeforeAndAfterAll {
+
+ val sc: SharkContext = SharkRunner.init()
+ val sharkMetastore = SharkEnv.memoryMetadataManager
+
+ // import expectSql() shortcut methods
+ import shark.SharkRunner._
+
+ override def beforeAll() {
+ sc.runSql("drop table if exists srcpart_cached")
+ sc.runSql("create table srcpart_cached(key int, val string) partitioned by (keypart int)")
+ sc.runSql("""load data local inpath '${hiveconf:shark.test.data.path}/kv1.txt'
+ into table srcpart_cached partition (keypart = 1)""")
+ }
+
+ override def afterAll() {
+ sc.runSql("drop table if exists srcpart_cached")
+ }
+
+ test("Hive partition stats are tracked") {
+ val tableOpt = sharkMetastore.getPartitionedTable(DEFAULT_DATABASE_NAME, "srcpart_cached")
+ assert(tableOpt.isDefined)
+ val partitionToStatsOpt = tableOpt.get.getStats("keypart=1")
+ assert(partitionToStatsOpt.isDefined)
+ val partitionToStats = partitionToStatsOpt.get
+ // The 'kv1.txt' file loaded into 'keypart=1' in beforeAll() has 2 partitions.
+ assert(partitionToStats.size == 2)
+ }
+
+ test("Hive partition stats are tracked after LOADs and INSERTs") {
+ // Load more data into srcpart_cached
+ sc.runSql("""load data local inpath '${hiveconf:shark.test.data.path}/kv1.txt'
+ into table srcpart_cached partition (keypart = 1)""")
+ val tableOpt = sharkMetastore.getPartitionedTable(DEFAULT_DATABASE_NAME, "srcpart_cached")
+ assert(tableOpt.isDefined)
+ var partitionToStatsOpt = tableOpt.get.getStats("keypart=1")
+ assert(partitionToStatsOpt.isDefined)
+ var partitionToStats = partitionToStatsOpt.get
+ // The 'kv1.txt' file loaded into 'keypart=1' has 2 partitions. We've loaded it twice at this
+ // point.
+ assert(partitionToStats.size == 4)
+
+ // Append using INSERT command
+ sc.runSql("insert into table srcpart_cached partition(keypart = 1) select * from test")
+ partitionToStatsOpt = tableOpt.get.getStats("keypart=1")
+ assert(partitionToStatsOpt.isDefined)
+ partitionToStats = partitionToStatsOpt.get
+ assert(partitionToStats.size == 6)
+
+ // INSERT OVERWRITE should overrwritie old table stats. This also restores srcpart_cached
+ // to contents contained before this test.
+ sc.runSql("""insert overwrite table srcpart_cached partition(keypart = 1)
+ select * from test""")
+ partitionToStatsOpt = tableOpt.get.getStats("keypart=1")
+ assert(partitionToStatsOpt.isDefined)
+ partitionToStats = partitionToStatsOpt.get
+ assert(partitionToStats.size == 2)
+ }
+
+ //////////////////////////////////////////////////////////////////////////////
+ // End-to-end sanity checks
+ //////////////////////////////////////////////////////////////////////////////
+ test("column pruning filters") {
+ expectSql("select count(*) from test_cached where key > -1", "500")
+ }
+
+ test("column pruning group by") {
+ expectSql("select key, count(*) from test_cached group by key order by key limit 1", "0\t3")
+ }
+
+ test("column pruning group by with single filter") {
+ expectSql("select key, count(*) from test_cached where val='val_484' group by key", "484\t1")
+ }
+
+ test("column pruning aggregate function") {
+ expectSql("select val, sum(key) from test_cached group by val order by val desc limit 1",
+ "val_98\t196")
+ }
+
+ test("column pruning filters for a Hive partition") {
+ expectSql("select count(*) from srcpart_cached where key > -1", "500")
+ expectSql("select count(*) from srcpart_cached where key > -1 and keypart = 1", "500")
+ }
+
+ test("column pruning group by for a Hive partition") {
+ expectSql("select key, count(*) from srcpart_cached group by key order by key limit 1", "0\t3")
+ }
+
+ test("column pruning group by with single filter for a Hive partition") {
+ expectSql("select key, count(*) from srcpart_cached where val='val_484' group by key", "484\t1")
+ }
+
+ test("column pruning aggregate function for a Hive partition") {
+ expectSql("select val, sum(key) from srcpart_cached group by val order by val desc limit 1",
+ "val_98\t196")
+ }
+
+}
diff --git a/src/test/scala/shark/SQLSuite.scala b/src/test/scala/shark/SQLSuite.scala
index 9751bcb3..746e3c18 100644
--- a/src/test/scala/shark/SQLSuite.scala
+++ b/src/test/scala/shark/SQLSuite.scala
@@ -17,87 +17,91 @@
package shark
-import org.scalatest.BeforeAndAfterAll
-import org.scalatest.FunSuite
-
-import shark.api.QueryExecutionException
-
-
-class SQLSuite extends FunSuite with BeforeAndAfterAll {
-
- val WAREHOUSE_PATH = TestUtils.getWarehousePath()
- val METASTORE_PATH = TestUtils.getMetastorePath()
- val MASTER = "local"
+import scala.collection.JavaConversions._
- var sc: SharkContext = _
-
- override def beforeAll() {
- sc = SharkEnv.initWithSharkContext("shark-sql-suite-testing", MASTER)
+import org.scalatest.FunSuite
- sc.runSql("set javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=" +
- METASTORE_PATH + ";create=true")
- sc.runSql("set hive.metastore.warehouse.dir=" + WAREHOUSE_PATH)
+import org.apache.hadoop.hive.metastore.MetaStoreUtils.DEFAULT_DATABASE_NAME
+import org.apache.hadoop.hive.ql.metadata.Hive
+import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.UnionRDD
+import org.apache.spark.storage.StorageLevel
- sc.runSql("set shark.test.data.path=" + TestUtils.dataFilePath)
+import shark.api.QueryExecutionException
+import shark.memstore2.{CacheType, MemoryMetadataManager, PartitionedMemoryTable}
+import shark.tgf.{RDDSchema, Schema}
+// import expectSql() shortcut methods
+import shark.SharkRunner._
- // test
- sc.runSql("drop table if exists test")
- sc.runSql("CREATE TABLE test (key INT, val STRING)")
- sc.runSql("LOAD DATA LOCAL INPATH '${hiveconf:shark.test.data.path}/kv1.txt' INTO TABLE test")
- sc.runSql("drop table if exists test_cached")
- sc.runSql("CREATE TABLE test_cached AS SELECT * FROM test")
- // test_null
- sc.runSql("drop table if exists test_null")
- sc.runSql("CREATE TABLE test_null (key INT, val STRING)")
- sc.runSql("LOAD DATA LOCAL INPATH '${hiveconf:shark.test.data.path}/kv3.txt' INTO TABLE test_null")
- sc.runSql("drop table if exists test_null_cached")
- sc.runSql("CREATE TABLE test_null_cached AS SELECT * FROM test_null")
+class SQLSuite extends FunSuite {
- // clicks
- sc.runSql("drop table if exists clicks")
- sc.runSql("""create table clicks (id int, click int)
- row format delimited fields terminated by '\t'""")
- sc.runSql("""load data local inpath '${hiveconf:shark.test.data.path}/clicks.txt'
- OVERWRITE INTO TABLE clicks""")
- sc.runSql("drop table if exists clicks_cached")
- sc.runSql("create table clicks_cached as select * from clicks")
+ val DEFAULT_DB_NAME = DEFAULT_DATABASE_NAME
+ val KV1_TXT_PATH = "${hiveconf:shark.test.data.path}/kv1.txt"
- // users
- sc.runSql("drop table if exists users")
- sc.runSql("""create table users (id int, name string)
- row format delimited fields terminated by '\t'""")
- sc.runSql("""load data local inpath '${hiveconf:shark.test.data.path}/users.txt'
- OVERWRITE INTO TABLE users""")
- sc.runSql("drop table if exists users_cached")
- sc.runSql("create table users_cached as select * from users")
+ var sc: SharkContext = SharkRunner.init()
+ var sharkMetastore: MemoryMetadataManager = SharkEnv.memoryMetadataManager
- // test1
- sc.sql("drop table if exists test1")
- sc.sql("""CREATE TABLE test1 (id INT, test1val ARRAY)
- row format delimited fields terminated by '\t'""")
- sc.sql("LOAD DATA LOCAL INPATH '${hiveconf:shark.test.data.path}/test1.txt' INTO TABLE test1")
- sc.sql("drop table if exists test1_cached")
- sc.sql("CREATE TABLE test1_cached AS SELECT * FROM test1")
+ private def createCachedPartitionedTable(
+ tableName: String,
+ numPartitionsToCreate: Int,
+ maxCacheSize: Int = 10,
+ cachePolicyClassName: String = "shark.memstore2.LRUCachePolicy"
+ ): PartitionedMemoryTable = {
+ sc.runSql("drop table if exists %s".format(tableName))
+ sc.runSql("""
+ create table %s(key int, value string)
+ partitioned by (keypart int)
+ tblproperties('shark.cache' = 'true',
+ 'shark.cache.policy.maxSize' = '%d',
+ 'shark.cache.policy' = '%s')
+ """.format(
+ tableName,
+ maxCacheSize,
+ cachePolicyClassName))
+ var partitionNum = 1
+ while (partitionNum <= numPartitionsToCreate) {
+ sc.runSql("""insert into table %s partition(keypart = %d)
+ select * from test_cached""".format(tableName, partitionNum))
+ partitionNum += 1
+ }
+ assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, tableName))
+ val partitionedTable = SharkEnv.memoryMetadataManager.getPartitionedTable(
+ DEFAULT_DB_NAME, tableName).get
+ partitionedTable
}
- override def afterAll() {
- sc.stop()
- System.clearProperty("spark.driver.port")
+ def isFlattenedUnionRDD(unionRDD: UnionRDD[_]) = {
+ unionRDD.rdds.find(_.isInstanceOf[UnionRDD[_]]).isEmpty
}
- private def expectSql(sql: String, expectedResults: Array[String], sort: Boolean = true) {
- val sharkResults: Array[String] = sc.runSql(sql).results.map(_.mkString("\t")).toArray
- val results = if (sort) sharkResults.sortWith(_ < _) else sharkResults
- val expected = if (sort) expectedResults.sortWith(_ < _) else expectedResults
- assert(results.corresponds(expected)(_.equals(_)),
- "In SQL: " + sql + "\n" +
- "Expected: " + expected.mkString("\n") + "; got " + results.mkString("\n"))
- }
+ // Takes a sum over the table's 'key' column, for both the cached contents and the copy on disk.
+ def expectUnifiedKVTable(
+ cachedTableName: String,
+ partSpecOpt: Option[Map[String, String]] = None) {
+ // Check that the table is in memory and is a unified view.
+ val sharkTableOpt = sharkMetastore.getTable(DEFAULT_DB_NAME, cachedTableName)
+ assert(sharkTableOpt.isDefined, "Table %s cannot be found in the Shark metastore")
+ assert(sharkTableOpt.get.cacheMode == CacheType.MEMORY,
+ "'shark.cache' field for table %s is not CacheType.MEMORY")
- // A shortcut for single row results.
- private def expectSql(sql: String, expectedResult: String) {
- expectSql(sql, Array(expectedResult))
+ // Load a non-cached copy of the table into memory.
+ val cacheSum = sc.sql("select sum(key) from %s".format(cachedTableName))(0)
+ val hiveTable = Hive.get().getTable(DEFAULT_DB_NAME, cachedTableName)
+ val location = partSpecOpt match {
+ case Some(partSpec) => {
+ val partition = Hive.get().getPartition(hiveTable, partSpec, false /* forceCreate */)
+ partition.getDataLocation.toString
+ }
+ case None => hiveTable.getDataLocation.toString
+ }
+ // Create a table with contents loaded from the table's data directory.
+ val diskTableName = "%s_disk_copy".format(cachedTableName)
+ sc.sql("drop table if exists %s".format(diskTableName))
+ sc.sql("create table %s (key int, value string)".format(diskTableName))
+ sc.sql("load data local inpath '%s' into table %s".format(location, diskTableName))
+ val diskSum = sc.sql("select sum(key) from %s".format(diskTableName))(0)
+ assert(diskSum == cacheSum, "Sum of keys from cached and disk contents differ")
}
//////////////////////////////////////////////////////////////////////////////
@@ -166,26 +170,6 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll {
sort = false)
}
- //////////////////////////////////////////////////////////////////////////////
- // column pruning
- //////////////////////////////////////////////////////////////////////////////
- test("column pruning filters") {
- expectSql("select count(*) from test_cached where key > -1", "500")
- }
-
- test("column pruning group by") {
- expectSql("select key, count(*) from test_cached group by key order by key limit 1", "0\t3")
- }
-
- test("column pruning group by with single filter") {
- expectSql("select key, count(*) from test_cached where val='val_484' group by key", "484\t1")
- }
-
- test("column pruning aggregate function") {
- expectSql("select val, sum(key) from test_cached group by val order by val desc limit 1",
- "val_98\t196")
- }
-
//////////////////////////////////////////////////////////////////////////////
// join
//////////////////////////////////////////////////////////////////////////////
@@ -221,6 +205,42 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll {
//////////////////////////////////////////////////////////////////////////////
// cache DDL
//////////////////////////////////////////////////////////////////////////////
+ test("Use regular CREATE TABLE and '_cached' suffix to create cached table") {
+ sc.runSql("drop table if exists empty_table_cached")
+ sc.runSql("create table empty_table_cached(key string, value string)")
+ assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, "empty_table_cached"))
+ assert(!SharkEnv.memoryMetadataManager.isHivePartitioned(DEFAULT_DB_NAME, "empty_table_cached"))
+ }
+
+ test("Use regular CREATE TABLE and table properties to create cached table") {
+ sc.runSql("drop table if exists empty_table_cached_tbl_props")
+ sc.runSql("""create table empty_table_cached_tbl_props(key string, value string)
+ TBLPROPERTIES('shark.cache' = 'true')""")
+ assert(SharkEnv.memoryMetadataManager.containsTable(
+ DEFAULT_DB_NAME, "empty_table_cached_tbl_props"))
+ assert(!SharkEnv.memoryMetadataManager.isHivePartitioned(
+ DEFAULT_DB_NAME, "empty_table_cached_tbl_props"))
+ }
+
+ test("Insert into empty cached table") {
+ sc.runSql("drop table if exists new_table_cached")
+ sc.runSql("create table new_table_cached(key string, value string)")
+ sc.runSql("insert into table new_table_cached select * from test where key > -1 limit 499")
+ expectSql("select count(*) from new_table_cached", "499")
+ }
+
+ test("rename cached table") {
+ sc.runSql("drop table if exists test_oldname_cached")
+ sc.runSql("drop table if exists test_rename")
+ sc.runSql("create table test_oldname_cached as select * from test")
+ sc.runSql("alter table test_oldname_cached rename to test_rename")
+
+ assert(!SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, "test_oldname_cached"))
+ assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, "test_rename"))
+
+ expectSql("select count(*) from test_rename", "500")
+ }
+
test("insert into cached tables") {
sc.runSql("drop table if exists test1_cached")
sc.runSql("create table test1_cached as select * from test")
@@ -249,22 +269,24 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll {
}
}
- ignore("drop partition") {
- sc.runSql("create table foo_cached(key int, val string) partitioned by (dt string)")
- sc.runSql("insert overwrite table foo_cached partition(dt='100') select * from test")
- expectSql("select count(*) from foo_cached", "500")
- sc.runSql("alter table foo_cached drop partition(dt='100')")
- expectSql("select count(*) from foo_cached", "0")
- }
-
- test("create cached table with table properties") {
+ test("create cached table with 'shark.cache' flag in table properties") {
sc.runSql("drop table if exists ctas_tbl_props")
sc.runSql("""create table ctas_tbl_props TBLPROPERTIES ('shark.cache'='true') as
select * from test""")
- assert(SharkEnv.memoryMetadataManager.contains("ctas_tbl_props"))
+ assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, "ctas_tbl_props"))
expectSql("select * from ctas_tbl_props where key=407", "407\tval_407")
}
+ test("default to Hive table creation when 'shark.cache' flag is false in table properties") {
+ sc.runSql("drop table if exists ctas_tbl_props_should_not_be_cached")
+ sc.runSql("""
+ CREATE TABLE ctas_tbl_props_result_should_not_be_cached
+ TBLPROPERTIES ('shark.cache'='false')
+ AS select * from test""")
+ assert(!SharkEnv.memoryMetadataManager.containsTable(
+ DEFAULT_DB_NAME, "ctas_tbl_props_should_not_be_cached"))
+ }
+
test("cached tables with complex types") {
sc.runSql("drop table if exists test_complex_types")
sc.runSql("drop table if exists test_complex_types_cached")
@@ -286,7 +308,8 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll {
assert(sc.sql("select d from test_complex_types_cached where a = 'a0'").head ===
"""{"d01":["d011","d012"],"d02":["d021","d022"]}""")
- assert(SharkEnv.memoryMetadataManager.contains("test_complex_types_cached"))
+ assert(SharkEnv.memoryMetadataManager.containsTable(
+ DEFAULT_DB_NAME, "test_complex_types_cached"))
}
test("disable caching by default") {
@@ -294,7 +317,8 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll {
sc.runSql("drop table if exists should_not_be_cached")
sc.runSql("create table should_not_be_cached as select * from test")
expectSql("select key from should_not_be_cached where key = 407", "407")
- assert(!SharkEnv.memoryMetadataManager.contains("should_not_be_cached"))
+ assert(!SharkEnv.memoryMetadataManager.containsTable(
+ DEFAULT_DB_NAME, "should_not_be_cached"))
sc.runSql("set shark.cache.flag.checkTableName=true")
}
@@ -303,7 +327,7 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll {
sc.runSql("""create table sharkTest5Cached TBLPROPERTIES ("shark.cache" = "true") as
select * from test""")
expectSql("select val from sharktest5Cached where key = 407", "val_407")
- assert(SharkEnv.memoryMetadataManager.contains("sharkTest5Cached"))
+ assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, "sharkTest5Cached"))
}
test("dropping cached tables should clean up RDDs") {
@@ -311,7 +335,325 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll {
sc.runSql("""create table sharkTest5Cached TBLPROPERTIES ("shark.cache" = "true") as
select * from test""")
sc.runSql("drop table sharkTest5Cached")
- assert(!SharkEnv.memoryMetadataManager.contains("sharkTest5Cached"))
+ assert(!SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, "sharkTest5Cached"))
+ }
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Caching Hive-partititioned tables
+ // Note: references to 'partition' for this section refer to a Hive-partition.
+ //////////////////////////////////////////////////////////////////////////////
+ test("Use regular CREATE TABLE and '_cached' suffix to create cached, partitioned table") {
+ sc.runSql("drop table if exists empty_part_table_cached")
+ sc.runSql("""create table empty_part_table_cached(key int, value string)
+ partitioned by (keypart int)""")
+ assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, "empty_part_table_cached"))
+ assert(SharkEnv.memoryMetadataManager.isHivePartitioned(
+ DEFAULT_DB_NAME, "empty_part_table_cached"))
+ }
+
+ test("Use regular CREATE TABLE and table properties to create cached, partitioned table") {
+ sc.runSql("drop table if exists empty_part_table_cached_tbl_props")
+ sc.runSql("""create table empty_part_table_cached_tbl_props(key int, value string)
+ partitioned by (keypart int) tblproperties('shark.cache' = 'true')""")
+ assert(SharkEnv.memoryMetadataManager.containsTable(
+ DEFAULT_DB_NAME, "empty_part_table_cached_tbl_props"))
+ assert(SharkEnv.memoryMetadataManager.isHivePartitioned(
+ DEFAULT_DB_NAME, "empty_part_table_cached_tbl_props"))
+ }
+
+ test("alter cached table by adding a new partition") {
+ sc.runSql("drop table if exists alter_part_cached")
+ sc.runSql("""create table alter_part_cached(key int, value string)
+ partitioned by (keypart int)""")
+ sc.runSql("""alter table alter_part_cached add partition(keypart = 1)""")
+ val tableName = "alter_part_cached"
+ val partitionColumn = "keypart=1"
+ assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, tableName))
+ val partitionedTable = SharkEnv.memoryMetadataManager.getPartitionedTable(
+ DEFAULT_DB_NAME, tableName).get
+ assert(partitionedTable.containsPartition(partitionColumn))
+ }
+
+ test("alter cached table by dropping a partition") {
+ sc.runSql("drop table if exists alter_drop_part_cached")
+ sc.runSql("""create table alter_drop_part_cached(key int, value string)
+ partitioned by (keypart int)""")
+ sc.runSql("""alter table alter_drop_part_cached add partition(keypart = 1)""")
+ val tableName = "alter_drop_part_cached"
+ val partitionColumn = "keypart=1"
+ assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, tableName))
+ val partitionedTable = SharkEnv.memoryMetadataManager.getPartitionedTable(
+ DEFAULT_DB_NAME, tableName).get
+ assert(partitionedTable.containsPartition(partitionColumn))
+ sc.runSql("""alter table alter_drop_part_cached drop partition(keypart = 1)""")
+ assert(!partitionedTable.containsPartition(partitionColumn))
+ }
+
+ test("insert into a partition of a cached table") {
+ val tableName = "insert_part_cached"
+ val partitionedTable = createCachedPartitionedTable(
+ tableName,
+ 1 /* numPartitionsToCreate */)
+ expectSql("select value from insert_part_cached where key = 407 and keypart = 1", "val_407")
+
+ }
+
+ test("insert overwrite a partition of a cached table") {
+ val tableName = "insert_over_part_cached"
+ val partitionedTable = createCachedPartitionedTable(
+ tableName,
+ 1 /* numPartitionsToCreate */)
+ expectSql("""select value from insert_over_part_cached
+ where key = 407 and keypart = 1""", "val_407")
+ sc.runSql("""insert overwrite table insert_over_part_cached partition(keypart = 1)
+ select key, -1 from test""")
+ expectSql("select value from insert_over_part_cached where key = 407 and keypart = 1", "-1")
+ }
+
+ test("scan cached, partitioned table that's empty") {
+ sc.runSql("drop table if exists empty_part_table_cached")
+ sc.runSql("""create table empty_part_table_cached(key int, value string)
+ partitioned by (keypart int)""")
+ expectSql("select count(*) from empty_part_table_cached", "0")
+ }
+
+ test("scan cached, partitioned table that has a single partition") {
+ val tableName = "scan_single_part_cached"
+ val partitionedTable = createCachedPartitionedTable(
+ tableName,
+ 1 /* numPartitionsToCreate */)
+ expectSql("select * from scan_single_part_cached where key = 407", "407\tval_407\t1")
+ }
+
+ test("scan cached, partitioned table that has multiple partitions") {
+ val tableName = "scan_mult_part_cached"
+ val partitionedTable = createCachedPartitionedTable(
+ tableName,
+ 3 /* numPartitionsToCreate */)
+ expectSql("select * from scan_mult_part_cached where key = 407 order by keypart",
+ Array("407\tval_407\t1", "407\tval_407\t2", "407\tval_407\t3"))
+ }
+
+ test("drop/unpersist cached, partitioned table that has multiple partitions") {
+ val tableName = "drop_mult_part_cached"
+ val partitionedTable = createCachedPartitionedTable(
+ tableName,
+ 3 /* numPartitionsToCreate */)
+ val keypart1RDD = partitionedTable.getPartition("keypart=1")
+ val keypart2RDD = partitionedTable.getPartition("keypart=2")
+ val keypart3RDD = partitionedTable.getPartition("keypart=3")
+ sc.runSql("drop table drop_mult_part_cached ")
+ assert(!SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, tableName))
+ // All RDDs should have been unpersisted.
+ assert(keypart1RDD.get.getStorageLevel == StorageLevel.NONE)
+ assert(keypart2RDD.get.getStorageLevel == StorageLevel.NONE)
+ assert(keypart3RDD.get.getStorageLevel == StorageLevel.NONE)
+ }
+
+ test("drop cached partition represented by a UnionRDD (i.e., the result of multiple inserts)") {
+ val tableName = "drop_union_part_cached"
+ val partitionedTable = createCachedPartitionedTable(
+ tableName,
+ 1 /* numPartitionsToCreate */)
+ sc.runSql("insert into table drop_union_part_cached partition(keypart = 1) select * from test")
+ sc.runSql("insert into table drop_union_part_cached partition(keypart = 1) select * from test")
+ sc.runSql("insert into table drop_union_part_cached partition(keypart = 1) select * from test")
+ val keypart1RDD = partitionedTable.getPartition("keypart=1")
+ sc.runSql("drop table drop_union_part_cached")
+ assert(!SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, tableName))
+ // All RDDs should have been unpersisted.
+ assert(keypart1RDD.get.getStorageLevel == StorageLevel.NONE)
+ }
+
+ //////////////////////////////////////////////////////////////////////////////
+ // RDD(partition) eviction policy for cached Hive-partititioned tables
+ //////////////////////////////////////////////////////////////////////////////
+
+ test("shark.memstore2.CacheAllPolicy is the default policy") {
+ val tableName = "default_policy_cached"
+ sc.runSql("""create table default_policy_cached(key int, value string)
+ partitioned by (keypart int)""")
+ assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, tableName))
+ val partitionedTable = SharkEnv.memoryMetadataManager.getPartitionedTable(
+ DEFAULT_DB_NAME, tableName).get
+ val cachePolicy = partitionedTable.cachePolicy
+ assert(cachePolicy.isInstanceOf[shark.memstore2.CacheAllPolicy[_, _]])
+ }
+
+ test("LRU: RDDs are not evicted if the cache isn't full.") {
+ val tableName = "evict_partitions_maxSize"
+ val partitionedTable = createCachedPartitionedTable(
+ tableName,
+ 2 /* numPartitionsToCreate */,
+ 3 /* maxCacheSize */,
+ "shark.memstore2.LRUCachePolicy")
+ val keypart1RDD = partitionedTable.keyToPartitions.get("keypart=1")
+ assert(TestUtils.getStorageLevelOfRDD(keypart1RDD.get) == StorageLevel.MEMORY_AND_DISK)
+ }
+
+ test("LRU: RDDs are evicted when the max size is reached.") {
+ val tableName = "evict_partitions_maxSize"
+ val partitionedTable = createCachedPartitionedTable(
+ tableName,
+ 3 /* numPartitionsToCreate */,
+ 3 /* maxCacheSize */,
+ "shark.memstore2.LRUCachePolicy")
+ val keypart1RDD = partitionedTable.keyToPartitions.get("keypart=1")
+ assert(TestUtils.getStorageLevelOfRDD(keypart1RDD.get) == StorageLevel.MEMORY_AND_DISK)
+ sc.runSql("""insert into table evict_partitions_maxSize partition(keypart = 4)
+ select * from test""")
+ assert(TestUtils.getStorageLevelOfRDD(keypart1RDD.get) == StorageLevel.NONE)
+ }
+
+ test("LRU: RDD eviction accounts for partition scans - a cache.get()") {
+ val tableName = "evict_partitions_with_get"
+ val partitionedTable = createCachedPartitionedTable(
+ tableName,
+ 3 /* numPartitionsToCreate */,
+ 3 /* maxCacheSize */,
+ "shark.memstore2.LRUCachePolicy")
+ val keypart1RDD = partitionedTable.keyToPartitions.get("keypart=1")
+ val keypart2RDD = partitionedTable.keyToPartitions.get("keypart=2")
+ assert(TestUtils.getStorageLevelOfRDD(keypart1RDD.get) == StorageLevel.MEMORY_AND_DISK)
+ assert(TestUtils.getStorageLevelOfRDD(keypart2RDD.get) == StorageLevel.MEMORY_AND_DISK)
+ sc.runSql("select count(1) from evict_partitions_with_get where keypart = 1")
+ sc.runSql("""insert into table evict_partitions_with_get partition(keypart = 4)
+ select * from test""")
+ assert(TestUtils.getStorageLevelOfRDD(keypart1RDD.get) == StorageLevel.MEMORY_AND_DISK)
+
+ assert(TestUtils.getStorageLevelOfRDD(keypart2RDD.get) == StorageLevel.NONE)
+ }
+
+ test("LRU: RDD eviction accounts for INSERT INTO - a cache.get().") {
+ val tableName = "evict_partitions_insert_into"
+ val partitionedTable = createCachedPartitionedTable(
+ tableName,
+ 3 /* numPartitionsToCreate */,
+ 3 /* maxCacheSize */,
+ "shark.memstore2.LRUCachePolicy")
+ assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, tableName))
+ val oldKeypart1RDD = partitionedTable.keyToPartitions.get("keypart=1")
+ val keypart2RDD = partitionedTable.keyToPartitions.get("keypart=2")
+ assert(TestUtils.getStorageLevelOfRDD(oldKeypart1RDD.get) == StorageLevel.MEMORY_AND_DISK)
+ assert(TestUtils.getStorageLevelOfRDD(keypart2RDD.get) == StorageLevel.MEMORY_AND_DISK)
+ sc.runSql("""insert into table evict_partitions_insert_into partition(keypart = 1)
+ select * from test""")
+ sc.runSql("""insert into table evict_partitions_insert_into partition(keypart = 4)
+ select * from test""")
+ assert(TestUtils.getStorageLevelOfRDD(oldKeypart1RDD.get) == StorageLevel.MEMORY_AND_DISK)
+ val newKeypart1RDD = partitionedTable.keyToPartitions.get("keypart=1")
+ assert(TestUtils.getStorageLevelOfRDD(newKeypart1RDD.get) == StorageLevel.MEMORY_AND_DISK)
+
+ val keypart2StorageLevel = TestUtils.getStorageLevelOfRDD(keypart2RDD.get)
+ assert(keypart2StorageLevel == StorageLevel.NONE)
+ }
+
+ test("LRU: RDD eviction accounts for INSERT OVERWRITE - a cache.put()") {
+ val tableName = "evict_partitions_insert_overwrite"
+ val partitionedTable = createCachedPartitionedTable(
+ tableName,
+ 3 /* numPartitionsToCreate */,
+ 3 /* maxCacheSize */,
+ "shark.memstore2.LRUCachePolicy")
+ assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, tableName))
+ val oldKeypart1RDD = partitionedTable.keyToPartitions.get("keypart=1")
+ val keypart2RDD = partitionedTable.keyToPartitions.get("keypart=2")
+ assert(TestUtils.getStorageLevelOfRDD(oldKeypart1RDD.get) == StorageLevel.MEMORY_AND_DISK)
+ assert(TestUtils.getStorageLevelOfRDD(keypart2RDD.get) == StorageLevel.MEMORY_AND_DISK)
+ sc.runSql("""insert overwrite table evict_partitions_insert_overwrite partition(keypart = 1)
+ select * from test""")
+ sc.runSql("""insert into table evict_partitions_insert_overwrite partition(keypart = 4)
+ select * from test""")
+ assert(TestUtils.getStorageLevelOfRDD(oldKeypart1RDD.get) == StorageLevel.NONE)
+ val newKeypart1RDD = partitionedTable.keyToPartitions.get("keypart=1")
+ assert(TestUtils.getStorageLevelOfRDD(newKeypart1RDD.get) == StorageLevel.MEMORY_AND_DISK)
+
+ val keypart2StorageLevel = TestUtils.getStorageLevelOfRDD(keypart2RDD.get)
+ assert(keypart2StorageLevel == StorageLevel.NONE)
+ }
+
+ test("LRU: RDD eviction accounts for ALTER TABLE DROP PARTITION - a cache.remove()") {
+ val tableName = "evict_partitions_removals"
+ val partitionedTable = createCachedPartitionedTable(
+ tableName,
+ 3 /* numPartitionsToCreate */,
+ 3 /* maxCacheSize */,
+ "shark.memstore2.LRUCachePolicy")
+ assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, tableName))
+ sc.runSql("alter table evict_partitions_removals drop partition(keypart = 1)")
+ sc.runSql("""insert into table evict_partitions_removals partition(keypart = 4)
+ select * from test""")
+ sc.runSql("""insert into table evict_partitions_removals partition(keypart = 5)
+ select * from test""")
+ val keypart2RDD = partitionedTable.keyToPartitions.get("keypart=2")
+ assert(TestUtils.getStorageLevelOfRDD(keypart2RDD.get) == StorageLevel.NONE)
+ }
+
+ test("LRU: get() reloads an RDD previously unpersist()'d.") {
+ val tableName = "reload_evicted_partition"
+ val partitionedTable = createCachedPartitionedTable(
+ tableName,
+ 3 /* numPartitionsToCreate */,
+ 3 /* maxCacheSize */,
+ "shark.memstore2.LRUCachePolicy")
+ assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, tableName))
+ val keypart1RDD = partitionedTable.keyToPartitions.get("keypart=1")
+ val lvl = TestUtils.getStorageLevelOfRDD(keypart1RDD.get)
+ assert(lvl == StorageLevel.MEMORY_AND_DISK, "got: " + lvl)
+ sc.runSql("""insert into table reload_evicted_partition partition(keypart = 4)
+ select * from test""")
+ assert(TestUtils.getStorageLevelOfRDD(keypart1RDD.get) == StorageLevel.NONE)
+
+ // Scanning partition (keypart = 1) should reload the corresponding RDD into the cache, and
+ // cause eviction of the RDD for partition (keypart = 2).
+ sc.runSql("select count(1) from reload_evicted_partition where keypart = 1")
+ assert(keypart1RDD.get.getStorageLevel == StorageLevel.MEMORY_AND_DISK)
+ val keypart2RDD = partitionedTable.keyToPartitions.get("keypart=2")
+ val keypart2StorageLevel = TestUtils.getStorageLevelOfRDD(keypart2RDD.get)
+ assert(keypart2StorageLevel == StorageLevel.NONE,
+ "StorageLevel for partition(keypart=2) should be NONE, but got: " + keypart2StorageLevel)
+ }
+
+ ///////////////////////////////////////////////////////////////////////////////////////
+ // Prevent nested UnionRDDs - those should be "flattened" in MemoryStoreSinkOperator.
+ ///////////////////////////////////////////////////////////////////////////////////////
+
+ test("flatten UnionRDDs") {
+ sc.sql("create table flat_cached as select * from test_cached")
+ sc.sql("insert into table flat_cached select * from test")
+ val tableName = "flat_cached"
+ var memoryTable = SharkEnv.memoryMetadataManager.getMemoryTable(DEFAULT_DB_NAME, tableName).get
+ var unionRDD = memoryTable.getRDD.get.asInstanceOf[UnionRDD[_]]
+ val numParentRDDs = unionRDD.rdds.size
+ assert(isFlattenedUnionRDD(unionRDD))
+
+ // Insert another set of query results. The flattening should kick in here.
+ sc.sql("insert into table flat_cached select * from test")
+ unionRDD = memoryTable.getRDD.get.asInstanceOf[UnionRDD[_]]
+ assert(isFlattenedUnionRDD(unionRDD))
+ assert(unionRDD.rdds.size == numParentRDDs + 1)
+ }
+
+ test("flatten UnionRDDs for partitioned tables") {
+ sc.sql("drop table if exists part_table_cached")
+ sc.sql("""create table part_table_cached(key int, value string)
+ partitioned by (keypart int)""")
+ sc.sql("alter table part_table_cached add partition(keypart = 1)")
+ sc.sql("insert into table part_table_cached partition(keypart = 1) select * from flat_cached")
+ val tableName = "part_table_cached"
+ val partitionKey = "keypart=1"
+ var partitionedTable = SharkEnv.memoryMetadataManager.getPartitionedTable(
+ DEFAULT_DB_NAME, tableName).get
+ var unionRDD = partitionedTable.keyToPartitions.get(partitionKey).get.asInstanceOf[UnionRDD[_]]
+ val numParentRDDs = unionRDD.rdds.size
+ assert(isFlattenedUnionRDD(unionRDD))
+
+ // Insert another set of query results into the same partition.
+ // The flattening should kick in here.
+ sc.runSql("insert into table part_table_cached partition(keypart = 1) select * from flat_cached")
+ unionRDD = partitionedTable.getPartition(partitionKey).get.asInstanceOf[UnionRDD[_]]
+ assert(isFlattenedUnionRDD(unionRDD))
+ assert(unionRDD.rdds.size == numParentRDDs + 1)
}
//////////////////////////////////////////////////////////////////////////////
@@ -322,11 +664,11 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll {
sc.sql("drop table if exists adw")
sc.sql("""create table adw TBLPROPERTIES ("shark.cache" = "true") as
select cast(key as int) as k, val from test""")
- expectSql("select count(k) from adw where val='val_487' group by 1 having count(1) > 0","1")
+ expectSql("select count(k) from adw where val='val_487' group by 1 having count(1) > 0", "1")
}
//////////////////////////////////////////////////////////////////////////////
- // Sel Star
+ // Partition pruning
//////////////////////////////////////////////////////////////////////////////
test("sel star pruning") {
@@ -336,11 +678,45 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll {
expectSql("select * from selstar where val='val_487'","487 val_487")
}
+ test("map pruning with functions in between clause") {
+ sc.sql("drop table if exists mapsplitfunc")
+ sc.sql("drop table if exists mapsplitfunc_cached")
+ sc.sql("create table mapsplitfunc(k bigint, v string)")
+ sc.sql("""load data local inpath '${hiveconf:shark.test.data.path}/kv1.txt'
+ OVERWRITE INTO TABLE mapsplitfunc""")
+ sc.sql("create table mapsplitfunc_cached as select * from mapsplitfunc")
+ expectSql("""select count(*) from mapsplitfunc_cached
+ where month(from_unixtime(k)) between "1" and "12" """, Array[String]("500"))
+ expectSql("""select count(*) from mapsplitfunc_cached
+ where year(from_unixtime(k)) between "2013" and "2014" """, Array[String]("0"))
+ }
+
+ //////////////////////////////////////////////////////////////////////////////
+ // SharkContext APIs (e.g. sql2rdd, sql)
+ //////////////////////////////////////////////////////////////////////////////
+
+ test("cached table in different new database") {
+ sc.sql("drop table if exists selstar")
+ sc.sql("""create table selstar TBLPROPERTIES ("shark.cache" = "true") as
+ select * from default.test """)
+ sc.sql("use seconddb")
+ sc.sql("drop table if exists selstar")
+ sc.sql("""create table selstar TBLPROPERTIES ("shark.cache" = "true") as
+ select * from default.test where key != 'val_487' """)
+
+ sc.sql("use default")
+ expectSql("select * from selstar where val='val_487'","487 val_487")
+
+ assert(SharkEnv.memoryMetadataManager.containsTable(DEFAULT_DB_NAME, "selstar"))
+ assert(SharkEnv.memoryMetadataManager.containsTable("seconddb", "selstar"))
+
+ }
+
//////////////////////////////////////////////////////////////////////////////
// various data types
//////////////////////////////////////////////////////////////////////////////
- test("various data types") {
+ test("boolean data type") {
sc.sql("drop table if exists checkboolean")
sc.sql("""create table checkboolean TBLPROPERTIES ("shark.cache" = "true") as
select key, val, true as flag from test where key < "300" """)
@@ -348,7 +724,9 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll {
select key, val, false as flag from test where key > "300" """)
expectSql("select flag, count(*) from checkboolean group by flag order by flag asc",
Array[String]("false\t208", "true\t292"))
+ }
+ test("byte data type") {
sc.sql("drop table if exists checkbyte")
sc.sql("drop table if exists checkbyte_cached")
sc.sql("""create table checkbyte (key string, val string, flag tinyint) """)
@@ -359,7 +737,10 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll {
sc.sql("""create table checkbyte_cached as select * from checkbyte""")
expectSql("select flag, count(*) from checkbyte_cached group by flag order by flag asc",
Array[String]("0\t208", "1\t292"))
+ }
+ test("binary data type") {
+
sc.sql("drop table if exists checkbinary")
sc.sql("drop table if exists checkbinary_cached")
sc.sql("""create table checkbinary (key string, flag binary) """)
@@ -370,7 +751,9 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll {
sc.sql("create table checkbinary_cached as select key, flag from checkbinary")
expectSql("select cast(flag as string) as f from checkbinary_cached order by f asc limit 2",
Array[String]("val_0", "val_0"))
+ }
+ test("short data type") {
sc.sql("drop table if exists checkshort")
sc.sql("drop table if exists checkshort_cached")
sc.sql("""create table checkshort (key string, val string, flag smallint) """)
@@ -419,4 +802,288 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll {
val e = intercept[QueryExecutionException] { sc.sql2rdd("asdfasdfasdfasdf") }
e.getMessage.contains("semantic")
}
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Default cache mode is CacheType.MEMORY (unified view)
+ //////////////////////////////////////////////////////////////////////////////
+ test ("Table created by CREATE TABLE, with table properties, is CacheType.MEMORY by default") {
+ sc.runSql("drop table if exists test_unify_creation")
+ sc.runSql("""create table test_unify_creation (key int, val string)
+ tblproperties('shark.cache'='true')""")
+ val table = sharkMetastore.getTable(DEFAULT_DB_NAME, "test_unify_creation").get
+ assert(table.cacheMode == CacheType.MEMORY)
+ sc.runSql("drop table if exists test_unify_creation")
+ }
+
+ test ("Table created by CREATE TABLE, with '_cached', is CacheType.MEMORY by default") {
+ sc.runSql("drop table if exists test_unify_creation_cached")
+ sc.runSql("create table test_unify_creation_cached(key int, val string)")
+ val table = sharkMetastore.getTable(DEFAULT_DB_NAME, "test_unify_creation_cached").get
+ assert(table.cacheMode == CacheType.MEMORY)
+ sc.runSql("drop table if exists test_unify_creation_cached")
+ }
+
+ test ("Table created by CTAS, with table properties, is CacheType.MEMORY by default") {
+ sc.runSql("drop table if exists test_unify_ctas")
+ sc.runSql("""create table test_unify_ctas tblproperties('shark.cache' = 'true')
+ as select * from test""")
+ val table = sharkMetastore.getTable(DEFAULT_DB_NAME, "test_unify_ctas").get
+ assert(table.cacheMode == CacheType.MEMORY)
+ expectSql("select count(*) from test_unify_ctas", "500")
+ sc.runSql("drop table if exists test_unify_ctas")
+ }
+
+ test ("Table created by CTAS, with '_cached', is CacheType.MEMORY by default") {
+ sc.runSql("drop table if exists test_unify_ctas_cached")
+ sc.runSql("create table test_unify_ctas_cached as select * from test")
+ val table = sharkMetastore.getTable(DEFAULT_DB_NAME, "test_unify_ctas_cached").get
+ assert(table.cacheMode == CacheType.MEMORY)
+ expectSql("select count(*) from test_unify_ctas_cached", "500")
+ sc.runSql("drop table if exists test_unify_ctas_cached")
+ }
+
+ test ("CREATE TABLE when 'shark.cache' is CacheType.MEMORY_ONLY") {
+ sc.runSql("drop table if exists test_non_unify_creation")
+ sc.runSql("""create table test_non_unify_creation(key int, val string)
+ tblproperties('shark.cache' = 'memory_only')""")
+ val table = sharkMetastore.getTable(DEFAULT_DB_NAME, "test_non_unify_creation").get
+ assert(table.cacheMode == CacheType.MEMORY_ONLY)
+ sc.runSql("drop table if exists test_non_unify_creation")
+ }
+
+ test ("CTAS when 'shark.cache' is CacheType.MEMORY_ONLY") {
+ sc.runSql("drop table if exists test_non_unify_ctas")
+ sc.runSql("""create table test_non_unify_ctas tblproperties
+ ('shark.cache' = 'memory_only') as select * from test""")
+ val table = sharkMetastore.getTable(DEFAULT_DB_NAME, "test_non_unify_ctas").get
+ assert(table.cacheMode == CacheType.MEMORY_ONLY)
+ sc.runSql("drop table if exists test_non_unify_ctas")
+ }
+
+ //////////////////////////////////////////////////////////////////////////////
+ // LOAD for tables cached in memory and stored on disk (unified view)
+ //////////////////////////////////////////////////////////////////////////////
+ test ("LOAD INTO unified view") {
+ sc.runSql("drop table if exists unified_view_cached")
+ sc.runSql("create table unified_view_cached (key int, value string)")
+ sc.runSql("load data local inpath '%s' into table unified_view_cached".format(KV1_TXT_PATH))
+ expectUnifiedKVTable("unified_view_cached")
+ expectSql("select count(*) from unified_view_cached", "500")
+ sc.runSql("drop table if exists unified_view_cached")
+ }
+
+ test ("LOAD OVERWRITE unified view") {
+ sc.runSql("drop table if exists unified_overwrite_cached")
+ sc.runSql("create table unified_overwrite_cached (key int, value string)")
+ sc.runSql("load data local inpath '%s' into table unified_overwrite_cached".
+ format("${hiveconf:shark.test.data.path}/kv3.txt"))
+ expectSql("select count(*) from unified_overwrite_cached", "25")
+ sc.runSql("load data local inpath '%s' overwrite into table unified_overwrite_cached".
+ format(KV1_TXT_PATH))
+ // Make sure the cached contents matches the disk contents.
+ expectUnifiedKVTable("unified_overwrite_cached")
+ expectSql("select count(*) from unified_overwrite_cached", "500")
+ sc.runSql("drop table if exists unified_overwrite_cached")
+ }
+
+ test ("LOAD INTO partitioned unified view") {
+ sc.runSql("drop table if exists unified_view_part_cached")
+ sc.runSql("""create table unified_view_part_cached (key int, value string)
+ partitioned by (keypart int)""")
+ sc.runSql("""load data local inpath '%s' into table unified_view_part_cached
+ partition(keypart = 1)""".format(KV1_TXT_PATH))
+ expectUnifiedKVTable("unified_view_part_cached", Some(Map("keypart" -> "1")))
+ expectSql("select count(*) from unified_view_part_cached", "500")
+ sc.runSql("drop table if exists unified_view_part_cached")
+ }
+
+ test ("LOAD OVERWRITE partitioned unified view") {
+ sc.runSql("drop table if exists unified_overwrite_part_cached")
+ sc.runSql("""create table unified_overwrite_part_cached (key int, value string)
+ partitioned by (keypart int)""")
+ sc.runSql("""load data local inpath '%s' overwrite into table unified_overwrite_part_cached
+ partition(keypart = 1)""".format(KV1_TXT_PATH))
+ expectUnifiedKVTable("unified_overwrite_part_cached", Some(Map("keypart" -> "1")))
+ expectSql("select count(*) from unified_overwrite_part_cached", "500")
+ sc.runSql("drop table if exists unified_overwrite_part_cached")
+ }
+
+ //////////////////////////////////////////////////////////////////////////////
+ // INSERT for tables cached in memory and stored on disk (unified view)
+ //////////////////////////////////////////////////////////////////////////////
+ test ("INSERT INTO unified view") {
+ sc.runSql("drop table if exists unified_view_cached")
+ sc.runSql("create table unified_view_cached as select * from test_cached")
+ sc.runSql("insert into table unified_view_cached select * from test_cached")
+ expectUnifiedKVTable("unified_view_cached")
+ expectSql("select count(*) from unified_view_cached", "1000")
+ sc.runSql("drop table if exists unified_view_cached")
+ }
+
+ test ("INSERT OVERWRITE unified view") {
+ sc.runSql("drop table if exists unified_overwrite_cached")
+ sc.runSql("create table unified_overwrite_cached as select * from test")
+ sc.runSql("insert overwrite table unified_overwrite_cached select * from test_cached")
+ expectUnifiedKVTable("unified_overwrite_cached")
+ expectSql("select count(*) from unified_overwrite_cached", "500")
+ sc.runSql("drop table if exists unified_overwrite_cached")
+ }
+
+ test ("INSERT INTO partitioned unified view") {
+ sc.runSql("drop table if exists unified_view_part_cached")
+ sc.runSql("""create table unified_view_part_cached (key int, value string)
+ partitioned by (keypart int)""")
+ sc.runSql("""insert into table unified_view_part_cached partition (keypart = 1)
+ select * from test_cached""")
+ expectUnifiedKVTable("unified_view_part_cached", Some(Map("keypart" -> "1")))
+ expectSql("select count(*) from unified_view_part_cached where keypart = 1", "500")
+ sc.runSql("drop table if exists unified_view_part_cached")
+ }
+
+ test ("INSERT OVERWRITE partitioned unified view") {
+ sc.runSql("drop table if exists unified_overwrite_part_cached")
+ sc.runSql("""create table unified_overwrite_part_cached (key int, value string)
+ partitioned by (keypart int)""")
+ sc.runSql("""insert overwrite table unified_overwrite_part_cached partition (keypart = 1)
+ select * from test_cached""")
+ expectUnifiedKVTable("unified_overwrite_part_cached", Some(Map("keypart" -> "1")))
+ expectSql("select count(*) from unified_overwrite_part_cached", "500")
+ sc.runSql("drop table if exists unified_overwrite_part_cached")
+ }
+
+ //////////////////////////////////////////////////////////////////////////////
+ // CACHE and ALTER TABLE commands
+ //////////////////////////////////////////////////////////////////////////////
+ test ("ALTER TABLE caches non-partitioned table if 'shark.cache' is set to true") {
+ sc.runSql("drop table if exists unified_load")
+ sc.runSql("create table unified_load as select * from test")
+ sc.runSql("alter table unified_load set tblproperties('shark.cache' = 'true')")
+ expectUnifiedKVTable("unified_load")
+ sc.runSql("drop table if exists unified_load")
+ }
+
+ test ("ALTER TABLE caches partitioned table if 'shark.cache' is set to true") {
+ sc.runSql("drop table if exists unified_part_load")
+ sc.runSql("create table unified_part_load (key int, value string) partitioned by (keypart int)")
+ sc.runSql("insert into table unified_part_load partition (keypart=1) select * from test_cached")
+ sc.runSql("alter table unified_part_load set tblproperties('shark.cache' = 'true')")
+ expectUnifiedKVTable("unified_part_load", Some(Map("keypart" -> "1")))
+ sc.runSql("drop table if exists unified_part_load")
+ }
+
+ test ("ALTER TABLE uncaches non-partitioned table if 'shark.cache' is set to false") {
+ sc.runSql("drop table if exists unified_load")
+ sc.runSql("create table unified_load as select * from test")
+ sc.runSql("alter table unified_load set tblproperties('shark.cache' = 'false')")
+ assert(!sharkMetastore.containsTable(DEFAULT_DB_NAME, "unified_load"))
+ expectSql("select count(*) from unified_load", "500")
+ sc.runSql("drop table if exists unified_load")
+ }
+
+ test ("ALTER TABLE uncaches partitioned table if 'shark.cache' is set to false") {
+ sc.runSql("drop table if exists unified_part_load")
+ sc.runSql("create table unified_part_load (key int, value string) partitioned by (keypart int)")
+ sc.runSql("insert into table unified_part_load partition (keypart=1) select * from test_cached")
+ sc.runSql("alter table unified_part_load set tblproperties('shark.cache' = 'false')")
+ assert(!sharkMetastore.containsTable(DEFAULT_DB_NAME, "unified_part_load"))
+ expectSql("select count(*) from unified_part_load", "500")
+ sc.runSql("drop table if exists unified_part_load")
+ }
+
+ test ("UNCACHE behaves like ALTER TABLE SET TBLPROPERTIES ...") {
+ sc.runSql("drop table if exists unified_load")
+ sc.runSql("create table unified_load as select * from test")
+ sc.runSql("cache unified_load")
+ // Double check the table properties.
+ val tableName = "unified_load"
+ val hiveTable = Hive.get().getTable(DEFAULT_DB_NAME, tableName)
+ assert(hiveTable.getProperty("shark.cache") == "MEMORY")
+ // Check that the cache and disk contents are synchronized.
+ expectUnifiedKVTable(tableName)
+ sc.runSql("drop table if exists unified_load")
+ }
+
+ test ("CACHE behaves like ALTER TABLE SET TBLPROPERTIES ...") {
+ sc.runSql("drop table if exists unified_load")
+ sc.runSql("create table unified_load as select * from test")
+ sc.runSql("cache unified_load")
+ // Double check the table properties.
+ val tableName = "unified_load"
+ val hiveTable = Hive.get().getTable(DEFAULT_DB_NAME, tableName)
+ assert(hiveTable.getProperty("shark.cache") == "MEMORY")
+ // Check that the cache and disk contents are synchronized.
+ expectUnifiedKVTable(tableName)
+ sc.runSql("drop table if exists unified_load")
+ }
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Cached table persistence
+ //////////////////////////////////////////////////////////////////////////////
+ test ("Cached tables persist across Shark metastore shutdowns.") {
+ val globalCachedTableNames = Seq("test_cached", "test_null_cached", "clicks_cached",
+ "users_cached", "test1_cached")
+
+ // Number of rows for each cached table.
+ val cachedTableCounts = new Array[String](globalCachedTableNames.size)
+ for ((tableName, i) <- globalCachedTableNames.zipWithIndex) {
+ val hiveTable = Hive.get().getTable(DEFAULT_DB_NAME, tableName)
+ val cachedCount = sc.sql("select count(*) from %s".format(tableName))(0)
+ cachedTableCounts(i) = cachedCount
+ }
+ sharkMetastore.shutdown()
+ for ((tableName, i) <- globalCachedTableNames.zipWithIndex) {
+ val hiveTable = Hive.get().getTable(DEFAULT_DB_NAME, tableName)
+ // Check that the number of rows from the table on disk remains the same.
+ val onDiskCount = sc.sql("select count(*) from %s".format(tableName))(0)
+ val cachedCount = cachedTableCounts(i)
+ assert(onDiskCount == cachedCount, """Num rows for %s differ across Shark metastore restart.
+ (rows cached = %s, rows on disk = %s)""".format(tableName, cachedCount, onDiskCount))
+ // Check that we're able to materialize a row - i.e., make sure that table scan operator
+ // doesn't try to use a ColumnarSerDe when scanning contents on disk (for our test tables,
+ // LazySimpleSerDes should be used).
+ sc.sql("select * from %s limit 1".format(tableName))
+ }
+ // Finally, reload all tables.
+ SharkRunner.loadTables()
+ }
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Table Generating Functions (TGFs)
+ //////////////////////////////////////////////////////////////////////////////
+
+ test("Simple TGFs") {
+ expectSql("generate shark.TestTGF1(test, 15)", Array(15,15,15,17,19).map(_.toString).toArray)
+ }
+
+ test("Saving simple TGFs") {
+ sc.sql("drop table if exists TGFTestTable")
+ sc.runSql("generate shark.TestTGF1(test, 15) as TGFTestTable")
+ expectSql("select * from TGFTestTable", Array(15,15,15,17,19).map(_.toString).toArray)
+ sc.sql("drop table if exists TGFTestTable")
+ }
+
+ test("Advanced TGFs") {
+ expectSql("generate shark.TestTGF2(test, 25)", Array(25,25,25,27,29).map(_.toString).toArray)
+ }
+
+ test("Saving advanced TGFs") {
+ sc.sql("drop table if exists TGFTestTable2")
+ sc.runSql("generate shark.TestTGF2(test, 25) as TGFTestTable2")
+ expectSql("select * from TGFTestTable2", Array(25,25,25,27,29).map(_.toString).toArray)
+ sc.sql("drop table if exists TGFTestTable2")
+ }
+}
+
+object TestTGF1 {
+ @Schema(spec = "values int")
+ def apply(test: RDD[(Int, String)], integer: Int) = {
+ test.map{ case Tuple2(k, v) => Tuple1(k + integer) }.filter{ case Tuple1(v) => v < 20 }
+ }
+}
+
+object TestTGF2 {
+ def apply(sc: SharkContext, test: RDD[(Int, String)], integer: Int) = {
+ val rdd = test.map{ case Tuple2(k, v) => Seq(k + integer) }.filter{ case Seq(v) => v < 30 }
+ RDDSchema(rdd.asInstanceOf[RDD[Seq[_]]], "myvalues int")
+ }
}
diff --git a/src/test/scala/shark/SharkRunner.scala b/src/test/scala/shark/SharkRunner.scala
new file mode 100644
index 00000000..573ecec2
--- /dev/null
+++ b/src/test/scala/shark/SharkRunner.scala
@@ -0,0 +1,127 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package shark
+
+import org.apache.hadoop.hive.metastore.MetaStoreUtils.DEFAULT_DATABASE_NAME
+
+import shark.api.JavaSharkContext
+import shark.memstore2.MemoryMetadataManager
+
+
+object SharkRunner {
+
+ val WAREHOUSE_PATH = TestUtils.getWarehousePath()
+ val METASTORE_PATH = TestUtils.getMetastorePath()
+ val MASTER = "local"
+
+ var sc: SharkContext = _
+
+ var javaSc: JavaSharkContext = _
+
+ def init(): SharkContext = synchronized {
+ if (sc == null) {
+ sc = SharkEnv.initWithSharkContext("shark-sql-suite-testing", MASTER)
+
+ sc.runSql("set javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=" +
+ METASTORE_PATH + ";create=true")
+ sc.runSql("set hive.metastore.warehouse.dir=" + WAREHOUSE_PATH)
+ sc.runSql("set shark.test.data.path=" + TestUtils.dataFilePath)
+
+ // second db
+ sc.sql("create database if not exists seconddb")
+
+ loadTables()
+ }
+ sc
+ }
+
+ def initWithJava(): JavaSharkContext = synchronized {
+ if (javaSc == null) {
+ javaSc = new JavaSharkContext(init())
+ }
+ javaSc
+ }
+
+ /**
+ * Tables accessible by any test. Their properties should remain constant across
+ * tests.
+ */
+ def loadTables() = synchronized {
+ require(sc != null, "call init() to instantiate a SharkContext first")
+
+ // Use the default namespace
+ sc.runSql("USE " + DEFAULT_DATABASE_NAME)
+
+ // test
+ sc.runSql("drop table if exists test")
+ sc.runSql("CREATE TABLE test (key INT, val STRING)")
+ sc.runSql("LOAD DATA LOCAL INPATH '${hiveconf:shark.test.data.path}/kv1.txt' INTO TABLE test")
+ sc.runSql("drop table if exists test_cached")
+ sc.runSql("CREATE TABLE test_cached AS SELECT * FROM test")
+
+ // test_null
+ sc.runSql("drop table if exists test_null")
+ sc.runSql("CREATE TABLE test_null (key INT, val STRING)")
+ sc.runSql("""LOAD DATA LOCAL INPATH '${hiveconf:shark.test.data.path}/kv3.txt'
+ INTO TABLE test_null""")
+ sc.runSql("drop table if exists test_null_cached")
+ sc.runSql("CREATE TABLE test_null_cached AS SELECT * FROM test_null")
+
+ // clicks
+ sc.runSql("drop table if exists clicks")
+ sc.runSql("""create table clicks (id int, click int)
+ row format delimited fields terminated by '\t'""")
+ sc.runSql("""load data local inpath '${hiveconf:shark.test.data.path}/clicks.txt'
+ OVERWRITE INTO TABLE clicks""")
+ sc.runSql("drop table if exists clicks_cached")
+ sc.runSql("create table clicks_cached as select * from clicks")
+
+ // users
+ sc.runSql("drop table if exists users")
+ sc.runSql("""create table users (id int, name string)
+ row format delimited fields terminated by '\t'""")
+ sc.runSql("""load data local inpath '${hiveconf:shark.test.data.path}/users.txt'
+ OVERWRITE INTO TABLE users""")
+ sc.runSql("drop table if exists users_cached")
+ sc.runSql("create table users_cached as select * from users")
+
+ // test1
+ sc.sql("drop table if exists test1")
+ sc.sql("""CREATE TABLE test1 (id INT, test1val ARRAY)
+ row format delimited fields terminated by '\t'""")
+ sc.sql("LOAD DATA LOCAL INPATH '${hiveconf:shark.test.data.path}/test1.txt' INTO TABLE test1")
+ sc.sql("drop table if exists test1_cached")
+ sc.sql("CREATE TABLE test1_cached AS SELECT * FROM test1")
+ Unit
+ }
+
+ def expectSql(sql: String, expectedResults: Array[String], sort: Boolean = true) {
+ val sharkResults: Array[String] = sc.runSql(sql).results.map(_.mkString("\t")).toArray
+ val results = if (sort) sharkResults.sortWith(_ < _) else sharkResults
+ val expected = if (sort) expectedResults.sortWith(_ < _) else expectedResults
+ assert(results.corresponds(expected)(_.equals(_)),
+ "In SQL: " + sql + "\n" +
+ "Expected: " + expected.mkString("\n") + "; got " + results.mkString("\n"))
+ }
+
+ // A shortcut for single row results.
+ def expectSql(sql: String, expectedResult: String) {
+ expectSql(sql, Array(expectedResult))
+ }
+
+}
diff --git a/src/test/scala/shark/SharkServerSuite.scala b/src/test/scala/shark/SharkServerSuite.scala
index e5df4f98..1310ca04 100644
--- a/src/test/scala/shark/SharkServerSuite.scala
+++ b/src/test/scala/shark/SharkServerSuite.scala
@@ -10,7 +10,8 @@ import scala.collection.JavaConversions._
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.scalatest.matchers.ShouldMatchers
-import scala.concurrent.ops._
+import scala.concurrent._
+import ExecutionContext.Implicits.global
/**
* Test for the Shark server.
@@ -57,7 +58,7 @@ class SharkServerSuite extends FunSuite with BeforeAndAfterAll with ShouldMatche
// Spawn a thread to read the output from the forked process.
// Note that this is necessary since in some configurations, log4j could be blocked
// if its output to stderr are not read, and eventually blocking the entire test suite.
- spawn {
+ future {
while (true) {
val stdout = readFrom(inputReader)
val stderr = readFrom(errorReader)
@@ -78,6 +79,7 @@ class SharkServerSuite extends FunSuite with BeforeAndAfterAll with ShouldMatche
}
test("test query execution against a shark server") {
+ Thread.sleep(5*1000) // I know... Gross. However, without this the tests fail non-deterministically.
val dataFilePath = TestUtils.dataFilePath + "/kv1.txt"
val stmt = createStatement()
diff --git a/src/test/scala/shark/SortSuite.scala b/src/test/scala/shark/SortSuite.scala
index 4e7e9c05..df948a54 100644
--- a/src/test/scala/shark/SortSuite.scala
+++ b/src/test/scala/shark/SortSuite.scala
@@ -31,28 +31,23 @@ class SortSuite extends FunSuite {
TestUtils.init()
+ var sc: SparkContext = SharkRunner.init()
+
test("order by limit") {
- var sc: SparkContext = null
- try {
- sc = new SparkContext("local", "test")
- val data = Array((4, 14), (1, 11), (7, 17), (0, 10))
- val expected = data.sortWith(_._1 < _._1).toSeq
- val rdd: RDD[(ReduceKey, BytesWritable)] = sc.parallelize(data, 50).map { x =>
- (new ReduceKeyMapSide(new BytesWritable(Array[Byte](x._1.toByte))),
- new BytesWritable(Array[Byte](x._2.toByte)))
- }
- for (k <- 0 to 5) {
- val sortedRdd = RDDUtils.topK(rdd, k).asInstanceOf[RDD[(ReduceKeyReduceSide, Array[Byte])]]
- val output = sortedRdd.map { case(k, v) =>
- (k.byteArray(0).toInt, v(0).toInt)
- }.collect().toSeq
- assert(output.size === math.min(k, 4))
- assert(output === expected.take(math.min(k, 4)))
- }
- } finally {
- sc.stop()
+ val data = Array((4, 14), (1, 11), (7, 17), (0, 10))
+ val expected = data.sortWith(_._1 < _._1).toSeq
+ val rdd: RDD[(ReduceKey, BytesWritable)] = sc.parallelize(data, 50).map { x =>
+ (new ReduceKeyMapSide(new BytesWritable(Array[Byte](x._1.toByte))),
+ new BytesWritable(Array[Byte](x._2.toByte)))
+ }
+ for (k <- 0 to 5) {
+ val sortedRdd = RDDUtils.topK(rdd, k).asInstanceOf[RDD[(ReduceKeyReduceSide, Array[Byte])]]
+ val output = sortedRdd.map { case(k, v) =>
+ (k.byteArray(0).toInt, v(0).toInt)
+ }.collect().toSeq
+ assert(output.size === math.min(k, 4))
+ assert(output === expected.take(math.min(k, 4)))
}
- sc.stop()
- System.clearProperty("spark.driver.port")
}
+
}
diff --git a/src/test/scala/shark/TachyonSQLSuite.scala b/src/test/scala/shark/TachyonSQLSuite.scala
new file mode 100644
index 00000000..899bc1d4
--- /dev/null
+++ b/src/test/scala/shark/TachyonSQLSuite.scala
@@ -0,0 +1,437 @@
+/*
+ * Copyright (C) 2012 The Regents of The University California.
+ * All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package shark
+
+import java.util.{HashMap => JavaHashMap}
+
+import scala.collection.JavaConversions._
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.FunSuite
+
+import org.apache.hadoop.hive.metastore.MetaStoreUtils.DEFAULT_DATABASE_NAME
+import org.apache.hadoop.hive.ql.metadata.Hive
+import org.apache.spark.rdd.UnionRDD
+import org.apache.spark.storage.StorageLevel
+
+import shark.api.QueryExecutionException
+import shark.memstore2.{CacheType, MemoryMetadataManager, PartitionedMemoryTable}
+// import expectSql() shortcut methods
+import shark.SharkRunner._
+
+
+class TachyonSQLSuite extends FunSuite with BeforeAndAfterAll {
+
+ val DEFAULT_DB_NAME = DEFAULT_DATABASE_NAME
+ val KV1_TXT_PATH = "${hiveconf:shark.test.data.path}/kv1.txt"
+
+ var sc: SharkContext = SharkRunner.init()
+ var sharkMetastore: MemoryMetadataManager = SharkEnv.memoryMetadataManager
+
+ // Determine if Tachyon enabled at runtime.
+ val isTachyonEnabled = SharkEnv.tachyonUtil.tachyonEnabled()
+
+
+ override def beforeAll() {
+ if (isTachyonEnabled) {
+ sc.runSql("create table test_tachyon as select * from test")
+ }
+ }
+
+ override def afterAll() {
+ if (isTachyonEnabled) {
+ sc.runSql("drop table test_tachyon")
+ }
+ }
+
+ private def isTachyonTable(
+ dbName: String,
+ tableName: String,
+ hivePartitionKeyOpt: Option[String] = None): Boolean = {
+ val tableKey = MemoryMetadataManager.makeTableKey(dbName, tableName)
+ SharkEnv.tachyonUtil.tableExists(tableKey, hivePartitionKeyOpt)
+ }
+
+ private def createPartitionedTachyonTable(tableName: String, numPartitionsToCreate: Int) {
+ sc.runSql("drop table if exists %s".format(tableName))
+ sc.runSql("""
+ create table %s(key int, value string)
+ partitioned by (keypart int)
+ tblproperties('shark.cache' = 'tachyon')
+ """.format(tableName))
+ var partitionNum = 1
+ while (partitionNum <= numPartitionsToCreate) {
+ sc.runSql("""insert into table %s partition(keypart = %d)
+ select * from test_tachyon""".format(tableName, partitionNum))
+ partitionNum += 1
+ }
+ assert(isTachyonTable(DEFAULT_DB_NAME, tableName))
+ }
+
+ if (isTachyonEnabled) {
+ //////////////////////////////////////////////////////////////////////////////
+ // basic SQL
+ //////////////////////////////////////////////////////////////////////////////
+ test("count") {
+ expectSql("select count(*) from test_tachyon", "500")
+ }
+
+ test("filter") {
+ expectSql("select * from test_tachyon where key=100 or key=497",
+ Array("100\tval_100", "100\tval_100", "497\tval_497"))
+ }
+
+ test("count distinct") {
+ sc.runSql("set mapred.reduce.tasks=3")
+ expectSql("select count(distinct key) from test_tachyon", "309")
+ expectSql(
+ """|SELECT substr(key,1,1), count(DISTINCT substr(val,5)) from test_tachyon
+ |GROUP BY substr(key,1,1)""".stripMargin,
+ Array("0\t1", "1\t71", "2\t69", "3\t62", "4\t74", "5\t6", "6\t5", "7\t6", "8\t8", "9\t7"))
+ }
+
+ test("count bigint") {
+ sc.runSql("drop table if exists test_bigint")
+ sc.runSql("create table test_bigint (key bigint, val string)")
+ sc.runSql("""load data local inpath '${hiveconf:shark.test.data.path}/kv1.txt'
+ OVERWRITE INTO TABLE test_bigint""")
+ sc.runSql("drop table if exists test_bigint_tachyon")
+ sc.runSql("create table test_bigint_tachyon as select * from test_bigint")
+ expectSql("select val, count(*) from test_bigint_tachyon where key=484 group by val",
+ "val_484\t1")
+
+ sc.runSql("drop table if exists test_bigint_tachyon")
+ }
+
+ test("limit") {
+ assert(sc.runSql("select * from test_tachyon limit 10").results.length === 10)
+ assert(sc.runSql("select * from test_tachyon limit 501").results.length === 500)
+ sc.runSql("drop table if exists test_limit0_tachyon")
+ assert(sc.runSql("select * from test_tachyon limit 0").results.length === 0)
+ assert(sc.runSql("create table test_limit0_tachyon as select * from test_tachyon limit 0")
+ .results.length === 0)
+ assert(sc.runSql("select * from test_limit0_tachyon limit 0").results.length === 0)
+ assert(sc.runSql("select * from test_limit0_tachyon limit 1").results.length === 0)
+
+ sc.runSql("drop table if exists test_limit0_tachyon")
+ }
+
+ //////////////////////////////////////////////////////////////////////////////
+ // cache DDL
+ //////////////////////////////////////////////////////////////////////////////
+ test("Use regular CREATE TABLE and '_tachyon' suffix to create Tachyon table") {
+ sc.runSql("drop table if exists empty_table_tachyon")
+ sc.runSql("create table empty_table_tachyon(key string, value string)")
+ assert(isTachyonTable(DEFAULT_DB_NAME, "empty_table_tachyon"))
+
+ sc.runSql("drop table if exists empty_table_tachyon")
+ }
+
+ test("Use regular CREATE TABLE and table properties to create Tachyon table") {
+ sc.runSql("drop table if exists empty_table_tachyon_tbl_props")
+ sc.runSql("""create table empty_table_tachyon_tbl_props(key string, value string)
+ TBLPROPERTIES('shark.cache' = 'tachyon')""")
+ assert(isTachyonTable(DEFAULT_DB_NAME, "empty_table_tachyon_tbl_props"))
+
+ sc.runSql("drop table if exists empty_table_tachyon_tbl_props")
+ }
+
+ test("Insert into empty Tachyon table") {
+ sc.runSql("drop table if exists new_table_tachyon")
+ sc.runSql("create table new_table_tachyon(key string, value string)")
+ sc.runSql("insert into table new_table_tachyon select * from test where key > -1 limit 499")
+ expectSql("select count(*) from new_table_tachyon", "499")
+
+ sc.runSql("drop table if exists new_table_tachyon")
+ }
+
+ test("rename Tachyon table") {
+ sc.runSql("drop table if exists test_oldname_tachyon")
+ sc.runSql("drop table if exists test_rename")
+ sc.runSql("create table test_oldname_tachyon as select * from test")
+ sc.runSql("alter table test_oldname_tachyon rename to test_rename")
+
+ assert(!isTachyonTable(DEFAULT_DB_NAME, "test_oldname_tachyon"))
+ assert(isTachyonTable(DEFAULT_DB_NAME, "test_rename"))
+
+ expectSql("select count(*) from test_rename", "500")
+
+ sc.runSql("drop table if exists test_rename")
+ }
+
+ test("insert into tachyon tables") {
+ sc.runSql("drop table if exists test1_tachyon")
+ sc.runSql("create table test1_tachyon as select * from test")
+ expectSql("select count(*) from test1_tachyon", "500")
+ sc.runSql("insert into table test1_tachyon select * from test where key > -1 limit 499")
+ expectSql("select count(*) from test1_tachyon", "999")
+
+ sc.runSql("drop table if exists test1_tachyon")
+ }
+
+ test("insert overwrite") {
+ sc.runSql("drop table if exists test2_tachyon")
+ sc.runSql("create table test2_tachyon as select * from test")
+ expectSql("select count(*) from test2_tachyon", "500")
+ sc.runSql("insert overwrite table test2_tachyon select * from test where key > -1 limit 499")
+ expectSql("select count(*) from test2_tachyon", "499")
+
+ sc.runSql("drop table if exists test2_tachyon")
+ }
+
+ test("error when attempting to update Tachyon table(s) using command with multiple INSERTs") {
+ sc.runSql("drop table if exists multi_insert_test")
+ sc.runSql("drop table if exists multi_insert_test_tachyon")
+ sc.runSql("create table multi_insert_test as select * from test")
+ sc.runSql("create table multi_insert_test_tachyon as select * from test")
+ intercept[QueryExecutionException] {
+ sc.runSql("""from test
+ insert into table multi_insert_test select *
+ insert into table multi_insert_test_tachyon select *""")
+ }
+
+ sc.runSql("drop table if exists multi_insert_test")
+ sc.runSql("drop table if exists multi_insert_test_tachyon")
+ }
+
+ test("create Tachyon table with 'shark.cache' flag in table properties") {
+ sc.runSql("drop table if exists ctas_tbl_props")
+ sc.runSql("""create table ctas_tbl_props TBLPROPERTIES ('shark.cache'='tachyon') as
+ select * from test""")
+ assert(isTachyonTable(DEFAULT_DB_NAME, "ctas_tbl_props"))
+ expectSql("select * from ctas_tbl_props where key=407", "407\tval_407")
+
+ sc.runSql("drop table if exists ctas_tbl_props")
+ }
+
+ test("tachyon tables with complex types") {
+ sc.runSql("drop table if exists test_complex_types")
+ sc.runSql("drop table if exists test_complex_types_tachyon")
+ sc.runSql("""CREATE TABLE test_complex_types (
+ a STRING, b ARRAY, c ARRAY